from abc import abstractmethod from django.db import models from django.db.models import query class PolymorphicQuerySet(query.QuerySet): def get(self, *args, **kwargs): if kwargs.has_key('model_class'): cls = kwargs.pop('model_class') kwargs['mod'] = cls.__module__ kwargs['cls'] = cls.__name__ return super(PolymorphicQuerySet, self).get(*args, **kwargs) def filter(self, *args, **kwargs): if kwargs.has_key('model_class'): cls = kwargs.pop('model_class') kwargs['mod'] = cls.__module__ kwargs['cls'] = cls.__name__ return super(PolymorphicQuerySet, self).filter(*args, **kwargs) class PolymorphicManager(models.Manager): def get_query_set(self): return PolymorphicQuerySet(self.model, using=self._db) def get(self, *args, **kwargs): return self.get_query_set().get(*args, **kwargs) def filter(self, *args, **kwargs): return self.get_query_set().filter(*args, **kwargs) class PolymorphicModel(models.Model): classes = dict() mod = models.CharField(max_length=50) cls = models.CharField(max_length=30) class Meta: abstract = True @staticmethod def __new__(cls, *args, **kwargs): if len(args) > len(cls._meta.fields): raise IndexError("Number of args exceeds number of fields") c = None m = None fields_iter = iter(cls._meta.fields) for val, field in itertools.izip(args, fields_iter): if field.name == 'cls': c = val elif field.name == 'mod': m = val if c is None: c = kwargs.get('cls', None) if m is None: m = kwargs.get('mod', None) if c is not None: assert m is not None m = str(m) c = str(c) if not Model.classes.has_key(c): cls = getattr(__import__(m, globals(), locals(), [c]), c) PolymorphicModel.classes[(m, c)] = cls else: cls = Model.classes[(m, c)] return super(PolymorphicModel, cls).__new__(cls, *args, **kwargs) def save(self, *args, **kwargs): if not self.cls: self.mod = self.__class__.__module__ self.cls = self.__class__.__name__ super(PolymorphicModel, self).save(*args, **kwargs) class SomeBase(PolymorphicModel): a = IntegerField(default=0) b = IntegerField(default=0) @abstractmethod def something(self): pass class SomeDerivedA(SomeBase): class Meta: proxy = True def something(self): self.a += 1 class SomeDerivedB(SomeBase): class Meta: proxy = True def something(self): self.b += 1 class SomeCollection(models.Model): values = models.ManyToManyField(SomeBase) >>> a = SomeDerivedA() >>> a.save() >>> b = SomeDerivedB() >>> b.save() >>> x = SomeCollection() >>> x.values = [a, b] >>> x.save() >>> y = SomeCollection.objects.get(pk=x.pk) >>> y.values.objects[0].something() >>> y.values.objects[0].a 1 >>> y.values.objects[0].b 0 >>> y.values.objects[1].something() >>> y.values.objects[1].a 0 >>> y.values.objects[1].b 1