Login

orm_tools

Author:
limscoder
Posted:
January 2, 2011
Language:
Python
Version:
1.2
Score:
0 (after 0 ratings)

This module contains classes that add new behavior to Django's ORM. Classes include:

Session

  • Forces QuerySet objects to return identical instances when objects with the same primary key are queried.
  • Similar to SQLAlchemy Session

GraphSaver

  • Save entire object graphs at once.
  • Automatically detects object dependencies and saves them in the correct order.

Collection

  • Easier one-to-many relationships.

Instructions and more information on limscoder.com.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
"""
Tools for persisting object graphs with Django ORM.

These classes make some assumptions about your datamodel:
Object graph does not have cyclic SQL insertion order dependencies.
Objects with pk=None are being inserted, all other pk values indicate an update.
"""

import threading

from django.db.models.base import Model
from django.db.models.query import QuerySet
from django.db.models.fields.related import ForeignKey

def is_managed(item):
    """Return True if item is managed by Django ORM."""
    
    return isinstance(item, Model)

def is_persistent(item):
    """Return True if an item has already been saved."""
    
    return item.pk is not None

class GraphError(Exception):
    pass

class Collection(object):
    """
    Allows iterables to be persisted in an object graph in an easier way.
    """
    
    @classmethod
    def set_property(cls, model, attr, parent_attr, set_attr=None):
        """
        Call this method with a model class as the 1st argument to
        create a property function which returns a Collection object.
        """
        
        private_attr = '_' + attr
    
        def _get(self):
            val = getattr(self, private_attr, None)
            if val is None:
                val = cls(self, parent_attr, set_attr=set_attr)
                setattr(self, private_attr, val)
            return val
    
        def _set(self, val):
            setattr(self, private_attr, val)
        setattr(model, attr, property(_get, _set))
    
        collection_attrs = getattr(model, '_collection_attrs', None)
        if collection_attrs is None:
            collection_attrs = []
            setattr(model, '_collection_attrs', collection_attrs)
        collection_attrs.append((attr, private_attr))
    
    def __init__(self, parent, parent_attr, set_attr=None, list_cls=None):
        self.parent = parent
        self.parent_attr = parent_attr
        self.set_attr = set_attr
        
        if list_cls is None:
            list_cls = list
        self.list_cls = list_cls
        
        # Holds list of items if parent is not persisted
        self._item_list = None
        
        # Holds set of items that have been accessed
        self._accessed_items = set()
        
    def __iter__(self):
        for item in self._items:
            self._accessed_items.add(item)
            yield item
            
    def __getitem__(self, k):
        if isinstance(k, slice):
            def _gen():
                items = self._items.__getitem__(k)
                for item in items:
                    self._accessed_items.add(item)
                    yield item
            return _gen()
        else:
            item = self._items.__getitem__(k)
            self._accessed_items.add(item)
            return item
    
    def __len__(self):
        items = self._items
        if isinstance(items, QuerySet):
            return items.count()
        else:
            return len(items)
    
    def _get_accessed_items(self):
        if is_persistent(self.parent) is True:
            # Return all items that have been accessed,
            # but only if they still belong in the collection.
            accessed_items = set()
            for item in self._accessed_items:
                parent = getattr(item, self.parent_attr, None)
                if parent is None:
                    continue
                
                if parent.pk != self.parent.pk:
                    continue
                
                accessed_items.add(item)
            return accessed_items
        else:
            return self._items
    accessed_items = property(_get_accessed_items)
    
    def _get_items(self):
        """Returns all items in the collection."""
        
        if is_persistent(self.parent) is True:
            return self.get_item_query()
        else:
            if self._item_list is None:
                self._item_list = self.list_cls()
            return self._item_list
    _items = property(_get_items)
    
    def add(self, item):
        """Add an item to the collection."""
        
        if is_managed(item) is False:
            # Bad things will happen
            raise GraphError('Cannot add unmanaged object.')
        
        setattr(item, self.parent_attr, self.parent)
        if is_persistent(self.parent) is True:
            item.save()
        else:
            found = False
            for existing_item in self._items:
                if existing_item is item:
                    found = True
                    break
            
            if found is False:
                self._items.append(item)
    
    def remove(self, item, delete=False):
        """Remove an item from the collection."""
        
        if delete is False:
            setattr(item, self.parent_attr, None)
            
        if is_persistent(self.parent) is True:
            if delete is True:
                item.delete()
            else:
                item.save()
        else:
            if item in self._items:
                self._items.remove(item)
    
    def clear(self, delete=False):
        """Remove all items from the collection."""
        
        to_remove = list(self._items)
        for item in to_remove:
            self.remove(item, delete=delete)
    
    def update(self, new):
        """[].update"""
        
        for item in new:
            self.add(item)
    
    def get_item_query(self):
        """Override this method to use a custom QuerySet."""
        
        return getattr(self.parent, self.set_attr).all()

