import sys
import socket
import atexit
import struct
from cStringIO import StringIO
#from threading import local

from django.conf import settings
from django.core.exceptions import ObjectDoesNotExist
from django.db.models.base import Model

# known searchd commands
SEARCHD_COMMAND_SEARCH = 0
SEARCHD_COMMAND_EXCERPT = 1

# current client-side command implementation versions
VER_COMMAND_SEARCH = 0x104
VER_COMMAND_EXCERPT = 0x100

# known searchd status codes
SEARCHD_OK = 0
SEARCHD_ERROR = 1
SEARCHD_RETRY = 2

# known match modes
SPH_MATCH_ALL = 0
SPH_MATCH_ANY = 1
SPH_MATCH_PHRASE = 2
SPH_MATCH_BOOLEAN = 3
SPH_MATCH_EXTENDED = 4

# known sort modes
SPH_SORT_RELEVANCE = 0
SPH_SORT_ATTR_DESC = 1
SPH_SORT_ATTR_ASC = 2
SPH_SORT_TIME_SEGMENTS = 3
SPH_SORT_EXTENDED = 4

# known attribute types
SPH_ATTR_INTEGER = 1
SPH_ATTR_TIMESTAMP = 2

# known grouping functions
SPH_GROUPBY_DAY = 0
SPH_GROUPBY_WEEK = 1
SPH_GROUPBY_MONTH = 2
SPH_GROUPBY_YEAR = 3
SPH_GROUPBY_ATTR = 4


class SphinxError(Exception): pass
class SphinxConnectionError(SphinxError): pass
class SphinxTransportError(SphinxError): pass
class SphinxFormatError(SphinxError): pass
class SphinxInputError(SphinxError): pass
class SphinxTemporaryError(SphinxError): pass

