Source code for sympy.printing.dot

from __future__ import print_function, division

from sympy.core.basic import Basic
from sympy.core.expr import Expr
from sympy.core.symbol import Symbol
from sympy.core.numbers import Integer, Rational, Float
from sympy.core.compatibility import default_sort_key
from sympy.core.add import Add
from sympy.core.mul import Mul

__all__ = ['dotprint']

default_styles = ((Basic, {'color': 'blue', 'shape': 'ellipse'}),
          (Expr,  {'color': 'black'}))


sort_classes = (Add, Mul)
slotClasses = (Symbol, Integer, Rational, Float)
# XXX: Why not just use srepr()?
def purestr(x):
    """ A string that follows obj = type(obj)(*obj.args) exactly """
    if not isinstance(x, Basic):
        return str(x)
    if type(x) in slotClasses:
        args = [getattr(x, slot) for slot in x.__slots__]
    elif type(x) in sort_classes:
        args = sorted(x.args, key=default_sort_key)
    else:
        args = x.args
    return "%s(%s)"%(type(x).__name__, ', '.join(map(purestr, args)))


def styleof(expr, styles=default_styles):
    """ Merge style dictionaries in order

    Examples
    ========

    >>> from sympy import Symbol, Basic, Expr
    >>> from sympy.printing.dot import styleof
    >>> styles = [(Basic, {'color': 'blue', 'shape': 'ellipse'}),
    ...           (Expr,  {'color': 'black'})]

    >>> styleof(Basic(1), styles)
    {'color': 'blue', 'shape': 'ellipse'}

    >>> x = Symbol('x')
    >>> styleof(x + 1, styles)  # this is an Expr
    {'color': 'black', 'shape': 'ellipse'}
    """
    style = dict()
    for typ, sty in styles:
        if isinstance(expr, typ):
            style.update(sty)
    return style

def attrprint(d, delimiter=', '):
    """ Print a dictionary of attributes

    Examples
    ========

    >>> from sympy.printing.dot import attrprint
    >>> print(attrprint({'color': 'blue', 'shape': 'ellipse'}))
    "color"="blue", "shape"="ellipse"
    """
    return delimiter.join('"%s"="%s"'%item for item in sorted(d.items()))

def dotnode(expr, styles=default_styles, labelfunc=str, pos=(), repeat=True):
    """ String defining a node

    Examples
    ========

    >>> from sympy.printing.dot import dotnode
    >>> from sympy.abc import x
    >>> print(dotnode(x))
    "Symbol(x)_()" ["color"="black", "label"="x", "shape"="ellipse"];
    """
    style = styleof(expr, styles)

    if isinstance(expr, Basic) and not expr.is_Atom:
        label = str(expr.__class__.__name__)
    else:
        label = labelfunc(expr)
    style['label'] = label
    expr_str = purestr(expr)
    if repeat:
        expr_str += '_%s' % str(pos)
    return '"%s" [%s];' % (expr_str, attrprint(style))


def dotedges(expr, atom=lambda x: not isinstance(x, Basic), pos=(), repeat=True):
    """ List of strings for all expr->expr.arg pairs

    See the docstring of dotprint for explanations of the options.

    Examples
    ========

    >>> from sympy.printing.dot import dotedges
    >>> from sympy.abc import x
    >>> for e in dotedges(x+2):
    ...     print(e)
    "Add(Integer(2), Symbol(x))_()" -> "Integer(2)_(0,)";
    "Add(Integer(2), Symbol(x))_()" -> "Symbol(x)_(1,)";
    """
    if atom(expr):
        return []
    else:
        # TODO: This is quadratic in complexity (purestr(expr) already
        # contains [purestr(arg) for arg in expr.args]).
        expr_str = purestr(expr)
        arg_strs = [purestr(arg) for arg in expr.args]
        if repeat:
            expr_str += '_%s' % str(pos)
            arg_strs = [arg_str + '_%s' % str(pos + (i,)) for i, arg_str in enumerate(arg_strs)]
        return ['"%s" -> "%s";' % (expr_str, arg_str) for arg_str in arg_strs]

template = \
"""digraph{

# Graph style
%(graphstyle)s

#########
# Nodes #
#########

%(nodes)s

#########
# Edges #
#########

%(edges)s
}"""

_graphstyle = {'rankdir': 'TD', 'ordering': 'out'}

[docs]def dotprint(expr, styles=default_styles, atom=lambda x: not isinstance(x, Basic), maxdepth=None, repeat=True, labelfunc=str, **kwargs): """ DOT description of a SymPy expression tree Options are ``styles``: Styles for different classes. The default is:: [(Basic, {'color': 'blue', 'shape': 'ellipse'}), (Expr, {'color': 'black'})]`` ``atom``: Function used to determine if an arg is an atom. The default is ``lambda x: not isinstance(x, Basic)``. Another good choice is ``lambda x: not x.args``. ``maxdepth``: The maximum depth. The default is None, meaning no limit. ``repeat``: Whether to different nodes for separate common subexpressions. The default is True. For example, for ``x + x*y`` with ``repeat=True``, it will have two nodes for ``x`` and with ``repeat=False``, it will have one (warning: even if it appears twice in the same object, like Pow(x, x), it will still only appear only once. Hence, with repeat=False, the number of arrows out of an object might not equal the number of args it has). ``labelfunc``: How to label leaf nodes. The default is ``str``. Another good option is ``srepr``. For example with ``str``, the leaf nodes of ``x + 1`` are labeled, ``x`` and ``1``. With ``srepr``, they are labeled ``Symbol('x')`` and ``Integer(1)``. Additional keyword arguments are included as styles for the graph. Examples ======== >>> from sympy.printing.dot import dotprint >>> from sympy.abc import x >>> print(dotprint(x+2)) # doctest: +NORMALIZE_WHITESPACE digraph{ <BLANKLINE> # Graph style "ordering"="out" "rankdir"="TD" <BLANKLINE> ######### # Nodes # ######### <BLANKLINE> "Add(Integer(2), Symbol(x))_()" ["color"="black", "label"="Add", "shape"="ellipse"]; "Integer(2)_(0,)" ["color"="black", "label"="2", "shape"="ellipse"]; "Symbol(x)_(1,)" ["color"="black", "label"="x", "shape"="ellipse"]; <BLANKLINE> ######### # Edges # ######### <BLANKLINE> "Add(Integer(2), Symbol(x))_()" -> "Integer(2)_(0,)"; "Add(Integer(2), Symbol(x))_()" -> "Symbol(x)_(1,)"; } """ # repeat works by adding a signature tuple to the end of each node for its # position in the graph. For example, for expr = Add(x, Pow(x, 2)), the x in the # Pow will have the tuple (1, 0), meaning it is expr.args[1].args[0]. graphstyle = _graphstyle.copy() graphstyle.update(kwargs) nodes = [] edges = [] def traverse(e, depth, pos=()): nodes.append(dotnode(e, styles, labelfunc=labelfunc, pos=pos, repeat=repeat)) if maxdepth and depth >= maxdepth: return edges.extend(dotedges(e, atom=atom, pos=pos, repeat=repeat)) [traverse(arg, depth+1, pos + (i,)) for i, arg in enumerate(e.args) if not atom(arg)] traverse(expr, 0) return template%{'graphstyle': attrprint(graphstyle, delimiter='\n'), 'nodes': '\n'.join(nodes), 'edges': '\n'.join(edges)}