Source code for sympy.matrices.expressions.matexpr

from __future__ import print_function, division

from functools import wraps, reduce
import collections

from sympy.core import S, Symbol, Tuple, Integer, Basic, Expr, Eq, Mul, Add
from sympy.core.decorators import call_highest_priority
from sympy.core.compatibility import range, SYMPY_INTS, default_sort_key, string_types
from sympy.core.sympify import SympifyError, _sympify
from sympy.functions import conjugate, adjoint
from sympy.functions.special.tensor_functions import KroneckerDelta
from sympy.matrices import ShapeError
from sympy.simplify import simplify
from sympy.utilities.misc import filldedent


def _sympifyit(arg, retval=None):
    # This version of _sympifyit sympifies MutableMatrix objects
    def deco(func):
        @wraps(func)
        def __sympifyit_wrapper(a, b):
            try:
                b = _sympify(b)
                return func(a, b)
            except SympifyError:
                return retval

        return __sympifyit_wrapper

    return deco


[docs]class MatrixExpr(Expr): """Superclass for Matrix Expressions MatrixExprs represent abstract matrices, linear transformations represented within a particular basis. Examples ======== >>> from sympy import MatrixSymbol >>> A = MatrixSymbol('A', 3, 3) >>> y = MatrixSymbol('y', 3, 1) >>> x = (A.T*A).I * A * y See Also ======== MatrixSymbol, MatAdd, MatMul, Transpose, Inverse """ # Should not be considered iterable by the # sympy.core.compatibility.iterable function. Subclass that actually are # iterable (i.e., explicit matrices) should set this to True. _iterable = False _op_priority = 11.0 is_Matrix = True is_MatrixExpr = True is_Identity = None is_Inverse = False is_Transpose = False is_ZeroMatrix = False is_MatAdd = False is_MatMul = False is_commutative = False is_number = False is_symbol = False def __new__(cls, *args, **kwargs): args = map(_sympify, args) return Basic.__new__(cls, *args, **kwargs) # The following is adapted from the core Expr object def __neg__(self): return MatMul(S.NegativeOne, self).doit() def __abs__(self): raise NotImplementedError @_sympifyit('other', NotImplemented) @call_highest_priority('__radd__') def __add__(self, other): return MatAdd(self, other, check=True).doit() @_sympifyit('other', NotImplemented) @call_highest_priority('__add__') def __radd__(self, other): return MatAdd(other, self, check=True).doit() @_sympifyit('other', NotImplemented) @call_highest_priority('__rsub__') def __sub__(self, other): return MatAdd(self, -other, check=True).doit() @_sympifyit('other', NotImplemented) @call_highest_priority('__sub__') def __rsub__(self, other): return MatAdd(other, -self, check=True).doit() @_sympifyit('other', NotImplemented) @call_highest_priority('__rmul__') def __mul__(self, other): return MatMul(self, other).doit() @_sympifyit('other', NotImplemented) @call_highest_priority('__rmul__') def __matmul__(self, other): return MatMul(self, other).doit() @_sympifyit('other', NotImplemented) @call_highest_priority('__mul__') def __rmul__(self, other): return MatMul(other, self).doit() @_sympifyit('other', NotImplemented) @call_highest_priority('__mul__') def __rmatmul__(self, other): return MatMul(other, self).doit() @_sympifyit('other', NotImplemented) @call_highest_priority('__rpow__') def __pow__(self, other): if not self.is_square: raise ShapeError("Power of non-square matrix %s" % self) elif self.is_Identity: return self elif other is S.Zero: return Identity(self.rows) elif other is S.One: return self return MatPow(self, other).doit(deep=False) @_sympifyit('other', NotImplemented) @call_highest_priority('__pow__') def __rpow__(self, other): raise NotImplementedError("Matrix Power not defined") @_sympifyit('other', NotImplemented) @call_highest_priority('__rdiv__') def __div__(self, other): return self * other**S.NegativeOne @_sympifyit('other', NotImplemented) @call_highest_priority('__div__') def __rdiv__(self, other): raise NotImplementedError() #return MatMul(other, Pow(self, S.NegativeOne)) __truediv__ = __div__ __rtruediv__ = __rdiv__ @property def rows(self): return self.shape[0] @property def cols(self): return self.shape[1] @property def is_square(self): return self.rows == self.cols def _eval_conjugate(self): from sympy.matrices.expressions.adjoint import Adjoint from sympy.matrices.expressions.transpose import Transpose return Adjoint(Transpose(self)) def as_real_imag(self): from sympy import I real = (S(1)/2) * (self + self._eval_conjugate()) im = (self - self._eval_conjugate())/(2*I) return (real, im) def _eval_inverse(self): from sympy.matrices.expressions.inverse import Inverse return Inverse(self) def _eval_transpose(self): return Transpose(self) def _eval_power(self, exp): return MatPow(self, exp) def _eval_simplify(self, **kwargs): if self.is_Atom: return self else: return self.__class__(*[simplify(x, **kwargs) for x in self.args]) def _eval_adjoint(self): from sympy.matrices.expressions.adjoint import Adjoint return Adjoint(self) def _eval_derivative(self, x): return _matrix_derivative(self, x) def _eval_derivative_n_times(self, x, n): return Basic._eval_derivative_n_times(self, x, n) def _visit_eval_derivative_scalar(self, x): # `x` is a scalar: if x.has(self): return _matrix_derivative(x, self) else: return ZeroMatrix(*self.shape) def _visit_eval_derivative_array(self, x): if x.has(self): return _matrix_derivative(x, self) else: from sympy import Derivative return Derivative(x, self) def _accept_eval_derivative(self, s): return s._visit_eval_derivative_array(self) def _entry(self, i, j, **kwargs): raise NotImplementedError( "Indexing not implemented for %s" % self.__class__.__name__) def adjoint(self): return adjoint(self)
[docs] def as_coeff_Mul(self, rational=False): """Efficiently extract the coefficient of a product. """ return S.One, self
def conjugate(self): return conjugate(self) def transpose(self): from sympy.matrices.expressions.transpose import transpose return transpose(self) T = property(transpose, None, None, 'Matrix transposition.') def inverse(self): return self._eval_inverse() inv = inverse @property def I(self): return self.inverse() def valid_index(self, i, j): def is_valid(idx): return isinstance(idx, (int, Integer, Symbol, Expr)) return (is_valid(i) and is_valid(j) and (self.rows is None or (0 <= i) != False and (i < self.rows) != False) and (0 <= j) != False and (j < self.cols) != False) def __getitem__(self, key): if not isinstance(key, tuple) and isinstance(key, slice): from sympy.matrices.expressions.slice import MatrixSlice return MatrixSlice(self, key, (0, None, 1)) if isinstance(key, tuple) and len(key) == 2: i, j = key if isinstance(i, slice) or isinstance(j, slice): from sympy.matrices.expressions.slice import MatrixSlice return MatrixSlice(self, i, j) i, j = _sympify(i), _sympify(j) if self.valid_index(i, j) != False: return self._entry(i, j) else: raise IndexError("Invalid indices (%s, %s)" % (i, j)) elif isinstance(key, (SYMPY_INTS, Integer)): # row-wise decomposition of matrix rows, cols = self.shape # allow single indexing if number of columns is known if not isinstance(cols, Integer): raise IndexError(filldedent(''' Single indexing is only supported when the number of columns is known.''')) key = _sympify(key) i = key // cols j = key % cols if self.valid_index(i, j) != False: return self._entry(i, j) else: raise IndexError("Invalid index %s" % key) elif isinstance(key, (Symbol, Expr)): raise IndexError(filldedent(''' Only integers may be used when addressing the matrix with a single index.''')) raise IndexError("Invalid index, wanted %s[i,j]" % self)
[docs] def as_explicit(self): """ Returns a dense Matrix with elements represented explicitly Returns an object of type ImmutableDenseMatrix. Examples ======== >>> from sympy import Identity >>> I = Identity(3) >>> I I >>> I.as_explicit() Matrix([ [1, 0, 0], [0, 1, 0], [0, 0, 1]]) See Also ======== as_mutable: returns mutable Matrix type """ from sympy.matrices.immutable import ImmutableDenseMatrix return ImmutableDenseMatrix([[ self[i, j] for j in range(self.cols)] for i in range(self.rows)])
[docs] def as_mutable(self): """ Returns a dense, mutable matrix with elements represented explicitly Examples ======== >>> from sympy import Identity >>> I = Identity(3) >>> I I >>> I.shape (3, 3) >>> I.as_mutable() Matrix([ [1, 0, 0], [0, 1, 0], [0, 0, 1]]) See Also ======== as_explicit: returns ImmutableDenseMatrix """ return self.as_explicit().as_mutable()
def __array__(self): from numpy import empty a = empty(self.shape, dtype=object) for i in range(self.rows): for j in range(self.cols): a[i, j] = self[i, j] return a
[docs] def equals(self, other): """ Test elementwise equality between matrices, potentially of different types >>> from sympy import Identity, eye >>> Identity(3).equals(eye(3)) True """ return self.as_explicit().equals(other)
def canonicalize(self): return self def as_coeff_mmul(self): return 1, MatMul(self)
[docs] @staticmethod def from_index_summation(expr, first_index=None, last_index=None, dimensions=None): r""" Parse expression of matrices with explicitly summed indices into a matrix expression without indices, if possible. This transformation expressed in mathematical notation: `\sum_{j=0}^{N-1} A_{i,j} B_{j,k} \Longrightarrow \mathbf{A}\cdot \mathbf{B}` Optional parameter ``first_index``: specify which free index to use as the index starting the expression. Examples ======== >>> from sympy import MatrixSymbol, MatrixExpr, Sum, Symbol >>> from sympy.abc import i, j, k, l, N >>> A = MatrixSymbol("A", N, N) >>> B = MatrixSymbol("B", N, N) >>> expr = Sum(A[i, j]*B[j, k], (j, 0, N-1)) >>> MatrixExpr.from_index_summation(expr) A*B Transposition is detected: >>> expr = Sum(A[j, i]*B[j, k], (j, 0, N-1)) >>> MatrixExpr.from_index_summation(expr) A.T*B Detect the trace: >>> expr = Sum(A[i, i], (i, 0, N-1)) >>> MatrixExpr.from_index_summation(expr) Trace(A) More complicated expressions: >>> expr = Sum(A[i, j]*B[k, j]*A[l, k], (j, 0, N-1), (k, 0, N-1)) >>> MatrixExpr.from_index_summation(expr) A*B.T*A.T """ from sympy import Sum, Mul, Add, MatMul, transpose, trace from sympy.strategies.traverse import bottom_up def remove_matelement(expr, i1, i2): def repl_match(pos): def func(x): if not isinstance(x, MatrixElement): return False if x.args[pos] != i1: return False if x.args[3-pos] == 0: if x.args[0].shape[2-pos] == 1: return True else: return False return True return func expr = expr.replace(repl_match(1), lambda x: x.args[0]) expr = expr.replace(repl_match(2), lambda x: transpose(x.args[0])) # Make sure that all Mul are transformed to MatMul and that they # are flattened: rule = bottom_up(lambda x: reduce(lambda a, b: a*b, x.args) if isinstance(x, (Mul, MatMul)) else x) return rule(expr) def recurse_expr(expr, index_ranges={}): if expr.is_Mul: nonmatargs = [] pos_arg = [] pos_ind = [] dlinks = {} link_ind = [] counter = 0 args_ind = [] for arg in expr.args: retvals = recurse_expr(arg, index_ranges) assert isinstance(retvals, list) if isinstance(retvals, list): for i in retvals: args_ind.append(i) else: args_ind.append(retvals) for arg_symbol, arg_indices in args_ind: if arg_indices is None: nonmatargs.append(arg_symbol) continue if isinstance(arg_symbol, MatrixElement): arg_symbol = arg_symbol.args[0] pos_arg.append(arg_symbol) pos_ind.append(arg_indices) link_ind.append([None]*len(arg_indices)) for i, ind in enumerate(arg_indices): if ind in dlinks: other_i = dlinks[ind] link_ind[counter][i] = other_i link_ind[other_i[0]][other_i[1]] = (counter, i) dlinks[ind] = (counter, i) counter += 1 counter2 = 0 lines = {} while counter2 < len(link_ind): for i, e in enumerate(link_ind): if None in e: line_start_index = (i, e.index(None)) break cur_ind_pos = line_start_index cur_line = [] index1 = pos_ind[cur_ind_pos[0]][cur_ind_pos[1]] while True: d, r = cur_ind_pos if pos_arg[d] != 1: if r % 2 == 1: cur_line.append(transpose(pos_arg[d])) else: cur_line.append(pos_arg[d]) next_ind_pos = link_ind[d][1-r] counter2 += 1 # Mark as visited, there will be no `None` anymore: link_ind[d] = (-1, -1) if next_ind_pos is None: index2 = pos_ind[d][1-r] lines[(index1, index2)] = cur_line break cur_ind_pos = next_ind_pos ret_indices = list(j for i in lines for j in i) lines = {k: MatMul.fromiter(v) if len(v) != 1 else v[0] for k, v in lines.items()} return [(Mul.fromiter(nonmatargs), None)] + [ (MatrixElement(a, i, j), (i, j)) for (i, j), a in lines.items() ] elif expr.is_Add: res = [recurse_expr(i) for i in expr.args] d = collections.defaultdict(list) for res_addend in res: scalar = 1 for elem, indices in res_addend: if indices is None: scalar = elem continue indices = tuple(sorted(indices, key=default_sort_key)) d[indices].append(scalar*remove_matelement(elem, *indices)) scalar = 1 return [(MatrixElement(Add.fromiter(v), *k), k) for k, v in d.items()] elif isinstance(expr, KroneckerDelta): i1, i2 = expr.args if dimensions is not None: identity = Identity(dimensions[0]) else: identity = S.One return [(MatrixElement(identity, i1, i2), (i1, i2))] elif isinstance(expr, MatrixElement): matrix_symbol, i1, i2 = expr.args if i1 in index_ranges: r1, r2 = index_ranges[i1] if r1 != 0 or matrix_symbol.shape[0] != r2+1: raise ValueError("index range mismatch: {0} vs. (0, {1})".format( (r1, r2), matrix_symbol.shape[0])) if i2 in index_ranges: r1, r2 = index_ranges[i2] if r1 != 0 or matrix_symbol.shape[1] != r2+1: raise ValueError("index range mismatch: {0} vs. (0, {1})".format( (r1, r2), matrix_symbol.shape[1])) if (i1 == i2) and (i1 in index_ranges): return [(trace(matrix_symbol), None)] return [(MatrixElement(matrix_symbol, i1, i2), (i1, i2))] elif isinstance(expr, Sum): return recurse_expr( expr.args[0], index_ranges={i[0]: i[1:] for i in expr.args[1:]} ) else: return [(expr, None)] retvals = recurse_expr(expr) factors, indices = zip(*retvals) retexpr = Mul.fromiter(factors) if len(indices) == 0 or list(set(indices)) == [None]: return retexpr if first_index is None: for i in indices: if i is not None: ind0 = i break return remove_matelement(retexpr, *ind0) else: return remove_matelement(retexpr, first_index, last_index)
def applyfunc(self, func): from .applyfunc import ElementwiseApplyFunction return ElementwiseApplyFunction(func, self) def _eval_Eq(self, other): if not isinstance(other, MatrixExpr): return False if self.shape != other.shape: return False if (self - other).is_ZeroMatrix: return True return Eq(self, other, evaluate=False)
def get_postprocessor(cls): def _postprocessor(expr): # To avoid circular imports, we can't have MatMul/MatAdd on the top level mat_class = {Mul: MatMul, Add: MatAdd}[cls] nonmatrices = [] matrices = [] for term in expr.args: if isinstance(term, MatrixExpr): matrices.append(term) else: nonmatrices.append(term) if not matrices: return cls._from_args(nonmatrices) if nonmatrices: if cls == Mul: for i in range(len(matrices)): if not matrices[i].is_MatrixExpr: # If one of the matrices explicit, absorb the scalar into it # (doit will combine all explicit matrices into one, so it # doesn't matter which) matrices[i] = matrices[i].__mul__(cls._from_args(nonmatrices)) nonmatrices = [] break else: # Maintain the ability to create Add(scalar, matrix) without # raising an exception. That way different algorithms can # replace matrix expressions with non-commutative symbols to # manipulate them like non-commutative scalars. return cls._from_args(nonmatrices + [mat_class(*matrices).doit(deep=False)]) return mat_class(cls._from_args(nonmatrices), *matrices).doit(deep=False) return _postprocessor Basic._constructor_postprocessor_mapping[MatrixExpr] = { "Mul": [get_postprocessor(Mul)], "Add": [get_postprocessor(Add)], } def _matrix_derivative(expr, x): from sympy import Derivative lines = expr._eval_derivative_matrix_lines(x) ranks = [i.rank() for i in lines] assert len(set(ranks)) == 1 rank = ranks[0] if rank <= 2: return Add.fromiter([i.matrix_form() for i in lines]) return Derivative(expr, x) class MatrixElement(Expr): parent = property(lambda self: self.args[0]) i = property(lambda self: self.args[1]) j = property(lambda self: self.args[2]) _diff_wrt = True is_symbol = True is_commutative = True def __new__(cls, name, n, m): n, m = map(_sympify, (n, m)) from sympy import MatrixBase if isinstance(name, (MatrixBase,)): if n.is_Integer and m.is_Integer: return name[n, m] if isinstance(name, string_types): name = Symbol(name) name = _sympify(name) obj = Expr.__new__(cls, name, n, m) return obj def doit(self, **kwargs): deep = kwargs.get('deep', True) if deep: args = [arg.doit(**kwargs) for arg in self.args] else: args = self.args return args[0][args[1], args[2]] @property def indices(self): return self.args[1:] def _eval_derivative(self, v): from sympy import Sum, symbols, Dummy if not isinstance(v, MatrixElement): from sympy import MatrixBase if isinstance(self.parent, MatrixBase): return self.parent.diff(v)[self.i, self.j] return S.Zero M = self.args[0] if M == v.args[0]: return KroneckerDelta(self.args[1], v.args[1])*KroneckerDelta(self.args[2], v.args[2]) if isinstance(M, Inverse): i, j = self.args[1:] i1, i2 = symbols("z1, z2", cls=Dummy) Y = M.args[0] r1, r2 = Y.shape return -Sum(M[i, i1]*Y[i1, i2].diff(v)*M[i2, j], (i1, 0, r1-1), (i2, 0, r2-1)) if self.has(v.args[0]): return None return S.Zero
[docs]class MatrixSymbol(MatrixExpr): """Symbolic representation of a Matrix object Creates a SymPy Symbol to represent a Matrix. This matrix has a shape and can be included in Matrix Expressions Examples ======== >>> from sympy import MatrixSymbol, Identity >>> A = MatrixSymbol('A', 3, 4) # A 3 by 4 Matrix >>> B = MatrixSymbol('B', 4, 3) # A 4 by 3 Matrix >>> A.shape (3, 4) >>> 2*A*B + Identity(3) I + 2*A*B """ is_commutative = False is_symbol = True _diff_wrt = True def __new__(cls, name, n, m): n, m = _sympify(n), _sympify(m) if isinstance(name, string_types): name = Symbol(name) obj = Basic.__new__(cls, name, n, m) return obj def _hashable_content(self): return (self.name, self.shape) @property def shape(self): return self.args[1:3] @property def name(self): return self.args[0].name def _eval_subs(self, old, new): # only do substitutions in shape shape = Tuple(*self.shape)._subs(old, new) return MatrixSymbol(self.name, *shape) def __call__(self, *args): raise TypeError("%s object is not callable" % self.__class__) def _entry(self, i, j, **kwargs): return MatrixElement(self, i, j) @property def free_symbols(self): return set((self,)) def doit(self, **hints): if hints.get('deep', True): return type(self)(self.name, self.args[1].doit(**hints), self.args[2].doit(**hints)) else: return self def _eval_simplify(self, **kwargs): return self def _eval_derivative_matrix_lines(self, x): if self != x: return [_LeftRightArgs( ZeroMatrix(x.shape[0], self.shape[0]), ZeroMatrix(x.shape[1], self.shape[1]), transposed=False, )] else: first = Identity(self.shape[0]) second = Identity(self.shape[1]) return [_LeftRightArgs( first=first, second=second, transposed=False, )]
[docs]class Identity(MatrixExpr): """The Matrix Identity I - multiplicative identity Examples ======== >>> from sympy.matrices import Identity, MatrixSymbol >>> A = MatrixSymbol('A', 3, 5) >>> I = Identity(3) >>> I*A A """ is_Identity = True def __new__(cls, n): return super(Identity, cls).__new__(cls, _sympify(n)) @property def rows(self): return self.args[0] @property def cols(self): return self.args[0] @property def shape(self): return (self.args[0], self.args[0]) @property def is_square(self): return True def _eval_transpose(self): return self def _eval_trace(self): return self.rows def _eval_inverse(self): return self def conjugate(self): return self def _entry(self, i, j, **kwargs): eq = Eq(i, j) if eq is S.true: return S.One elif eq is S.false: return S.Zero return KroneckerDelta(i, j) def _eval_determinant(self): return S.One
class GenericIdentity(Identity): """ An identity matrix without a specified shape This exists primarily so MatMul() with no arguments can return something meaningful. """ def __new__(cls): # super(Identity, cls) instead of super(GenericIdentity, cls) because # Identity.__new__ doesn't have the same signature return super(Identity, cls).__new__(cls) @property def rows(self): raise TypeError("GenericIdentity does not have a specified shape") @property def cols(self): raise TypeError("GenericIdentity does not have a specified shape") @property def shape(self): raise TypeError("GenericIdentity does not have a specified shape") # Avoid Matrix.__eq__ which might call .shape def __eq__(self, other): return isinstance(other, GenericIdentity) def __ne__(self, other): return not (self == other) def __hash__(self): return super(GenericIdentity, self).__hash__()
[docs]class ZeroMatrix(MatrixExpr): """The Matrix Zero 0 - additive identity Examples ======== >>> from sympy import MatrixSymbol, ZeroMatrix >>> A = MatrixSymbol('A', 3, 5) >>> Z = ZeroMatrix(3, 5) >>> A + Z A >>> Z*A.T 0 """ is_ZeroMatrix = True def __new__(cls, m, n): return super(ZeroMatrix, cls).__new__(cls, m, n) @property def shape(self): return (self.args[0], self.args[1]) @_sympifyit('other', NotImplemented) @call_highest_priority('__rpow__') def __pow__(self, other): if other != 1 and not self.is_square: raise ShapeError("Power of non-square matrix %s" % self) if other == 0: return Identity(self.rows) if other < 1: raise ValueError("Matrix det == 0; not invertible.") return self def _eval_transpose(self): return ZeroMatrix(self.cols, self.rows) def _eval_trace(self): return S.Zero def _eval_determinant(self): return S.Zero def conjugate(self): return self def _entry(self, i, j, **kwargs): return S.Zero def __nonzero__(self): return False __bool__ = __nonzero__
class GenericZeroMatrix(ZeroMatrix): """ A zero matrix without a specified shape This exists primarily so MatAdd() with no arguments can return something meaningful. """ def __new__(cls): # super(ZeroMatrix, cls) instead of super(GenericZeroMatrix, cls) # because ZeroMatrix.__new__ doesn't have the same signature return super(ZeroMatrix, cls).__new__(cls) @property def rows(self): raise TypeError("GenericZeroMatrix does not have a specified shape") @property def cols(self): raise TypeError("GenericZeroMatrix does not have a specified shape") @property def shape(self): raise TypeError("GenericZeroMatrix does not have a specified shape") # Avoid Matrix.__eq__ which might call .shape def __eq__(self, other): return isinstance(other, GenericZeroMatrix) def __ne__(self, other): return not (self == other) def __hash__(self): return super(GenericZeroMatrix, self).__hash__() def matrix_symbols(expr): return [sym for sym in expr.free_symbols if sym.is_Matrix] class _LeftRightArgs(object): r""" Helper class to compute matrix derivatives. The logic: when an expression is derived by a matrix `X_{mn}`, two lines of matrix multiplications are created: the one contracted to `m` (first line), and the one contracted to `n` (second line). Transposition flips the side by which new matrices are connected to the lines. The trace connects the end of the two lines. """ def __init__(self, first, second, higher=S.One, transposed=False): self.first = first self.second = second self.higher = higher self.transposed = transposed def __repr__(self): return "_LeftRightArgs(first=%s[%s], second=%s[%s], higher=%s, transposed=%s)" % ( self.first, self.first.shape if isinstance(self.first, MatrixExpr) else None, self.second, self.second.shape if isinstance(self.second, MatrixExpr) else None, self.higher, self.transposed, ) def transpose(self): self.transposed = not self.transposed return self def matrix_form(self): if self.first != 1 and self.higher != 1: raise ValueError("higher dimensional array cannot be represented") # Remove one-dimensional identity matrices: # (this is needed by `a.diff(a)` where `a` is a vector) if self.first == Identity(1): return self.second.T if self.second == Identity(1): return self.first if self.first != 1: return self.first*self.second.T else: return self.higher def rank(self): """ Number of dimensions different from trivial (warning: not related to matrix rank). """ rank = 0 if self.first != 1: rank += sum([i != 1 for i in self.first.shape]) if self.second != 1: rank += sum([i != 1 for i in self.second.shape]) if self.higher != 1: rank += 2 return rank def append_first(self, other): self.first *= other def append_second(self, other): self.second *= other def __hash__(self): return hash((self.first, self.second, self.transposed)) def __eq__(self, other): if not isinstance(other, _LeftRightArgs): return False return (self.first == other.first) and (self.second == other.second) and (self.transposed == other.transposed) from .matmul import MatMul from .matadd import MatAdd from .matpow import MatPow from .transpose import Transpose from .inverse import Inverse