class SphinxClient(object): #(local):
    """Client to send requests to Sphinx searchd and interpret results.
    
    The client can be used for multiple requests against the same host/port.
    
    """
    
    def __init__(self, host='localhost', port=3312, timeout=0):
        """Set timeout to 0 or None if you do not need it."""
        self.host = host
        self.port = port
        if timeout:
            socket.setdefaulttimeout(timeout)
        self._sock = None
        
    def query(self, query, index='*', offset=0, limit=20, **kw):
        """Get search results from searchd.
        
        Keyword arguments:
        query -- search terms as a single string
        index -- the index to query against, defaults to '*'
        offset -- fetch results starting from offset, defaults to 0
        limit -- number of results to return, defaults to 20
        mode -- search mode, defaults to SPH_MATCH_ALL
        sort -- sort mode, defaults to SPH_SORT_RELEVANCE
        sortby --
        weights --
        min_id --
        max_id --
        filter --
        min_ --
        max_ --
        groupby --
        groupfunc -- defaults to SPH_GROUPBY_DAY
        maxmatches -- defaults to 1000
        
        """
        mode = kw.get('mode', SPH_MATCH_EXTENDED)
        sort = kw.get('sort', SPH_SORT_RELEVANCE)
        sortby = kw.get('sortby', '')
        weights = kw.get('weights', [])
        min_id = kw.get('min_id', 0)
        max_id = kw.get('max_id', 0xFFFFFFFF)
        filter = kw.get('filter', {})
        min_ = kw.get('min_', {})
        max_ = kw.get('max_', {})
        groupby = kw.get('groupby', '')
        groupfunc = kw.get('groupfunc', SPH_GROUPBY_DAY)
        maxmatches = kw.get('maxmatches', 1000)
        # check args
        assert SPH_MATCH_ALL <= mode <= SPH_MATCH_EXTENDED
        assert SPH_SORT_RELEVANCE <= sort <= SPH_SORT_EXTENDED
        assert min_id <= max_id
        assert SPH_GROUPBY_DAY <= groupfunc <= SPH_GROUPBY_ATTR
        # build request
        buffer = StringIO()
        try:
            # offset, limit, mode, sort
            buffer.write(struct.pack('!LLLL', offset, limit, mode, sort))
            # sortby, query
            for v in [sortby, query]:
                buffer.write(struct.pack('!L', len(v)))
                buffer.write(v)
            # weights
            buffer.write(struct.pack('!L', len(weights)))
            for w in weights:
                buffer.write(struct.pack('!L', w))
            # index
            buffer.write(struct.pack('!L', len(index)))
            buffer.write(index)
            # id range
            for i in (min_id, max_id, len(min_) + len(filter)):
                buffer.write(struct.pack('!L', i))
            for k, v in min_.items():
                buffer.write(struct.pack('!L', len(k)))
                buffer.write(k)
                buffer.write(struct.pack('!LLL', 0, v, max_[k]))
            for k, values in filter.items():
                buffer.write(struct.pack('!L', len(k)))
                buffer.write(k)
                buffer.write(struct.pack('!L', len(values)))
                for v in values:
                    buffer.write(struct.pack('!L', v))
            # groupby
            buffer.write(struct.pack('!LL', groupfunc, len(groupby)))
            buffer.write(groupby)
            # maxmatches
            buffer.write(struct.pack('!L', maxmatches))
            data = buffer.getvalue()
            req = struct.pack('!HHL', SEARCHD_COMMAND_SEARCH, VER_COMMAND_SEARCH, len(data)) + data
        except (struct.error, TypeError), e:
            raise SphinxInputError, "Error generating request, %s" % e, sys.exc_info()[2]
        self._connect()
        self._write(req)
        # read back result
        result = dict()
        data = self._get_response(VER_COMMAND_SEARCH)
        data_len = len(data)
        pos = 0
        fields = list()
        try:
            num = struct.unpack('!L', data[pos:pos+4])[0]
            pos += 4
            while len(fields) < num and pos < data_len:
                l = struct.unpack('!L', data[pos:pos+4])[0]
                pos += 4
                fields.append(data[pos:pos+l])
                pos += l
            result['fields'] = fields
            attrs = list()
            num = struct.unpack('!L', data[pos:pos+4])[0]
            pos += 4
            while len(attrs) < num and pos < data_len:
                l = struct.unpack('!L', data[pos:pos+4])[0]
                pos += 4
                k = data[pos:pos+l]
                pos += l
                v = struct.unpack('!L', data[pos:pos+4])[0]
                pos += 4
                attrs.append((k, v))
            result['attrs'] = attrs
            matches = []
            num_matches = struct.unpack('!L', data[pos:pos+4])[0]
            pos += 4
            while len(matches) < num_matches and pos < len(data):
                doc = struct.unpack('!L', data[pos:pos+4])[0]
                weight = struct.unpack('!L', data[pos+4:pos+8])[0]
                pos += 8
                doc_attrs = dict()
                for attr, attr_value in attrs:
                    doc_attrs[attr] = struct.unpack('!L', data[pos:pos+4])[0]
                    pos += 4
                matches.append((doc, weight, doc_attrs))
            result['matches'] = matches
            result['total'], result['total_found'], result['time'], num_words = struct.unpack('!LLLL', data[pos:pos+16])
            pos += 16
            words = dict()
            for i in range(num_words):
                l = struct.unpack('!L', data[pos:pos+4])[0]
                pos += 4
                word = data[pos:pos+l]
                pos += l
                docs, hits = struct.unpack('!LL', data[pos:pos+8])
                pos += 8
                words[word] = dict(docs=docs, hits=hits)
        except (struct.error, TypeError), e:
            raise SphinxFormatError, "error unpacking result, %s" % e, sys.exc_info()[2]
        result['words'] = words
        self._disconnect()
        return result
        
    def build_excerpts(self, docs, index, words='', **kw):
        """Get excerpts from searchd for a list of documents.
        
        Keyword arguments:
        docs -- a list of document bodies
        index -- the index to use
        words -- keywords to highlight as a single string
        before_match -- prefix for keyword matches, defaults to '<strong>'
        after_match -- suffix for keyword matches, defaults to '</strong>'
        chunk_separator - defaults to ' ... '
        limit - defaults to 256
        around - defaults to 5
        
        """
        before_match = kw.get('before_match', '<strong>')
        after_match = kw.get('after_match', '</strong>')
        chunk_separator = kw.get('chunk_separator', ' ... ')
        limit = kw.get('limit', 256)
        around = kw.get('around', 5)
        # build request
        buffer = StringIO()
        try:
            buffer.write(struct.pack('!LL', 0, 1))
            for opt in (index, words, before_match, after_match, chunk_separator):
                buffer.write(struct.pack('!L', len(opt)))
                buffer.write(opt)
            buffer.write(struct.pack('!L', limit))
            buffer.write(struct.pack('!L', around))
            buffer.write(struct.pack('!L', len(docs)))
            for d in docs:
                buffer.write(struct.pack('!L', len(d)))
                buffer.write(d)
            data = buffer.getvalue()
            req = struct.pack('!HHL', SEARCHD_COMMAND_EXCERPT, VER_COMMAND_EXCERPT, len(data)) + data
        except (struct.error, TypeError), e:
            raise SphinxInputError, "Error generating request, %s" % e, sys.exc_info()[2]
        self._connect()
        self._write(req)
        # read back result
        result = list()
        data = self._get_response(VER_COMMAND_EXCERPT)
        data_len = len(data)
        pos = 0
        try:
            for d in docs:
                l = struct.unpack('!L', data[pos:pos+4])[0]
                pos += 4
                result.append(data[pos:pos+l])
                pos += l
        except (struct.error, TypeError), e:
            raise SphinxFormatError, "error unpacking result, %s" % e, sys.exc_info()[2]
        self._disconnect()
        return result
            
    def _get_response(self, client_version):
        # fetch the response from searchd and split it in header and data
        header = self._read(8)
        try:
            status, version, length = struct.unpack('!HHL', header)
        except (ValueError, struct.error), e:
            raise SphinxError, "error unpacking response header %s" % e
        data = self._read(length)
        if status == SEARCHD_ERROR:
            raise SphinxError, data[4:]
        if status == SEARCHD_RETRY:
            raise SphinxTemporaryError, data[4:]
        if status != SEARCHD_OK:
            raise SphinxError, "unkown status code %s" % status
        if version < client_version:
            # TODO: use logging
            print >>sys.stderr, "searchd command v.%d.%d older than client's v.%d.%d" % (
                version >> 8, version & 0xff, client_version >> 8, client_version & 0xff)
        return data
        
    def _read(self, length):
        # read from socket
        msg = list()
        received = 0
        while received < length:
            try:
                chunk = self._sock.recv(length - received)
            except socket.error, e:
                raise SphinxTransportError, "error while reading from socket, %s" % e
            if chunk == '':
                raise SphinxError, "socket connection broken, read %s bytes" % received
            msg.append(chunk)
            received += len(chunk)
        return ''.join(msg)
    
    def _write(self, buffer):
        # write to socket
        sent = 0
        while sent < len(buffer):
            try:
                s = self._sock.send(buffer[sent:])
            except socket.error, e:
                raise SphinxTransportError, "error while writing to socket, %s" % e
            if s == 0:
                raise SphinxError, "socket connection broken"
            sent += s
        return sent
        
    def _connect(self):
        # connect to searchd
        if not self._sock:
            try:
                s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
                s.connect((self.host, self.port))
                self._sock = s
                atexit.register(self._disconnect)
                data = struct.unpack('!L', self._read(4))
                if not len(data):
                    raise SphinxConnectionError, "no version received from server"
                if data[0] < 1:
                    raise SphinxConnectionError, "invalid protocol version %s received from server" % data[0]
                self._write(struct.pack('!L', 1))
            except socket.error, e:
                raise SphinxConnectionError, e
            except struct.error, e:
                raise SphinxConnectionError, e

    def _disconnect(self):
        # disconnect from searchd
        if self._sock:
            try:
                self._sock.close()
            except socket.error:
                pass
            self._sock = None

