从零搭建深度学习框架(二)用Python实现计算图和自动微分

我们在上一篇文章《从零搭建深度学习框架(一)用NumPy实现GAN》中用Python+NumPy实现了一个简单的GAN模型,并大致设想了一下深度学习框架需要实现的主要功能。其中,不确定性最大的要属于计算图的实现。所以在这篇文章中,我们用Python实现一个简单的计算图,并用它对一个线性模型进行自动微分,作为后续C++开发的思路验证。


计算图的设计

我们先在纸上设计一下如何用计算图执行一个简单的y=w*x+b形式的前向和反向计算,然后以此确定计算图的基本规则。

线性模型的构造

为了便于实现,我们假定需要解决一个线性拟合问题:

样本采样自黑色的直线y=-x+3,我们的线性模型初始化为y=x+1(灰色直线)。以L2距离作为误差函数,模型训练的过程如下:

如果用一个有向无循环图(DAG)把这个过程表示出来,就是:

上方黑色线表示前向计算的过程,下方灰色线表示反向计算的梯度传播过程,被传播的梯度值标记在了箭头旁边。

计算图的构造规则

这样我们就可以暂时确定下来计算图的构造规则了:

  1. 计算图是有向无环的。因为每个节点都需要知道它的输入和输出节点来执行梯度的反向传播。
  2. 计算图包含两种节点,数值节点和计算节点。其中数值节点负责保存数据和梯度,计算节点负责执行数学计算和自动微分。
  3. 数值节点可以没有输入或输出(图的根节点和叶子节点),也可以有多个输入和输出,但其输入和输出必须为计算节点。因为数值之间的运算关系需要用计算节点来描述,数据从一个数值节点直接传给另一个数值节点是没有意义的。
  4. 计算节点必须有输入和输出,其中有且只有一个输出,而且其输入和输出必须为数值节点。这是为了避免在微分计算时多个输出值使求导过程复杂化。如果某一个函数的输出是后续多个不同函数的输入,我们完全可以让这个计算节点的唯一输出数值节点有多条连接到那些函数的出边。

计算图的实现

下面我们开始用Python实现一个简单的计算图。

基类节点

我们首先定义一个基类节点,便于后面继承出数值节点和计算节点。

import random
import warnings
from collections import deque
from enum import Enum


class NodeType(Enum):
    VAL = 0
    COMP = 1


class Node(object):
    def __init__(self, name: str, inputs=None, outputs=None):
        self.name = name
        self.node_type = None
        self.inputs = [] if inputs is None else inputs
        self.outputs = [] if outputs is None else outputs

    def __repr__(self):
        return "{} -> {}".format(self.inputs, self.outputs)

    def __str__(self):
        return self.__repr__()

    def forward(self, **kwargs):
        raise NotImplementedError("`forward` method is not implemented.")

    def backward(self, **kwargs):
        if 'delta' not in kwargs:
            raise ValueError("Arg `delta` must be passed to `Node` in backward pass.")
        delta = kwargs['delta']
        return delta

Nodename属性用来唯一地标识每个节点,在计算图中也会通过name来定位到唯一的节点。forward方法只给出接口,不提供默认实现,有点类似于C++里的纯虚函数。

数值节点

数值节点负责保存数据和梯度,由Node继承而来。

class ValNode(Node):
    def __init__(self, val, name: str, train=True):
        super().__init__(name=name)
        self.val = val
        self.grad = None
        self.train = train
        self.node_type=NodeType.VAL

    def __repr__(self):
        return "{}".format(self.val)

    def forward(self):
        return self.val

    def backward(self, **kwargs):
        delta = self.grad
        if delta is None:
            raise ValueError("Grad has not been calculated for Node: {}".format(self.name))
        return delta

因为梯度的计算由计算节点负责,所以我们索性也让计算节点顺便把计算得到的梯度也赋值给相应的输入数值节点(即ValNode.grad)。于是数值节点本身不需要关心太多操作,只需要在前向计算时给出val,反向计算时直接给出grad

计算节点

与数值节点不同,计算节点需要支持很多种不同的计算操作。我们先定义一个基类计算节点。

class CompNode(Node):
    def __init__(self, name: str):
        super().__init__(name=name)
        self.node_type=NodeType.COMP

然后,我们就可以定义几个我们所需要的计算操作了(加法,乘法和L2距离)。

