from django.db import models
from django.db.models import QuerySet
from django.db.models.constants import LOOKUP_SEP
from django.db.models.query import normalize_prefetch_lookups
from rest_framework import serializers
from rest_framework.utils import model_meta


class ModelViewSetMetaclass(type):
    """
    This metaclass optimizes the queryset using `prefetch_related` and `select_related`.

    Any attribute of `_base_forward_rel` as attributes on either the class or on
    any of its superclasses will be included in the `base_forward_rel`
    they must be ForeignKey fields.

    Explicitly add properties `_related_fields` for forward related objects,
    `_many_to_many_fields` for many to many related objects and  
    `_many_to_one_fields` for many to one related objects that need to be 
    added to the optimized queryset on the `serializer_class`.

    If the `serializer_class` attribute is an instance of `serializers.ModelSerializer` use the
    `serializers.ModelSerializer.Meta.fields` to determine which field should be included in the 
    optimized queryset added calling `prefetch_related` on Many-To-One and Many-To-Many 
    related objects and `select_related` on a forward related objects.
    """

    @classmethod
    def get_many_to_many_rel(cls, info, meta_fields):
        many_to_many_fields = [
            field_name for field_name, relation_info in info.relations.items()
            if relation_info.to_many
        ]
        many_to_many_lookups = []
        for lookup_name, lookup in cls.get_lookups(meta_fields):
            if lookup_name in many_to_many_fields:
                many_to_many_lookups.append(lookup)

        return many_to_many_lookups

    @classmethod
    def get_lookups(cls, fields, strict=False):
        field_lookups = [(lookup.split(LOOKUP_SEP, 1)[0], lookup) for lookup in fields]
        if strict:
            field_lookups = [f for f in field_lookups if LOOKUP_SEP in f[1]]

        return field_lookups

    @classmethod
    def get_many_to_one_rel(cls, info, meta_fields):
        try:
            fields = [
                field_name for field_name, relation_info in info.forward_relations.items()
                if issubclass(type(relation_info[0]), models.ForeignKey)
            ]
        except IndexError:
            pass
        else:
            if fields:
                forward_many_to_many_rel = []
                for lookup_name, lookup in cls.get_lookups(meta_fields, strict=True):
                    if lookup_name in fields:
                        forward_many_to_many_rel.append(lookup)
                return forward_many_to_many_rel

        return []

    @classmethod
    def get_forward_rel(cls, info, meta_fields):
        return [
            field_name for field_name, relation_info in info.forward_relations.items()
            if field_name in meta_fields and not relation_info.to_many
        ]

    def __new__(cls, name, bases, attrs):
        serializer_class = attrs.get('serializer_class', None)
        many_to_many_fields = many_to_one_fields = related_fields = []

        info = None
        base_forward_rel = list(attrs.pop('_base_forward_rel', ()))

        for base in reversed(bases):
            if hasattr(base, '_base_forward_rel'):
                base_forward_rel.extend(list(base._base_forward_rel))
        if serializer_class and issubclass(serializer_class, serializers.ModelSerializer):
            base_forward_rel.extend(
                list(getattr(serializer_class, '_related_fields', [])),
            )
            many_to_many_fields.extend(
                list(getattr(serializer_class, '_many_to_many_fields', [])),
            )
            many_to_one_fields.extend(
                list(getattr(serializer_class, '_many_to_one_fields', [])),
            )
            if hasattr(serializer_class.Meta, 'model'):
                meta_fields = []
                info = model_meta.get_field_info(serializer_class.Meta.model)
                if hasattr(serializer_class.Meta, 'fields'):
                    meta_fields = list(serializer_class.Meta.fields)
                elif hasattr(serializer_class.Meta, 'exclude'):
                    meta_fields = [
                        fname for fname in info.fields.keys()
                        if fname not in serializer_class.Meta.exclude
                    ]
                many_to_many_fields.extend(meta_fields)
                many_to_one_fields.extend(meta_fields)
                base_forward_rel.extend(meta_fields)

        if info:
            many_to_many_fields = cls.get_many_to_many_rel(info, set(many_to_many_fields))
            many_to_one_fields = cls.get_many_to_one_rel(info, set(many_to_one_fields))
            related_fields = cls.get_forward_rel(info, set(base_forward_rel))

        if 'queryset' in attrs:
            queryset = attrs['queryset'] or QuerySet()
            if many_to_many_fields:
                queryset = queryset.prefetch_related(
                    *normalize_prefetch_lookups(set(many_to_many_fields + many_to_one_fields)),
                )
            if related_fields:
                queryset = queryset.select_related(*related_fields)
            attrs['queryset'] = queryset.all()
        return super(OptimizeRelatedModelViewSetMetaclass, cls).__new__(cls, name, bases, attrs)


#-------------------------------------------------
# Usage 
#-------------------------------------------------
from django.utils import six
from rest_framework import viewsets


@six.add_metaclass(ModelViewSetMetaclass)
class MyModelViewSet(viewsets.ModelViewSet):
    """
    API Endpoint for MyModel which should optimize the queryset base on the fields declared 
    on the serializer.
    """
    queryset = MyModel.objects.all()
    serializer_class = MySerializer.  # Used to determine which fields should be prefetched