I recently came across microgpt - a single-file, python-only training script for a GPT model.
What stood out to me were 40 or so lines that implemented autograd. That is, calculating gradients of computations done in code, automatically. Autograd is not novel (python autograd). It underpins every single deep learning library (jax, torch, tensorflow, and yes, numpy too).
No, the reason I was so taken aback was explicitly facing the simplicity of the mechanism that has fueled machine learning for decades now. It’s one thing to call loss.backward(); it’s an entirely different thing to pause and examine what’s happening underneath.
The goal of this post is to build up those 40 lines, step by step. First, we revisit two rules of differentiation. Then we come up with a representation of differentiable values in code. Finally, we apply the two rules over a few revisions, each more complex.
The Chain Rule
The derivative multiplies over function composition.
graph LR
x[x]
g["g(x)"]
f["f(x)"]
two["2x"]
x --"dg/dx"--> g
g --"df/dg"--> f
x --"d(2x)/dx"--> two
two --"df/d(2x)"--> f
Looking at the graph, $x$ affects $f(x)$ twice: once through $g(x)$ and once through $2x$. Visually, the change in $f$ with respect to $x$ should add the change due to each of the two branches.
Representing a differentiable value
We need to represent each scalar such that we can contain its value and gradient. Each operation produces a new value.
w = Value(2)
x = Value(1)
y = w + xz = y + xprint(z.data)
graph LR
w --"dy/dw"--> y
y --"dz/dy"--> z
x --"dy/dx"--> y
x --"dz/dx"--> z
Backward pass on the computation graph
Now, we need a backward pass. What is dz/d[w,x,y]? Note that each intermediate result, starting from the inputs, is a Value object. Also note that the relationships lend themselves to a graph representation. We can call each Value, be it input, intermediate, or output, a node in the computation graph.
Note that the gradients from the two edges to x get added (the distributive property). One path has multiple hops y--z and x--y. Their gradients get multiplied (the chain rule), before being accumulated into x.
Since, there are multiple paths back from z to x. Therefore, x needs to keep track of all the gradients it has accumulated so far. This can be done in the grad attribute.
To traverse edges, we need to track the dependencies of each node. Then we can calculate, for each operation (i.e. edge), d node / d dependency. For addition, dependencies are the two arguments into the add operation. They can be tracked as a tuple attribute deps. The pesudocode for an add operation is:
y = w + xz = y + x# dz/dx = [dz/dy * dy/dx] + dz/dx # \--chain rule--/# \--distributive property--/# Starting from the final node fordepinz.deps: # i.e. y, x# new gradient =dep.grad = \
# current gradient (distributive property)dep.grad + # it's 0 for now \# parent's gradient (dz/dz) * gradient of add operation (chain rule) (z.grad * 1)
# Then proceeding down the graph to the deps of depsfordepiny.deps: # i.e. w, x# new gradient =dep.grad = \
# current gradient (distributive property)dep.grad + # non-zero for x; it was visited earlier ^^ \# parent's gradient (dz/dy) * gradient of add operation (chain rule) (y.grad * 1)
Putting it together as follows. A backward() method is defined which calculates the derivative with respect to each node in the computation graph. We make two observations when giving grad a default value:
When calling backward() on a node, the most trivial gradient is that of the node with itself. d node / d node = 1. The first thing we check in backward() is whether grad is None. If so, we know that that node is the root node from which backward() was first called.
When looking at a node’s dependencies, we want to respect other gradients that may have accumulated from other edges in the computation graph. We check whether the grad attribute of a dependency is None. If so, it means that that dependency hasn’t started accumulating gradients. We set it to 0. (Why not set to 1? We only set 1 for the root node. Since this is necessarily a dependency of some node, therefore it is not the root node.)
classValue:
def__init__(self, data, grad=None, deps=()):
self.data = dataself.grad = gradself.deps = depsdef__repr__(self): returnf"V({self.data}, {self.grad})"def__add__(self, other):
data = self.data + other.datadeps = (self, other)
returnValue(data=data, grad=None, deps=deps)
defbackward(self):
# The root node: d(root)/d root = 1ifself.gradisNone:
self.grad = 1fordepinself.deps:
# This is the first time the dependency is seen.# It does not have any gradients accumulated.ifdep.gradisNone:
dep.grad = 0# accumulating edges = last total + gradient of current edgedep.grad = dep.grad + (self.grad * 1)
dep.backward()
Let’s try it out:
1
2
3
4
5
6
7
8
9
10
11
12
13
w = Value(2)
x = Value(1)
y = w + xz = y + x# z = (w+x) + x = w + 2x# dz/dx = 2, dz/dy = 1, dz/dw = 1print("z", z)
z.backward()
print(f"dz/dz: {z.grad}")
print(f"dz/dy: {y.grad}")
print(f"dz/dx: {x.grad}")
print(f"dz/dw: {w.grad}")
Adding operations
Looking closely at the gradient accumulation line, let’s think from the perspective of y:
1
dep.grad = dep.grad + (self.grad * 1)
The self.grad * 1 is the chain rule in action. Where 1 is the gradient of an addition operation. y = x + w, dy/dx = 1. We’re multiplying that by whatever gradients are coming down from top. In this case, dz/dy. Therefore, the gradient coming to x (the dep in self.deps) from z through y (self) is dz/dy * dy/dx i.e. self.grad * 1.
graph LR
x --"add"--> y
y --"add"--> z
If instead the operation from x to y were multiplication, the gradient would be different. Therefore, we need to track the operation that relates a node to its dependencies. Such that, the new gradient update line should become:
classValue:
def__init__(self, data, grad=None, deps=(), partials=()):
self.data = dataself.grad = gradself.deps = depsself.partials = partialsdef__repr__(self): returnf"V({self.data}, {self.grad})"def__add__(self, other):
data = self.data + other.datadeps = (self, other)
partials = (
lambdaself, dep, other_dep: 1,
lambdaself, dep, other_dep: 1 )
returnValue(data=data, grad=None, deps=deps, partials=partials)
def__mul__(self, other):
data = self.data * other.datadeps = (self, other)
partials = (
lambdaself, dep, other_dep: other_dep.data,
lambdaself, dep, other_dep: other_dep.data )
returnValue(data=data, grad=None, deps=deps, partials=partials)
defbackward(self):
# The root node: d(root)/d root = 1ifself.gradisNone:
self.grad = 1ifnotself.deps:
return# root node reacheddep, *other = self.depspartial, *other_partial = self.partialsifother:
other = other[0]
other_partial = other_partial[0]
# This is the first time the dependency is seen.# It does not have any gradients accumulated.ifdep.gradisNone:
dep.grad = 0# accumulating edges = last total + gradient of current edgedep.grad = dep.grad + (self.grad * partial(self, dep, other))
dep.backward()
ifother:
# This is the first time the dependency is seen.# It does not have any gradients accumulated.ifother.gradisNone:
other.grad = 0# accumulating edges = last total + gradient of current edgeother.grad = other.grad + (self.grad * other_partial(self, other, dep))
other.backward()
Let’s try it out:
1
2
3
4
5
6
7
8
9
10
11
12
13
w = Value(2)
x = Value(1)
y = w * xz = y + x# z = (w*x) + x = wx + x# dz/dx = w+1, dz/dy = 1, dz/dw = xprint("z", z.data)
z.backward()
print(f"dz/dz: {z.grad}")
print(f"dz/dy: {y.grad}")
print(f"dz/dx: {x.grad}")
print(f"dz/dw: {w.grad}")
We can clean this up a bit. Creating a class-level attribute to store partial derivative functions for supported operations. There are three kinds of operations:
Symmetric binary. The partial function is the same for both dependencies. For example z = x+y; dz/dx = dz/dy = 1. Same for multiplication; z = x*y; dz/dx=y, dz/dy=x i.e. dz/d (one dependency) = other dependency.
Unary operations. The operation only has one dependency. For example $e^x, \sin(x)$.
Unsymmetric binary. The partial function is different for each dependency. For example $z=x^y$.
importmathclassValue:
_partials = {
'+': lambdaself, dep, other: 1,
'*': lambdaself, dep, other: other.data,
'_^': lambdaself, dep, other: other.data * dep.data ** (other.data - 1),
'^_': lambdaself, dep, other: self.data * math.log(other.data),
'e': lambdaself, dep, other: self.data,
}
def__init__(self, data, grad=None, deps=(), partials=()):
self.data = dataself.grad = gradself.deps = depsself.partials = partialsdef__repr__(self): returnf"V({self.data}, {self.grad})"def__add__(self, other):
data = self.data + other.datadeps = (self, other)
partials = (self._partials['+'], self._partials['+'])
returnValue(data=data, grad=None, deps=deps, partials=partials)
def__mul__(self, other):
data = self.data * other.datadeps = (self, other)
partials = (self._partials['*'], self._partials['*'])
returnValue(data=data, grad=None, deps=deps, partials=partials)
def__pow__(self, other):
data = self.data**other.datadeps = (self, other)
partials = (self._partials['_^'], self._partials['^_'])
returnValue(data=data, grad=None, deps=deps, partials=partials)
defbackward(self):
# The root node: d(root)/d root = 1ifself.gradisNone:
self.grad = 1ifnotself.deps:
return# root node reacheddep, *other = self.depspartial, *other_partial = self.partialsifother:
other = other[0]
other_partial = other_partial[0]
# This is the first time the dependency is seen.# It does not have any gradients accumulated.ifdep.gradisNone:
dep.grad = 0# accumulating edges = last total + gradient of current edgedep.grad = dep.grad + (self.grad * partial(self, dep, other))
dep.backward()
ifother_partialisnotNone:
# This is the first time the dependency is seen.# It does not have any gradients accumulated.ifother.gradisNone:
other.grad = 0# accumulating edges = last total + gradient of current edgeother.grad = other.grad + (self.grad * other_partial(self, other, dep))
other.backward()
1
2
3
4
5
6
7
8
9
10
11
12
13
w = Value(2)
x = Value(2)
y = w * xz = y + x**Value(2)
# z = (w*x) + x**2 = wx + x# dz/dx = w+2x, dz/dy = 1, dz/dw = xprint("z", z.data)
z.backward()
print(f"dz/dz: {z.grad}")
print(f"dz/dy: {y.grad}")
print(f"dz/dx: {x.grad}")
print(f"dz/dw: {w.grad}")
Homework: how to take higher order derivatives? Grads of grads? Hint: the grad attribute is calculated by the same operations as the original computation graph: add/multiply/exp. Then, it can be wrapped in a Value too, right? If so, we can call backward() on it as well.