以深度学习为例理解22种设计模式(二)结构型模式

本系列包括3篇文章,以深度学习的模型搭建和训练过程为例,解释面向对象编程中22种设计模式的基本原理,并给出C++实现。

这些设计模式的实现方法大多参考自《深入设计模式》。为了用尽可能少的代码体现这些设计模式的过程,在代码中直接对指针进行操作,并没有使用智能指针。

该篇文章介绍结构型模式,包括适配器、桥接、组合、装饰、外观、享元以及代理。

本文的所有代码都在我的GitHub上:github.com/johnhany/design_patterns


创建型模式

请见《以深度学习为例理解22种设计模式(一)创建型模式》


结构型模式

结构型模式(Structural Pattern)主要关心如何设计类的接口,以提高类结构的灵活性。

适配器

适配器(Adapter)模式适用于使现有代码调用不兼容的类,比如遗留代码、第三方库等。

每当需要在祖传老代码上增加新功能,或者在不修改对方源码的前提下调用第三方库时,适配器模式是十分有效的。比如,我们已经实现了一个读取PyTorch模型的工具,现在需要用这个工具来读取一个用TensorFlow训练得到的模型。我们知道TensorFlow默认的数据格式是“N, H, W, C”,而PyTorch默认的数据格式是“N, C, H, W”。如果直接用现有工具调用TensorFlow的模型文件,整个网络的参数都要乱套了。所以,在不更改模型文件和现有工具的前提下,我们就需要设计一个适配器来实现数据格式的转换。

首先,我们定义一个PyTorch类来模拟产生PyTorch格式的模型:

class PyTorch {
public:
    virtual ~PyTorch() = default;

    virtual string getParameters() const {
        return "NCHW";
    }
};

然后,定义一个独立的ModelLoader类来读取这个模型:

class ModelLoader {
public:
    virtual ~ModelLoader() = default;

    void loadTrainedModel(const string& param) const {
        if (param == "NCHW")
            cout << "I can run this model correctly." << endl;
        else
            cout << "I cannot recognize this model." << endl;
    }
};

我们在ModelLoader::loadTrainedModel()当中加了一个限制,即只识别PyTorch格式的模型。

我们再定义一个TensorFlow类来产生TensorFlow格式的模型:

class TensorFlow {
public:
    virtual ~TensorFlow() = default;

    virtual string getTFParameters() const {
        return "NHWC";
    }
};

需要注意的是,我们的TensorFlow类产生模型的接口名称(getTFParameters)与PyTorch也是不一样的(getParameters)。这说明如果不借助适配器的话,我们每一处读取模型的代码都要针对TensorFlow的接口做特殊处理。这样无疑会增加现有代码的复杂程度。

模型适配器的定义如下:

class Adapter : public PyTorch {
private:
    TensorFlow* tensorflow;
public:
    Adapter(TensorFlow* tf) : tensorflow(tf) {}

    string getParameters() const override {
        string param = this->tensorflow->getTFParameters();
        std::swap(param[2], param[3]);
        std::swap(param[1], param[2]);
        return param;
    }
};

我们在与PyTorch名称相同的接口(getParameters)内部实现数据格式的转换。

于是,我们现有的代码只需要按照PyTorch的接口来读取模型就可以了:

void load_model(const PyTorch* model) {
    string param = model->getParameters();
    ModelLoader* loader = new ModelLoader();
    loader->loadTrainedModel(param);
}

最后,给出main()函数:

int main() {
    cout << "Loading a PyTorch model..." << endl;
    PyTorch* pytorch = new PyTorch;
    load_model(pytorch);

    cout << "Loading a TensorFlow model..." << endl;
    TensorFlow* tensorflow = new TensorFlow();
    //load_model(TensorFlow);   // won't compile

    cout << "Loading a TensorFlow model using Adapter..." << endl;
    Adapter* adapter = new Adapter(tensorflow);
    load_model(adapter);

    delete pytorch;
    delete tensorflow;
    delete adapter;

    return 0;
}

其中,load_model(TensorFlow);那一行编译时会报错,以模拟现有代码在不做更改的前提下,无法直接处理TensorFlow模型的情景。但是通过适配器,我们就可以正常读取TensorFlow模型了。