class AddNode(CompNode):
    def __init__(self, name: str):
        super().__init__(name=name)

    def forward(self, *argv):
        rst = 0
        for node in argv:
            rst += node.forward()
        return rst

    def backward(self, **kwargs):
        if 'delta' not in kwargs:
            raise ValueError("Arg `delta` must be passed to `AddNode` in backward pass.")
        delta = kwargs['delta']
        in_num = len(self.inputs)
        grads = [delta for _ in range(in_num)]
        return grads


class MulNode(CompNode):
    def __init__(self, name: str):
        super().__init__(name=name)

    def forward(self, *argv):
        rst = 1
        for node in argv:
            rst *= node.forward()
        return rst

    def backward(self, **kwargs):
        if 'delta' not in kwargs:
            raise ValueError("Arg `delta` must be passed to `MulNode` in backward pass.")
        delta = kwargs['delta']
        if 'inputs' not in kwargs:
            raise ValueError("Arg `inputs` must be passed to `MulNode` in backward pass.")
        inputs = kwargs['inputs']
        in_num = len(inputs)
        grads = []
        for i in range(in_num):
            grad = delta
            for j in range(in_num):
                if j != i:
                    grad *= inputs[j].val
            grads.append(grad)
        return grads


class L2Node(CompNode):
    def __init__(self, name: str):
        super().__init__(name=name)

    def forward(self, predict, label):
        rst = 0.5 * (predict.forward() - label.forward()) ** 2
        return rst

    def backward(self, **kwargs):
        if 'delta' not in kwargs:
            raise ValueError("Arg `delta` must be passed to `L2Node` in backward pass.")
        delta = kwargs['delta']
        if 'inputs' not in kwargs:
            raise ValueError("Arg `inputs` must be passed to `MulNode` in backward pass.")
        inputs = kwargs['inputs']
        if len(inputs) != 2:
            raise ValueError("A L2Node must have 2 inputs.")
        predict = inputs[0]
        label = inputs[1]
        grad = delta * (predict.val - label.val)
        grads = [grad, -grad]
        return grads

计算节点的forward方法需要调用其输入的数值节点的forward来获取输入数据,在计算梯度时,必须要知道传入它的梯度(delta),可能也需要它原本的输入数值节点的值,然后把计算得到的梯度与delta乘起来,并返回对应到每个输入数值节点的最终梯度值。

由于Python不像C++那样可以很方便地操作指针,我们这里只是在Node中记录了输入和输出节点的name,但并不能直接获取这些节点的值。所以我们在backward方法中只返回梯度值,后面在节点外部的函数再把这些梯度值写入输入节点的grad

计算图的定义

定义好节点之后,我们就可以实现计算图了。

class Graph(object):
    def __init__(self, nodes=None, edges=None):
        self.nodes = {}
        self._node_idx = 0
        self.root_nodes = []
        self.leaf_nodes = []
        self.init_graph(nodes, edges)
        self.valid_graph()

    def add_node(self, node: Node):
        if node.name is None or len(node.name) == 0:
            node.name = "node_{}".format(self._node_idx)
        if node.name in self.nodes:
            raise ValueError("Duplicate node name: {}".format(node.name))
        self.nodes[node.name] = node
        self._node_idx += 1

    def add_edge(self, in_node: str, out_node: str, insert_new_node=False):
        if in_node not in self.nodes:
            if insert_new_node:
                self.add_node(Node(name=in_node))
            else:
                raise ValueError("Input node does not exist for edge: {} -> {}".format(in_node, out_node))
        if out_node not in self.nodes:
            if insert_new_node:
                self.add_node(Node(name=out_node))
            else:
                raise ValueError("Output node does not exist for edge: {} -> {}".format(in_node, out_node))
        self.nodes[in_node].outputs.append(out_node)
        self.nodes[out_node].inputs.append(in_node)

计算图的所有节点保存在Graph.nodes当中,这是一个以namekey,以Nodevaluedictadd_node用来向计算图添加新的节点,如果没有指定Node.name的话,会默认用“node_0, node_1, …”的方式来命名。add_edge用来向计算图添加新的边,如果这条边的输入或输出节点不存在的话,允许新建一个Node加到计算图中会报错提示。由于我们在后面只会用add_node来构建计算图,所以并没有在add_edge中为新建的Node提供数值/计算的选项。

图的初始化

图初始化的思路是,可以接受一个Node构成的list来初始化所有的节点,但节点之间的连通信息需要提前写在Node里;也可以接受一个由name的二元组构成的list来初始化所有的边,在建立边的过程中初始化所需的所有Node

    def init_graph(self, nodes: list, edges: list):
        self.nodes = {}
        self.root_nodes = []
        self.leaf_nodes = []
        if nodes is not None and len(nodes) > 0:
            for node in nodes:
               self.add_node(node_name=None, node=node)
        if edges is not None and len(edges) > 0:
            for edge in edges:
                nin, nout = edge
                self.add_edge(nin, nout, insert_new_node=True)

