"""
Tools for persisting object graphs with Django ORM.

These classes make some assumptions about your datamodel:
Object graph does not have cyclic SQL insertion order dependencies.
Objects with pk=None are being inserted, all other pk values indicate an update.
"""

import threading

from django.db.models.base import Model
from django.db.models.query import QuerySet
from django.db.models.fields.related import ForeignKey

def is_managed(item):
    """Return True if item is managed by Django ORM."""
    
    return isinstance(item, Model)

def is_persistent(item):
    """Return True if an item has already been saved."""
    
    return item.pk is not None

class GraphError(Exception):
    pass

class Collection(object):
    """
    Allows iterables to be persisted in an object graph in an easier way.
    """
    
    @classmethod
    def set_property(cls, model, attr, parent_attr, set_attr=None):
        """
        Call this method with a model class as the 1st argument to
        create a property function which returns a Collection object.
        """
        
        private_attr = '_' + attr
    
        def _get(self):
            val = getattr(self, private_attr, None)
            if val is None:
                val = cls(self, parent_attr, set_attr=set_attr)
                setattr(self, private_attr, val)
            return val
    
        def _set(self, val):
            setattr(self, private_attr, val)
        setattr(model, attr, property(_get, _set))
    
        collection_attrs = getattr(model, '_collection_attrs', None)
        if collection_attrs is None:
            collection_attrs = []
            setattr(model, '_collection_attrs', collection_attrs)
        collection_attrs.append((attr, private_attr))
    
    def __init__(self, parent, parent_attr, set_attr=None, list_cls=None):
        self.parent = parent
        self.parent_attr = parent_attr
        self.set_attr = set_attr
        
        if list_cls is None:
            list_cls = list
        self.list_cls = list_cls
        
        # Holds list of items if parent is not persisted
        self._item_list = None
        
        # Holds set of items that have been accessed
        self._accessed_items = set()
        
    def __iter__(self):
        for item in self._items:
            self._accessed_items.add(item)
            yield item
            
    def __getitem__(self, k):
        if isinstance(k, slice):
            def _gen():
                items = self._items.__getitem__(k)
                for item in items:
                    self._accessed_items.add(item)
                    yield item
            return _gen()
        else:
            item = self._items.__getitem__(k)
            self._accessed_items.add(item)
            return item
    
    def __len__(self):
        items = self._items
        if isinstance(items, QuerySet):
            return items.count()
        else:
            return len(items)
    
    def _get_accessed_items(self):
        if is_persistent(self.parent) is True:
            # Return all items that have been accessed,
            # but only if they still belong in the collection.
            accessed_items = set()
            for item in self._accessed_items:
                parent = getattr(item, self.parent_attr, None)
                if parent is None:
                    continue
                
                if parent.pk != self.parent.pk:
                    continue
                
                accessed_items.add(item)
            return accessed_items
        else:
            return self._items
    accessed_items = property(_get_accessed_items)
    
    def _get_items(self):
        """Returns all items in the collection."""
        
        if is_persistent(self.parent) is True:
            return self.get_item_query()
        else:
            if self._item_list is None:
                self._item_list = self.list_cls()
            return self._item_list
    _items = property(_get_items)
    
    def add(self, item):
        """Add an item to the collection."""
        
        if is_managed(item) is False:
            # Bad things will happen
            raise GraphError('Cannot add unmanaged object.')
        
        setattr(item, self.parent_attr, self.parent)
        if is_persistent(self.parent) is True:
            item.save()
        else:
            found = False
            for existing_item in self._items:
                if existing_item is item:
                    found = True
                    break
            
            if found is False:
                self._items.append(item)
    
    def remove(self, item, delete=False):
        """Remove an item from the collection."""
        
        if delete is False:
            setattr(item, self.parent_attr, None)
            
        if is_persistent(self.parent) is True:
            if delete is True:
                item.delete()
            else:
                item.save()
        else:
            if item in self._items:
                self._items.remove(item)
    
    def clear(self, delete=False):
        """Remove all items from the collection."""
        
        to_remove = list(self._items)
        for item in to_remove:
            self.remove(item, delete=delete)
    
    def update(self, new):
        """[].update"""
        
        for item in new:
            self.add(item)
    
    def get_item_query(self):
        """Override this method to use a custom QuerySet."""
        
        return getattr(self.parent, self.set_attr).all()