运行效果如下:

Loading a PyTorch model...
I can run this model correctly.
Loading a TensorFlow model...
Loading a TensorFlow model using Adapter...
I can run this model correctly.

桥接

桥接(Bridge)模式适用于需要按照若干维度的多个属性扩展现有类,或者在运行时切换不同实现方式的情景。典型的桥接模式把原问题拆分为“抽象”和“实现”两个结构,“实现”负责提供底层的、具体的单元操作(比如跨平台的底层驱动),“抽象”负责在实现的基础之上执行高阶的操作(比如程序的GUI层)。桥接通过不同“实现”和“抽象”的组合实现更丰富的行为。

比如,我们现在想做许多组CNN的图像分类实验。可用的数据集有Cifar10和ImageNet。模型除了一个基本的CNN之外,我们还想试一下网络结构搜索(NAS)的效果,可选的方法有强化学习(RL)和遗传策略(ES)。于是,我们现在需要做的实验有6组:Cifar10+CNN,ImageNet+CNN,Cifar10+RL,ImageNet+RL,Cifar10+ES,ImageNet+ES。如果每一组实验都用一个子类来描述的话,代码量还是比较大的。尤其是,如果未来需要增加新的数据集或模型时,需要增加的子类数量是呈倍数增长的。

如果用桥接模式来设计类的结构就简单多了。因为我们要做的6组实验可以按照两个维度分割开:数据集(Cifar10, ImageNet)和模型(CNN, RL, ES)。我们把数据集视为“实现”层,即我们的实验需要具体操作的数据对象;把模型视为“抽象”层,即需要用模型调用数据集来达到完成一组实验的目的。这样,我们只需要提供2个数据集子类和3个模型子类,就能够进行6组实验了。

我们先定义一个数据集的基类,并继承出两个子类:

class DatasetImplementation {
public:
    virtual ~DatasetImplementation() {}

    virtual string dataPreprocess() const = 0;
};

class Cifar10 : public DatasetImplementation {
public:
    string dataPreprocess() const override {
        return "Cifar10";
    }
};

class ImageNet : public DatasetImplementation {
public:
    string dataPreprocess() const override {
        return "ImageNet";
    }
};

然后再定义一个模型的基类,并继承出两个NAS子类:

class ModelAbstraction {
protected:
    DatasetImplementation* dataset;
public:
    ModelAbstraction(DatasetImplementation* impl) : dataset(impl) {}
    virtual ~ModelAbstraction() {}

    virtual void train() const {
        cout << "Train base model on " + this->dataset->dataPreprocess() << endl;
    }
};

class RLSearchModel : public ModelAbstraction {
public:
    RLSearchModel(DatasetImplementation* impl) : ModelAbstraction(impl) {}

    void train() const override {
        cout << "Apply Reinforcement Learning search on " + this->dataset->dataPreprocess() << endl;
    }
};

class ESSearchModel : public ModelAbstraction {
public:
    ESSearchModel(DatasetImplementation* impl) : ModelAbstraction(impl) {}

    void train() const override {
        cout << "Apply Evolution Strategy search on " + this->dataset->dataPreprocess() << endl;
    }
};

这里为了简便,CNN的训练直接由基类完成,而两种NAS的训练则由子类完成。

于是,执行实验的函数就非常简洁了:

void train_api(const ModelAbstraction& model) {
    model.train();
}

最后,我们在main()里面依次执行每组实验:

int main() {
    DatasetImplementation* dataset = new Cifar10();
    ModelAbstraction* model = new ModelAbstraction(dataset);
    train_api(*model);
    delete dataset;
    delete model;

    dataset = new ImageNet();
    model = new ModelAbstraction(dataset);
    train_api(*model);
    delete dataset;
    delete model;

    dataset = new Cifar10();
    model = new RLSearchModel(dataset);
    train_api(*model);
    delete dataset;
    delete model;

    dataset = new ImageNet();
    model = new RLSearchModel(dataset);
    train_api(*model);
    delete dataset;
    delete model;

    dataset = new Cifar10();
    model = new ESSearchModel(dataset);
    train_api(*model);
    delete dataset;
    delete model;

    dataset = new ImageNet();
    model = new ESSearchModel(dataset);
    train_api(*model);
    delete dataset;
    delete model;

    return 0;
}

