from django import template
from django.template import Library, Node, Variable, loader
from django.template.context import Context

register = template.Library()

def trim(s):
    return s.strip()

def stripquotes(s):
    if s[0] == s[-1] and s[0] in ('"',"'"):
        return s[1:-1]
    return s

class PartialTemplateNode(Node):
    def __init__(self, template_name, context_vars):
        self.template_name = template_name
        self.context_vars = dict(zip(context_vars.keys(),
            map(Variable, context_vars.values())))

    def render(self, context):
        template = loader.get_template(self.template_name)
        return template.render(Context(dict(zip(self.context_vars.keys(),
            map(lambda v: v.resolve(context), self.context_vars.values())))))

@register.tag(name='partial')
def partial_template(parser, token):
    bits = token.split_contents()
    tag, template, rest = bits[0], trim(stripquotes(bits[1])), ''.join(bits[2:])
    pairs = rest.split(',')
    context_vars = {}
    for pair in pairs:
        x = map(stripquotes, map(trim, pair.split('=')))
        if len(x) == 1:
            context_vars['item'] = x[0]
        else:
            context_vars[x[0]] = x[1]
    return PartialTemplateNode(template, context_vars)