class DependencyList(object):
    """Lists all parents for a child where parent == dependency."""
    
    __slots__ = ['obj', 'deps']
    
    def __init__(self, obj=None, deps=None):

        self.obj = obj
        self.deps = deps

class Dependency(object):
    """Signifies a dependency between two objects in the graph."""

    __slots__ = ['parent', 'field', 'level']

    def __init__(self, parent=None, field=None, level=None):

        self.parent = parent
        self.field = field
        self.level = level

class GraphSaver(object):
    """Provides methods to simplify the persistence of large object graphs."""

    def _add_dep(self, parent, child, field, deps, level):
        """Adds a dependency to the list."""

        dep_list = self._get_dep_list(child, deps)
        if dep_list is None:
            raise GraphError('Dependency list not found!')
        
        for dep in dep_list.deps:
            if dep.parent is parent:
                raise GraphError('Circular dependency detected.')

        dep_list.deps.append(Dependency(
            parent=parent,
            field=field,
            level=level))
        
    def _get_dep_list(self, obj, deps):
        """Returns a list that dependencies can be added to, or None"""
        
        dep_key = id(obj)
        return deps.get(dep_key, None)
    
    def _init_dep_list(self, obj, deps):
        """Returns a list that dependencies can be added to."""
        
        dep_key = id(obj)
        dep_list = deps.get(dep_key, None)
        dep_list = self._get_dep_list(obj, deps)
        if dep_list is None:
            dep_list = DependencyList(obj=obj, deps=[])
            deps[dep_key] = dep_list
        return dep_list

    def _build_deps(self, parent, deps, update=True, level=1):
        
        if self._get_dep_list(parent, deps) is not None:
            # This object has already had it's dependencies added!
            return

        if (is_persistent(parent) is False) or (update is True):
            # Makes sure the parent obj shows
            # up in the dependency list, so that
            # it gets saved!
            self._init_dep_list(parent, deps)

        for name in parent._meta.get_all_field_names():
            field_info = parent._meta.get_field_by_name(name)
            if field_info[2] is False:
                # This is a magic reverse reference to
                # some other related field. Ignore it.
                continue
            
            field = field_info[0]
            if isinstance(field, ForeignKey):
               child = getattr(parent, name)
               if child is not None:
                   # This is a dependency!
                   # Recursively add to deps.
                   self._build_deps(child, deps, update=update, level=level + 1)
                   self._add_dep(parent, child, field, deps, level)
                   
        # Save items in any collections
        self._build_collection_deps(parent, deps, update=update, level=level)
        
    def _build_collection_deps(self, parent, deps, update=True, level=1):
        """Add dependencies from collection objects."""
        
        collection_attrs = getattr(parent.__class__, '_collection_attrs', [])
        for collection_attr in collection_attrs:
            collection = getattr(parent, collection_attr[0], None)
            if collection is not None:
                for item in collection.accessed_items:
                    self._build_deps(item, deps, update=update, level=level)
    
    def _save_deps(self, deps):
        # Group dependencies by their level
        dep_levels = {}
        for dep_list in deps.itervalues():
            max_level = 0
            for dep in dep_list.deps:
                if dep.level > max_level:
                    max_level = dep.level
            
            level_list = dep_levels.get(max_level, None)
            if level_list is None:
                level_list = []
                dep_levels[max_level] = level_list
            level_list.append(dep_list)

        self._save_by_level(dep_levels)

    def _save_by_level(self, levels):
        # Save children by their dependency level
        level_keys = levels.keys()
        level_keys.sort()
        level_keys.reverse()
        for level in level_keys:
            for dep_list in levels[level]:
                self._save_dep_list(dep_list)

    def _save_dep_list(self, dep_list):
        # All children in the list should be the same,
        # so we only need to save the 1st one.
        child = dep_list.obj
        child.save()

        for dep in dep_list.deps:
            setattr(dep.parent, dep.field.name, child)

    def save_many(self, items, update=True):
        """Save multiple object graphs."""
        
        deps = {}
        for item in items:
            if is_managed(item) is False:
                raise GraphError('Cannot save unmanaged item.')
            self._build_deps(item, deps, level=1, update=update)
        self._save_deps(deps)

    def save(self, item, update=True):
        """Save an object graph."""
        
        return self.save_many((item,), update=update)

