from typing import Dict, Iterable, List

from django.db.models.base import Model
from django.db.models.fields import Field
from django.db.models.manager import Manager

ListOfField = List[Field]
Dict


class HandoverObjects(object):
    """
    Hand over `ForeignKey`s and `ManyToMany` relations on an `Iterable`
    of `Model` instances.
    The instances should come from the same `Model`!
    """

    def __init__(self, objs: Iterable[Model]):
        self.objs = list(objs)
        self.model = self.objs[0]._meta.model
        self.old_pks = [x.pk for x in objs]
        self.fields = [
            x for x in self.model._meta.get_fields()
            if x.is_relation and x.auto_created and not x.concrete]

    def _invalidate_caches_on_relation(self, field_name: str):
        """Invalidate caches on a given relation."""
        for obj in self.objs:
            if not hasattr(obj, '_prefetched_objects_cache') or \
                    field_name not in obj._prefetched_objects_cache:
                continue
            del obj._prefetched_objects_cache[field_name]

    def _process_one_to_x(self):
        """
        Go through all `ForeignKey` relations and update them to point to
        `new_pk`.
        """
        fk_fields = [x for x in self.fields if x.one_to_many or x.one_to_one]
        for field in fk_fields:
            manager = field.related_model.objects
            foreign_field_idname = f'{field.field.name}_id'
            manager.filter(
                **{f'{foreign_field_idname}__in': self.old_pks}).update(
                **{foreign_field_idname: self.new_pk})

            # related_name = field.related_name or f'{field.name}_set'
            # from pudb import set_trace
            # set_trace()
            # relation_manager = getattr(self.model, related_name)
            # foreign_field_id_name = f'{field.field.name}_id'
            # relation_manager.update(**{foreign_field_id_name: self.new_pk})
            self._invalidate_caches_on_relation(field_name=field.name)

    def _m2m_remove_existing_dups(
            self, field_idname: str, other_field_idname: str,
            manager: Manager):
        """
        Stage 1: Remove the old PKs from the field that already have a
        relation with the new PK.
        """
        existing_other_ids = manager.filter(
            **{field_idname: self.new_pk}).values_list(
            other_field_idname, flat=True)
        # delete() will trigger execution on existing_other_ids
        manager.filter(**{
            f'{other_field_idname}__in': existing_other_ids,
            f'{field_idname}__in': self.old_pks}).delete()

    def _m2m_remove_remaining_dups(
        self, field_idname: str, other_field_idname: str,
            manager: Manager):
        """
        Stage 2: After deleting the first duplicates, remove the ones
        that will cause a constraint error later. That is, different old
        PKs that will be updated to the same new PK, causing the error.
        """
        to_update = manager.filter(**{f'{field_idname}__in': self.old_pks})
        pks_to_remove = set()
        rel_pks = dict()  # Dict[int, bool]
        for rel_obj in to_update:  # type: Model
            other_id = getattr(rel_obj, other_field_idname)
            if rel_pks.get(other_id):
                pks_to_remove.add(rel_obj.pk)
            rel_pks[getattr(rel_obj, other_field_idname)] = True
        if pks_to_remove:
            manager.filter(pk__in=pks_to_remove).delete()

    def _m2m_merge_field(self, field: Field):
        """Merge a `ManyToMany` relation on a passed field."""
        m2m_model = field.through
        manager = m2m_model.objects
        field_name = self.model._meta.model_name
        field_idname = f'{field_name}_id'
        other_field_name = [
            x for x in m2m_model._meta.get_fields()
            if x.many_to_one and x.name != field_name][0].name
        other_field_idname = f'{other_field_name}_id'
        # Delete the old IDs where others existing with new_pk
        self._m2m_remove_existing_dups(
            field_idname=field_idname, other_field_idname=other_field_idname,
            manager=manager)
        # Look for duplicates on the remaining PKs and remove them
        if len(self.objs) > 1:
            # Only necessary when more than 1 objects` relations updated
            self._m2m_remove_remaining_dups(
                field_idname=field_idname,
                other_field_idname=other_field_idname, manager=manager)
        # Update any remaining to new_pk
        manager.filter(**{f'{field_idname}__in': self.old_pks}).update(
            **{field_idname: self.new_pk})

    def _process_many_to_many(self):
        """Go through all `ManyToMany` relations and change them."""
        fk_fields = [x for x in self.fields if x.many_to_many]
        for field in fk_fields:
            self._m2m_merge_field(field=field)
            field_name = self.model._meta.model_name
            self._invalidate_caches_on_relation(field_name=field_name)

    def process(self, new_pk: int):
        """
        Go through all `ForeignKey` relations and update them to point to
        `new_pk`.

        @see
        https://docs.djangoproject.com/en/2.1/ref/models/meta/#migrating-from-the-old-api
        """
        self.new_pk = new_pk
        self._process_one_to_x()
        self._process_many_to_many()