import re

from django import forms
from django.forms.fields import Field, EMPTY_VALUES
from django.forms.util import smart_unicode

ABN_DIGITS_RE = re.compile(r'^(\d{11})$')

class ABNField(Field):
    default_error_messages = {
        'invalid': u'Australian Business Numbers must contain 11 digits.',
    }
    weights = [10, 1, 3, 5, 7, 9, 11, 13, 15, 17, 19]
    
    def clean(self, value):
        super(ABNField, self).clean(value)
        if value in EMPTY_VALUES:
            return u''
        
        # remove spaces from value
        value = re.sub('(\s+)', '', smart_unicode(value))
        abn_match = ABN_DIGITS_RE.search(value)
        if abn_match:
            abn = u'%s' % abn_match.group(1)
            return self.validate_abn(abn)
        raise ValidationError(self.error_messages['invalid'])
        
    def validate_abn(self, value):
        # http://www.ato.gov.au/businesses/content.asp?doc=/content/13187.htm&pc=001/003/021/002/001&mnu=610&mfp=001/003&st=&cy=1
        
        # convert to list of integers
        values = [int(i) for i in value]
        
        # subtract 1 from the first digit
        values[0] = values[0] - 1
        
        # multiple by weights
        for index, digit in enumerate(values):
            values[index] = digit * self.weights[index]
        
        # sum together 
        total_value = sum(values)
        
        # check if we can divide by 89 and have no remainder
        remainder = total_value % 89
        
        if remainder != 0:
            raise forms.ValidationError(self.error_messages['invalid'])
            
        return value


ACN_DIGITS_RE = re.compile(r'^(\d{9})$')


class ACNField(Field):
    default_error_messages = {
        'invalid': u'Australian Company Numbers must contain 9 digits.',
    }
    weights = [8, 7, 6, 5, 4, 3, 2, 1]

    def clean(self, value):
        super(ACNField, self).clean(value)
        
        if value in EMPTY_VALUES:
            return u''

        # remove spaces from value
        value = re.sub('(\s+)', '', smart_unicode(value))
        acn_match = ACN_DIGITS_RE.search(value)
        if acn_match:
            acn = u'%s' % acn_match.group(1)
            return self.validate_acn(acn)
        raise ValidationError(self.error_messages['invalid'])

    def validate_acn(self, value):
        # http://www.asic.gov.au/asic/asic.nsf/byheadline/Australian+Company+Number+(ACN)+Check+Digit
        # modified modulus 10 calculation

        # convert to list of integers
        values = [int(i) for i in value]

        # this must equal our check digit
        last_digit = values.pop(8)
        
        # multiple by weights
        for index, digit in enumerate(values):
            values[index] = digit * self.weights[index]

        # sum together 
        total_value = sum(values)

        # get the remainder of dividing by 10
        remainder = total_value % 10
        
        # subtract the remainder from 10
        check_digit = 10 - remainder
        
        if last_digit != check_digit:
            raise forms.ValidationError(self.error_messages['invalid'])

        return value