运行结果如下:

Train base model on Cifar10
Train base model on ImageNet
Apply Reinforcement Learning search on Cifar10
Apply Reinforcement Learning search on ImageNet
Apply Evolution Strategy search on Cifar10
Apply Evolution Strategy search on ImageNet

组合

组合(Composite)模式适用于需要以树状结构表示对象之间关系的情景。树由两种元素构成:简单叶子元素和容器元素,其中容器元素必须能够包含叶子元素和其他容器元素。所有元素共用相同的接口,容器的接口只需要关心如何把其元素返回的结果合并起来,具体的接口细节由叶子元素实现。

比如,我们现在已经训练好了一个弱分类器,但是单独用弱分类器来处理实际问题的效果总是不太好。于是,我们想利用现有的弱分类器实现一个集成分类器,甚至通过级联多个集成分类器的方式来构造一个庞大的分类器。这个大分类器的构造就可以用组合模式来设计。

首先,我们定义一个分类器基类,来声明所有的接口:

class Classifier {
protected:
    Classifier* parent;
public:
    virtual ~Classifier() { cout << "Classifier destroyed" << endl; }

    void setParent(Classifier* parent) {
        this->parent = parent;
    }
    Classifier* getParent() const {
        return this->parent;
    }

    virtual void add(Classifier* child) {}
    virtual void remove(Classifier* child) {}

    virtual bool isEnsemble() const {
        return false;
    }
    virtual string predict() const = 0;
};

然后定义弱分类器,即树结构的叶子节点:

std::default_random_engine generator;
std::uniform_int_distribution<int> distribution(0, 1);

class LeafClassifier : public Classifier {
public:
    string predict() const override {
        int dice_roll = distribution(generator);
        return to_string(dice_roll);
    }
};

因为在组合模式中只有叶子节点会进行具体的操作,所以弱分类器会放在叶子上,而且只有叶子会给出原始的预测值。

接着定义集成分类器,即树结构的容器节点:

class EnsembleClassifier : public Classifier {
protected:
    list<Classifier*> children;
public:
    ~EnsembleClassifier() override {
        for (auto& i : children)
            delete i;
        cout << "  Ensemble destroyed" << endl;
    }
    void add(Classifier* child) override {
        this->children.push_back(child);
        child->setParent(this);
    }
    void remove(Classifier* child) override {
        this->children.remove(child);
        child->setParent(nullptr);
    }
    bool isEnsemble() const override {
        return true;
    }
    string predict() const override {
        string rst = "[";
        for (auto& i : children) {
            rst += i->predict();
        }
        rst += "]";
        return rst;
    }
};

容器节点只负责把被其包含的元素所给出的结果进行汇总,然后给出最终的结果。

于是,预测函数只接受一个Classifier指针,而不需要关心这个指针指向的模型到底具有什么样的结构:

void predict_api(Classifier* classifier) {
    cout << "Predict result: " << classifier->predict() << endl;
}

最后给出main()函数:

int main() {
    cout << "Using a single classifier" << endl;
    Classifier* single = new LeafClassifier();

    predict_api(single);
    delete single;

    cout << "Using an ensemble classifier" << endl;
    Classifier* ensemble = new EnsembleClassifier();
    ensemble->add(new LeafClassifier());
    ensemble->add(new LeafClassifier());
    ensemble->add(new LeafClassifier());

    predict_api(ensemble);
    delete ensemble;

    cout << "Using a stacked-ensemble classifier" << endl;
    Classifier* root = new EnsembleClassifier();
    Classifier* sub1 = new EnsembleClassifier();
    root->add(sub1);
    sub1->add(new LeafClassifier());
    Classifier* sub2 = new EnsembleClassifier();
    root->add(sub2);
    sub2->add(new LeafClassifier());
    sub2->add(new LeafClassifier());

    predict_api(root);
    delete root;

    return 0;
}

运行结果如下:

Using a single classifier
Predict result: 0
Classifier destroyed
Using an ensemble classifier
Predict result: [010]
Classifier destroyed
Classifier destroyed
Classifier destroyed
  Ensemble destroyed
Classifier destroyed
Using a stacked-ensemble classifier
Predict result: [[1][00]]
Classifier destroyed
  Ensemble destroyed
Classifier destroyed
Classifier destroyed
Classifier destroyed
  Ensemble destroyed
Classifier destroyed
  Ensemble destroyed
Classifier destroyed

装饰

装饰(Decorator)模式适用于在不修改原始接口的前提下增加新功能的情景,尤其是原始接口被声明为final时。

比如,我们辛辛苦苦实现了一个模型,为了避免使用我们代码的人通过继承的方式随意更改代码的行为,于是把公共的接口声明为final,来传达一个“我的接口已经能够满足你们需求了”的信息。但是,我们后来发现有一些特殊的需求没有考虑到,比如有时候我们需要把在ImageNet下训练好的模型的前几层参数固定,然后针对新数据集只训练最后一层,来快速地把已有的分类模型运用在新问题上。而新问题可以是分类问题,也可以是目标检测问题。如果是目标检测问题的话,我们还要在网络后面再加一些层来进行微调训练。但是我们又不想改变已经写好的final接口,这时装饰模式就派上用场了。

首先,定义一个模型基类:

class Network {
protected:
    vector<bool> layers;
public:
    Network() {}
    Network(int depth) {
        while (depth--)
            layers.push_back(true);
    }
    virtual ~Network() { cout << "Network destroyed" << endl; }
    virtual string train() const = 0;
};

然后继承出一个可以训练的模型,并给出train()接口的实现:

class TrainableNetwork : public Network {
public:
    TrainableNetwork(int depth) : Network(depth) {}

    string train() const final {
        string rst = "";
        for (bool i : layers)
            rst += to_string(i);
        return rst;
    }
};

请注意,这里的train()被声明为final,所以我们无法用继承的方式改变train()的行为。比如,下面的代码是无法编译的:

class FineTunedNetwork : public TrainableNetwork {
public:
    string train() const {}
};

这时,就需要定义一个装饰器:

class Decorator : public Network {
protected:
    Network* network;
public:
    Decorator(Network* net) : network(net) {}

    string train() const override {
        return this->network->train();
    }
};

装饰器可以继承出子类:

class FineTuned : public Decorator {
public:
    FineTuned(Network* net): Decorator(net) {}

    string train() const override {
        string base = this->network->train();
        int depth = base.size();
        for (int i = 0; i < depth-1; i++)
            base[i] = '0';
        base[depth-1] = '1';
        return base;
    }
};

我们这里指定最后一层是可训练的(值为1),而前几层都是不可训练的(值为0)。

装饰器的子类还可以继续继承:

class ObjectDetection : public FineTuned {
public:
    ObjectDetection(Network* net): FineTuned(net) {}

    string train() const override {
        string base = this->network->train();
        base.push_back('1');
        base.push_back('1');
        return base;
    }
};

我们这里在前面微调模型的基础上,增加了两层可训练的参数(值为1)。

模型训练的函数不需要做特别的处理也能调用这几种模型:

void train_api(Network* network) {
    cout << "Layers that will be trained: " << network->train() << endl;
}

最后是main()函数:

int main() {
    cout << "Train a base model" << endl;
    Network* base_model = new TrainableNetwork(3);
    train_api(base_model);

    cout << "Fine-tune the base model" << endl;
    Network* fine_tuned = new FineTuned(base_model);
    train_api(fine_tuned);

    cout << "Train an object detection model based on fine-tuned" << endl;
    Network* obj_det = new ObjectDetection(fine_tuned);
    train_api(obj_det);

    delete base_model;
    delete fine_tuned;
    delete obj_det;

    return 0;
}

运行结果如下:

Train a base model
Layers that will be trained: 111
Fine-tune the base model
Layers that will be trained: 001
Train an object detection model based on fine-tuned
Layers that will be trained: 00111
Network destroyed
Network destroyed
Network destroyed

外观

