from django.template.loader_tags import *

register = Library()

class ExtendsNode(Node):
    must_be_first = True

    def __init__(self, nodelist, parent_name, parent_name_expr, template_dirs=None):
        self.nodelist = nodelist
        self.parent_name, self.parent_name_expr = parent_name, parent_name_expr
        self.template_dirs = template_dirs
        self.compiled_parent = None

    def __repr__(self):
        if self.parent_name_expr:
            return "<ExtendsNode: extends %s>" % self.parent_name_expr.token
        return '<ExtendsNode: extends "%s">' % self.parent_name

    def get_parent(self, context):
        if self.parent_name_expr:
            self.parent_name = self.parent_name_expr.resolve(context)
        parent = self.parent_name
        if not parent:
            error_msg = "Invalid template name in 'extends' tag: %r." % parent
            if self.parent_name_expr:
                error_msg += " Got this from the '%s' variable." % self.parent_name_expr.token
            raise TemplateSyntaxError, error_msg
        if hasattr(parent, 'render'):
            return parent # parent is a Template object
        try:
            source, origin = find_template_source(parent, self.template_dirs)
        except TemplateDoesNotExist:
            raise TemplateSyntaxError, "Template %r cannot be extended, because it doesn't exist" % parent
        else:
            return get_template_from_string(source, origin, parent)

    def process_parent( self, context ):
        compiled_parent = self.get_parent(context)
        pos = 0
        while isinstance(compiled_parent.nodelist[pos], TextNode):
            pos += 1
        parent_is_child = isinstance(compiled_parent.nodelist[pos], ExtendsNode)
        parent_blocks = dict([(n.name, n) for n in compiled_parent.nodelist.get_nodes_by_type(BlockNode)])
        for block_node in self.nodelist.get_nodes_by_type(BlockNode):
            try:
                parent_block = parent_blocks[block_node.name]
            except KeyError:
                if parent_is_child:
                    compiled_parent.nodelist[pos].nodelist.append(block_node)
            else:
                parent_block.parent = block_node.parent
                parent_block.add_parent(parent_block.nodelist)
                parent_block.nodelist = block_node.nodelist
        return compiled_parent

    def render(self, context):
        if self.compiled_parent:
            compiled_parent = self.compiled_parent
        else:
            compiled_parent = self.process_parent(context)
            # cache it, if static
            if self.parent_name:
                self.compiled_parent = compiled_parent

        return compiled_parent.render(context)


@register.tag
def extends(parser, token):
    bits = token.contents.split()
    if len(bits) != 2:
        raise TemplateSyntaxError, "'%s' takes one argument" % bits[0]
    parent_name, parent_name_expr = None, None
    if bits[1][0] in ('"', "'") and bits[1][-1] == bits[1][0]:
        parent_name = bits[1][1:-1]
    else:
        parent_name_expr = parser.compile_filter(bits[1])
    nodelist = parser.parse()
    if nodelist.get_nodes_by_type(ExtendsNode):
        raise TemplateSyntaxError, "'%s' cannot appear more than once in the same template" % bits[0]
    return ExtendsNode(nodelist, parent_name, parent_name_expr)