from __future__ import print_function, division
from sympy.core.basic import Basic
from sympy.core.expr import Expr
from sympy.core.symbol import Symbol
from sympy.core.numbers import Integer, Rational, Float
from sympy.core.compatibility import default_sort_key
from sympy.core.add import Add
from sympy.core.mul import Mul
__all__ = ['dotprint']
default_styles = ((Basic, {'color': 'blue', 'shape': 'ellipse'}),
(Expr, {'color': 'black'}))
sort_classes = (Add, Mul)
slotClasses = (Symbol, Integer, Rational, Float)
# XXX: Why not just use srepr()?
def purestr(x):
""" A string that follows obj = type(obj)(*obj.args) exactly """
if not isinstance(x, Basic):
return str(x)
if type(x) in slotClasses:
args = [getattr(x, slot) for slot in x.__slots__]
elif type(x) in sort_classes:
args = sorted(x.args, key=default_sort_key)
else:
args = x.args
return "%s(%s)"%(type(x).__name__, ', '.join(map(purestr, args)))
def styleof(expr, styles=default_styles):
""" Merge style dictionaries in order
Examples
========
>>> from sympy import Symbol, Basic, Expr
>>> from sympy.printing.dot import styleof
>>> styles = [(Basic, {'color': 'blue', 'shape': 'ellipse'}),
... (Expr, {'color': 'black'})]
>>> styleof(Basic(1), styles)
{'color': 'blue', 'shape': 'ellipse'}
>>> x = Symbol('x')
>>> styleof(x + 1, styles) # this is an Expr
{'color': 'black', 'shape': 'ellipse'}
"""
style = dict()
for typ, sty in styles:
if isinstance(expr, typ):
style.update(sty)
return style
def attrprint(d, delimiter=', '):
""" Print a dictionary of attributes
Examples
========
>>> from sympy.printing.dot import attrprint
>>> print(attrprint({'color': 'blue', 'shape': 'ellipse'}))
"color"="blue", "shape"="ellipse"
"""
return delimiter.join('"%s"="%s"'%item for item in sorted(d.items()))
def dotnode(expr, styles=default_styles, labelfunc=str, pos=(), repeat=True):
""" String defining a node
Examples
========
>>> from sympy.printing.dot import dotnode
>>> from sympy.abc import x
>>> print(dotnode(x))
"Symbol(x)_()" ["color"="black", "label"="x", "shape"="ellipse"];
"""
style = styleof(expr, styles)
if isinstance(expr, Basic) and not expr.is_Atom:
label = str(expr.__class__.__name__)
else:
label = labelfunc(expr)
style['label'] = label
expr_str = purestr(expr)
if repeat:
expr_str += '_%s' % str(pos)
return '"%s" [%s];' % (expr_str, attrprint(style))
def dotedges(expr, atom=lambda x: not isinstance(x, Basic), pos=(), repeat=True):
""" List of strings for all expr->expr.arg pairs
See the docstring of dotprint for explanations of the options.
Examples
========
>>> from sympy.printing.dot import dotedges
>>> from sympy.abc import x
>>> for e in dotedges(x+2):
... print(e)
"Add(Integer(2), Symbol(x))_()" -> "Integer(2)_(0,)";
"Add(Integer(2), Symbol(x))_()" -> "Symbol(x)_(1,)";
"""
if atom(expr):
return []
else:
# TODO: This is quadratic in complexity (purestr(expr) already
# contains [purestr(arg) for arg in expr.args]).
expr_str = purestr(expr)
arg_strs = [purestr(arg) for arg in expr.args]
if repeat:
expr_str += '_%s' % str(pos)
arg_strs = [arg_str + '_%s' % str(pos + (i,)) for i, arg_str in enumerate(arg_strs)]
return ['"%s" -> "%s";' % (expr_str, arg_str) for arg_str in arg_strs]
template = \
"""digraph{
# Graph style
%(graphstyle)s
#########
# Nodes #
#########
%(nodes)s
#########
# Edges #
#########
%(edges)s
}"""
_graphstyle = {'rankdir': 'TD', 'ordering': 'out'}
[docs]def dotprint(expr, styles=default_styles, atom=lambda x: not isinstance(x,
Basic), maxdepth=None, repeat=True, labelfunc=str, **kwargs):
"""
DOT description of a SymPy expression tree
Options are
``styles``: Styles for different classes. The default is::
[(Basic, {'color': 'blue', 'shape': 'ellipse'}),
(Expr, {'color': 'black'})]``
``atom``: Function used to determine if an arg is an atom. The default is
``lambda x: not isinstance(x, Basic)``. Another good choice is
``lambda x: not x.args``.
``maxdepth``: The maximum depth. The default is None, meaning no limit.
``repeat``: Whether to different nodes for separate common subexpressions.
The default is True. For example, for ``x + x*y`` with
``repeat=True``, it will have two nodes for ``x`` and with
``repeat=False``, it will have one (warning: even if it appears
twice in the same object, like Pow(x, x), it will still only appear
only once. Hence, with repeat=False, the number of arrows out of an
object might not equal the number of args it has).
``labelfunc``: How to label leaf nodes. The default is ``str``. Another
good option is ``srepr``. For example with ``str``, the leaf nodes
of ``x + 1`` are labeled, ``x`` and ``1``. With ``srepr``, they
are labeled ``Symbol('x')`` and ``Integer(1)``.
Additional keyword arguments are included as styles for the graph.
Examples
========
>>> from sympy.printing.dot import dotprint
>>> from sympy.abc import x
>>> print(dotprint(x+2)) # doctest: +NORMALIZE_WHITESPACE
digraph{
<BLANKLINE>
# Graph style
"ordering"="out"
"rankdir"="TD"
<BLANKLINE>
#########
# Nodes #
#########
<BLANKLINE>
"Add(Integer(2), Symbol(x))_()" ["color"="black", "label"="Add", "shape"="ellipse"];
"Integer(2)_(0,)" ["color"="black", "label"="2", "shape"="ellipse"];
"Symbol(x)_(1,)" ["color"="black", "label"="x", "shape"="ellipse"];
<BLANKLINE>
#########
# Edges #
#########
<BLANKLINE>
"Add(Integer(2), Symbol(x))_()" -> "Integer(2)_(0,)";
"Add(Integer(2), Symbol(x))_()" -> "Symbol(x)_(1,)";
}
"""
# repeat works by adding a signature tuple to the end of each node for its
# position in the graph. For example, for expr = Add(x, Pow(x, 2)), the x in the
# Pow will have the tuple (1, 0), meaning it is expr.args[1].args[0].
graphstyle = _graphstyle.copy()
graphstyle.update(kwargs)
nodes = []
edges = []
def traverse(e, depth, pos=()):
nodes.append(dotnode(e, styles, labelfunc=labelfunc, pos=pos, repeat=repeat))
if maxdepth and depth >= maxdepth:
return
edges.extend(dotedges(e, atom=atom, pos=pos, repeat=repeat))
[traverse(arg, depth+1, pos + (i,)) for i, arg in enumerate(e.args) if not atom(arg)]
traverse(expr, 0)
return template%{'graphstyle': attrprint(graphstyle, delimiter='\n'),
'nodes': '\n'.join(nodes),
'edges': '\n'.join(edges)}