from django.db.transaction import managed, enter_transaction_management, is_dirty, rollback, commit, leave_transaction_management, _transaction_func

import threading

# storage of nested count
_tl = threading.local()

def nested_commit_on_success(using=None):

   def entering(using):
     lev = getattr(_tl,"level",0)
     lev += 1
     _tl.level = lev
     if lev >= 2: # is it nested ?
       return # yes it's nested, do nothing
     else:
       # first time, enter transaction
       enter_transaction_management(using=using)
       managed(True, using=using)

   def exiting(exc_value, using):
     lev = _tl.level
     _tl.level -= 1
     if lev >= 2: # is it nested ?
       return # yes, do nothing
     # last time, must do correct transaction ending
     try:
         if exc_value is not None:
             if is_dirty(using=using):
                 rollback(using=using)
         else:
             if is_dirty(using=using):
                 try:
                     commit(using=using)
                 except:
                     rollback(using=using)
                     raise
     finally:
         leave_transaction_management(using=using)

   return _transaction_func(entering, exiting, using)

