from __future__ import print_function, division
from functools import wraps
from sympy.core import Add, Mul, Pow, S, sympify, Float
from sympy.core.basic import Basic
from sympy.core.compatibility import default_sort_key, string_types
from sympy.core.function import Lambda
from sympy.core.mul import _keep_coeff
from sympy.core.symbol import Symbol
from sympy.printing.str import StrPrinter
from sympy.printing.precedence import precedence
# Backwards compatibility
from sympy.codegen.ast import Assignment
class requires(object):
    """ Decorator for registering requirements on print methods. """
    def __init__(self, **kwargs):
        self._req = kwargs
    def __call__(self, method):
        def _method_wrapper(self_, *args, **kwargs):
            for k, v in self._req.items():
                getattr(self_, k).update(v)
            return method(self_, *args, **kwargs)
        return wraps(method)(_method_wrapper)
[docs]class AssignmentError(Exception):
    """
    Raised if an assignment variable for a loop is missing.
    """
    pass 
[docs]class CodePrinter(StrPrinter):
    """
    The base class for code-printing subclasses.
    """
    _operators = {
        'and': '&&',
        'or': '||',
        'not': '!',
    }
    _default_settings = {
        'order': None,
        'full_prec': 'auto',
        'error_on_reserved': False,
        'reserved_word_suffix': '_',
        'human': True,
        'inline': False,
        'allow_unknown_functions': False,
    }
    def __init__(self, settings=None):
        super(CodePrinter, self).__init__(settings=settings)
        if not hasattr(self, 'reserved_words'):
            self.reserved_words = set()
    def doprint(self, expr, assign_to=None):
        """
        Print the expression as code.
        Parameters
        ----------
        expr : Expression
            The expression to be printed.
        assign_to : Symbol, MatrixSymbol, or string (optional)
            If provided, the printed code will set the expression to a
            variable with name ``assign_to``.
        """
        from sympy.matrices.expressions.matexpr import MatrixSymbol
        if isinstance(assign_to, string_types):
            if expr.is_Matrix:
                assign_to = MatrixSymbol(assign_to, *expr.shape)
            else:
                assign_to = Symbol(assign_to)
        elif not isinstance(assign_to, (Basic, type(None))):
            raise TypeError("{0} cannot assign to object of type {1}".format(
                    type(self).__name__, type(assign_to)))
        if assign_to:
            expr = Assignment(assign_to, expr)
        else:
            # _sympify is not enough b/c it errors on iterables
            expr = sympify(expr)
        # keep a set of expressions that are not strictly translatable to Code
        # and number constants that must be declared and initialized
        self._not_supported = set()
        self._number_symbols = set()
        lines = self._print(expr).splitlines()
        # format the output
        if self._settings["human"]:
            frontlines = []
            if self._not_supported:
                frontlines.append(self._get_comment(
                        "Not supported in {0}:".format(self.language)))
                for expr in sorted(self._not_supported, key=str):
                    frontlines.append(self._get_comment(type(expr).__name__))
            for name, value in sorted(self._number_symbols, key=str):
                frontlines.append(self._declare_number_const(name, value))
            lines = frontlines + lines
            lines = self._format_code(lines)
            result = "\n".join(lines)
        else:
            lines = self._format_code(lines)
            num_syms = set([(k, self._print(v)) for k, v in self._number_symbols])
            result = (num_syms, self._not_supported, "\n".join(lines))
        self._not_supported = set()
        self._number_symbols = set()
        return result
    def _doprint_loops(self, expr, assign_to=None):
        # Here we print an expression that contains Indexed objects, they
        # correspond to arrays in the generated code.  The low-level implementation
        # involves looping over array elements and possibly storing results in temporary
        # variables or accumulate it in the assign_to object.
        if self._settings.get('contract', True):
            from sympy.tensor import get_contraction_structure
            # Setup loops over non-dummy indices  --  all terms need these
            indices = self._get_expression_indices(expr, assign_to)
            # Setup loops over dummy indices  --  each term needs separate treatment
            dummies = get_contraction_structure(expr)
        else:
            indices = []
            dummies = {None: (expr,)}
        openloop, closeloop = self._get_loop_opening_ending(indices)
        # terms with no summations first
        if None in dummies:
            text = StrPrinter.doprint(self, Add(*dummies[None]))
        else:
            # If all terms have summations we must initialize array to Zero
            text = StrPrinter.doprint(self, 0)
        # skip redundant assignments (where lhs == rhs)
        lhs_printed = self._print(assign_to)
        lines = []
        if text != lhs_printed:
            lines.extend(openloop)
            if assign_to is not None:
                text = self._get_statement("%s = %s" % (lhs_printed, text))
            lines.append(text)
            lines.extend(closeloop)
        # then terms with summations
        for d in dummies:
            if isinstance(d, tuple):
                indices = self._sort_optimized(d, expr)
                openloop_d, closeloop_d = self._get_loop_opening_ending(
                    indices)
                for term in dummies[d]:
                    if term in dummies and not ([list(f.keys()) for f in dummies[term]]
                            == [[None] for f in dummies[term]]):
                        # If one factor in the term has it's own internal
                        # contractions, those must be computed first.
                        # (temporary variables?)
                        raise NotImplementedError(
                            "FIXME: no support for contractions in factor yet")
                    else:
                        # We need the lhs expression as an accumulator for
                        # the loops, i.e
                        #
                        # for (int d=0; d < dim; d++){
                        #    lhs[] = lhs[] + term[][d]
                        # }           ^.................. the accumulator
                        #
                        # We check if the expression already contains the
                        # lhs, and raise an exception if it does, as that
                        # syntax is currently undefined.  FIXME: What would be
                        # a good interpretation?
                        if assign_to is None:
                            raise AssignmentError(
                                "need assignment variable for loops")
                        if term.has(assign_to):
                            raise ValueError("FIXME: lhs present in rhs,\
                                this is undefined in CodePrinter")
                        lines.extend(openloop)
                        lines.extend(openloop_d)
                        text = "%s = %s" % (lhs_printed, StrPrinter.doprint(
                            self, assign_to + term))
                        lines.append(self._get_statement(text))
                        lines.extend(closeloop_d)
                        lines.extend(closeloop)
        return "\n".join(lines)
    def _get_expression_indices(self, expr, assign_to):
        from sympy.tensor import get_indices
        rinds, junk = get_indices(expr)
        linds, junk = get_indices(assign_to)
        # support broadcast of scalar
        if linds and not rinds:
            rinds = linds
        if rinds != linds:
            raise ValueError("lhs indices must match non-dummy"
                    " rhs indices in %s" % expr)
        return self._sort_optimized(rinds, assign_to)
    def _sort_optimized(self, indices, expr):
        from sympy.tensor.indexed import Indexed
        if not indices:
            return []
        # determine optimized loop order by giving a score to each index
        # the index with the highest score are put in the innermost loop.
        score_table = {}
        for i in indices:
            score_table[i] = 0
        arrays = expr.atoms(Indexed)
        for arr in arrays:
            for p, ind in enumerate(arr.indices):
                try:
                    score_table[ind] += self._rate_index_position(p)
                except KeyError:
                    pass
        return sorted(indices, key=lambda x: score_table[x])
    def _rate_index_position(self, p):
        """function to calculate score based on position among indices
        This method is used to sort loops in an optimized order, see
        CodePrinter._sort_optimized()
        """
        raise NotImplementedError("This function must be implemented by "
                                  "subclass of CodePrinter.")
    def _get_statement(self, codestring):
        """Formats a codestring with the proper line ending."""
        raise NotImplementedError("This function must be implemented by "
                                  "subclass of CodePrinter.")
    def _get_comment(self, text):
        """Formats a text string as a comment."""
        raise NotImplementedError("This function must be implemented by "
                                  "subclass of CodePrinter.")
    def _declare_number_const(self, name, value):
        """Declare a numeric constant at the top of a function"""
        raise NotImplementedError("This function must be implemented by "
                                  "subclass of CodePrinter.")
    def _format_code(self, lines):
        """Take in a list of lines of code, and format them accordingly.
        This may include indenting, wrapping long lines, etc..."""
        raise NotImplementedError("This function must be implemented by "
                                  "subclass of CodePrinter.")
    def _get_loop_opening_ending(self, indices):
        """Returns a tuple (open_lines, close_lines) containing lists
        of codelines"""
        raise NotImplementedError("This function must be implemented by "
                                  "subclass of CodePrinter.")
    def _print_Dummy(self, expr):
        if expr.name.startswith('Dummy_'):
            return '_' + expr.name
        else:
            return '%s_%d' % (expr.name, expr.dummy_index)
    def _print_CodeBlock(self, expr):
        return '\n'.join([self._print(i) for i in expr.args])
    def _print_String(self, string):
        return str(string)
    def _print_QuotedString(self, arg):
        return '"%s"' % arg.text
    def _print_Comment(self, string):
        return self._get_comment(str(string))
    def _print_Assignment(self, expr):
        from sympy.functions.elementary.piecewise import Piecewise
        from sympy.matrices.expressions.matexpr import MatrixSymbol
        from sympy.tensor.indexed import IndexedBase
        lhs = expr.lhs
        rhs = expr.rhs
        # We special case assignments that take multiple lines
        if isinstance(expr.rhs, Piecewise):
            # Here we modify Piecewise so each expression is now
            # an Assignment, and then continue on the print.
            expressions = []
            conditions = []
            for (e, c) in rhs.args:
                expressions.append(Assignment(lhs, e))
                conditions.append(c)
            temp = Piecewise(*zip(expressions, conditions))
            return self._print(temp)
        elif isinstance(lhs, MatrixSymbol):
            # Here we form an Assignment for each element in the array,
            # printing each one.
            lines = []
            for (i, j) in self._traverse_matrix_indices(lhs):
                temp = Assignment(lhs[i, j], rhs[i, j])
                code0 = self._print(temp)
                lines.append(code0)
            return "\n".join(lines)
        elif self._settings.get("contract", False) and (lhs.has(IndexedBase) or
                rhs.has(IndexedBase)):
            # Here we check if there is looping to be done, and if so
            # print the required loops.
            return self._doprint_loops(rhs, lhs)
        else:
            lhs_code = self._print(lhs)
            rhs_code = self._print(rhs)
            return self._get_statement("%s = %s" % (lhs_code, rhs_code))
    def _print_AugmentedAssignment(self, expr):
        lhs_code = self._print(expr.lhs)
        rhs_code = self._print(expr.rhs)
        return self._get_statement("{0} {1} {2}".format(
            *map(lambda arg: self._print(arg),
                 [lhs_code, expr.op, rhs_code])))
    def _print_FunctionCall(self, expr):
        return '%s(%s)' % (
            expr.name,
            ', '.join(map(lambda arg: self._print(arg),
                          expr.function_args)))
    def _print_Variable(self, expr):
        return self._print(expr.symbol)
    def _print_Statement(self, expr):
        arg, = expr.args
        return self._get_statement(self._print(arg))
    def _print_Symbol(self, expr):
        name = super(CodePrinter, self)._print_Symbol(expr)
        if name in self.reserved_words:
            if self._settings['error_on_reserved']:
                msg = ('This expression includes the symbol "{}" which is a '
                       'reserved keyword in this language.')
                raise ValueError(msg.format(name))
            return name + self._settings['reserved_word_suffix']
        else:
            return name
    def _print_Function(self, expr):
        if expr.func.__name__ in self.known_functions:
            cond_func = self.known_functions[expr.func.__name__]
            func = None
            if isinstance(cond_func, string_types):
                func = cond_func
            else:
                for cond, func in cond_func:
                    if cond(*expr.args):
                        break
            if func is not None:
                try:
                    return func(*[self.parenthesize(item, 0) for item in expr.args])
                except TypeError:
                    return "%s(%s)" % (func, self.stringify(expr.args, ", "))
        elif hasattr(expr, '_imp_') and isinstance(expr._imp_, Lambda):
            # inlined function
            return self._print(expr._imp_(*expr.args))
        elif expr.is_Function and self._settings.get('allow_unknown_functions', False):
            return '%s(%s)' % (self._print(expr.func), ', '.join(map(self._print, expr.args)))
        else:
            return self._print_not_supported(expr)
    _print_Expr = _print_Function
    def _print_NumberSymbol(self, expr):
        if self._settings.get("inline", False):
            return self._print(Float(expr.evalf(self._settings["precision"])))
        else:
            # A Number symbol that is not implemented here or with _printmethod
            # is registered and evaluated
            self._number_symbols.add((expr,
                Float(expr.evalf(self._settings["precision"]))))
            return str(expr)
    def _print_Catalan(self, expr):
        return self._print_NumberSymbol(expr)
    def _print_EulerGamma(self, expr):
        return self._print_NumberSymbol(expr)
    def _print_GoldenRatio(self, expr):
        return self._print_NumberSymbol(expr)
    def _print_TribonacciConstant(self, expr):
        return self._print_NumberSymbol(expr)
    def _print_Exp1(self, expr):
        return self._print_NumberSymbol(expr)
    def _print_Pi(self, expr):
        return self._print_NumberSymbol(expr)
    def _print_And(self, expr):
        PREC = precedence(expr)
        return (" %s " % self._operators['and']).join(self.parenthesize(a, PREC)
                for a in sorted(expr.args, key=default_sort_key))
    def _print_Or(self, expr):
        PREC = precedence(expr)
        return (" %s " % self._operators['or']).join(self.parenthesize(a, PREC)
                for a in sorted(expr.args, key=default_sort_key))
    def _print_Xor(self, expr):
        if self._operators.get('xor') is None:
            return self._print_not_supported(expr)
        PREC = precedence(expr)
        return (" %s " % self._operators['xor']).join(self.parenthesize(a, PREC)
                for a in expr.args)
    def _print_Equivalent(self, expr):
        if self._operators.get('equivalent') is None:
            return self._print_not_supported(expr)
        PREC = precedence(expr)
        return (" %s " % self._operators['equivalent']).join(self.parenthesize(a, PREC)
                for a in expr.args)
    def _print_Not(self, expr):
        PREC = precedence(expr)
        return self._operators['not'] + self.parenthesize(expr.args[0], PREC)
    def _print_Mul(self, expr):
        prec = precedence(expr)
        c, e = expr.as_coeff_Mul()
        if c < 0:
            expr = _keep_coeff(-c, e)
            sign = "-"
        else:
            sign = ""
        a = []  # items in the numerator
        b = []  # items that are in the denominator (if any)
        pow_paren = []  # Will collect all pow with more than one base element and exp = -1
        if self.order not in ('old', 'none'):
            args = expr.as_ordered_factors()
        else:
            # use make_args in case expr was something like -x -> x
            args = Mul.make_args(expr)
        # Gather args for numerator/denominator
        for item in args:
            if item.is_commutative and item.is_Pow and item.exp.is_Rational and item.exp.is_negative:
                if item.exp != -1:
                    b.append(Pow(item.base, -item.exp, evaluate=False))
                else:
                    if len(item.args[0].args) != 1 and isinstance(item.base, Mul):   # To avoid situations like #14160
                        pow_paren.append(item)
                    b.append(Pow(item.base, -item.exp))
            else:
                a.append(item)
        a = a or [S.One]
        a_str = [self.parenthesize(x, prec) for x in a]
        b_str = [self.parenthesize(x, prec) for x in b]
        # To parenthesize Pow with exp = -1 and having more than one Symbol
        for item in pow_paren:
            if item.base in b:
                b_str[b.index(item.base)] = "(%s)" % b_str[b.index(item.base)]
        if not b:
            return sign + '*'.join(a_str)
        elif len(b) == 1:
            return sign + '*'.join(a_str) + "/" + b_str[0]
        else:
            return sign + '*'.join(a_str) + "/(%s)" % '*'.join(b_str)
    def _print_not_supported(self, expr):
        self._not_supported.add(expr)
        return self.emptyPrinter(expr)
    # The following can not be simply translated into C or Fortran
    _print_Basic = _print_not_supported
    _print_ComplexInfinity = _print_not_supported
    _print_Derivative = _print_not_supported
    _print_ExprCondPair = _print_not_supported
    _print_GeometryEntity = _print_not_supported
    _print_Infinity = _print_not_supported
    _print_Integral = _print_not_supported
    _print_Interval = _print_not_supported
    _print_AccumulationBounds = _print_not_supported
    _print_Limit = _print_not_supported
    _print_Matrix = _print_not_supported
    _print_ImmutableMatrix = _print_not_supported
    _print_ImmutableDenseMatrix = _print_not_supported
    _print_MutableDenseMatrix = _print_not_supported
    _print_MatrixBase = _print_not_supported
    _print_DeferredVector = _print_not_supported
    _print_NaN = _print_not_supported
    _print_NegativeInfinity = _print_not_supported
    _print_Normal = _print_not_supported
    _print_Order = _print_not_supported
    _print_PDF = _print_not_supported
    _print_RootOf = _print_not_supported
    _print_RootsOf = _print_not_supported
    _print_RootSum = _print_not_supported
    _print_Sample = _print_not_supported
    _print_SparseMatrix = _print_not_supported
    _print_MutableSparseMatrix = _print_not_supported
    _print_ImmutableSparseMatrix = _print_not_supported
    _print_Uniform = _print_not_supported
    _print_Unit = _print_not_supported
    _print_Wild = _print_not_supported
    _print_WildFunction = _print_not_supported