This module contains classes that add new behavior to Django's ORM. Classes include:
Session
- Forces QuerySet objects to return identical instances when objects with the same primary key are queried.
- Similar to SQLAlchemy Session
GraphSaver
- Save entire object graphs at once.
- Automatically detects object dependencies and saves them in the correct order.
Collection
- Easier one-to-many relationships.
Instructions and more information on limscoder.com.
| """
Tools for persisting object graphs with Django ORM.
These classes make some assumptions about your datamodel:
Object graph does not have cyclic SQL insertion order dependencies.
Objects with pk=None are being inserted, all other pk values indicate an update.
"""
import threading
from django.db.models.base import Model
from django.db.models.query import QuerySet
from django.db.models.fields.related import ForeignKey
def is_managed(item):
"""Return True if item is managed by Django ORM."""
return isinstance(item, Model)
def is_persistent(item):
"""Return True if an item has already been saved."""
return item.pk is not None
class GraphError(Exception):
pass
class Collection(object):
"""
Allows iterables to be persisted in an object graph in an easier way.
"""
@classmethod
def set_property(cls, model, attr, parent_attr, set_attr=None):
"""
Call this method with a model class as the 1st argument to
create a property function which returns a Collection object.
"""
private_attr = '_' + attr
def _get(self):
val = getattr(self, private_attr, None)
if val is None:
val = cls(self, parent_attr, set_attr=set_attr)
setattr(self, private_attr, val)
return val
def _set(self, val):
setattr(self, private_attr, val)
setattr(model, attr, property(_get, _set))
collection_attrs = getattr(model, '_collection_attrs', None)
if collection_attrs is None:
collection_attrs = []
setattr(model, '_collection_attrs', collection_attrs)
collection_attrs.append((attr, private_attr))
def __init__(self, parent, parent_attr, set_attr=None, list_cls=None):
self.parent = parent
self.parent_attr = parent_attr
self.set_attr = set_attr
if list_cls is None:
list_cls = list
self.list_cls = list_cls
# Holds list of items if parent is not persisted
self._item_list = None
# Holds set of items that have been accessed
self._accessed_items = set()
def __iter__(self):
for item in self._items:
self._accessed_items.add(item)
yield item
def __getitem__(self, k):
if isinstance(k, slice):
def _gen():
items = self._items.__getitem__(k)
for item in items:
self._accessed_items.add(item)
yield item
return _gen()
else:
item = self._items.__getitem__(k)
self._accessed_items.add(item)
return item
def __len__(self):
items = self._items
if isinstance(items, QuerySet):
return items.count()
else:
return len(items)
def _get_accessed_items(self):
if is_persistent(self.parent) is True:
# Return all items that have been accessed,
# but only if they still belong in the collection.
accessed_items = set()
for item in self._accessed_items:
parent = getattr(item, self.parent_attr, None)
if parent is None:
continue
if parent.pk != self.parent.pk:
continue
accessed_items.add(item)
return accessed_items
else:
return self._items
accessed_items = property(_get_accessed_items)
def _get_items(self):
"""Returns all items in the collection."""
if is_persistent(self.parent) is True:
return self.get_item_query()
else:
if self._item_list is None:
self._item_list = self.list_cls()
return self._item_list
_items = property(_get_items)
def add(self, item):
"""Add an item to the collection."""
if is_managed(item) is False:
# Bad things will happen
raise GraphError('Cannot add unmanaged object.')
setattr(item, self.parent_attr, self.parent)
if is_persistent(self.parent) is True:
item.save()
else:
found = False
for existing_item in self._items:
if existing_item is item:
found = True
break
if found is False:
self._items.append(item)
def remove(self, item, delete=False):
"""Remove an item from the collection."""
if delete is False:
setattr(item, self.parent_attr, None)
if is_persistent(self.parent) is True:
if delete is True:
item.delete()
else:
item.save()
else:
if item in self._items:
self._items.remove(item)
def clear(self, delete=False):
"""Remove all items from the collection."""
to_remove = list(self._items)
for item in to_remove:
self.remove(item, delete=delete)
def update(self, new):
"""[].update"""
for item in new:
self.add(item)
def get_item_query(self):
"""Override this method to use a custom QuerySet."""
return getattr(self.parent, self.set_attr).all()
class DependencyList(object):
"""Lists all parents for a child where parent == dependency."""
__slots__ = ['obj', 'deps']
def __init__(self, obj=None, deps=None):
self.obj = obj
self.deps = deps
class Dependency(object):
"""Signifies a dependency between two objects in the graph."""
__slots__ = ['parent', 'field', 'level']
def __init__(self, parent=None, field=None, level=None):
self.parent = parent
self.field = field
self.level = level
class GraphSaver(object):
"""Provides methods to simplify the persistence of large object graphs."""
def _add_dep(self, parent, child, field, deps, level):
"""Adds a dependency to the list."""
dep_list = self._get_dep_list(child, deps)
if dep_list is None:
raise GraphError('Dependency list not found!')
for dep in dep_list.deps:
if dep.parent is parent:
raise GraphError('Circular dependency detected.')
dep_list.deps.append(Dependency(
parent=parent,
field=field,
level=level))
def _get_dep_list(self, obj, deps):
"""Returns a list that dependencies can be added to, or None"""
dep_key = id(obj)
return deps.get(dep_key, None)
def _init_dep_list(self, obj, deps):
"""Returns a list that dependencies can be added to."""
dep_key = id(obj)
dep_list = deps.get(dep_key, None)
dep_list = self._get_dep_list(obj, deps)
if dep_list is None:
dep_list = DependencyList(obj=obj, deps=[])
deps[dep_key] = dep_list
return dep_list
def _build_deps(self, parent, deps, update=True, level=1):
if self._get_dep_list(parent, deps) is not None:
# This object has already had it's dependencies added!
return
if (is_persistent(parent) is False) or (update is True):
# Makes sure the parent obj shows
# up in the dependency list, so that
# it gets saved!
self._init_dep_list(parent, deps)
for name in parent._meta.get_all_field_names():
field_info = parent._meta.get_field_by_name(name)
if field_info[2] is False:
# This is a magic reverse reference to
# some other related field. Ignore it.
continue
field = field_info[0]
if isinstance(field, ForeignKey):
child = getattr(parent, name)
if child is not None:
# This is a dependency!
# Recursively add to deps.
self._build_deps(child, deps, update=update, level=level + 1)
self._add_dep(parent, child, field, deps, level)
# Save items in any collections
self._build_collection_deps(parent, deps, update=update, level=level)
def _build_collection_deps(self, parent, deps, update=True, level=1):
"""Add dependencies from collection objects."""
collection_attrs = getattr(parent.__class__, '_collection_attrs', [])
for collection_attr in collection_attrs:
collection = getattr(parent, collection_attr[0], None)
if collection is not None:
for item in collection.accessed_items:
self._build_deps(item, deps, update=update, level=level)
def _save_deps(self, deps):
# Group dependencies by their level
dep_levels = {}
for dep_list in deps.itervalues():
max_level = 0
for dep in dep_list.deps:
if dep.level > max_level:
max_level = dep.level
level_list = dep_levels.get(max_level, None)
if level_list is None:
level_list = []
dep_levels[max_level] = level_list
level_list.append(dep_list)
self._save_by_level(dep_levels)
def _save_by_level(self, levels):
# Save children by their dependency level
level_keys = levels.keys()
level_keys.sort()
level_keys.reverse()
for level in level_keys:
for dep_list in levels[level]:
self._save_dep_list(dep_list)
def _save_dep_list(self, dep_list):
# All children in the list should be the same,
# so we only need to save the 1st one.
child = dep_list.obj
child.save()
for dep in dep_list.deps:
setattr(dep.parent, dep.field.name, child)
def save_many(self, items, update=True):
"""Save multiple object graphs."""
deps = {}
for item in items:
if is_managed(item) is False:
raise GraphError('Cannot save unmanaged item.')
self._build_deps(item, deps, level=1, update=update)
self._save_deps(deps)
def save(self, item, update=True):
"""Save an object graph."""
return self.save_many((item,), update=update)
class Session(object):
"""
Modifies QuerySets to return consistent object graphs.
This will modify query behavior for all code executed
during a 'with' statement.
"""
_session_objs = threading.local()
def __init__(self):
self._entered = False
self._teardown_on_exit = False
def __enter__(self, *args, **kwargs):
"""Setup session system."""
if self._entered is True:
raise GraphError('Cannot re-enter session!')
if self._is_setup() is False:
self._setup()
self._entered = True
def __exit__(self, *args, **kwargs):
if self._teardown_on_exit is True:
self._teardown()
def _is_setup(self):
return getattr(self._session_objs, '_setup', False)
def _setup(self):
self._clear()
self._patch_query_set()
self._patch_model_base()
self._session_objs._setup = True
self._teardown_on_exit = True
def _teardown(self):
self._clear()
self._session_objs._setup = False
def _clear(self):
"""Clear all existing session objects."""
self._session_objs.objs = {}
def _patch_query_set(self):
"""Monkey patches QuerySet to return cached objects."""
if getattr(QuerySet, '_session_enabled', False) is True:
return
f = QuerySet.iterator
def _gen(q, *args, **kwargs):
itr = f(q, *args, **kwargs)
for item in itr:
yield self.add(item)
def _s_iterator(q, *args, **kwargs):
if self._is_setup():
return _gen(q, *args, **kwargs)
else:
return f(q, *args, **kwargs)
QuerySet.iterator = _s_iterator
QuerySet._session_enabled = True
def _patch_model_base(self):
"""Monkey patches Model to added saved items to session."""
if getattr(Model, '_session_enabled', False) is True:
return
f = Model.save
def _s_save(item, *args, **kwargs):
if self._is_setup():
if is_persistent(item) is True:
# updating this object
if ('force_insert' not in kwargs) and ('force_update' not in kwargs):
kwargs['force_update'] = True
self.add_with_dup_check(item)
f(item, *args, **kwargs)
else:
# inserting this object
if ('force_insert' not in kwargs) and ('force_update' not in kwargs):
kwargs['force_insert'] = True
f(item, *args, **kwargs)
self.add_with_dup_check(item)
else:
return f(item, *args, **kwargs)
Model.save = _s_save
Model._session_enabled = True
def _obj_key(self, cls, pk):
"""Returns session key for an object."""
cls_key = getattr(cls, '_cls_key', None)
if cls_key is None:
cls_key = cls.__module__ + '.' + cls.__name__
cls._cls_key = cls_key
return '%s:%s' % (cls_key, pk)
def add(self, item):
"""
Adds an object to the session.
Added object is returned.
If item's key matches existing key,
existing object is returned.
"""
if is_persistent(item) is False:
raise GraphError('Cannot add unpersisted object.')
key = self._obj_key(item.__class__, item.pk)
existing = self._session_objs.objs.get(key, None)
if existing is None:
self._session_objs.objs[key] = item
return item
else:
return existing
def add_with_dup_check(self, item):
"""Add a persistent object to the session."""
if self.add(item) is not item:
raise GraphError('Instance with identical id already exists in session.')
|
More like this
- Template tag - list punctuation for a list of items by shapiromatron 1 year ago
- JSONRequestMiddleware adds a .json() method to your HttpRequests by cdcarter 1 year ago
- Serializer factory with Django Rest Framework by julio 1 year, 7 months ago
- Image compression before saving the new model / work with JPG, PNG by Schleidens 1 year, 8 months ago
- Help text hyperlinks by sa2812 1 year, 8 months ago
Comments
Please login first before commenting.