class Session(object):
    """
    Modifies QuerySets to return consistent object graphs.
    
    This will modify query behavior for all code executed
    during a 'with' statement.
    """
    
    _session_objs = threading.local()
    
    def __init__(self):
        self._entered = False
        self._teardown_on_exit = False
    
    def __enter__(self, *args, **kwargs):
        """Setup session system."""
        
        if self._entered is True:
            raise GraphError('Cannot re-enter session!')
        
        if self._is_setup() is False:
            self._setup()
            
        self._entered = True
            
    def __exit__(self, *args, **kwargs):
        if self._teardown_on_exit is True:
            self._teardown()
            
    def _is_setup(self):
        return getattr(self._session_objs, '_setup', False)
    
    def _setup(self):
        self._clear()
        self._patch_query_set()
        self._patch_model_base()
        self._session_objs._setup = True
        self._teardown_on_exit = True
        
    def _teardown(self):
        self._clear()
        self._session_objs._setup = False
    
    def _clear(self):
        """Clear all existing session objects."""
        
        self._session_objs.objs = {}
    
    def _patch_query_set(self):
        """Monkey patches QuerySet to return cached objects."""

        if getattr(QuerySet, '_session_enabled', False) is True:
            return

        f = QuerySet.iterator
        def _gen(q, *args, **kwargs):
            itr = f(q, *args, **kwargs)
            for item in itr:
                yield self.add(item)
        
        def _s_iterator(q, *args, **kwargs):
            if self._is_setup():
                return _gen(q, *args, **kwargs)
            else:
                return f(q, *args, **kwargs)
        QuerySet.iterator = _s_iterator
        QuerySet._session_enabled = True

    def _patch_model_base(self):
        """Monkey patches Model to added saved items to session."""
       
        if getattr(Model, '_session_enabled', False) is True:
            return
 
        f = Model.save
        def _s_save(item, *args, **kwargs):
            if self._is_setup():
                if is_persistent(item) is True:
                    # updating this object
                    if ('force_insert' not in kwargs) and ('force_update' not in kwargs):
                        kwargs['force_update'] = True
                    self.add_with_dup_check(item)
                    f(item, *args, **kwargs)
                else:
                    # inserting this object
                    if ('force_insert' not in kwargs) and ('force_update' not in kwargs):
                        kwargs['force_insert'] = True
                    f(item, *args, **kwargs)
                    self.add_with_dup_check(item)
            else:
                return f(item, *args, **kwargs)
        Model.save = _s_save
        Model._session_enabled = True

    def _obj_key(self, cls, pk):
        """Returns session key for an object."""
        
        cls_key = getattr(cls, '_cls_key', None)
        if cls_key is None:
            cls_key = cls.__module__ + '.' + cls.__name__
            cls._cls_key = cls_key
        
        return '%s:%s' % (cls_key, pk)

    def add(self, item):
        """
        Adds an object to the session.

        Added object is returned.
        If item's key matches existing key,
        existing object is returned.
        """
        
        if is_persistent(item) is False:
            raise GraphError('Cannot add unpersisted object.')
        
        key = self._obj_key(item.__class__, item.pk)
        existing = self._session_objs.objs.get(key, None)
        if existing is None:
            self._session_objs.objs[key] = item
            return item
        else:
            return existing

    def add_with_dup_check(self, item):
       """Add a persistent object to the session."""

       if self.add(item) is not item:
           raise GraphError('Instance with identical id already exists in session.')

More like this

  1. Template tag - list punctuation for a list of items by shapiromatron 8 months, 4 weeks ago
  2. JSONRequestMiddleware adds a .json() method to your HttpRequests by cdcarter 9 months ago
  3. Serializer factory with Django Rest Framework by julio 1 year, 3 months ago
  4. Image compression before saving the new model / work with JPG, PNG by Schleidens 1 year, 4 months ago
  5. Help text hyperlinks by sa2812 1 year, 5 months ago

Comments

Please login first before commenting.