from django.conf import settings
from django.db.models import Model, ForeignKey, ManyToManyField, OneToOneField

def get_all_models():
    all_models = set()
    
    for app in settings.INSTALLED_APPS:
        try:
            models = __import__(app + '.models', {}, {}, 'models')
        except ImportError:
            continue
        for attr in dir(models):
            obj = getattr(models, attr)
            if isinstance(obj, object) and \
               hasattr(obj, '__bases__') and \
               Model in obj.__bases__:
                    all_models.add(obj)
    return all_models

def get_relns(all_models):
    foreignkeys = []
    one_to_one = []
    many_to_many = []
    
    for model in all_models:
        for field in model._meta.fields:
            if isinstance(field, OneToOneField): # must come before FK
                to = field.rel.to
                if to not in all_models:
                    raise ValueError
                print "%s o2o to %s" % (repr(model), repr(to))
                one_to_one.append((model, to))
            elif isinstance(field, ForeignKey):
                to = field.rel.to
                if to not in all_models:
                    raise ValueError
                print "%s fk to %s" % (repr(model), repr(to))
                foreignkeys.append((model, to))
            elif isinstance(field, ManyToManyField):
                to = field.rel.to
                if to not in all_models:
                    raise ValueError
                print "%s m2m to %s" % (repr(model), repr(to))
                many_to_many.append((model, to))
    return foreignkeys, one_to_one, many_to_many
                   
def class_name(cls):
    name = repr(cls)
    name = name[name.index("'")+1:]
    name = name[:name.index("'")]
    return quoted(name)

# A lot of code below this line is copied from some stuff I wrote for my paper.
# Should probably be cleaned up.

import os, collections

BG_COLOR = 'black'
FG_COLOR = 'white'

RED    = '"#B30000"'
GREEN  = '"#008F00"'
PURPLE = '"#24006B"'
YELLOW = '"#B38F00"'
GREY   = '"#8F8F8F"'
COLORS = [YELLOW, GREEN, PURPLE, RED]

def quoted(s):
    return '"%s"' % str(s)
    
def nodename(state):
    return quoted(state[1]) if (state[1] != '') else quoted('start')

def attrs_string(attrs):
    return '[%s]' %  ', '.join(["%s=%s" % (k, v) for k, v in attrs.iteritems()])
    
def model_attrs_string(model, extra_node_attrs, labels):
    attrs = {}
    attrs['color'] = GREEN if 'django' in repr(model) else PURPLE
    if not labels:
        attrs['label'] = '""'
    
    attrs.update(extra_node_attrs.get(model, {}))
    return attrs_string(attrs)

def edge_attrs_string(c):
    return attrs_string({
        'label': quoted(c),
     })

def fg_attrs_string(extra={}):
    attrs = {
        'color': FG_COLOR,
        'fontcolor': FG_COLOR
      }
    attrs.update(extra)
    return attrs_string(attrs)

def relns_to_dot(models, foreignkeys, one_to_one, many_to_many, extra_node_attrs={}, **kwargs):
    edge_len = kwargs.pop('edgelen', 3)
    labels = kwargs.pop('labels', True)
    result = []
    result.append('digraph django_model_relationships {')
    result.append('    graph [bgcolor=%s];' % BG_COLOR)
    result.append('    node %s;' % fg_attrs_string())
    result.append('    edge %s;' % fg_attrs_string({'len': 3}))
    result.append('    rankdir=LR;')
    result.append('    size="30,20!";')   
    for model in sorted(models):
        result.append('    %s %s;' % (class_name(model), model_attrs_string(model, extra_node_attrs, labels)))
    for source, dest in foreignkeys:
        edge_label = 'fk'
        result.append('    %s -> %s %s;' % (class_name(source), class_name(dest), edge_attrs_string(edge_label)))
    for source, dest in one_to_one:
        edge_label = 'o2o'
        result.append('    %s -> %s %s;' % (class_name(source), class_name(dest), edge_attrs_string(edge_label)))
    for source, dest in many_to_many:
        edge_label = 'm2m'
        result.append('    %s -> %s %s;' % (class_name(source), class_name(dest), edge_attrs_string(edge_label)))
    result.append('}')
    return '\n'.join(result)

def write_dot(dot, filename, neato_options={}, *args, **kwargs):
    program = kwargs.pop('program', 'dot')
    f = open(filename, 'w')
    f.write(dot.encode('utf-8'))
    f.close()
    neato_options_str = ' '.join("-G%s=%s" % pair for pair in neato_options.iteritems())
    cmd = "%s %s %s -T png -o %s.png" % (program, filename, neato_options_str, filename)
    os.system(cmd)

def main():
    filename = '/tmp/model_graph.dot'
    models = get_all_models()
    relns = get_relns(models)
    dot = relns_to_dot(models, *relns)
    write_dot(dot, filename)

if __name__ == '__main__':
    main()