Source code for sympy.printing.pycode

"""
Python code printers

This module contains python code printers for plain python as well as NumPy & SciPy enabled code.
"""


from collections import defaultdict
from itertools import chain
from sympy.core import S, Number, Symbol, Mul, Add
from .precedence import precedence
from .codeprinter import CodePrinter

_kw_py2and3 = {
    'and', 'as', 'assert', 'break', 'class', 'continue', 'def', 'del', 'elif',
    'else', 'except', 'finally', 'for', 'from', 'global', 'if', 'import', 'in',
    'is', 'lambda', 'not', 'or', 'pass', 'raise', 'return', 'try', 'while',
    'with', 'yield', 'None'  # 'None' is actually not in Python 2's keyword.kwlist
}
_kw_only_py2 = {'exec', 'print'}
_kw_only_py3 = {'False', 'nonlocal', 'True'}

_known_functions = {
    'Abs': 'abs',
}
_known_functions_math = {
    'acos': 'acos',
    'acosh': 'acosh',
    'asin': 'asin',
    'asinh': 'asinh',
    'atan': 'atan',
    'atan2': 'atan2',
    'atanh': 'atanh',
    'ceiling': 'ceil',
    'cos': 'cos',
    'cosh': 'cosh',
    'erf': 'erf',
    'erfc': 'erfc',
    'exp': 'exp',
    'expm1': 'expm1',
    'factorial': 'factorial',
    'floor': 'floor',
    'gamma': 'gamma',
    'hypot': 'hypot',
    'loggamma': 'lgamma',
    'log': 'log',
    'ln': 'log',
    'log10': 'log10',
    'log1p': 'log1p',
    'log2': 'log2',
    'sin': 'sin',
    'sinh': 'sinh',
    'Sqrt': 'sqrt',
    'tan': 'tan',
    'tanh': 'tanh'
}  # Not used from ``math``: [copysign isclose isfinite isinf isnan ldexp frexp pow modf
# radians trunc fmod fsum gcd degrees fabs]
_known_constants_math = {
    'Exp1': 'e',
    'Pi': 'pi',
    'E': 'e'
    # Only in python >= 3.5:
    # 'Infinity': 'inf',
    # 'NaN': 'nan'
}

def _print_known_func(self, expr):
    known = self.known_functions[expr.__class__.__name__]
    return '{name}({args})'.format(name=self._module_format(known),
                                   args=', '.join(map(lambda arg: self._print(arg), expr.args)))


def _print_known_const(self, expr):
    known = self.known_constants[expr.__class__.__name__]
    return self._module_format(known)


