Source code for sympy.printing.codeprinter

from __future__ import print_function, division

from functools import wraps

from sympy.core import Add, Mul, Pow, S, sympify, Float
from sympy.core.basic import Basic
from sympy.core.compatibility import default_sort_key, string_types
from sympy.core.function import Lambda
from sympy.core.mul import _keep_coeff
from sympy.core.symbol import Symbol
from sympy.printing.str import StrPrinter
from sympy.printing.precedence import precedence

# Backwards compatibility
from sympy.codegen.ast import Assignment


class requires(object):
    """ Decorator for registering requirements on print methods. """
    def __init__(self, **kwargs):
        self._req = kwargs

    def __call__(self, method):
        def _method_wrapper(self_, *args, **kwargs):
            for k, v in self._req.items():
                getattr(self_, k).update(v)
            return method(self_, *args, **kwargs)
        return wraps(method)(_method_wrapper)


[docs]class AssignmentError(Exception): """ Raised if an assignment variable for a loop is missing. """ pass
[docs]class CodePrinter(StrPrinter): """ The base class for code-printing subclasses. """ _operators = { 'and': '&&', 'or': '||', 'not': '!', } _default_settings = { 'order': None, 'full_prec': 'auto', 'error_on_reserved': False, 'reserved_word_suffix': '_', 'human': True, 'inline': False, 'allow_unknown_functions': False, } def __init__(self, settings=None): super(CodePrinter, self).__init__(settings=settings) if not hasattr(self, 'reserved_words'): self.reserved_words = set() def doprint(self, expr, assign_to=None): """ Print the expression as code. Parameters ---------- expr : Expression The expression to be printed. assign_to : Symbol, MatrixSymbol, or string (optional) If provided, the printed code will set the expression to a variable with name ``assign_to``. """ from sympy.matrices.expressions.matexpr import MatrixSymbol if isinstance(assign_to, string_types): if expr.is_Matrix: assign_to = MatrixSymbol(assign_to, *expr.shape) else: assign_to = Symbol(assign_to) elif not isinstance(assign_to, (Basic, type(None))): raise TypeError("{0} cannot assign to object of type {1}".format( type(self).__name__, type(assign_to))) if assign_to: expr = Assignment(assign_to, expr) else: # _sympify is not enough b/c it errors on iterables expr = sympify(expr) # keep a set of expressions that are not strictly translatable to Code # and number constants that must be declared and initialized self._not_supported = set() self._number_symbols = set() lines = self._print(expr).splitlines() # format the output if self._settings["human"]: frontlines = [] if self._not_supported: frontlines.append(self._get_comment( "Not supported in {0}:".format(self.language))) for expr in sorted(self._not_supported, key=str): frontlines.append(self._get_comment(type(expr).__name__)) for name, value in sorted(self._number_symbols, key=str): frontlines.append(self._declare_number_const(name, value)) lines = frontlines + lines lines = self._format_code(lines) result = "\n".join(lines) else: lines = self._format_code(lines) num_syms = set([(k, self._print(v)) for k, v in self._number_symbols]) result = (num_syms, self._not_supported, "\n".join(lines)) self._not_supported = set() self._number_symbols = set() return result def _doprint_loops(self, expr, assign_to=None): # Here we print an expression that contains Indexed objects, they # correspond to arrays in the generated code. The low-level implementation # involves looping over array elements and possibly storing results in temporary # variables or accumulate it in the assign_to object. if self._settings.get('contract', True): from sympy.tensor import get_contraction_structure # Setup loops over non-dummy indices -- all terms need these indices = self._get_expression_indices(expr, assign_to) # Setup loops over dummy indices -- each term needs separate treatment dummies = get_contraction_structure(expr) else: indices = [] dummies = {None: (expr,)} openloop, closeloop = self._get_loop_opening_ending(indices) # terms with no summations first if None in dummies: text = StrPrinter.doprint(self, Add(*dummies[None])) else: # If all terms have summations we must initialize array to Zero text = StrPrinter.doprint(self, 0) # skip redundant assignments (where lhs == rhs) lhs_printed = self._print(assign_to) lines = [] if text != lhs_printed: lines.extend(openloop) if assign_to is not None: text = self._get_statement("%s = %s" % (lhs_printed, text)) lines.append(text) lines.extend(closeloop) # then terms with summations for d in dummies: if isinstance(d, tuple): indices = self._sort_optimized(d, expr) openloop_d, closeloop_d = self._get_loop_opening_ending( indices) for term in dummies[d]: if term in dummies and not ([list(f.keys()) for f in dummies[term]] == [[None] for f in dummies[term]]): # If one factor in the term has it's own internal # contractions, those must be computed first. # (temporary variables?) raise NotImplementedError( "FIXME: no support for contractions in factor yet") else: # We need the lhs expression as an accumulator for # the loops, i.e # # for (int d=0; d < dim; d++){ # lhs[] = lhs[] + term[][d] # } ^.................. the accumulator # # We check if the expression already contains the # lhs, and raise an exception if it does, as that # syntax is currently undefined. FIXME: What would be # a good interpretation? if assign_to is None: raise AssignmentError( "need assignment variable for loops") if term.has(assign_to): raise ValueError("FIXME: lhs present in rhs,\ this is undefined in CodePrinter") lines.extend(openloop) lines.extend(openloop_d) text = "%s = %s" % (lhs_printed, StrPrinter.doprint( self, assign_to + term)) lines.append(self._get_statement(text)) lines.extend(closeloop_d) lines.extend(closeloop) return "\n".join(lines) def _get_expression_indices(self, expr, assign_to): from sympy.tensor import get_indices rinds, junk = get_indices(expr) linds, junk = get_indices(assign_to) # support broadcast of scalar if linds and not rinds: rinds = linds if rinds != linds: raise ValueError("lhs indices must match non-dummy" " rhs indices in %s" % expr) return self._sort_optimized(rinds, assign_to) def _sort_optimized(self, indices, expr): from sympy.tensor.indexed import Indexed if not indices: return [] # determine optimized loop order by giving a score to each index # the index with the highest score are put in the innermost loop. score_table = {} for i in indices: score_table[i] = 0 arrays = expr.atoms(Indexed) for arr in arrays: for p, ind in enumerate(arr.indices): try: score_table[ind] += self._rate_index_position(p) except KeyError: pass return sorted(indices, key=lambda x: score_table[x]) def _rate_index_position(self, p): """function to calculate score based on position among indices This method is used to sort loops in an optimized order, see CodePrinter._sort_optimized() """ raise NotImplementedError("This function must be implemented by " "subclass of CodePrinter.") def _get_statement(self, codestring): """Formats a codestring with the proper line ending.""" raise NotImplementedError("This function must be implemented by " "subclass of CodePrinter.") def _get_comment(self, text): """Formats a text string as a comment.""" raise NotImplementedError("This function must be implemented by " "subclass of CodePrinter.") def _declare_number_const(self, name, value): """Declare a numeric constant at the top of a function""" raise NotImplementedError("This function must be implemented by " "subclass of CodePrinter.") def _format_code(self, lines): """Take in a list of lines of code, and format them accordingly. This may include indenting, wrapping long lines, etc...""" raise NotImplementedError("This function must be implemented by " "subclass of CodePrinter.") def _get_loop_opening_ending(self, indices): """Returns a tuple (open_lines, close_lines) containing lists of codelines""" raise NotImplementedError("This function must be implemented by " "subclass of CodePrinter.") def _print_Dummy(self, expr): if expr.name.startswith('Dummy_'): return '_' + expr.name else: return '%s_%d' % (expr.name, expr.dummy_index) def _print_CodeBlock(self, expr): return '\n'.join([self._print(i) for i in expr.args]) def _print_String(self, string): return str(string) def _print_QuotedString(self, arg): return '"%s"' % arg.text def _print_Comment(self, string): return self._get_comment(str(string)) def _print_Assignment(self, expr): from sympy.functions.elementary.piecewise import Piecewise from sympy.matrices.expressions.matexpr import MatrixSymbol from sympy.tensor.indexed import IndexedBase lhs = expr.lhs rhs = expr.rhs # We special case assignments that take multiple lines if isinstance(expr.rhs, Piecewise): # Here we modify Piecewise so each expression is now # an Assignment, and then continue on the print. expressions = [] conditions = [] for (e, c) in rhs.args: expressions.append(Assignment(lhs, e)) conditions.append(c) temp = Piecewise(*zip(expressions, conditions)) return self._print(temp) elif isinstance(lhs, MatrixSymbol): # Here we form an Assignment for each element in the array, # printing each one. lines = [] for (i, j) in self._traverse_matrix_indices(lhs): temp = Assignment(lhs[i, j], rhs[i, j]) code0 = self._print(temp) lines.append(code0) return "\n".join(lines) elif self._settings.get("contract", False) and (lhs.has(IndexedBase) or rhs.has(IndexedBase)): # Here we check if there is looping to be done, and if so # print the required loops. return self._doprint_loops(rhs, lhs) else: lhs_code = self._print(lhs) rhs_code = self._print(rhs) return self._get_statement("%s = %s" % (lhs_code, rhs_code)) def _print_AugmentedAssignment(self, expr): lhs_code = self._print(expr.lhs) rhs_code = self._print(expr.rhs) return self._get_statement("{0} {1} {2}".format( *map(lambda arg: self._print(arg), [lhs_code, expr.op, rhs_code]))) def _print_FunctionCall(self, expr): return '%s(%s)' % ( expr.name, ', '.join(map(lambda arg: self._print(arg), expr.function_args))) def _print_Variable(self, expr): return self._print(expr.symbol) def _print_Statement(self, expr): arg, = expr.args return self._get_statement(self._print(arg)) def _print_Symbol(self, expr): name = super(CodePrinter, self)._print_Symbol(expr) if name in self.reserved_words: if self._settings['error_on_reserved']: msg = ('This expression includes the symbol "{}" which is a ' 'reserved keyword in this language.') raise ValueError(msg.format(name)) return name + self._settings['reserved_word_suffix'] else: return name def _print_Function(self, expr): if expr.func.__name__ in self.known_functions: cond_func = self.known_functions[expr.func.__name__] func = None if isinstance(cond_func, string_types): func = cond_func else: for cond, func in cond_func: if cond(*expr.args): break if func is not None: try: return func(*[self.parenthesize(item, 0) for item in expr.args]) except TypeError: return "%s(%s)" % (func, self.stringify(expr.args, ", ")) elif hasattr(expr, '_imp_') and isinstance(expr._imp_, Lambda): # inlined function return self._print(expr._imp_(*expr.args)) elif expr.is_Function and self._settings.get('allow_unknown_functions', False): return '%s(%s)' % (self._print(expr.func), ', '.join(map(self._print, expr.args))) else: return self._print_not_supported(expr) _print_Expr = _print_Function def _print_NumberSymbol(self, expr): if self._settings.get("inline", False): return self._print(Float(expr.evalf(self._settings["precision"]))) else: # A Number symbol that is not implemented here or with _printmethod # is registered and evaluated self._number_symbols.add((expr, Float(expr.evalf(self._settings["precision"])))) return str(expr) def _print_Catalan(self, expr): return self._print_NumberSymbol(expr) def _print_EulerGamma(self, expr): return self._print_NumberSymbol(expr) def _print_GoldenRatio(self, expr): return self._print_NumberSymbol(expr) def _print_TribonacciConstant(self, expr): return self._print_NumberSymbol(expr) def _print_Exp1(self, expr): return self._print_NumberSymbol(expr) def _print_Pi(self, expr): return self._print_NumberSymbol(expr) def _print_And(self, expr): PREC = precedence(expr) return (" %s " % self._operators['and']).join(self.parenthesize(a, PREC) for a in sorted(expr.args, key=default_sort_key)) def _print_Or(self, expr): PREC = precedence(expr) return (" %s " % self._operators['or']).join(self.parenthesize(a, PREC) for a in sorted(expr.args, key=default_sort_key)) def _print_Xor(self, expr): if self._operators.get('xor') is None: return self._print_not_supported(expr) PREC = precedence(expr) return (" %s " % self._operators['xor']).join(self.parenthesize(a, PREC) for a in expr.args) def _print_Equivalent(self, expr): if self._operators.get('equivalent') is None: return self._print_not_supported(expr) PREC = precedence(expr) return (" %s " % self._operators['equivalent']).join(self.parenthesize(a, PREC) for a in expr.args) def _print_Not(self, expr): PREC = precedence(expr) return self._operators['not'] + self.parenthesize(expr.args[0], PREC) def _print_Mul(self, expr): prec = precedence(expr) c, e = expr.as_coeff_Mul() if c < 0: expr = _keep_coeff(-c, e) sign = "-" else: sign = "" a = [] # items in the numerator b = [] # items that are in the denominator (if any) pow_paren = [] # Will collect all pow with more than one base element and exp = -1 if self.order not in ('old', 'none'): args = expr.as_ordered_factors() else: # use make_args in case expr was something like -x -> x args = Mul.make_args(expr) # Gather args for numerator/denominator for item in args: if item.is_commutative and item.is_Pow and item.exp.is_Rational and item.exp.is_negative: if item.exp != -1: b.append(Pow(item.base, -item.exp, evaluate=False)) else: if len(item.args[0].args) != 1 and isinstance(item.base, Mul): # To avoid situations like #14160 pow_paren.append(item) b.append(Pow(item.base, -item.exp)) else: a.append(item) a = a or [S.One] a_str = [self.parenthesize(x, prec) for x in a] b_str = [self.parenthesize(x, prec) for x in b] # To parenthesize Pow with exp = -1 and having more than one Symbol for item in pow_paren: if item.base in b: b_str[b.index(item.base)] = "(%s)" % b_str[b.index(item.base)] if not b: return sign + '*'.join(a_str) elif len(b) == 1: return sign + '*'.join(a_str) + "/" + b_str[0] else: return sign + '*'.join(a_str) + "/(%s)" % '*'.join(b_str) def _print_not_supported(self, expr): self._not_supported.add(expr) return self.emptyPrinter(expr) # The following can not be simply translated into C or Fortran _print_Basic = _print_not_supported _print_ComplexInfinity = _print_not_supported _print_Derivative = _print_not_supported _print_ExprCondPair = _print_not_supported _print_GeometryEntity = _print_not_supported _print_Infinity = _print_not_supported _print_Integral = _print_not_supported _print_Interval = _print_not_supported _print_AccumulationBounds = _print_not_supported _print_Limit = _print_not_supported _print_Matrix = _print_not_supported _print_ImmutableMatrix = _print_not_supported _print_ImmutableDenseMatrix = _print_not_supported _print_MutableDenseMatrix = _print_not_supported _print_MatrixBase = _print_not_supported _print_DeferredVector = _print_not_supported _print_NaN = _print_not_supported _print_NegativeInfinity = _print_not_supported _print_Normal = _print_not_supported _print_Order = _print_not_supported _print_PDF = _print_not_supported _print_RootOf = _print_not_supported _print_RootsOf = _print_not_supported _print_RootSum = _print_not_supported _print_Sample = _print_not_supported _print_SparseMatrix = _print_not_supported _print_MutableSparseMatrix = _print_not_supported _print_ImmutableSparseMatrix = _print_not_supported _print_Uniform = _print_not_supported _print_Unit = _print_not_supported _print_Wild = _print_not_supported _print_WildFunction = _print_not_supported