图的验证

计算图初始化好以后,还要检查一下图中是否存在环,能否从根节点成功到达叶子节点等。

环的检测。检测环用拓扑排序的方法,通过维护一个栈,栈中保存当前图中入度为0的所有节点,然后不断地删除栈顶元素,把该元素的所有输出节点的入度减1,同时把入度减到0的节点压入栈,直到栈空为止。截止时如果图中还有剩余节点,则存在环。

    @property
    def graph(self):
        _graph = {}
        for name, node in self.nodes.items():
            _graph[name] = [set(node.inputs), set(node.outputs)]
        return _graph

    def _find_root_leaf_nodes(self):
        """
        Only look for nodes without inputs or outputs.
        Does not check the validity of graph
        """
        root_nodes = []
        leaf_nodes = []
        for name, node in self.nodes.items():
            # nodes with neither inputs nor outputs are included in both collections
            if not node.inputs:
                root_nodes.append(name)
            if not node.outputs:
                leaf_nodes.append(name)
        return root_nodes, leaf_nodes

    def _check_loop(self, root_nodes=None):
        if not root_nodes:
            root_nodes, _ = self._find_root_leaf_nodes()
        stack = deque()
        for node in root_nodes:
            stack.append(node)
        graph = self.graph
        while len(stack) > 0:
            cur = stack.pop()
            outputs = graph[cur][1]
            graph.pop(cur)
            for out in outputs:
                graph[out][0].remove(cur)
                if len(graph[out][0]) == 0:
                    stack.append(out)
        return True if graph else False

接下来是完整的valid_graph的实现。在检查完环之后,我们从所有入度为0的节点(根节点)出发,遍历到所有出度为0的节点(叶子节点),检查是否存在无法到达的叶子节点;然后从这些叶子节点出发,反向遍历到所有入度为0的节点,检查是否存在无法到达的根节点,以此来检查图中各个节点的连通信息是否完整。可以注意到,在这个过程中,我们并没有对连通分量做出限制(即是否要求所有的节点两两之间都存在连通路径)。只要给出每个节点的计算和连通关系,计算图的每个节点都可以正确地执行前向和反向计算。

    def _get_leaves(self, input_nodes: list, direction='forward'):
        stack = deque(input_nodes)
        graph = self.graph
        rst = []
        if direction == 'forward':
            dir_idx = 1
        elif direction == 'backward':
            dir_idx = 0
        else:
            raise ValueError("Unrecognized arg `direction`: {}".format(direction))
        while len(stack) > 0:
            cur = stack.pop()
            outputs = graph[cur][dir_idx]
            graph.pop(cur)
            if not outputs:
                rst.append(cur)
            else:
                for out in outputs:
                    stack.append(out)
        return rst

    def valid_graph(self):
        if len(self.nodes) == 0:
            return
        root_nodes, leaf_nodes = self._find_root_leaf_nodes()
        if not root_nodes or not leaf_nodes:
            raise ValueError("A graph must have at least 1 root node and 1 leaf node.")
        if self._check_loop(root_nodes):
            raise ValueError("Loop detected in graph.")
        # forward pass
        forward_leaves = set()
        for node in root_nodes:
            leaves = self._get_leaves([node], direction='forward')
            forward_leaves = forward_leaves.union(set(leaves))
        if len(forward_leaves.difference(set(leaf_nodes))):
            # This should not happen, in fact. Just check for sure.
            raise ValueError("What the hell? `forward_leaves - leaf_nodes` should be empty.")
        if len(set(leaf_nodes).difference(forward_leaves)):
            warnings.warn("Found stranded leaf nodes. Ignore them.")
        # backward pass
        backward_leaves = set()
        for node in forward_leaves:
            leaves = self._get_leaves([node], direction='backward')
            backward_leaves = backward_leaves.union(set(leaves))
        if len(backward_leaves.difference(set(root_nodes))):
            raise ValueError("What the hell? `backward_leaves - root_nodes` should be empty.")
        if len(set(root_nodes).difference(backward_leaves)):
            warnings.warn("Found stranded root nodes. Ignore them.")
        self.root_nodes = list(backward_leaves)
        self.leaf_nodes = list(forward_leaves)

图的前向计算

