from datetime import timedelta
from django.core.cache import cache
from hashlib import md5
from importlib import import_module
from random import randint
from redis import Redis
from rq import Queue


EXPIRY_VARIANCE = 0.2


redis_conn = Redis()
redis_queue = Queue(connection=redis_conn)


class CacheWarmingUp(Exception):
    pass


# Based on http://djangosnippets.org/snippets/564/
# Modified to stagger cache results (+/- 20% to cache durations)
# Modified to prevent cache stampeding (on cache miss, inserts a placeholder while regenerating a hit)
# Modified to allow correct storage of None results without triggering cache expiry
# Modified to provide asynchronous cache regeneration through python-rq
# Cached objects should customize their __repr__ response to include updated timestamps and primary keys
def cache_result(*decorator_args, **decorator_kwargs):

    # Extract the static parameters for this decorator
    decorator_cache_key = decorator_kwargs.pop("cache_key_override", None)
    duration = timedelta(*decorator_args, **decorator_kwargs)
    seconds = int((duration.microseconds + (duration.seconds + duration.days * 24 * 3600) * 1e6) / 1e6)
    prevent_asynchronous_deferral = decorator_kwargs.pop("cache_prevent_async", seconds < 600)

    def doCache(original_function):

        # If caching for 0 seconds, don't even bother
        if not seconds:
            return original_function

        def wrapped_function(*function_args, **function_kwargs):

            # Generate the key from the function name and given arguments
            skip_cache_read = function_kwargs.pop("skip_cache_read", False)
            function_cache_key = function_kwargs.pop("cache_key_override", None)
            if function_cache_key or decorator_cache_key:
                unhashed_key = function_cache_key or decorator_cache_key
                key = unhashed_key
            else:
                unhashed_key = unicode([
                    original_function.__module__,
                    original_function.func_name,
                    function_args,
                    function_kwargs,
                ])
                key = md5(unhashed_key).hexdigest()
            flag = key + "flag"

            # Determine what's already stored in cache
            cached_flag = cache.get(flag) if not skip_cache_read else None
            cached_result = cache.get(key) if not skip_cache_read else None

            if cached_flag:

                # If a result exists, return it.
                if cached_result:
                    return cached_result[0]

                # If the flag is set but we have no result then the cache is warming up
                else:
                    raise CacheWarmingUp(unhashed_key)

            else:
                # Protect against cache stampeding
                cache.set(flag, True, 10 + pow(seconds, 0.5))

                # Enqueue a task to regenerate the result, returning a stale result in the interim
                if cached_result and not prevent_asynchronous_deferral:
                    redis_queue.enqueue(
                        refresh_cache,
                        key,
                        flag,
                        seconds,
                        original_function.__module__,
                        original_function.func_name,
                        function_args,
                        function_kwargs,
                    )
                    return cached_result[0]

                # If we have no flag and no result, generate it right now (aka warm the cache)
                else:
                    result = original_function(*function_args, **function_kwargs)
                    store_cache_result(key, flag, seconds, result)
                    return result

        return wrapped_function

    return doCache


# A common way of storing both a result and a flag into the cache
def store_cache_result(key, flag, seconds, result):
    cache.set(flag, True, seconds + randint(
        -int(seconds * EXPIRY_VARIANCE),
        +int(seconds * EXPIRY_VARIANCE),
    ))
    cache.set(key, (result, ), seconds * 100)


# The function by which our job queue will process deferred calls
def refresh_cache(key, flag, seconds, function_module, function_name, function_args, function_kwargs):
    function_kwargs["skip_cache_read"] = True
    try:
        function = getattr(import_module(function_module), function_name)
    except AttributeError:
        function = getattr(function_args[0], function_name)
        function_args = function_args[1:]
        function_kwargs["cache_key_override"] = key
    function(*function_args, **function_kwargs)