An extension of ManyToManyField for limiting the maximum number of relationships. See the docstring for more information and sample usage.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 | from django.db.models import ManyToManyField, Count, signals
from django.db.models.fields.related import add_lazy_relation
class RelationCardinalityException(Exception):
pass
class MaxCardinalityManyToManyField(ManyToManyField):
'''A ManyToManyField that constrains the maximum number of relationships in
one or both directions.
An upper bound can be set for the forward relationships (``max_cardinality``)
and/or the reverse relationships (``reverse_max_cardinality``). If either is
left undefined (or None), it defaults to unbounded. Attempting to add one or
more relationships that would result in exceeding the bound(s) raises a
:class:`RelationCardinalityException`.
For symmetric relationships, ``max_cardinality`` and ``reverse_max_cardinality``
must be equal. As a shortcut, leaving one of the two undefined defaults to
the other, so just defining one of them is enough.
Example::
class Topping(models.Model):
name = models.CharField(max_length=128, unique=True)
class Pizza(models.Model):
name = models.CharField(max_length=128, unique=True)
toppings = MaxCardinalityManyToManyField(Topping,
max_cardinality=2,
reverse_max_cardinality=3)
>>> mushrooms = Topping.objects.get_or_create(name='mushrooms')[0]
>>> anchovies = Topping.objects.get_or_create(name='anchovies')[0]
>>> mozzarella = Topping.objects.get_or_create(name='mozzarella')[0]
>>> margherita = Pizza.objects.get_or_create(name='margherita')[0]
>>> marinara = Pizza.objects.get_or_create(name='marinara')[0]
>>> sicilian = Pizza.objects.get_or_create(name='sicilian')[0]
>>> california = Pizza.objects.get_or_create(name='california')[0]
>>> # try to exceed max_cardinality through 'toppings'
>>> margherita.toppings.add(mushrooms, anchovies)
>>> margherita.toppings.add(mozzarella)
Traceback (most recent call last):
...
RelationCardinalityException: No more pizza-topping relationships allowed for pizza.pk=1
>>> margherita.toppings.clear()
>>> # try to exceed max_cardinality through 'pizza_set'
>>> for topping in mushrooms, mozzarella:
>>> ... topping.pizza_set = [marinara, sicilian]
>>> anchovies.pizza_set.add(sicilian)
Traceback (most recent call last):
...
RelationCardinalityException: No more pizza-topping relationships allowed for pizza.pk=3
>>> for topping in mushrooms, mozzarella, anchovies:
>>> ... topping.pizza_set.clear()
>>> # try to exceed reverse_max_cardinality through 'pizza_set'.
>>> mushrooms.pizza_set.add(margherita, marinara, sicilian)
>>> mushrooms.pizza_set.add(california)
Traceback (most recent call last):
...
RelationCardinalityException: No more pizza-topping relationships allowed for topping.pk=1
>>> mushrooms.pizza_set.clear()
>>> # try to exceed reverse_max_cardinality through 'toppings'
>>> for pizza in margherita, marinara, sicilian:
... pizza.toppings = [mushrooms, mozzarella]
>>> california.toppings.add(mushrooms)
RelationCardinalityException: No more pizza-topping relationships allowed for topping.pk=3
'''
def __init__(self, to, **kwargs):
self.max_cardinality = kwargs.pop('max_cardinality', None)
self.reverse_max_cardinality = kwargs.pop('reverse_max_cardinality', None)
super(MaxCardinalityManyToManyField,self).__init__(to, **kwargs)
if self.rel.symmetrical:
if self.reverse_max_cardinality is None:
self.reverse_max_cardinality = self.max_cardinality
elif self.max_cardinality is None:
self.max_cardinality = self.reverse_max_cardinality
elif self.max_cardinality != self.reverse_max_cardinality:
raise ValueError('Symmetrical relationships must have equal '
'forward and reverse max cardinality')
def contribute_to_class(self, cls, name):
super(MaxCardinalityManyToManyField, self).contribute_to_class(cls, name)
if self.max_cardinality or self.reverse_max_cardinality:
through = self.rel.through
if through:
if isinstance(through, basestring):
add_lazy_relation(cls, self, through,
lambda self, through, cls: self.__connect_through_signals(through))
else:
self.__connect_through_signals(through)
def __connect_through_signals(self, through):
def validate_cardinalities(sender, instance, **kwargs):
pk = instance._get_pk_val()
# XXX: _base_manager or _default_manager ?
exists = pk is not None and instance.__class__._base_manager.filter(pk=pk).exists()
if not exists:
self._validate_cardinality(getattr(instance, self.m2m_column_name()),
reverse=False)
self._validate_cardinality(getattr(instance, self.m2m_reverse_name()),
reverse=True)
signals.pre_save.connect(validate_cardinalities, sender=through, weak=False)
def m2m_validate_cardinalities(sender, instance, action, reverse, pk_set, **kwargs):
if action != 'pre_add' or not pk_set:
return
if reverse:
self._validate_cardinality(*pk_set, reverse=False)
self._validate_cardinality(instance._get_pk_val(),
reverse=True, num_added=len(pk_set))
else:
self._validate_cardinality(instance._get_pk_val(),
reverse=False, num_added=len(pk_set))
self._validate_cardinality(*pk_set, reverse=True)
signals.m2m_changed.connect(m2m_validate_cardinalities, sender=through, weak=False)
def _validate_cardinality(self, *pks, **kwargs):
if not kwargs['reverse'] or self.rel.symmetrical:
field_name = self.m2m_field_name()
threshold = self.max_cardinality
else:
field_name = self.m2m_reverse_field_name()
threshold = self.reverse_max_cardinality
if threshold is None:
return
threshold -= kwargs.get('num_added', 1)
for pk, count in self._get_counts(field_name, pks).iteritems():
if count > threshold:
raise RelationCardinalityException('No more %s allowed for %s.pk=%s' % (
unicode(self.rel.through._meta.verbose_name_plural), field_name, pk))
def _get_counts(self, field_name, pks):
pk2count = dict.fromkeys(pks, 0) # ensure that all pks have a count
pk2count.update(
self.rel.through._default_manager.values_list(field_name).
filter(**{field_name+'__in': pk2count}).annotate(Count(field_name)))
return pk2count
|
More like this
- Template tag - list punctuation for a list of items by shapiromatron 1 year ago
- JSONRequestMiddleware adds a .json() method to your HttpRequests by cdcarter 1 year ago
- Serializer factory with Django Rest Framework by julio 1 year, 7 months ago
- Image compression before saving the new model / work with JPG, PNG by Schleidens 1 year, 8 months ago
- Help text hyperlinks by sa2812 1 year, 8 months ago
Comments
Please login first before commenting.