外观(Facade)模式适用于需要用一个简洁的接口来协调多个类的复杂行为的情景。其实许多人很可能在不了解外观模式的情况下就已经在用了,因为外观模式的含义就是在一个类的接口内调用很多个其他类的接口来完成一个复杂的任务。

比如,在搭建深度学习工具时,我们希望用一个train()接口来执行数据集的初始化、模型的初始化、模型的迭代训练和模型的测试这一系列动作。这样,我们就可以在外部函数直接调用一次train()就能完成整个模型的训练了。

我们首先定义数据集类和模型类:

class Dataset {
public:
    void init() const {
        cout << "Dataset initialized" << endl;
    }

    string getBatch() const {
        return "Here's your data mini-batch";
    }
};

class Network {
public:
    void init() const {
        cout << "Network initialized" << endl;
    }

    string optimize() const {
        return "Network optimized for 1 step";
    }

    string predict() const {
        return "Network predicts some results";
    }
};

然后定义一个外观类,负责调用数据集和模型,使得通过一个简单的接口就能完成整个模型的训练过程:

class TrainerFacade {
protected:
    Dataset* dataset;
    Network* network;
public:
    TrainerFacade(Dataset* data = nullptr, Network* net = nullptr) {
        this->dataset = data ?: new Dataset();
        this->network = net ?: new Network();
    }
    ~TrainerFacade() {
        delete dataset;
        delete network;
    }

    void train() {
        dataset->init();
        network->init();

        for (int i = 0; i < 3; i++) {
            cout << "Iteration " << i << endl;
            string batch = dataset->getBatch();
            cout << batch << endl;
            string logs = network->optimize();
            cout << logs << endl;
        }

        string predict = network->predict();
        cout << predict << endl;
    }
};

这样,我们只需要调用一次train()就能训练模型了:

void train_api(TrainerFacade* trainer) {
    trainer->train();
}

最后是main()函数:

int main() {
    Dataset* dataset = new Dataset();
    Network* network = new Network();

    TrainerFacade* trainer = new TrainerFacade(dataset, network);
    train_api(trainer);

    delete trainer;
    return 0;
}

运行结果如下所示:

Dataset initialized
Network initialized
Iteration 0
Here's your data mini-batch
Network optimized for 1 step
Iteration 1
Here's your data mini-batch
Network optimized for 1 step
Iteration 2
Here's your data mini-batch
Network optimized for 1 step
Network predicts some results

在使用外观模式时要注意避免让外观类和很多个类过于耦合,造成后期代码修改和维护上的困难。

享元

享元(Flyweight)模式适用于需要在有限内存保留大量相似对象的情景,尤其是对象之间存在重复的属性时。

享元的一个典型的应用场景是游戏开发。比如游戏场景里需要很多汽车,相同外形的轿车可能会有不同的颜色,而每一辆车又有各自不同的位置坐标。如果在每一辆车的实例中都完整地保存其模型信息、颜色信息和位置信息的话,越多的车辆势必消耗更多的内存空间。而且由于每辆轿车所记录的模型信息是完全相同的(相同的形状),说明有很多内存被重复的数据平白无故地占用了。这时,我们可以用一个共享的实例来保存模型信息,然后每辆车只需要保存各自的位置信息,然后引用共享的模型实例(或者在二者中间再加一层保存颜色信息的实例,就可以允许很多辆车共享同一种颜色)。这样就可以节约非常可观的内存空间了。

在深度学习中,我们可以设想这样一种场景:我们已经训练好了一个通用性很强的CNN,而且在这个CNN的基础之上设计了几种用来解决其他问题的模型(比如在前面装饰模式中的那样的微调模型和目标检测模型)。现在我们想把这几种模型同时运用在一个复杂系统里,但是机器的内存不足以同时容纳这几个模型。这时,我们可以用享元模式来设计模型的结构,即让几个模型共享相同的参数集合,以减小内存的消耗。

首先,我们定义两个结构体,用来表示可以共享的CNN隐层和在各模型内都不相同的微调层:

struct BackboneNetwork {
    vector<string> layers;

    BackboneNetwork(std::initializer_list<string> names) {
        for (const string& s : names)
            layers.push_back(s);
    }

