Rewriting Aesara graphs
This article is a port of Brandon Willard's Tour of the Symbolic PyMC library, and is a simplified version of the example in Aesara's documentation. The text is almost a verbatim copy of the original, but mistakes are obviously mine.
In this document we will be implementing a symbolic "search-and-replace" that changes aesara graphs like at.dot(A, x+y)
to at.dot(A, x) + at.dot(A, y)
. In other words we will demonstrate how to implement the distributive property of the matrix multiplication so it can be applied to any aesara graph. Aesara allows one to implement rewrite rules like the distributive property—and many other sophisticated manipulation of graphs—by providing flexible, pure Python versions of core operations in symbolic computation. These operations are then combined and orchestrated through the relational programming DSL miniKanren.
More specifically, we’ll introduce the basic unification and reification operations and explicitly show how they relate to graph manipulation and the modeling of high-level mathematical relations. Along the way, we’ll cover some of the necessary details behind Aesara graphs.
We start by creating a graph of our target expressions–i.e. at.dot(A, x + y)
in Aesara. We need to do this in order to determine exactly what we’re searching for and–later–what to put in its place.
import aesara.tensor as at A_tt = at.matrix("A") x_tt = at.vector("x") y_tt = at.vector("y") z_tt = at.dot(A_tt, x_tt + y_tt)
We can get a text print-out of the graph using the debug print function dprint
import aesara
aesara.dprint(z_tt)
dot [id A] '' |A [id B] |Elemwise{add,no_inplace} [id C] '' |x [id D] |y [id E]
The output of dprint
shows the underlying operators (dot
, add
) and their arguments.
To "math/search for" combinations of Aesara operations–or, as we just saw, graphs–we use unification; to "replace" parts of the graph (well, produce a copy with replaced parts) we use reificatoin. Aesara provides support for these via expression-tuples.
S-expressions
We can convert an Aesara graphs into an S-expression-like form using etuples.
from etuples import etuple, etuplize from IPython.lib.pretty import pprint z_et = etuplize(z_tt) pprint(z_et)
e( e(aesara.tensor.math.Dot), A, e( e( aesara.tensor.elemwise.Elemwise, <aesara.scalar.basic.Add at 0x7fd9c9f1a440>, <frozendict {}>), x, y))
An etuple
is like a normal tuple
, except that its first element is a Callable
and the remaining elements are the Callable
's arguments. As above, a pretty-printed etuple
looks like a tuple
prefixed by an e
.
By working with etuples
we can use arbitrary Python functions in conjunction with Aesara graphs and logic variable arguments. Basically, and etuple
can be manipulated until all of its constituent logic variables are replaced with valid arguments to the function/operators. At that point the etuple can be evaluated.
For instance we can create an etuple
that uses the function at.add
with a logic variable argument.
from unification import var x_lv, y_lv = var('x'), var('y') add_pattern = etuple(etuplize(at.add), x_lv, y_lv)
It wouldn't normally be possible to call the at.add
function with these argument types, as demonstrated in this example:
try: at.add(x_lv, 1) except NotImplementedError as e: print(str(e))
Cannot convert ~x to a tensor variable.
We'll get a similar error if we attempt to evaluate the etuple
by accessing its ExpressionTuple.evaled_obj
property. However, after performing a simple manipulation that replaces the logic variable with a valid input to at.add
(reificatoin), we are able to evaluate the etuple
and obtain an Aesara Tensor result, as demonstrated by the following code:
from unification import reify new_add_pattern = reify(add_pattern, {x_lv: 1., y_lv: 1.}) pprint(new_add_pattern)
e( e( aesara.tensor.elemwise.Elemwise, <aesara.scalar.basic.Add at 0x7fd9c9f1a440>, <frozendict {}>), 1.0, 1.0)
pprint(new_add_pattern.evaled_obj)
Elemwise{add,no_inplace}.0
Working with S-expressions is much like manipulating a subset of Python AST, so, when using etuples
, one is–in effect-meta programming (e.g. by automating the production and evaluation of Aesara graphs using Python code). As a matter of fact, etuples
could be recast as ast.Expr
and ast.Call
objects that, though the use of eval
, could achieve the same results-albeit without the more convenient tuple-like structuring.
Operators and their parameters
In etuplized-graph-print the etuple
form of our matrix-multuplication graph z_et
produced Aesaa operators
pprint(z_et[0])
e(aesara.tensor.math.Dot)
Unification and reification
With the ability to use logic variables and Aesara graphs together, we can now "search" or "match" arbitrary graphs using unification and produce new graphs by replacing logic variables using reification.
We start by making "patterns" or templates for the subgraphs we would like to match. Patterns, in this case, take the form of S-expressions with the desired structure and logic variables in place of "unknown" or arbitrary terms that we might like to reference elsewhere.
dot_pattern
represents an S-expression that evaluateds to a graph in which two terms are matrix-multiplied.
from aesara.tensor.math import Dot A_lv, B_lv = var("A"), var("B") dot_pattern = etuple(etuple(Dot), A_lv, B_lv)
"Matching" a graph against this pattern is called unification. Unificatoin of two graphs implies unification of all sub-graphs and elements between them. When unification is successful, it returns a map of logic variables and their unified values. If there are no logic variables in the graphs it simply returns an empty map. If unification fails, it returns False
–at least in the implementation we use.
Unification
We can perform unification using the function unify
. The result is a dict
mapping logic variables to their unified values.
from unification import unify s = unify(dot_pattern, z_et) pprint(s)
{~A: A, ~B: e( e( aesara.tensor.elemwise.Elemwise, <aesara.scalar.basic.Add at 0x7fd9c9f1a440>, <frozendict {}>), x, y)}
The logic variable A
has been correctly unified with A_tt
, while the logic variable B
has been correctly unified with the addition of x_tt
and y_tt
.
Reification
Using reify
we can "fill-in"–or replace—the logic variables of our "pattern" with the matches obtained by unify
that are held within the variable s, or we could specify our own substitutions based on that information.
In the following snipped we simply exchange the A_tt
tensor with another X_tt
tensor and create a new graph with that value. The end result is a version of the original graph z_et
, with the new tensor.
X_tt = at.matrix("X") s[A_lv] = X_tt z_et_re = reify(dot_pattern, s) pprint(z_et_re)
e( e(aesara.tensor.math.Dot), X, e( e( aesara.tensor.elemwise.Elemwise, <aesara.scalar.basic.Add at 0x7fd9c9f1a440>, <frozendict {}>), x, y))
Finishing our implementation
We can also reify an entirely different graph using the values extracted from the graph z_et
. In this case, we create an "output" pattern graph, to complement our new "input" pattern graph dot_pattern
. If we combine our dot product and addition etuple
patterns, we can extract all the arguments needed as input to a distributed multiplication pattern.
output_pattern = etuple(etuplize(at.add), etuple(etuple(Dot), A_lv, x_lv), etuple(etuple(Dot), B_lv, y_lv))
pprint(output_pattern)
e( e( aesara.tensor.elemwise.Elemwise, <aesara.scalar.basic.Add at 0x7fd9c9f1a440>, <frozendict {}>), e(e(aesara.tensor.math.Dot), ~A, ~x), e(e(aesara.tensor.math.Dot), ~B, ~y))
With logic variables A_lv
, x_lv
and y_lv
mapped to their template-corresponding objects in another graph, we can reify output_pattern
and obtain a reified version of said graph.
Using the previous unification results contained in s
we only need to reify output_pattern
with those mappings. However, since our pattern refers to logic variables x_lv
and y_lv
we'll need to unify these logic variables with the appropriate terms in the graph.
s_add = unify(s[B_lv], add_pattern, s)
pprint(s_add)
{~A: X, ~B: e( e( aesara.tensor.elemwise.Elemwise, <aesara.scalar.basic.Add at 0x7fd9c9f1a440>, <frozendict {}>), x, y), ~x: x, ~y: y}
z_new = reify(output_pattern, s_add)
aesara.dprint(z_new.evaled_obj)
Elemwise{add,no_inplace} [id A] '' |dot [id B] '' | |X [id C] | |x [id D] |InplaceDimShuffle{x} [id E] '' |dot [id F] '' |Elemwise{add,no_inplace} [id G] '' | |x [id D] | |y [id H] |y [id H]
Using only the basics of unification and reification provided by Aesara one can extract specific elements from Aesara graphs and use them to implement mathematical identities/relations. Through clever use of multiple mathematical relations, one can–for example–construct graph optimizations that turn large classes of user-defined statistical models into computational tractable reformulations. Similarly, one can construct "normal forms" for models, making it possible to determine whether or not a user-defined model is suitable for a specific sampler.
Next we will introduce another major element of Aesara that orchestrates and simplifies sequences of unifications like we used earlier, provides control-flow-like capabilities, produces fully reified results of arbitrary forms and does so within a genuinely declarative formalism that carries much of the same power of logical programming: miniKanren!
Relational programming in miniKanren
Aesara uses a Python implementation of the embedded domain-specific language miniKanren–provided by the kanren
package–to orchestrate more sophisticated uses of unification and reification. For a quick intro, see the basic introduction provided by the kanren
package. We'll cover most of the same basic material here.
To start, miniKanren uses goals (in the same sense as logic programming) to assert relations, and the run
function evaluates those goals and allows one to specify the exact amount and type of reified output desired from the states that satisfy the goals.
In their most basic form, miniKanren states are simply the substitution maps returned by unification, which–in the normal course of operations–are not dealt with directly.
The basic goals
Normally, a user will only need to construct compound goals from a basic set of primitives. Arguably, the most primitive goal is the equivalence relation under unification denoted by eq
in Python.
In the following code block we ask for all successful results/reifications (signified by the 0
argument) of the logic variable var('q')
for the goal eq(var('q'), 1)
, i.e. unify var('q')
with 1
.
from kanren import run, eq q_lv = var('q') mk_res = run(0, q_lv, eq(q_lv, 1)) pprint(mk_res)
(1,)
Since miniKanren's run
always returns a stream of results, we obtain a tuple containing the reified values of q_lv
under the one possible state for which our stated goal successfully evaluates.
The other basic primitives represent conjunction and disjunction of miniKanren goals: lall
and lany
respectively.
from kanren import lall mk_res = run(0, q_lv, lall(eq(q_lv, 1), eq(q_lv, 2))) pprint(mk_res)
We just used lall
to obtain the conjunction of two unificatoin goals. Since we requested the same logic variable be unified with 1
and 2
simultaneously, which is imposssibe, we got back an empty stream of results–indicating failure.
Goal disjunction, lany
, will split a state stream accross goals, producing new distrinct states for each:
from kanren import lany mk_res = run(0, q_lv, lany(eq(q_lv, 1), eq(q_lv, 2))) pprint(mk_res)
The goal disjunction result shows that the logic variable q_lv
can be unified with either 1
or 2
under the two unification goals.
A common pattern of disjuntion and conjunction is called conde
, and it mirrors the Lisp function cond
, which is effectively a type compound if ... elif ... elif ...
. Specifically, conde([x_1, ...], ..., [y_1,...])
is the same as lany(lall(x_1,...), ..., lall(y_1, ...))
-i.e. a disjunction of goal conjunctions.
from kanren import conde r_lv = var("r") mk_res = run( 0, [q_lv, r_lv], conde( [eq(q_lv, 1), eq(r_lv, 10)], [eq(q_lv, 2), eq(r_lv, 20)] ) ) pprint(mk_res)
([1, 10], [2, 20])
We introduced another logic variable r_lv
and requested the reified values of a list containing both logic variables. The output resembles the idea thatif q_lv
is "equal" to 1
, then r_lv
is "equal" to 10
, etc. Unlike normal conditionals, each clause/branch isn't exclusive, instead each is realized when the goals in a branch can be successful.
The following code demonstrated when conde
can behave more like a traditional statement.
mk_res = run(0, [q_lv, r_lv],
lall(eq(q_lv, 1),
conde(
[eq(q_lv, 1), eq(r_lv, 10)],
[eq(q_lv, 2), eq(r_lv, 20)],
)))
pprint(mk_res)
([1, 10],)
A better implementation
Since miniKanren uses unification and reification, we can apply its basic goals to Aesara graphs, as we did earlier, and reproduce the entire implementation in a much more concise manner.
mk_res = run(1, output_pattern, eq(dot_pattern, z_et), eq(add_pattern, B_lv))
pprint(mk_res)
(e( e( aesara.tensor.elemwise.Elemwise, <aesara.scalar.basic.Add at 0x7fd9c9f1a440>, <frozendict {}>), e(e(aesara.tensor.math.Dot), A, x), e( e(aesara.tensor.math.Dot), e( e( aesara.tensor.elemwise.Elemwise, <aesara.scalar.basic.Add at 0x7fd9c9f1a440>, <frozendict {}>), x, y), y)),)
We obtain an etuple that we can evaluate to get the graph
aesara.dprint(mk_res[0].evaled_obj)
Elemwise{add,no_inplace} [id A] '' |dot [id B] '' | |A [id C] | |x [id D] |InplaceDimShuffle{x} [id E] '' |dot [id F] '' |Elemwise{add,no_inplace} [id G] '' | |x [id D] | |y [id H] |y [id H]
We did not need to use the conjunction operation lall
explicitly, because all remaining goal arguments to run
are automatically applied in conjunction.
Before moving on to the next section and goal construction, let us summarize everything we did in a self-contained exampe:
import aesara import aesara.tensor as at from aesara.tensor.math import Dot from etuples import etuple, etuplize from kanren import eq, run from unification import var from IPython.lib.pretty import pprint # Define the graph we want to "modify" A_tt = at.matrix("A") x_tt = at.vector("x") y_tt = at.vector("y") z_tt = at.dot(A_tt, x_tt + y_tt) z_et = etuplize(z_tt) # Input patterns and logic variables x_lv, y_lv = var('x'), var('y') add_pattern = etuple(etuplize(at.add), x_lv, y_lv) A_lv, B_lv = var('A'), var('B') dot_pattern = etuple(etuple(Dot), A_lv, B_lv) # Output pattern output_pattern = etuple(etuplize(at.add), etuple(etuple(Dot), A_lv, x_lv), etuple(etuple(Dot), B_lv, y_lv)) # Using miniKanren mk_res = run(1, output_pattern, eq(dot_pattern, z_et), eq(add_pattern, B_lv)) aesara.dprint(mk_res[0].evaled_obj)
Elemwise{add,no_inplace} [id A] '' |dot [id B] '' | |A [id C] | |x [id D] |InplaceDimShuffle{x} [id E] '' |dot [id F] '' |Elemwise{add,no_inplace} [id G] '' | |x [id D] | |y [id H] |y [id H]
When combinations of miniKanren goals comprise logical units, we can wrap their construction in functions which we call goal constructors.
Goals Constructors
Using our distributive law example, we can create a goal constructor that creates our combined pattern and applies it in one go.
def distributeo(in_g, out_g): """Create a oal that represents commuted matrix multiplicatoin and addition.""" A_lv, x_lv, y_lv = var(), var(), var() dot_pattern = etuple(etuple(Dot), A_lv, etuple(etuplize(at.add), x_lv, y_lv)) dist_pattern = etuple(etuplize(at.add), etuple(etuple(Dot), A_lv, x_lv), etuple(etuple(Dot), A_lv, y_lv)) return lall(eq(in_g, dot_pattern), eq(out_g, dist_pattern))
Our goal constructor represent the relation for distribution of matrix multiplication and addition. In this sense, it can be run both ways i.e. it can "expand" a multiplication by distributing it through addition, and it can "contract" it by doing the opposite.
In the following example we "expand" the multiplication:
q_lv = var() mk_res = run(1, q_lv, distributeo(z_et, q_lv)) z_expanded_et = mk_res[0].evaled_obj aesara.dprint(z_expanded_et)
Elemwise{add,no_inplace} [id A] '' |dot [id B] '' | |A [id C] | |x [id D] |dot [id E] '' |A [id C] |y [id F]
And in the following example we "contract" the previously expanded result
q_lv = var() mk_res = run(1, q_lv, distributeo(q_lv, z_expanded_et)) z_contracted_et = mk_res[0].evaled_obj aesara.dprint(z_contracted_et)
dot [id A] '' |A [id B] |Elemwise{add,no_inplace} [id C] '' |x [id D] |y [id E]
Graph-based goals
In most situation the desired graphs will be subgraphs of much larger ones. Aesara introduces some miniKanren goals that apply other goals throughout graphs until a fixed-point is reached. This sequence of operations is generally necessary for graph simplification and rewriting.
In the following example we create a new graph that contains at.dot(A, x+y)
as a subgraph.
e( e( aesara.tensor.elemwise.Elemwise, <aesara.scalar.basic.Add at 0x7fd9c9f1a440>, <frozendict {}>), e( e( aesara.tensor.elemwise.Elemwise, <aesara.scalar.basic.Mul at 0x7fd9c9f1a560>, <frozendict {}>), e( e(aesara.tensor.elemwise.DimShuffle, (), ('x',), True), TensorConstant{2}), e( e(aesara.tensor.math.Dot), A, e( e( aesara.tensor.elemwise.Elemwise, <aesara.scalar.basic.Add at 0x7fd9c9f1a440>, <frozendict {}>), x, y))), e( e(aesara.tensor.elemwise.DimShuffle, (), ('x',), True), TensorConstant{1.0}))
We define graph_walko
, a function that walks term graphs and will apply our distributeo
goal throughout the graph until the applicable subgraph is found (and replaced)
from etuples.core import ExpressionTuple from kanren.graph import walko from kanren import eq from functools import partial graph_walko = partial(walko, rator_goal=eq) q_lv = var() mk_res = run(1, q_lv, graph_walko(distributeo, z_graph_et, q_lv)) aesara.dprint(mk_res[0].evaled_obj)
Elemwise{add,no_inplace} [id A] '' |Elemwise{mul,no_inplace} [id B] '' | |InplaceDimShuffle{x} [id C] '' | | |TensorConstant{2} [id D] | |Elemwise{add,no_inplace} [id E] '' | |dot [id F] '' | | |A [id G] | | |x [id H] | |dot [id I] '' | |A [id G] | |y [id J] |InplaceDimShuffle{x} [id K] '' |TensorConstant{1.0} [id L]