class AbstractPythonCodePrinter(CodePrinter):
    printmethod = "_pythoncode"
    language = "Python"
    standard = "python3"
    reserved_words = _kw_py2and3.union(_kw_only_py3)
    modules = None  # initialized to a set in __init__
    tab = '    '
    _kf = dict(chain(
        _known_functions.items(),
        [(k, 'math.' + v) for k, v in _known_functions_math.items()]
    ))
    _kc = {k: 'math.'+v for k, v in _known_constants_math.items()}
    _operators = {'and': 'and', 'or': 'or', 'not': 'not'}
    _default_settings = dict(
        CodePrinter._default_settings,
        user_functions={},
        precision=17,
        inline=True,
        fully_qualified_modules=True,
        contract=False
    )

    def __init__(self, settings=None):
        super(AbstractPythonCodePrinter, self).__init__(settings)
        self.module_imports = defaultdict(set)
        self.known_functions = dict(self._kf, **(settings or {}).get(
            'user_functions', {}))
        self.known_constants = dict(self._kc, **(settings or {}).get(
            'user_constants', {}))

    def _declare_number_const(self, name, value):
        return "%s = %s" % (name, value)

    def _module_format(self, fqn, register=True):
        parts = fqn.split('.')
        if register and len(parts) > 1:
            self.module_imports['.'.join(parts[:-1])].add(parts[-1])

        if self._settings['fully_qualified_modules']:
            return fqn
        else:
            return fqn.split('(')[0].split('[')[0].split('.')[-1]

    def _format_code(self, lines):
        return lines

    def _get_statement(self, codestring):
        return "{}".format(codestring)

    def _get_comment(self, text):
        return "  # {0}".format(text)

    def _expand_fold_binary_op(self, op, args):
        """
        This method expands a fold on binary operations.

        ``functools.reduce`` is an example of a folded operation.

        For example, the expression

        `A + B + C + D`

        is folded into

        `((A + B) + C) + D`
        """
        if len(args) == 1:
            return self._print(args[0])
        else:
            return "%s(%s, %s)" % (
                self._module_format(op),
                self._expand_fold_binary_op(op, args[:-1]),
                self._print(args[-1]),
            )

    def _expand_reduce_binary_op(self, op, args):
        """
        This method expands a reductin on binary operations.

        Notice: this is NOT the same as ``functools.reduce``.

        For example, the expression

        `A + B + C + D`

        is reduced into:

        `(A + B) + (C + D)`
        """
        if len(args) == 1:
            return self._print(args[0])
        else:
            N = len(args)
            Nhalf = N // 2
            return "%s(%s, %s)" % (
                self._module_format(op),
                self._expand_reduce_binary_op(args[:Nhalf]),
                self._expand_reduce_binary_op(args[Nhalf:]),
            )

    def _get_einsum_string(self, subranks, contraction_indices):
        letters = self._get_letter_generator_for_einsum()
        contraction_string = ""
        counter = 0
        d = {j: min(i) for i in contraction_indices for j in i}
        indices = []
        for rank_arg in subranks:
            lindices = []
            for i in range(rank_arg):
                if counter in d:
                    lindices.append(d[counter])
                else:
                    lindices.append(counter)
                counter += 1
            indices.append(lindices)
        mapping = {}
        letters_free = []
        letters_dum = []
        for i in indices:
            for j in i:
                if j not in mapping:
                    l = next(letters)
                    mapping[j] = l
                else:
                    l = mapping[j]
                contraction_string += l
                if j in d:
                    if l not in letters_dum:
                        letters_dum.append(l)
                else:
                    letters_free.append(l)
            contraction_string += ","
        contraction_string = contraction_string[:-1]
        return contraction_string, letters_free, letters_dum

    def _print_NaN(self, expr):
        return "float('nan')"

    def _print_Infinity(self, expr):
        return "float('inf')"

    def _print_NegativeInfinity(self, expr):
        return "float('-inf')"

    def _print_ComplexInfinity(self, expr):
        return self._print_NaN(expr)

    def _print_Mod(self, expr):
        PREC = precedence(expr)
        return ('{0} % {1}'.format(*map(lambda x: self.parenthesize(x, PREC), expr.args)))

    def _print_Piecewise(self, expr):
        result = []
        i = 0
        for arg in expr.args:
            e = arg.expr
            c = arg.cond
            if i == 0:
                result.append('(')
            result.append('(')
            result.append(self._print(e))
            result.append(')')
            result.append(' if ')
            result.append(self._print(c))
            result.append(' else ')
            i += 1
        result = result[:-1]
        if result[-1] == 'True':
            result = result[:-2]
            result.append(')')
        else:
            result.append(' else None)')
        return ''.join(result)

    def _print_Relational(self, expr):
        "Relational printer for Equality and Unequality"
        op = {
            '==' :'equal',
            '!=' :'not_equal',
            '<'  :'less',
            '<=' :'less_equal',
            '>'  :'greater',
            '>=' :'greater_equal',
        }
        if expr.rel_op in op:
            lhs = self._print(expr.lhs)
            rhs = self._print(expr.rhs)
            return '({lhs} {op} {rhs})'.format(op=expr.rel_op, lhs=lhs, rhs=rhs)
        return super(AbstractPythonCodePrinter, self)._print_Relational(expr)

    def _print_ITE(self, expr):
        from sympy.functions.elementary.piecewise import Piecewise
        return self._print(expr.rewrite(Piecewise))

    def _print_Sum(self, expr):
        loops = (
            'for {i} in range({a}, {b}+1)'.format(
                i=self._print(i),
                a=self._print(a),
                b=self._print(b))
            for i, a, b in expr.limits)
        return '(builtins.sum({function} {loops}))'.format(
            function=self._print(expr.function),
            loops=' '.join(loops))

    def _print_ImaginaryUnit(self, expr):
        return '1j'

    def _print_MatrixBase(self, expr):
        name = expr.__class__.__name__
        func = self.known_functions.get(name, name)
        return "%s(%s)" % (func, self._print(expr.tolist()))

    _print_SparseMatrix = \
        _print_MutableSparseMatrix = \
        _print_ImmutableSparseMatrix = \
        _print_Matrix = \
        _print_DenseMatrix = \
        _print_MutableDenseMatrix = \
        _print_ImmutableMatrix = \
        _print_ImmutableDenseMatrix = \
        lambda self, expr: self._print_MatrixBase(expr)

    def _indent_codestring(self, codestring):
        return '\n'.join([self.tab + line for line in codestring.split('\n')])

    def _print_FunctionDefinition(self, fd):
        body = '\n'.join(map(lambda arg: self._print(arg), fd.body))
        return "def {name}({parameters}):\n{body}".format(
            name=self._print(fd.name),
            parameters=', '.join([self._print(var.symbol) for var in fd.parameters]),
            body=self._indent_codestring(body)
        )

    def _print_While(self, whl):
        body = '\n'.join(map(lambda arg: self._print(arg), whl.body))
        return "while {cond}:\n{body}".format(
            cond=self._print(whl.condition),
            body=self._indent_codestring(body)
        )

    def _print_Declaration(self, decl):
        return '%s = %s' % (
            self._print(decl.variable.symbol),
            self._print(decl.variable.value)
        )

    def _print_Return(self, ret):
        arg, = ret.args
        return 'return %s' % self._print(arg)

    def _print_Print(self, prnt):
        print_args = ', '.join(map(lambda arg: self._print(arg), prnt.print_args))
        if prnt.format_string != None: # Must be '!= None', cannot be 'is not None'
            print_args = '{0} % ({1})'.format(
                self._print(prnt.format_string), print_args)
        if prnt.file != None: # Must be '!= None', cannot be 'is not None'
            print_args += ', file=%s' % self._print(prnt.file)
        return 'print(%s)' % print_args

    def _print_Stream(self, strm):
        if str(strm.name) == 'stdout':
            return self._module_format('sys.stdout')
        elif str(strm.name) == 'stderr':
            return self._module_format('sys.stderr')
        else:
            return self._print(strm.name)

    def _print_NoneToken(self, arg):
        return 'None'


