from django.db import models

# Stores up to 64 categories (assuming MySQL 64-bit integers)
MAX_CATEGORIES = 64

class CategoriesField(models.Field):
    __metaclass__ = models.SubfieldBase
    
    def __init__(self, categories=None, *args, **kwds):
        self.categories = categories or []
        assert len(self.categories) < MAX_CATEGORIES, "Too many categories!"
        super(CategoriesField, self).__init__(*args, **kwds)
    
    def get_internal_type(self):
        return "IntegerField"
    
    def to_python(self, value):
        if not value:
            return set()
        if isinstance(value, int) or isinstance(value, long):
            cats = set()
            index = 0
            while value:
                if value % 2:
                    cats.add(self.categories[index])
                index += 1
                value = value >> 1
            return cats
        return value
    
    def get_db_prep_value(self, value):
        if not value:
            return 0
        if isinstance(value, int):
            return value
        value = set(value)
        db_value = 0
        for index, category in enumerate(self.categories):
            if category in value:
                db_value = db_value | (1 << index)
        return db_value


############# Tests ############
from django.test import TestCase

CHEESES = ['Red Leicester', 'Tilsit', 'Caerphilly', 'Bel Paese',
           'Illchester', 'Gouda', 'Venezuelan Beaver']

class CheeseShop(models.Model):
    cheeses = CategoriesField(CHEESES, blank=True)

class CategoriesFieldTests(TestCase):
    def setUp(self):
        self.cheeses = ['Venezuelan Beaver', 'Tilsit', 'Illchester']
        self.cheese_shop = CheeseShop(cheeses=self.cheeses)
        self.cheese_shop.save()
      
    def testDataIntegrity(self):
        """
        Tests that data remains the same when saved to database
        """
        self.assertEqual(set(self.cheese_shop.cheeses), set(self.cheeses))
    
    def testDataIntegrityFetch(self):
        """
        Tests that data remains the same when fetched from database
        """
        cheese_shop = CheeseShop.objects.get(pk=self.cheese_shop.pk)
        self.assertEqual(set(cheese_shop.cheeses), set(self.cheeses))