这里我们设计一个类似于静态图的方式来执行整个计算图的前向计算。称这个过程为“类似于静态图”的原因是,我们并没有针对节点的执行效率进行优化,节点本身还是按照动态图的方式来执行,即在计算过程的任何一步都可以打印出节点的值。只是接口的调用方式类似于TensorFlow 1.x那样,必须要传入一个dict提供所有根节点的值才能执行图的前向计算。

    def forward(self, keywords: dict):
        root_nodes, _ = self._find_root_leaf_nodes()
        rst = {}
        q = list(root_nodes)
        for name in root_nodes:
            if self.nodes[name].val is None and name not in keywords:
                raise ValueError("Input node {} is not provided for the forward pass.".format(name))
            if name in keywords:
                node = keywords[name]
                if isinstance(node, ValNode):
                    self.nodes[name].val = node.val
                else:
                    self.nodes[name].val = node
        graph = self.graph
        while len(q):
            cur = q[0]
            ops = list(graph[cur][1])
            if len(ops) == 0:
                rst[cur] = self.nodes[cur]
                q.remove(cur)
            else:
                for op in ops:
                    process_node = True
                    for inp in graph[op][0]:
                        if inp not in q:
                            process_node = False
                            break
                    if process_node:
                        if len(graph[op][1]) != 1:
                            raise ValueError("A computational node must have 1 output: {}".format(op))
                        out = next(iter(graph[op][1]))
                        inputs = [self.nodes[inp] for inp in graph[op][0]]
                        self.nodes[out].val = self.nodes[op].forward(*inputs)
                        for inp in graph[op][0]:
                            graph[inp][1].remove(op)
                            if len(graph[inp][1]) == 0 and inp in q:
                                q.remove(inp)
                        graph[out][0].remove(op)
                        q.append(out)
        return rst

节点的前向计算

由于我们在每个Node中只记录了其输入和输出Nodename,我们还需要实现一个运行于Node外部的通用性函数来执行前向计算。

def register_op(a: ValNode, b: ValNode, op: CompNode, out_name: str, graph: Graph):
    if a.name not in graph.nodes:
        graph.add_node(a)
    if b.name not in graph.nodes:
        graph.add_node(b)
    if op.name not in graph.nodes:
        graph.add_node(op)
    graph.add_edge(a.name, op.name, False)
    graph.add_edge(b.name, op.name, False)

    rst = op.forward(a, b)
    c = ValNode(val=rst, name=out_name)

    graph.add_node(c)
    graph.add_edge(op.name, c.name, False)

    return c

在这个函数中,我们会把原本不在图中的节点添加到计算图中,然后调用计算节点的forward方法执行前向计算,然后新建一个数值节点来保存计算结果并加入到图当中。我们这里假设所有的计算节点只接受两个输入值。

节点的反向计算

同样地,我们定义一个外部函数来执行计算节点的自动微分。

def register_grad(loss: ValNode, graph: Graph):
    loss.grad = 1
    q = deque()
    q.append(loss)
    while len(q):
        cur = q[0]
        if cur.node_type == NodeType.COMP:
            # For comp node, update its input nodes' grads
            if len(cur.outputs) != 1:
                raise ValueError("A computational node must have 1 output: {}".format(cur.name))
            output = graph.nodes[cur.outputs[0]]
            if output.grad is None:
                tmp = q.popleft()
                if len(q) == 0:
                    raise ValueError("One of the incoming grad is None for Node: {}".format(cur.name))
                q.append(tmp)
                continue
            cur = q.popleft()
            inputs = [graph.nodes[in_name] for in_name in cur.inputs]
            grads = cur.backward(delta=output.grad, inputs=inputs)
            if len(grads) != len(inputs):
                raise ValueError("len(grads) != len(inputs) for Node: {}".format(cur.name))
            for inp, grad in zip(cur.inputs, grads):
                if graph.nodes[inp].grad is None:
                    graph.nodes[inp].grad = grad
                else:
                    graph.nodes[inp].grad += grad
                q.append(graph.nodes[inp])
        elif cur.node_type == NodeType.VAL:
            # For val node, push its input nodes into q
            if cur.grad is None:
                tmp = q.popleft()
                if len(q) == 0:
                    raise ValueError("The output nodes have not calculated grad for Node: {}".format(cur.name))
                q.append(tmp)
                continue
            cur = q.popleft()
            for inp in cur.inputs:
                q.append(graph.nodes[inp])
        else:
            raise ValueError("Unrecognized `node_type` in: {}".format(cur.name))
    return

优化器的实现

我们仿照PyTorch的思路,另外构造一个优化器,接受图节点为参数。图自身的反向计算只是算出梯度值,而优化器负责根据梯度值来更新图里面所有节点的值。