    string train() const {
        int n = layers.size();
        string output = "";
        for (int i = 0; i < n; i++) {
            output += layers[i];
            if (i != n-1)
                output += "+";
        }
        return output;
    }
};


struct FineTunePart {
    string linear_layer;

    FineTunePart(string name) : linear_layer(name) {}
};

然后定义一个目标检测的模型类:

class ObjectDetection {
private:
    BackboneNetwork* backbone;
public:
    ObjectDetection(const BackboneNetwork* network) : backbone(new BackboneNetwork(*network)) {}
    ObjectDetection(const ObjectDetection& other) : backbone(new BackboneNetwork(*other.backbone)) {}
    ~ObjectDetection() {
        delete backbone;
    }
    BackboneNetwork* getBackbone() const {
        return backbone;
    }
    string train(const FineTunePart& fine_tune) const {
        return this->backbone->train() + "+" + fine_tune.linear_layer;
    }
};

这里,我们希望不同的目标检测模型都指向同一个BackboneNetwork,但各自有其唯一的FineTunePart。这两部分合起来,用来完成目标检测任务。

另外,我们还需要定义一个FlyweightFactory,负责维护被共享的BackboneNetwork

class FlyweightFactory {
private:
    unordered_map<string, ObjectDetection> flyweights;
    string getKey(const BackboneNetwork& backbone) const {
        return backbone.train();
    }
public:
    FlyweightFactory(std::initializer_list<BackboneNetwork> backbones) {
        for (const BackboneNetwork& bb : backbones)
            this->flyweights.insert(make_pair(this->getKey(bb), ObjectDetection(&bb)));
    }

    ObjectDetection getFlyweight(const BackboneNetwork& backbone) {
        string key = this->getKey(backbone);
        if (this->flyweights.find(key) == this->flyweights.end()) {
            this->flyweights.insert(make_pair(key, ObjectDetection(&backbone)));
        }
        return this->flyweights.at(key);
    }

    void printAllFlyweights() const {
        cout << "All Flyweights:" << endl;
        for (auto& itm : this->flyweights)
            cout << itm.first << endl;
    }
};

于是,我们就可以用下面的方式调用整个网络:

void train_api(FlyweightFactory& factory, std::initializer_list<string> base_layers, string final_layer) {
    const ObjectDetection& obj_det = factory.getFlyweight(base_layers);
    string rst = obj_det.train({final_layer});
    cout << "Fine-tuning a network:" << endl;
    cout << rst << endl;
}

最后给出main()函数:

int main()
{
    FlyweightFactory *factory = new FlyweightFactory({{"A1", "A2", "A3"}, {"B1", "B2", "B3"}});
    factory->printAllFlyweights();

    train_api(*factory, {"A1", "A2", "A3"}, "X1");
    train_api(*factory, {"A1", "A2", "A3"}, "Y1");
    train_api(*factory, {"B1", "B2", "B3"}, "X1");
    train_api(*factory, {"C1", "C2", "C3"}, "X1");

    factory->printAllFlyweights();

    delete factory;

    return 0;
}

FlyweightFactory内部保证每种BackboneNetwork只会被保存一份。代码运行的结果如下:

All Flyweights:
B1+B2+B3
A1+A2+A3
Fine-tuning a network:
A1+A2+A3+X1
Fine-tuning a network:
A1+A2+A3+Y1
Fine-tuning a network:
B1+B2+B3+X1
Fine-tuning a network:
C1+C2+C3+X1
All Flyweights:
C1+C2+C3
A1+A2+A3
B1+B2+B3

代理

代理(Proxy)模式适用于允许延迟调用一个十分耗费资源的复杂对象的情景。代理可以在不修改服务对象的前提下,增加缓存、日志等功能。

