Every expression is a tree. (2 + 3) * x looks like this:
* In Python: ('*', ('+', 2, 3), 'x')
/ \
+ x First element = operator.
/ \ Rest = children.
2 3 Tuples all the way down.
A rewrite rule says: when you see this pattern, replace it with that. A handful of them, applied until nothing changes, can simplify or differentiate arbitrary expressions.
Here’s the whole thing. 90 lines of Python. A complete symbolic differentiator:
class _W:
def __repr__(self): return '_'
_ = _W()
def rewrite(t, *rules):
while True:
for r in rules:
n = r(t)
if n != t: t = n; break
else: return t
def bottom_up(rule):
def go(t):
if isinstance(t, tuple) and t:
t = (t[0],) + tuple(go(c) for c in t[1:])
return rule(t)
return go
class when:
def __init__(self, *pat):
self.pat, self.fn = pat, None
def then(self, f):
self.fn = f if callable(f) else (lambda *_, v=f: v)
return self
def __call__(self, t):
b = self._m(self.pat, t)
if b is not None and self.fn:
return self.fn(*b.values())
return t
def _m(self, p, t, b=None):
if b is None: b = {}
if callable(p) and not isinstance(p, type):
if not p(t): return None
b[f'_{len(b)}'] = t; return b
if p is _:
b[f'_{len(b)}'] = t; return b
if isinstance(p, str) and p.startswith('$'):
if p in b: return b if b[p] == t else None
b[p] = t; return b
if p == t: return b
if isinstance(p, tuple) and isinstance(t, tuple) and len(p) == len(t):
for pe, te in zip(p, t):
if self._m(pe, te, b) is None: return None
return b
return None
is_lit = lambda x: isinstance(x, (int, float))
def const_wrt(var):
def check(e):
if e == var: return False
if isinstance(e, tuple): return all(check(s) for s in e[1:])
return True
return check
simplify = [
when('+', 0, _).then(lambda x: x), # 0 + x = x
when('+', _, 0).then(lambda x: x), # x + 0 = x
when('*', 0, _).then(0), # 0 * x = 0
when('*', _, 0).then(0), # x * 0 = 0
when('*', 1, _).then(lambda x: x), # 1 * x = x
when('*', _, 1).then(lambda x: x), # x * 1 = x
when('^', _, 0).then(1), # x^0 = 1
when('^', _, 1).then(lambda x: x), # x^1 = x
when('+', is_lit, is_lit).then(lambda a, b: a + b),
when('-', is_lit, is_lit).then(lambda a, b: a - b),
when('*', is_lit, is_lit).then(lambda a, b: a * b),
]
diff = [
when('d', const_wrt('x'), 'x').then(0), # constant
when('d', 'x', 'x').then(1), # variable
when('d', ('+', _, _), '$v').then(lambda u, w, v: ('+', ('d', u, v), ('d', w, v))), # sum
when('d', ('-', _, _), '$v').then(lambda u, w, v: ('-', ('d', u, v), ('d', w, v))), # difference
when('d', ('*', _, _), '$v').then( # product
lambda u, w, v: ('+', ('*', u, ('d', w, v)), ('*', w, ('d', u, v)))),
when('d', ('^', 'x', is_lit), 'x').then(lambda n: ('*', n, ('^', 'x', n - 1))), # power
when('d', ('^', _, is_lit), '$v').then( # chain+power
lambda u, n, v: ('*', ('*', n, ('^', u, n - 1)), ('d', u, v))),
when('d', ('sin', 'x'), 'x').then(('cos', 'x')), # sin
when('d', ('sin', _), '$v').then(lambda u, v: ('*', ('cos', u), ('d', u, v))), # chain+sin
when('d', ('cos', 'x'), 'x').then(('-', 0, ('sin', 'x'))), # cos
when('d', ('cos', _), '$v').then(lambda u, v: ('*', ('-', 0, ('sin', u)), ('d', u, v))), # chain+cos
when('d', ('exp', 'x'), 'x').then(('exp', 'x')), # exp
when('d', ('exp', _), '$v').then(lambda u, v: ('*', ('exp', u), ('d', u, v))), # chain+exp
when('d', ('ln', 'x'), 'x').then(('/', 1, 'x')), # ln
when('d', ('/', _, _), '$v').then( # quotient
lambda u, w, v: ('/', ('-', ('*', w, ('d', u, v)), ('*', u, ('d', w, v))), ('^', w, 2))),
]
rules = [bottom_up(r) for r in diff + simplify]
result = rewrite(('d', ('*', ('^', 'x', 2), ('sin', 'x')), 'x'), *rules)
# => ('+', ('*', ('^', 'x', 2), ('cos', 'x')), ('*', ('sin', 'x'), ('*', 2, 'x')))
That’s it. rewrite applies rules until nothing changes. bottom_up walks the tree leaves-first. when does pattern matching: _ matches anything, $x captures a named variable, callables act as predicates. The rest is just calculus rules written as patterns.
The chain rule is not coded explicitly. When a rule like d/dx sin(u) produces cos(u) * d/dx u, that new d node gets rewritten by the same rules on the next pass. The recursion emerges from the fixed-point loop.
The widgets below run a JS reimplementation of the same logic. Step through each one to watch the engine think.
Bottom-up traversal
The key mechanism. Rules only match at the root of whatever they’re given, so bottom_up recurses into children first, then tries the rule at the rebuilt node. Leaves light up first, then the engine works upward.
Arithmetic simplification
More rules, same engine. Identity elements, absorbing elements, constant folding, cancellation:
Symbolic differentiation
The same engine with the diff rules from above. Watch the product rule split into two branches, then the power and chain rules reduce each piece, then simplification cleans up:
Discussion