class PythonCodePrinter(AbstractPythonCodePrinter):

    def _print_sign(self, e):
        return '(0.0 if {e} == 0 else {f}(1, {e}))'.format(
            f=self._module_format('math.copysign'), e=self._print(e.args[0]))

    def _print_Not(self, expr):
        PREC = precedence(expr)
        return self._operators['not'] + self.parenthesize(expr.args[0], PREC)


for k in PythonCodePrinter._kf:
    setattr(PythonCodePrinter, '_print_%s' % k, _print_known_func)

for k in _known_constants_math:
    setattr(PythonCodePrinter, '_print_%s' % k, _print_known_const)


[docs]def pycode(expr, **settings): """ Converts an expr to a string of Python code Parameters ========== expr : Expr A SymPy expression. fully_qualified_modules : bool Whether or not to write out full module names of functions (``math.sin`` vs. ``sin``). default: ``True``. Examples ======== >>> from sympy import tan, Symbol >>> from sympy.printing.pycode import pycode >>> pycode(tan(Symbol('x')) + 1) 'math.tan(x) + 1' """ return PythonCodePrinter(settings).doprint(expr)
_not_in_mpmath = 'log1p log2'.split() _in_mpmath = [(k, v) for k, v in _known_functions_math.items() if k not in _not_in_mpmath] _known_functions_mpmath = dict(_in_mpmath, **{ 'sign': 'sign', }) _known_constants_mpmath = { 'Pi': 'pi' }
[docs]class MpmathPrinter(PythonCodePrinter): """ Lambda printer for mpmath which maintains precision for floats """ printmethod = "_mpmathcode" _kf = dict(chain( _known_functions.items(), [(k, 'mpmath.' + v) for k, v in _known_functions_mpmath.items()] )) def _print_Float(self, e): # XXX: This does not handle setting mpmath.mp.dps. It is assumed that # the caller of the lambdified function will have set it to sufficient # precision to match the Floats in the expression. # Remove 'mpz' if gmpy is installed. args = str(tuple(map(int, e._mpf_))) return '{func}({args})'.format(func=self._module_format('mpmath.mpf'), args=args) def _print_Rational(self, e): return '{0}({1})/{0}({2})'.format( self._module_format('mpmath.mpf'), e.p, e.q, ) def _print_uppergamma(self, e): return "{0}({1}, {2}, {3})".format( self._module_format('mpmath.gammainc'), self._print(e.args[0]), self._print(e.args[1]), self._module_format('mpmath.inf')) def _print_lowergamma(self, e): return "{0}({1}, 0, {2})".format( self._module_format('mpmath.gammainc'), self._print(e.args[0]), self._print(e.args[1])) def _print_log2(self, e): return '{0}({1})/{0}(2)'.format( self._module_format('mpmath.log'), self._print(e.args[0])) def _print_log1p(self, e): return '{0}({1}+1)'.format( self._module_format('mpmath.log'), self._print(e.args[0]))
for k in MpmathPrinter._kf: setattr(MpmathPrinter, '_print_%s' % k, _print_known_func) for k in _known_constants_mpmath: setattr(MpmathPrinter, '_print_%s' % k, _print_known_const) _not_in_numpy = 'erf erfc factorial gamma loggamma'.split() _in_numpy = [(k, v) for k, v in _known_functions_math.items() if k not in _not_in_numpy] _known_functions_numpy = dict(_in_numpy, **{ 'acos': 'arccos', 'acosh': 'arccosh', 'asin': 'arcsin', 'asinh': 'arcsinh', 'atan': 'arctan', 'atan2': 'arctan2', 'atanh': 'arctanh', 'exp2': 'exp2', 'sign': 'sign', })
[docs]class NumPyPrinter(PythonCodePrinter): """ Numpy printer which handles vectorized piecewise functions, logical operators, etc. """ printmethod = "_numpycode" _kf = dict(chain( PythonCodePrinter._kf.items(), [(k, 'numpy.' + v) for k, v in _known_functions_numpy.items()] )) _kc = {k: 'numpy.'+v for k, v in _known_constants_math.items()} def _print_seq(self, seq): "General sequence printer: converts to tuple" # Print tuples here instead of lists because numba supports # tuples in nopython mode. delimiter=', ' return '({},)'.format(delimiter.join(self._print(item) for item in seq)) def _print_MatMul(self, expr): "Matrix multiplication printer" if expr.as_coeff_matrices()[0] is not S(1): expr_list = expr.as_coeff_matrices()[1]+[(expr.as_coeff_matrices()[0])] return '({0})'.format(').dot('.join(self._print(i) for i in expr_list)) return '({0})'.format(').dot('.join(self._print(i) for i in expr.args)) def _print_MatPow(self, expr): "Matrix power printer" return '{0}({1}, {2})'.format(self._module_format('numpy.linalg.matrix_power'), self._print(expr.args[0]), self._print(expr.args[1])) def _print_Inverse(self, expr): "Matrix inverse printer" return '{0}({1})'.format(self._module_format('numpy.linalg.inv'), self._print(expr.args[0])) def _print_DotProduct(self, expr): # DotProduct allows any shape order, but numpy.dot does matrix # multiplication, so we have to make sure it gets 1 x n by n x 1. arg1, arg2 = expr.args if arg1.shape[0] != 1: arg1 = arg1.T if arg2.shape[1] != 1: arg2 = arg2.T return "%s(%s, %s)" % (self._module_format('numpy.dot'), self._print(arg1), self._print(arg2)) def _print_Piecewise(self, expr): "Piecewise function printer" exprs = '[{0}]'.format(','.join(self._print(arg.expr) for arg in expr.args)) conds = '[{0}]'.format(','.join(self._print(arg.cond) for arg in expr.args)) # If [default_value, True] is a (expr, cond) sequence in a Piecewise object # it will behave the same as passing the 'default' kwarg to select() # *as long as* it is the last element in expr.args. # If this is not the case, it may be triggered prematurely. return '{0}({1}, {2}, default=numpy.nan)'.format(self._module_format('numpy.select'), conds, exprs) def _print_Relational(self, expr): "Relational printer for Equality and Unequality" op = { '==' :'equal', '!=' :'not_equal', '<' :'less', '<=' :'less_equal', '>' :'greater', '>=' :'greater_equal', } if expr.rel_op in op: lhs = self._print(expr.lhs) rhs = self._print(expr.rhs) return '{op}({lhs}, {rhs})'.format(op=self._module_format('numpy.'+op[expr.rel_op]), lhs=lhs, rhs=rhs) return super(NumPyPrinter, self)._print_Relational(expr) def _print_And(self, expr): "Logical And printer" # We have to override LambdaPrinter because it uses Python 'and' keyword. # If LambdaPrinter didn't define it, we could use StrPrinter's # version of the function and add 'logical_and' to NUMPY_TRANSLATIONS. return '{0}.reduce(({1}))'.format(self._module_format('numpy.logical_and'), ','.join(self._print(i) for i in expr.args)) def _print_Or(self, expr): "Logical Or printer" # We have to override LambdaPrinter because it uses Python 'or' keyword. # If LambdaPrinter didn't define it, we could use StrPrinter's # version of the function and add 'logical_or' to NUMPY_TRANSLATIONS. return '{0}.reduce(({1}))'.format(self._module_format('numpy.logical_or'), ','.join(self._print(i) for i in expr.args)) def _print_Not(self, expr): "Logical Not printer" # We have to override LambdaPrinter because it uses Python 'not' keyword. # If LambdaPrinter didn't define it, we would still have to define our # own because StrPrinter doesn't define it. return '{0}({1})'.format(self._module_format('numpy.logical_not'), ','.join(self._print(i) for i in expr.args)) def _print_Min(self, expr): return '{0}(({1}))'.format(self._module_format('numpy.amin'), ','.join(self._print(i) for i in expr.args)) def _print_Max(self, expr): return '{0}(({1}))'.format(self._module_format('numpy.amax'), ','.join(self._print(i) for i in expr.args)) def _print_Pow(self, expr): if expr.exp == 0.5: return '{0}({1})'.format(self._module_format('numpy.sqrt'), self._print(expr.base)) else: return super(NumPyPrinter, self)._print_Pow(expr) def _print_arg(self, expr): return "%s(%s)" % (self._module_format('numpy.angle'), self._print(expr.args[0])) def _print_im(self, expr): return "%s(%s)" % (self._module_format('numpy.imag'), self._print(expr.args[0])) def _print_Mod(self, expr): return "%s(%s)" % (self._module_format('numpy.mod'), ', '.join( map(lambda arg: self._print(arg), expr.args))) def _print_re(self, expr): return "%s(%s)" % (self._module_format('numpy.real'), self._print(expr.args[0])) def _print_sinc(self, expr): return "%s(%s)" % (self._module_format('numpy.sinc'), self._print(expr.args[0]/S.Pi)) def _print_MatrixBase(self, expr): func = self.known_functions.get(expr.__class__.__name__, None) if func is None: func = self._module_format('numpy.array') return "%s(%s)" % (func, self._print(expr.tolist())) def _print_CodegenArrayTensorProduct(self, expr): array_list = [j for i, arg in enumerate(expr.args) for j in (self._print(arg), "[%i, %i]" % (2*i, 2*i+1))] return "%s(%s)" % (self._module_format('numpy.einsum'), ", ".join(array_list)) def _print_CodegenArrayContraction(self, expr): from sympy.codegen.array_utils import CodegenArrayTensorProduct base = expr.expr contraction_indices = expr.contraction_indices if not contraction_indices: return self._print(base) if isinstance(base, CodegenArrayTensorProduct): counter = 0 d = {j: min(i) for i in contraction_indices for j in i} indices = [] for rank_arg in base.subranks: lindices = [] for i in range(rank_arg): if counter in d: lindices.append(d[counter]) else: lindices.append(counter) counter += 1 indices.append(lindices) elems = ["%s, %s" % (self._print(arg), ind) for arg, ind in zip(base.args, indices)] return "%s(%s)" % ( self._module_format('numpy.einsum'), ", ".join(elems) ) raise NotImplementedError() def _print_CodegenArrayDiagonal(self, expr): diagonal_indices = list(expr.diagonal_indices) if len(diagonal_indices) > 1: # TODO: this should be handled in sympy.codegen.array_utils, # possibly by creating the possibility of unfolding the # CodegenArrayDiagonal object into nested ones. Same reasoning for # the array contraction. raise NotImplementedError if len(diagonal_indices[0]) != 2: raise NotImplementedError return "%s(%s, 0, axis1=%s, axis2=%s)" % ( self._module_format("numpy.diagonal"), self._print(expr.expr), diagonal_indices[0][0], diagonal_indices[0][1], ) def _print_CodegenArrayPermuteDims(self, expr): return "%s(%s, %s)" % ( self._module_format("numpy.transpose"), self._print(expr.expr), self._print(expr.permutation.args[0]), ) def _print_CodegenArrayElementwiseAdd(self, expr): return self._expand_fold_binary_op('numpy.add', expr.args)
for k in NumPyPrinter._kf: setattr(NumPyPrinter, '_print_%s' % k, _print_known_func) for k in NumPyPrinter._kc: setattr(NumPyPrinter, '_print_%s' % k, _print_known_const) _known_functions_scipy_special = { 'erf': 'erf', 'erfc': 'erfc', 'besselj': 'jv', 'bessely': 'yv', 'besseli': 'iv', 'besselk': 'kv', 'factorial': 'factorial', 'gamma': 'gamma', 'loggamma': 'gammaln', 'digamma': 'psi', 'RisingFactorial': 'poch', 'jacobi': 'eval_jacobi', 'gegenbauer': 'eval_gegenbauer', 'chebyshevt': 'eval_chebyt', 'chebyshevu': 'eval_chebyu', 'legendre': 'eval_legendre', 'hermite': 'eval_hermite', 'laguerre': 'eval_laguerre', 'assoc_laguerre': 'eval_genlaguerre', } _known_constants_scipy_constants = { 'GoldenRatio': 'golden_ratio', 'Pi': 'pi', 'E': 'e' } class SciPyPrinter(NumPyPrinter): _kf = dict(chain( NumPyPrinter._kf.items(), [(k, 'scipy.special.' + v) for k, v in _known_functions_scipy_special.items()] )) _kc = {k: 'scipy.constants.' + v for k, v in _known_constants_scipy_constants.items()} def _print_SparseMatrix(self, expr): i, j, data = [], [], [] for (r, c), v in expr._smat.items(): i.append(r) j.append(c) data.append(v) return "{name}({data}, ({i}, {j}), shape={shape})".format( name=self._module_format('scipy.sparse.coo_matrix'), data=data, i=i, j=j, shape=expr.shape ) _print_ImmutableSparseMatrix = _print_SparseMatrix # SciPy's lpmv has a different order of arguments from assoc_legendre def _print_assoc_legendre(self, expr): return "{0}({2}, {1}, {3})".format( self._module_format('scipy.special.lpmv'), self._print(expr.args[0]), self._print(expr.args[1]), self._print(expr.args[2])) for k in SciPyPrinter._kf: setattr(SciPyPrinter, '_print_%s' % k, _print_known_func) for k in SciPyPrinter._kc: setattr(SciPyPrinter, '_print_%s' % k, _print_known_const) class SymPyPrinter(PythonCodePrinter): _kf = {k: 'sympy.' + v for k, v in chain( _known_functions.items(), _known_functions_math.items() )} def _print_Function(self, expr): mod = expr.func.__module__ or '' return '%s(%s)' % (self._module_format(mod + ('.' if mod else '') + expr.func.__name__), ', '.join(map(lambda arg: self._print(arg), expr.args)))