from django.db.models import Expression

class Filter(Expression):
    template = '%(expression)s FILTER (WHERE %(condition)s)'

    def __init__(self, expression, condition, output_field=None):
        if not expression.contains_aggregate:
            raise TypeError('Expression must either be an aggregate function or contain an aggregate function')

        if not hasattr(condition, 'resolve_expression'):
            raise TypeError('Condition must be a class defining resolve_expression')

        super().__init__(output_field=output_field)
        self.source_expression = self._parse_expressions(expression)[0]
        self.condition = condition
        if not getattr(self.source_expression, 'contains_aggregate', False):
            raise FieldError('Window function expressions must be aggregate functions')

    def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
        c = self.copy()
        c.source_expression = self.source_expression.resolve_expression(query, allow_joins, reuse, summarize, for_save)
        c.condition = self.condition.resolve_expression(query, allow_joins, reuse, summarize, for_save)
        return c

    def _resolve_output_field(self):
        if self._output_field is None:
            self._output_field = self.source_expression.output_field

    def copy(self):
        clone = super().copy()
        clone.source_expression = self.source_expression.copy()
        clone.condition = copy.copy(self.condition)
        return clone

    def as_sql(self, compiler, connection):
        connection.ops.check_expression_support(self)
        params = []
        condition_sql, condition_params = compiler.compile(self.condition)
        params.extend(condition_params)
        expr_sql, expr_params = compiler.compile(self.source_expression)
        condition_params.extend(expr_params)

        return self.template % {
            'expression': expr_sql,
            'condition': condition_sql,
        }, params

    def get_source_expressions(self):
        return self.source_expression, self.condition

    def set_source_expressions(self, exprs):
        self.source_expression, self.condition = exprs[0], exprs[1]

    def get_group_by_cols(self):
        return []

    def __str__(self):
        return self.template % {
            'expression': str(self.source_expression),
            'condition': str(self.condition),
        }

    def __repr__(self):
        return '<%s: %s>' % (self.__class__.__name__, self)