class DependencyList(object):
    """Lists all parents for a child where parent == dependency."""
    
    __slots__ = ['obj', 'deps']
    
    def __init__(self, obj=None, deps=None):

        self.obj = obj
        self.deps = deps

class Dependency(object):
    """Signifies a dependency between two objects in the graph."""

    __slots__ = ['parent', 'field', 'level']

    def __init__(self, parent=None, field=None, level=None):

        self.parent = parent
        self.field = field
        self.level = level

class GraphSaver(object):
    """Provides methods to simplify the persistence of large object graphs."""

    def _add_dep(self, parent, child, field, deps, level):
        """Adds a dependency to the list."""

        dep_list = self._get_dep_list(child, deps)
        if dep_list is None:
            raise GraphError('Dependency list not found!')
        
        for dep in dep_list.deps:
            if dep.parent is parent:
                raise GraphError('Circular dependency detected.')

        dep_list.deps.append(Dependency(
            parent=parent,
            field=field,
            level=level))
        
    def _get_dep_list(self, obj, deps):
        """Returns a list that dependencies can be added to, or None"""
        
        dep_key = id(obj)
        return deps.get(dep_key, None)
    
    def _init_dep_list(self, obj, deps):
        """Returns a list that dependencies can be added to."""
        
        dep_key = id(obj)
        dep_list = deps.get(dep_key, None)
        dep_list = self._get_dep_list(obj, deps)
        if dep_list is None:
            dep_list = DependencyList(obj=obj, deps=[])
            deps[dep_key] = dep_list
        return dep_list

    def _build_deps(self, parent, deps, update=True, level=1):
        
        if self._get_dep_list(parent, deps) is not None:
            # This object has already had it's dependencies added!
            return

        if (is_persistent(parent) is False) or (update is True):
            # Makes sure the parent obj shows
            # up in the dependency list, so that
            # it gets saved!
            self._init_dep_list(parent, deps)

        for name in parent._meta.get_all_field_names():
            field_info = parent._meta.get_field_by_name(name)
            if field_info[2] is False:
                # This is a magic reverse reference to
                # some other related field. Ignore it.
                continue
            
            field = field_info[0]
            if isinstance(field, ForeignKey):
               child = getattr(parent, name)
               if child is not None:
                   # This is a dependency!
                   # Recursively add to deps.
                   self._build_deps(child, deps, update=update, level=level + 1)
                   self._add_dep(parent, child, field, deps, level)
                   
        # Save items in any collections
        self._build_collection_deps(parent, deps, update=update, level=level)
        
    def _build_collection_deps(self, parent, deps, update=True, level=1):
        """Add dependencies from collection objects."""
        
        collection_attrs = getattr(parent.__class__, '_collection_attrs', [])
        for collection_attr in collection_attrs:
            collection = getattr(parent, collection_attr[0], None)
            if collection is not None:
                for item in collection.accessed_items:
                    self._build_deps(item, deps, update=update, level=level)
    
    def _save_deps(self, deps):
        # Group dependencies by their level
        dep_levels = {}
        for dep_list in deps.itervalues():
            max_level = 0
            for dep in dep_list.deps:
                if dep.level > max_level:
                    max_level = dep.level
            
            level_list = dep_levels.get(max_level, None)
            if level_list is None:
                level_list = []
                dep_levels[max_level] = level_list
            level_list.append(dep_list)

        self._save_by_level(dep_levels)

    def _save_by_level(self, levels):
        # Save children by their dependency level
        level_keys = levels.keys()
        level_keys.sort()
        level_keys.reverse()
        for level in level_keys:
            for dep_list in levels[level]:
                self._save_dep_list(dep_list)

    def _save_dep_list(self, dep_list):
        # All children in the list should be the same,
        # so we only need to save the 1st one.
        child = dep_list.obj
        child.save()

        for dep in dep_list.deps:
            setattr(dep.parent, dep.field.name, child)

    def save_many(self, items, update=True):
        """Save multiple object graphs."""
        
        deps = {}
        for item in items:
            if is_managed(item) is False:
                raise GraphError('Cannot save unmanaged item.')
            self._build_deps(item, deps, level=1, update=update)
        self._save_deps(deps)

    def save(self, item, update=True):
        """Save an object graph."""
        
        return self.save_many((item,), update=update)