比如,我们的实验室突然有钱了,可以购买几十台带GPU的机器,组成一个大规模的并行计算网络。此后在训练新模型时,可以设置一个非常大的batch size,并把一个batch的样本均匀地分配给每台机器。在每台机器完成一次迭代训练之后,把各自的梯度回传给一台主机。然后在这台主机内把所有的梯度合并起来,就能够知道应该如何更新网络的参数了。但是实际运行时,我们发现连接每台机器的网络带宽不足以支撑如此大的数据吞吐量,使得这个并行计算平台并不能发挥其威力。于是,我们设计了新的并行训练策略,就是让每台机器先各自保留几次迭代的梯度之后再回传给主机,而且不同的从机的回传时间也要错开,这样就减小了整体的带宽消耗。这时,我们可以利用代理模式来管理每台机器的梯度回传行为。具体地说,就是每台机器仍然会在每次迭代训练结束时把梯度传给代理,但传给代理的梯度究竟会在未来的什么时刻交给主机是由代理决定的。

首先,我们定义一个计算节点的基类:

class DeviceNode {
public:
    virtual void pushGrad(int grad) = 0;
};

然后继承出表示真实机器的节点:

class RealDeviceNode : public DeviceNode {
public:
    void pushGrad(int grad) override {
        cout << "Pushing grad value " + to_string(grad) + " from a real device node" << endl;
    }
};

每调用一次pushGrad()就表示梯度从真机传给了主机,这个过程是消耗带宽资源的。

我们还需要设计一个代理类:

class ProxyDeviceNode : public DeviceNode {
private:
    RealDeviceNode* real_device;
    int accum_grad = 0;

    bool isItTimeToPush() const {
        if (accum_grad < 3) {
            cout << "  Not yet..." << endl;
            return false;
        } else {
            cout << "  Okay, it's time" << endl;
            return true;
        }
    }

    void printLog() const {
        cout << "  Printing some logs" << endl;
    }
public:
    ProxyDeviceNode(RealDeviceNode* device) : real_device(new RealDeviceNode(*device)) {}
    ~ProxyDeviceNode() {
        delete real_device;
    }

    void pushGrad(int grad) override {
        accum_grad += grad;
        if (this->isItTimeToPush()) {
            this->real_device->pushGrad(accum_grad);
            accum_grad = 0;
        }
        this->printLog();
    }
};

代理节点提供了缓存功能和日志打印功能,它的pushGrad()会按照提前给定的策略决定是把梯度放入缓存还是调用真机的接口传给主机。

我们模拟一个需要6次迭代的训练过程:

void train_api(DeviceNode& device) {
    for (int i = 0; i < 6; i++) {
        device.pushGrad(1);
    }
}

最后给出main()函数:

int main() {
    cout << "Executing grad push with a real device:" << endl;
    RealDeviceNode* real_device = new RealDeviceNode();
    train_api(*real_device);

    cout << "Executing grad push via a proxy device:" << endl;
    ProxyDeviceNode* proxy_device = new ProxyDeviceNode(real_device);
    train_api(*proxy_device);

    delete real_device;
    delete proxy_device;
    return 0;
}

运行结果如下:

Executing grad push with a real device:
Pushing grad value 1 from a real device node
Pushing grad value 1 from a real device node
Pushing grad value 1 from a real device node
Pushing grad value 1 from a real device node
Pushing grad value 1 from a real device node
Pushing grad value 1 from a real device node
Executing grad push via a proxy device:
  Not yet...
  Printing some logs
  Not yet...
  Printing some logs
  Okay, it's time
Pushing grad value 3 from a real device node
  Printing some logs
  Not yet...
  Printing some logs
  Not yet...
  Printing some logs
  Okay, it's time
Pushing grad value 3 from a real device node
  Printing some logs

可见,通过直接调用真机接口的方式,6次迭代回传了6次梯度;而通过代理的缓存方式,6次迭代只回传了2次梯度。这样我们就把带宽的占用减少了2/3。


行为模式

请见《以深度学习为例理解22种设计模式(三)行为模式》


本文的所有代码都在我的GitHub上:github.com/johnhany/design_patterns

把这篇文章分享给你的朋友:
Subscribe
订阅评论
guest
2 评论
最新
最旧 得票最多
Inline Feedbacks
View all comments
trackback
6 月 之前

[…] 请见《以深度学习为例理解22种设计模式(二)结构型模式》。 […]

trackback
6 月 之前

[…] 请见《以深度学习为例理解22种设计模式(二)结构型模式》。 […]