class SphinxSearch(object):
    
    def __init__(self, index=None, excerpts_field=None, **kw):
        # TODO: basic search parameters
        self._model = None
        self.index = kw.get('index')
        if excerpts_field:
            if isinstance(excerpts_field, tuple):
                try:
                    attr, field = excerpts_field
                except ValueError:
                    raise AssertionError("excerpts_fields has to be a string or a two-element tuple.")
            elif isinstance(excerpts_field, basestring):
                attr = excerpts_field
                field = None
            else:
                raise AssertionError("excerpts_fields has to be a string or a two-element tuple.")
            charset = getattr(settings, 'DEFAULT_CHARSET', 'utf-8')
            def func(o):
                text = getattr(o, attr)
                if field:
                    text = getattr(text, field)
                if isinstance(text, unicode):
                    text = text.encode(charset, 'replace')
                return text
            self._get_excerpt = func
        self._build_excerpts = True
        self._query = None
        self._query_cache = None
        self._offset = 0
        self._limit = 20
        self._client = SphinxClient(
            getattr(settings, 'SPHINX_SERVER', None),
            getattr(settings, 'SPHINX_PORT', None))
        self.before_match = getattr(settings, 'SPHINX_EXCERPT_BEFORE_MATCH', '<strong>')
        self.after_match = getattr(settings, 'SPHINX_EXCERPT_AFTER_MATCH', '</strong>')
        self._select_related = False
        self._select_related_args = dict()
        self._extra = dict()
        
    def __get__(self, instance, owner):
        if instance is not None or Model not in owner.__bases__:
            raise AttributeError, "Search manager is only accessible via a model class"
        self._model = owner
        if not self.index:
            self.index = self._model._meta.db_table
        return self

    def __len__(self):
        self._matches
        return self._result.get('total', 0)
    
    def __iter__(self):
        return iter(self._matches)
    
    def __getitem__(self, k):
        if not isinstance(k, (slice, int)):
            raise TypeError
        if isinstance(k, slice):
            start = k.start or 0
            stop = k.stop or 0
            if start < 0 or stop < 0:
                raise AssertionError("Negative indexing is not supported.")
            num = stop - start
            if start < self._offset or (start - self._offset) + num > self._limit:
                #print "clearing qcache", self._offset, self._limit, start, stop
                self._query_cache = None
                self._offset = start
                self._limit = stop - start
                return self.matches
            #print self._offset, start, stop, num
            #print "[%s:%s]" % (start-self._offset, num)
            return self._matches[start-self._offset:stop-self._offset]
        else:
            if k < 0:
                raise AssertionError("Negative indexing is not supported.")
            if k < self._offset or k > self._offset + self._limit:
                #print "clearing qcache", self._offset, self._limit, k
                self._query_cache = None
                self._offset = k
                self._limit = 1
                return self._matches[0]
            return self._matches[k - self._offset]

    @property
    def _matches(self):
        if self._query_cache:
            return self._query_cache
        self._result = self._client.query(self._query, index=self.index, offset=self._offset, limit=self._limit)
        qs = self._model.objects.filter(pk__in=[m[0] for m in self._result['matches']])
        if self._select_related:
            qs = qs.select_related(self._select_related_args)
        if self._extra:
            qs = qs.extra(**self._extra)
        qs = dict((o.id, o) for o in qs)
        matches = list()
        for id, weight, extra in self._result['matches']:
            match = qs.get(id)
            if not match:
                continue
            match.sphinx_weight = weight
            matches.append(match)
        # build excerpts
        if self._build_excerpts:
            excerpts = self._client.build_excerpts(
                [self._get_excerpt(o) for o in matches],
                self.index, self._query,
                before_match=self.before_match,
                after_match=self.after_match)
            if len(excerpts) == len(matches):
                for i, e in enumerate(excerpts):
                    matches[i].sphinx_excerpt = e
        self._query_cache = matches
        return matches
    
    def query(self, query, build_excerpts=True):
        if query == self._query:
            return self
        #print "new query", query, self._query
        self._query = query
        self._query_cache = None
        if self._get_excerpt and build_excerpts:
            self._build_excerpts = True
        self._offset = 0
        self._limit = 20
        return self
    
    def select_related(self, **kw):
        self._select_related = True
        self._select_related_args.update(**kw)
        return self
    
    def extra(self, **kw):
        self._extra.update(**kw)
        return self

    def count(self):
        self._matches
        return self._result.get('total', 0)
    
