from django.newforms import *
from django.newforms.widgets import flatatt

def form_decorator(fields = {}, attrs = {}, widgets = {},
	labels = {}, choices = {}):

	"""
	This function helps to add overrides when creating forms from models/instances.
	Pass in dictionary of fields to override certain fields altogether, otherwise
	add widgets or labels as desired.
	
	For example:
	
	class Project(models.Model):
	
			name = models.CharField(maxlength = 100)
			description = models.TextField()
			owner = models.ForeignKey(User)
	
	project_fields = dict(
			owner = None
	)
	
	project_widgets = dict(
			name = forms.TextInput({"size":40}),
			description = forms.Textarea({"rows":5, "cols":40}))
	
	project_labels = dict(
			name = "Enter your project name here"
	)
	
	callback = form_decorator(project_fields, project_widgets, project_labels)
	project_form = forms.form_for_model(Project, formfield_callback = callback)
	
	This saves having to redefine whole fields for example just to change a widget
	setting or label.
	"""
	
	def formfields_callback(f, **kw):
		if f.name in fields:
			# replace field altogether
			field = fields[f.name]
			f.initial = kw.pop("initial", None)
			return field
		if f.name in widgets:
			kw["widget"] = widgets[f.name]
		if f.name in attrs:
			widget = kw.pop("widget", f.formfield().widget)
			if widget :
				widget.attrs.update(attrs[f.name])
				kw["widget"] = widget
		if f.name in labels:
			kw["label"] = labels[f.name]
		if f.name in choices:
			choice_set = choices[f.name]
			if callable(choice_set) : choice_set = choice_set()
			kw["choices"] = choice_set
		return f.formfield(**kw)
	return formfields_callback

class ChainSelectWidget(Widget):
	#This widget uses javascript to build Chain Selects to
	#narrow down ForeignKey object types in an intuitive manner.
	#It is especially useful when the __str__ of the object direct foreign
	#key isn't necessarily unique, and the parent model of it needs 
	#to be looked at.
	#This code uses the Chained Select javascript written by
	#Xin Yang (http://www.yxscripts.com/)
	#This widget must be used on custom views. I had a VERY hard time
	#trying to get it registered into the form_for_model and 
	#form_for_instance helper functions.
	#example:
	###models.py###
	#class A(models.Model):
	#	name=models.CharField()
	#class B(models.Model):
	#	name=models.CharField()
	#	to_A = models.ForeignKey(A)
	#class C(models.Model):
	#	name=models.CharField()
	#	to_B = models.ForeignKey(B)
	
	###views.py###
	#def test(request):
	#	import A,B,C
	#	from CustomWidgets import *
	#	from django.newforms import form_for_model
	#	from django.shortcuts import render_to_response
	#	widget_overwrite=dict(to_B=ChainSelectWidget(order=[(A, 'name'), (B, 'name'), (C, 'name')]))
	#	callback=form_decorator(widgets=widget_overwrite)
	#	modified_form=form_for_model(C, formfield_callback=callback)()
	# return render_to_response('path/to/template.html', {'form': modified_form})
	
	###template.html###
	#...
	#<head>
	#<script language="javascript" src="path/to/chainedselects.js"></script>
	#</head>
	#...
	#<form>
	#{% for field in form %}
	#{{field.label}}: {{field}}
	#{% endfor %}
	#...
	def __init__(self, attrs=None, order=[]):
		#Order is a list of model objects that define the chain select tree
		#it is a list of tuples.  The first value is the model object, the second
		#value the field to order by
		#eg:
		#	order=[(A, 'name'), (B, 'name'), (C, 'name')]
		self.attrs = attrs or {}
		self.html = ''
		self.order = list(order)

	def _buildjs(self, current=None, backtrail=''):
		if current == None:
			current = self.order[0][0].objects.all().order_by(self.order[0][1])
		if len(current) == 0:
			return ''
		self.html +='addOption("%s", "---------", "", 1);\n'%(backtrail)
		if current[0]._meta.module_name == self.order[-1][0]._meta.module_name:
			for end in current:
				self.html += 'addList("%s", "%s", "%s");\n' % (backtrail, str(end), end._get_pk_val())
		else:
			for (base_model, order_set_by) in self.order:
				if base_model._meta.module_name == current[0]._meta.module_name:
					get_set = self.order[self.order.index((base_model, order_set_by))+1][0]._meta.module_name
			for entry in current:
				self.html += 'addList("%s", "%s", "", "%s__%s");\n' % (backtrail, str(entry), backtrail, str(entry))
				self._buildjs(backtrail='%s__%s'%(backtrail, str(entry)), current=getattr(entry, '%s_set'%(get_set)).all().order_by(order_set_by))

	def render(self, name, value, attrs=None):
		self.html += '<table>\n'
		for entry in self.order:
			self.html += '<tr><td align="right">%s:</td>\n' % entry[0]._meta.module_name.capitalize()
			if entry[0] == self.order[-1][0]:
				final_attrs = self.build_attrs(attrs, name=name)
				self.html += '<td><select %s></select></td></tr>\n' % flatatt(final_attrs)
			else:
				self.html += '<td><select name="chain_to_%s__%s"></select></td></tr>\n' % (name, entry[0]._meta.module_name)
		self.html += '</table>\n'
		self.html += '<script language="javascript">\n'
		self.html += 'var disable_empty_list=true;\n'
		self.html += 'var newwindow=0;\n'
		self.html += 'addListGroup("%s", "%s__%s");\n'%(name, name, self.order[0][0].objects.all()[0]._meta.module_name)
		self._buildjs(current=None, backtrail='%s__%s'%(name, self.order[0][0].objects.all()[0]._meta.module_name))
		self.html += '</script>\n'
		self.html += '<script language="javascript">\n'
		self.html += 'initListGroup("%s", '%name
		for entry in self.order:
			if entry[0] == self.order[-1][0]:
				self.html += 'document.forms[0].%s, ' % name
			else:
				self.html += 'document.forms[0].chain_to_%s__%s, ' % (name, entry[0]._meta.module_name) 
		self.html += '"savestate");\n'
		self.html += '</script>\n'
		return u'%s' % self.html