from django.db import models from django.db.models import QuerySet from django.db.models.constants import LOOKUP_SEP from django.db.models.query import normalize_prefetch_lookups from rest_framework import serializers from rest_framework.utils import model_meta class ModelViewSetMetaclass(type): """ This metaclass optimizes the queryset using `prefetch_related` and `select_related`. Any attribute of `_base_forward_rel` as attributes on either the class or on any of its superclasses will be included in the `base_forward_rel` they must be ForeignKey fields. Explicitly add properties `_related_fields` for forward related objects, `_many_to_many_fields` for many to many related objects and `_many_to_one_fields` for many to one related objects that need to be added to the optimized queryset on the `serializer_class`. If the `serializer_class` attribute is an instance of `serializers.ModelSerializer` use the `serializers.ModelSerializer.Meta.fields` to determine which field should be included in the optimized queryset added calling `prefetch_related` on Many-To-One and Many-To-Many related objects and `select_related` on a forward related objects. """ @classmethod def get_many_to_many_rel(cls, info, meta_fields): many_to_many_fields = [ field_name for field_name, relation_info in info.relations.items() if relation_info.to_many ] many_to_many_lookups = [] for lookup_name, lookup in cls.get_lookups(meta_fields): if lookup_name in many_to_many_fields: many_to_many_lookups.append(lookup) return many_to_many_lookups @classmethod def get_lookups(cls, fields, strict=False): field_lookups = [(lookup.split(LOOKUP_SEP, 1)[0], lookup) for lookup in fields] if strict: field_lookups = [f for f in field_lookups if LOOKUP_SEP in f[1]] return field_lookups @classmethod def get_many_to_one_rel(cls, info, meta_fields): try: fields = [ field_name for field_name, relation_info in info.forward_relations.items() if issubclass(type(relation_info[0]), models.ForeignKey) ] except IndexError: pass else: if fields: forward_many_to_many_rel = [] for lookup_name, lookup in cls.get_lookups(meta_fields, strict=True): if lookup_name in fields: forward_many_to_many_rel.append(lookup) return forward_many_to_many_rel return [] @classmethod def get_forward_rel(cls, info, meta_fields): return [ field_name for field_name, relation_info in info.forward_relations.items() if field_name in meta_fields and not relation_info.to_many ] def __new__(cls, name, bases, attrs): serializer_class = attrs.get('serializer_class', None) many_to_many_fields = many_to_one_fields = related_fields = [] info = None base_forward_rel = list(attrs.pop('_base_forward_rel', ())) for base in reversed(bases): if hasattr(base, '_base_forward_rel'): base_forward_rel.extend(list(base._base_forward_rel)) if serializer_class and issubclass(serializer_class, serializers.ModelSerializer): base_forward_rel.extend( list(getattr(serializer_class, '_related_fields', [])), ) many_to_many_fields.extend( list(getattr(serializer_class, '_many_to_many_fields', [])), ) many_to_one_fields.extend( list(getattr(serializer_class, '_many_to_one_fields', [])), ) if hasattr(serializer_class.Meta, 'model'): meta_fields = [] info = model_meta.get_field_info(serializer_class.Meta.model) if hasattr(serializer_class.Meta, 'fields'): meta_fields = list(serializer_class.Meta.fields) elif hasattr(serializer_class.Meta, 'exclude'): meta_fields = [ fname for fname in info.fields.keys() if fname not in serializer_class.Meta.exclude ] many_to_many_fields.extend(meta_fields) many_to_one_fields.extend(meta_fields) base_forward_rel.extend(meta_fields) if info: many_to_many_fields = cls.get_many_to_many_rel(info, set(many_to_many_fields)) many_to_one_fields = cls.get_many_to_one_rel(info, set(many_to_one_fields)) related_fields = cls.get_forward_rel(info, set(base_forward_rel)) if 'queryset' in attrs: queryset = attrs['queryset'] or QuerySet() if many_to_many_fields: queryset = queryset.prefetch_related( *normalize_prefetch_lookups(set(many_to_many_fields + many_to_one_fields)), ) if related_fields: queryset = queryset.select_related(*related_fields) attrs['queryset'] = queryset.all() return super(OptimizeRelatedModelViewSetMetaclass, cls).__new__(cls, name, bases, attrs) #------------------------------------------------- # Usage #------------------------------------------------- from django.utils import six from rest_framework import viewsets @six.add_metaclass(ModelViewSetMetaclass) class MyModelViewSet(viewsets.ModelViewSet): """ API Endpoint for MyModel which should optimize the queryset base on the fields declared on the serializer. """ queryset = MyModel.objects.all() serializer_class = MySerializer. # Used to determine which fields should be prefetched