Source code for sympy.matrices.expressions.trace
from __future__ import print_function, division
from sympy import Basic, Expr, sympify, S
from sympy.matrices.matrices import MatrixBase
from .matexpr import ShapeError
[docs]class Trace(Expr):
    """Matrix Trace
    Represents the trace of a matrix expression.
    Examples
    ========
    >>> from sympy import MatrixSymbol, Trace, eye
    >>> A = MatrixSymbol('A', 3, 3)
    >>> Trace(A)
    Trace(A)
    """
    is_Trace = True
    is_commutative = True
    def __new__(cls, mat):
        mat = sympify(mat)
        if not mat.is_Matrix:
            raise TypeError("input to Trace, %s, is not a matrix" % str(mat))
        if not mat.is_square:
            raise ShapeError("Trace of a non-square matrix")
        return Basic.__new__(cls, mat)
    def _eval_transpose(self):
        return self
    def _eval_derivative(self, v):
        from sympy.matrices.expressions.matexpr import _matrix_derivative
        return _matrix_derivative(self, v)
    def _eval_derivative_matrix_lines(self, x):
        r = self.args[0]._eval_derivative_matrix_lines(x)
        for lr in r:
            if lr.higher == 1:
                lr.higher *= lr.first * lr.second.T
            else:
                # This is not a matrix line:
                lr.higher *= Trace(lr.first * lr.second.T)
            lr.first = S.One
            lr.second = S.One
        return r
    @property
    def arg(self):
        return self.args[0]
    def doit(self, **kwargs):
        if kwargs.get('deep', True):
            arg = self.arg.doit(**kwargs)
            try:
                return arg._eval_trace()
            except (AttributeError, NotImplementedError):
                return Trace(arg)
        else:
            # _eval_trace would go too deep here
            if isinstance(self.arg, MatrixBase):
                return trace(self.arg)
            else:
                return Trace(self.arg)
    def _eval_rewrite_as_Sum(self, expr, **kwargs):
        from sympy import Sum, Dummy
        i = Dummy('i')
        return Sum(self.arg[i, i], (i, 0, self.arg.rows-1)).doit() 
def trace(expr):
    """Trace of a Matrix.  Sum of the diagonal elements.
    Examples
    ========
    >>> from sympy import trace, Symbol, MatrixSymbol, pprint, eye
    >>> n = Symbol('n')
    >>> X = MatrixSymbol('X', n, n)  # A square matrix
    >>> trace(2*X)
    2*Trace(X)
    >>> trace(eye(3))
    3
    """
    return Trace(expr).doit()