###
# Implementation
###

import tempfile
import os
import errno                                     

from django.core.files import locks
from django.core.files.move import file_move_safe
from django.utils.text import get_valid_filename
from django.core.files.storage import FileSystemStorage, Storage

class OverwritingStorage(FileSystemStorage):
    """ 
    File storage that allows overwriting of stored files.
    """ 
        
    def get_available_name(self, name):
        return name
            
    def _save(self, name, content):
        """
        Lifted partially from django/core/files/storage.py
        """ 
        full_path = self.path(name)
            
        directory = os.path.dirname(full_path)
        if not os.path.exists(directory):        
            os.makedirs(directory)
        elif not os.path.isdir(directory):
            raise IOError("%s exists and is not a directory." % directory)
                
        # This file has a file path that we can move.
        if hasattr(content, 'temporary_file_path'):
            temp_data_location = content.temporary_file_path()
        else:   
            tmp_prefix = "tmp_%s" %(get_valid_filename(name), )
            temp_data_location = tempfile.mktemp(prefix=tmp_prefix,
                                                 dir=self.location)
            try:
                # This is a normal uploadedfile that we can stream.
                # This fun binary flag incantation makes os.open throw an
                # OSError if the file already exists before we open it.
                fd = os.open(temp_data_location,
                             os.O_WRONLY | os.O_CREAT |
                             os.O_EXCL | getattr(os, 'O_BINARY', 0))
                locks.lock(fd, locks.LOCK_EX)
                for chunk in content.chunks():
                    os.write(fd, chunk)
                locks.unlock(fd)
                os.close(fd)
            except Exception, e:
                if os.path.exists(temp_data_location):
                    os.remove(temp_data_location)
                raise

        file_move_safe(temp_data_location, full_path)
        content.close()
                
        if settings.FILE_UPLOAD_PERMISSIONS is not None:
            os.chmod(full_path, settings.FILE_UPLOAD_PERMISSIONS)
                
        return name

###
# Tests (to be run with django-nose, although they could easily be adapted to work with unittest)
###

import os
import shutil
import tempfile

from django.core.files.base import ContentFile as C
from django.core.files import File
from django.conf import settings
from nose.tools import assert_equal

from .storage import OverwritingStorage

class TestOverwritingDefaultStorage(object):
    def setup(self):
        self.location = tempfile.mktemp(prefix="overwriting_storage_test")
        self.storage = OverwritingDefaultStorage(location=self.location)

    def teardown(self):
        shutil.rmtree(self.location)

    def test_new_file(self):
        s = self.storage
        assert not s.exists("foo")
        s.save("foo", C("new"))
        assert_equal(s.open("foo").read(), "new")

    def test_overwriting_existing_file_with_string(self):
        s = self.storage
    
        s.save("foo", C("old"))
        name = s.save("foo", C("new"))           
        assert_equal(s.open("foo").read(), "new")
        assert_equal(name, "foo")

    def test_overwrite_with_file(self):
        s = self.storage

        input_file = s.location + "/input_file"
        with open(input_file, "w") as input:
            input.write("new")
        
        s.save("foo", C("old"))
        name = s.save("foo", File(open(input_file)))

        assert_equal(s.open("foo").read(), "new")
        assert_equal(name, "foo")
        
    def test_upload_fails(self):
        s = self.storage
            
        class Explosion(Exception):
            pass
            
        class ExplodingContentFile(C):
            def __init__(self):
                super(ExplodingContentFile, self).__init__("")
        
            def chunks(self):
                yield "bad chunk"
                raise Explosion("explode!")      
            
        s.save("foo", C("old"))
                
        try:    
            s.save("foo", ExplodingContentFile())
            raise Exception("Oh no! ExplodingContentFile didn't explode.")
        except Explosion:    
            pass
                
        assert_equal(s.open("foo").read(), "old")