class Session(object):
    """
    Modifies QuerySets to return consistent object graphs.
    
    This will modify query behavior for all code executed
    during a 'with' statement.
    """
    
    _session_objs = threading.local()
    
    def __init__(self):
        self._entered = False
        self._teardown_on_exit = False
    
    def __enter__(self, *args, **kwargs):
        """Setup session system."""
        
        if self._entered is True:
            raise GraphError('Cannot re-enter session!')
        
        if self._is_setup() is False:
            self._setup()
            
        self._entered = True
            
    def __exit__(self, *args, **kwargs):
        if self._teardown_on_exit is True:
            self._teardown()
            
    def _is_setup(self):
        return getattr(self._session_objs, '_setup', False)
    
    def _setup(self):
        self._clear()
        self._patch_query_set()
        self._patch_model_base()
        self._session_objs._setup = True
        self._teardown_on_exit = True
        
    def _teardown(self):
        self._clear()
        self._session_objs._setup = False
    
    def _clear(self):
        """Clear all existing session objects."""
        
        self._session_objs.objs = {}
    
    def _patch_query_set(self):
        """Monkey patches QuerySet to return cached objects."""

        if getattr(QuerySet, '_session_enabled', False) is True:
            return

        f = QuerySet.iterator
        def _gen(q, *args, **kwargs):
            itr = f(q, *args, **kwargs)
            for item in itr:
                yield self.add(item)
        
        def _s_iterator(q, *args, **kwargs):
            if self._is_setup():
                return _gen(q, *args, **kwargs)
            else:
                return f(q, *args, **kwargs)
        QuerySet.iterator = _s_iterator
        QuerySet._session_enabled = True

    def _patch_model_base(self):
        """Monkey patches Model to added saved items to session."""
       
        if getattr(Model, '_session_enabled', False) is True:
            return
 
        f = Model.save
        def _s_save(item, *args, **kwargs):
            if self._is_setup():
                if is_persistent(item) is True:
                    # updating this object
                    if ('force_insert' not in kwargs) and ('force_update' not in kwargs):
                        kwargs['force_update'] = True
                    self.add_with_dup_check(item)
                    f(item, *args, **kwargs)
                else:
                    # inserting this object
                    if ('force_insert' not in kwargs) and ('force_update' not in kwargs):
                        kwargs['force_insert'] = True
                    f(item, *args, **kwargs)
                    self.add_with_dup_check(item)
            else:
                return f(item, *args, **kwargs)
        Model.save = _s_save
        Model._session_enabled = True

    def _obj_key(self, cls, pk):
        """Returns session key for an object."""
        
        cls_key = getattr(cls, '_cls_key', None)
        if cls_key is None:
            cls_key = cls.__module__ + '.' + cls.__name__
            cls._cls_key = cls_key
        
        return '%s:%s' % (cls_key, pk)

    def add(self, item):
        """
        Adds an object to the session.

        Added object is returned.
        If item's key matches existing key,
        existing object is returned.
        """
        
        if is_persistent(item) is False:
            raise GraphError('Cannot add unpersisted object.')
        
        key = self._obj_key(item.__class__, item.pk)
        existing = self._session_objs.objs.get(key, None)
        if existing is None:
            self._session_objs.objs[key] = item
            return item
        else:
            return existing

    def add_with_dup_check(self, item):
       """Add a persistent object to the session."""

       if self.add(item) is not item:
           raise GraphError('Instance with identical id already exists in session.')