class Optimizer(object):
    def __init__(self, graph: Graph):
        self.graph = graph

    def step(self):
        raise NotImplementedError("`step` method is not implemented.")

    def zero_grad(self):
        for name, node in self.graph.nodes.items():
            if node.node_type == NodeType.VAL:
                node.grad = 0


class SGDOptimizer(Optimizer):
    def __init__(self, lr, graph: Graph):
        super().__init__(graph=graph)
        self.lr = lr

    def step(self):
        for name, node in self.graph.nodes.items():
            if node.node_type == NodeType.VAL and node.train:
                if node.grad is None:
                    raise ValueError("Grad has not been calculated for Node: {}".format(node.name))
                node.val -= self.lr * node.grad

与PyTorch类似,我们这里的zero_grad也是必不可少的,避免每次梯度的计算受到前一次迭代梯度的影响。

计算图实例

计算图相关的代码准备好以后,我们用几个例子验证一下计算图的功能。

DAG的构造

我们按照上面的线性模型检查一下能否成功初始化一个DAG。

edges = [("x", "multiply"), ("w", "multiply"), ("multiply", "plus"), ("b", "plus"), ("plus", "y")]
graph = Graph(edges=edges)

print("Every node in graph")
for name, node in graph.nodes.items():
    print("{}: {}".format(name, node))

print("\nRoot nodes")
print(graph.root_nodes)
print("\nLeaf nodes")
print(graph.leaf_nodes)

结果如下图所示:

 

环的检测

我们把上面的edges改为一个带环的结构:

edges = [("a", "b"), ("e", "b"), ("b", "c"), ("f", "c"), ("c", "d"), ("c", "g"), ("d", "b")]

执行时会提示存在环:

把最后一条边("d", "b")删除后再执行就不会报错了。

动态图执行

接下来我们用动态图的方式为计算图添加新节点,并支持在构造的过程中随时打印节点的值。

graph = Graph()

print("\nDynamic graph")
x = ValNode(2, name="x", train=False)
w = ValNode(1, name="w")
x1 = register_op(x, w, MulNode("mul"), "x1", graph)
print("x1: {}".format(x1))
b = ValNode(1, name="b")
y = register_op(x1, b, AddNode("add"), "y", graph)
label = ValNode(1, name="label", train=False)
loss = register_op(y, label, L2Node("l2"), "loss", graph)
print("loss: {}".format(loss))

print("\nEvery node in graph")
for name, node in graph.nodes.items():
    print("{}: {}".format(name, node))

我们这里的register_op如果用运算符重载的方式来实现,就能够达到PyTorch中执行x1=w*x来创建一个乘法计算节点和一个结果节点的效果。

梯度的计算也类似于PyTorch的过程,只不过我们用一个显示的函数调用(register_grad)来模仿PyTorch中loss.backward()的效果:

print("\nBack-propagation")
register_grad(loss, graph)

optim = SGDOptimizer(lr=0.1, graph=graph)
optim.step()

print("\nUpdated parameters")
print("w: {}".format(graph.nodes["w"]))
print("b: {}".format(graph.nodes["b"]))

结果如下图所示:

可见,两个参数经过一次迭代已经得到了更新。我们可以手动计算对结果进行验证:

b的验证过程类似,就不写出来了。

静态图执行

我们用“静态图”接口来对这个线性模型进行迭代训练。

print("\nStatic graph")
iter_num = 20
for itr in range(iter_num):
    data = random.uniform(-5, 5)
    target = -data + 3
    rst = graph.forward({"x": data, "label": target})

    optim.zero_grad()
    register_grad(loss, graph)
    optim.step()

    print("Iter {}: loss {:.4f}, w {:.4f}, b {:.4f}".format(itr, rst["loss"].val, graph.nodes["w"].val, graph.nodes["b"].val))

结果如下图所示:

可见,训练得到的参数wb已经十分接近我们的采样函数(y=-x+3)了。由于训练过程本身有随机性,所以每次训练的结果都会不同。但总体来说,较小的步长+更多的迭代次数,或者较大的采样区间(比如random.uniform(-5, 5)相对于random.uniform(-3, 3))可以提高收敛效果。


我们已经用Python验证了实现计算图的大体思路,接下来就可以开始正式的C++开发工作了。预计下一篇文章会大致给出C++项目的结构,另外可能会给出底层Tensor的初步实现。

把这篇文章分享给你的朋友:
Subscribe
订阅评论
guest
0 评论
Inline Feedbacks
View all comments