from email.MIMEBase import MIMEBase
from email.Utils import formatdate
from email.message import Message
from M2Crypto import BIO, SMIME, X509, Rand

from django.utils.encoding import smart_str
from django.core.mail import EmailMessage, SafeMIMEText, SafeMIMEMultipart, make_msgid, forbid_multi_line_headers
from django.conf import settings

class SafeMessage(Message):
    def __setitem__(self, name, val):
        name, val = forbid_multi_line_headers(name, val)
        Message.__setitem__(self, name, val)

class SecureEmailMessage(EmailMessage):
    cert = 'recipient.pem'
    
    def __init__(self, *args, **kwargs):
        if 'cert' in kwargs:
            self.cert = kwargs['cert']
            del kwargs['cert']
        super(SecureEmailMessage, self).__init__(*args, **kwargs)
    
    def message(self):
        encoding = self.encoding or settings.DEFAULT_CHARSET
        msg = SafeMIMEText(smart_str(self.body, settings.DEFAULT_CHARSET),
                           self.content_subtype, encoding)
        if self.attachments:
            body_msg = msg
            msg = SafeMIMEMultipart(_subtype=self.multipart_subtype)
            if self.body:
                msg.attach(body_msg)
            for attachment in self.attachments:
                if isinstance(attachment, MIMEBase):
                    msg.attach(attachment)
                else:
                    msg.attach(self._create_attachment(*attachment))
        
        buf = BIO.MemoryBuffer(msg.as_string())
        # Seed the PRNG.
        Rand.load_file('randpool.dat', -1)
    
        # Instantiate an SMIME object.
        s = SMIME.SMIME()
        
        # Load target cert to encrypt to.
        x509 = X509.load_cert(self.cert)
        sk = X509.X509_Stack()
        sk.push(x509)
        s.set_x509_stack(sk)
        
        # Set cipher: 3-key triple-DES in CBC mode.
        s.set_cipher(SMIME.Cipher('des_ede3_cbc'))
        
        # Encrypt the buffer.
        p7 = s.encrypt(buf)
        
        out = BIO.MemoryBuffer()
        s.write(out, p7)
        headers, body = out.read().split('\n\n', 1)
        for line in headers.splitlines():
            key, value = line.split(': ')
            self.extra_headers[key] = value
        
        # Save the PRNG's state.
        Rand.save_file('randpool.dat')
        
        msg = SafeMessage()
        msg.set_payload(body)
        msg['Subject'] = self.subject
        msg['From'] = self.from_email
        msg['To'] = ', '.join(self.to)

        # Email header names are case-insensitive (RFC 2045), so we have to
        # accommodate that when doing comparisons.
        header_names = [key.lower() for key in self.extra_headers]
        if 'date' not in header_names:
            msg['Date'] = formatdate()
        if 'message-id' not in header_names:
            msg['Message-ID'] = make_msgid()
        for name, value in self.extra_headers.items():
            msg[name] = value
        return msg