以深度学习为例理解22种设计模式(三)行为模式

本系列包括3篇文章,以深度学习的模型搭建和训练过程为例,解释面向对象编程中22种设计模式的基本原理,并给出C++实现。

这些设计模式的实现方法大多参考自《深入设计模式》。为了用尽可能少的代码体现这些设计模式的过程,在代码中直接对指针进行操作,并没有使用智能指针。

该篇文章介绍行为模式,包括责任链、命令、迭代器、中介者、备忘录、观察者、状态、策略、模板方法以及访问者。

本文的所有代码都在我的GitHub上:github.com/johnhany/design_patterns


创建型模式

请见《以深度学习为例理解22种设计模式(一)创建型模式》


结构型模式

请见《以深度学习为例理解22种设计模式(二)结构型模式》


行为模式

行为模式(Behavioral Pattern)主要关心对象间的沟通和职责分配。

责任链

责任链(Chain of Responsibility)模式适用于以不同方式处理请求,但处理请求的顺序事先未知或需要在运行时调整顺序的情景。

比如,我们已经训练好了几种不同的检测模型,包括人脸检测、人体检测和文字检测。现在我们想把这些模型运用在一些各式各样的真实照片上,有的照片可能包含人像、可能包含完整的人脸、也能有文字。但是我们不能假定每幅照片都存在这3类目标。例如,如果照片中存在人脸的话,我们还想调用一个其他的模型来估计这个人的情绪、年龄等信息。我们可以用比较繁琐的if/switch分支把这3类目标存在或不存在的情况都覆盖到;也可以用责任链模式,通过比较统一的接口先执行人脸检测及需要依赖人脸进行的工作、再执行人体检测、最后执行文字检测的方法来减少未来增加新功能的工作量。

首先,我们定义责任链上每一个节点的接口和默认实现:

class Handler {
public:
    virtual Handler* setNext(Handler* handler) = 0;
    virtual string process(string img) = 0;
};

class AbstractHandler : public Handler {
private:
    Handler* next;
public:
    AbstractHandler() : next(nullptr) {}

    Handler* setNext(Handler* handler) override {
        this->next = handler;
        return handler;
    }
    string process(string img) override {
        if (this->next)
            return this->next->process(img);
        else
            return {};
    }
};

然后就可以定义3种检测模型的节点了:

class FaceDetection : public AbstractHandler {
public:
    string process(string img) override {
        string rst = "";
        if (img.find("face") != string::npos)
            rst += "Detected face. ";
        rst += AbstractHandler::process(img);
        return rst;
    }
};

class BodyDetection : public AbstractHandler {
public:
    string process(string img) override {
        string rst = "";
        if (img.find("body") != string::npos)
            rst += "Detected body. ";
        rst += AbstractHandler::process(img);
        return rst;
    }
};

class TextDetection : public AbstractHandler {
public:
    string process(string img) override {
        string rst = "";
        if (img.find("text") != string::npos)
            rst += "Detected text. ";
        rst += AbstractHandler::process(img);
        return rst;
    }
};

责任链本身的结构呈一条链表的形式,每次执行完一个节点就会递归地调用下一个节点。

所以,我们的检测入口函数就是这样的:

void detect_api(Handler& handler, string& image) {
    cout << "Processing image: " << image << endl;
    string rst = handler.process(image);
    if (rst.empty())
        rst = "Detected nothing.";
    cout << rst << endl;
}

最后给出main()函数:

int main() {
    FaceDetection* face = new FaceDetection();
    BodyDetection* body = new BodyDetection();
    TextDetection* text = new TextDetection();

    face->setNext(body)->setNext(text);

    string img1 = "An image of a person with face and body, and some text";
    detect_api(*face, img1);

    string img2 = "An image of a person with face, and some text";
    detect_api(*body, img2);

    delete face;
    delete body;
    delete text;

    return 0;
}

运行结果如下:

Processing image: An image of a person with face and body, and some text
Detected face. Detected body. Detected text. 
Processing image: An image of a person with face, and some text
Detected text.

命令

命令(Command)模式用来把某种请求封装为具体的对象,以实现多个请求的队列、日志、取消等功能。

比如,我们想开发一款为图片添加多种特效的APP。特效效果的实现依靠深度学习模型来完成,所以每执行一步都要消耗比较长的时间。所以,我们需要在APP中实现一个命令队列,用来记录用户想要添加的所有特效,然后APP在后台依次执行每个特效模型。同时,我们也需要增加一个撤销未执行命令的功能,为耐心不足的用户缩短等待时间。这个功能就可以用命令模式来实现。

首先,我们定义一个Command基类,表示要执行的命令:

class Command {
protected:
    string image;
public:
    virtual ~Command() {}
    void setCurrentImage(string& img) {
        this->image = img;
    }
    virtual void execute() const = 0;
};

然后定义一个Receiver类,负责调用比较耗时的深度学习模型:

class Receiver {
public:
    void applyEffectsOnForeground(const string& img) {
        cout << "  Add some special effects on the foreground of " << img << endl;
    }
    void applyEffectsOnBackground(const string& img) {
        cout << "  Add some special effects on the background of " << img << endl;
    }
};

在这里,我们提供了两个函数,分别对图片中的前景(比如人像)和背景进行特效处理。

现在我们就可以继承出两个命令子类,一个用来处理人像照片的背景,一个用来处理整幅图像:

class EditPortrait : public Command {
private:
    Receiver* receiver;
public:
    explicit EditPortrait(Receiver* rec) : receiver(rec) {}
    void execute() const override {
        cout << "Edit portrait..." << endl;
        this->receiver->applyEffectsOnBackground(this->image);
    }
};

class EditFullImage : public Command {
private:
    Receiver* receiver;
public:
    explicit EditFullImage(Receiver* rec) : receiver(rec) {}
    void execute() const override {
        cout << "Edit full image..." << endl;
        this->receiver->applyEffectsOnForeground(this->image);
        this->receiver->applyEffectsOnBackground(this->image);
    }
};

我们还需要定义一个Invoker类来维护命令队列:

class Invoker {
private:
    vector<Command*> cmd_queue;
public:
    virtual ~Invoker() {
        for (auto& p : this->cmd_queue)
            delete p;
        cout << "Invoker destroyed" << endl;
    }

    void addCommand(Command* cmd) {
        this->cmd_queue.push_back(cmd);
    }

    void revokeLastCommand() {
        if (!this->cmd_queue.empty())
            this->cmd_queue.pop_back();
    }

    void executeAllCommands(string img) {
        for (auto& p : this->cmd_queue) {
            p->setCurrentImage(img);
            p->execute();
        }
    }
};

Invoker内提供了增加命令、撤销最后一个命令和执行所有命令的接口。

最后是main()函数:

int main() {
    string image = "IMAGE";
    Receiver* receiver = new Receiver();
    Invoker* invoker = new Invoker();
    Command* cmd1 = new EditPortrait(receiver);
    Command* cmd2 = new EditFullImage(receiver);
    Command* cmd3 = new EditPortrait(receiver);
    invoker->addCommand(cmd1);
    invoker->addCommand(cmd2);
    invoker->addCommand(cmd3);

    invoker->revokeLastCommand();

    invoker->executeAllCommands(image);

    delete receiver;
    delete invoker;
    delete cmd1;
    delete cmd2;
    delete cmd3;
    return 0;
}

执行结果如下:

Edit portrait...
  Add some special effects on the background of IMAGE
Edit full image...
  Add some special effects on the foreground of IMAGE
  Add some special effects on the background of IMAGE
Invoker destroyed

迭代器

迭代器(Iterator)模式用来提供遍历容器内部元素的接口。元素的遍历策略由容器的开发者来设计,容器的调用者不需要关心遍历策略的细节。

比如,我们在很多时候都需要自己实现数据集的读取接口,以便在模型的迭代训练过程中按照mini-batch的方式给出一组组的训练数据。为了节约时间,我们的数据容器底层直接利用STL的vector容器来保存所有的数据。这样我们只需要关心容器和迭代器本身的接口设计,而不需要花很多精力研究内存地址的维护了。

迭代器的代码如下:

template <typename T, typename U>
class Iterator {
public:
    typedef typename vector<T>::iterator iter_type;
    Iterator(U *data) : container(data) {
        iter = container->data.begin();
    }

    void first() {
        iter = container->data.begin();
    }

    void next() {
        iter++;
    }

    vector<T> getBatch(int batch_size) {
        vector<T> rst;
        while (!this->reachEnd() && batch_size--)
            rst.push_back(this->getSample());
        return rst;
    }

    T getSample() {
        T rst = *iter;
        this->next();
        return rst;
    }

    bool reachEnd() {
        return (iter == container->data.end());
    }
private:
    U *container;
    iter_type iter;
};

容器的实现如下:

template <class T>
class Container {
    friend class Iterator<T, Container>;

public:
    Container(std::initializer_list<T> data) {
        for (auto& a : data)
            this->add(a);
    }

    void add(T a) {
        data.push_back(a);
    }

    Iterator<T, Container> *createIterator() {
        return new Iterator<T, Container>(this);
    }
private:
    std::vector<T> data;
};

然后再定义一个保存单个数据的容器元素:

class Sample {
public:
    Sample(string img) : data(img) {}

    string getData() {
        return data;
    }
private:
    string data;
};

最后给出main()函数。我们这里分别测试了intSample两种数据的读取和迭代过程:

int main() {
    cout << "Dataset of int" << endl;
    Container<int> dataset1{1,2,3,4,5,6};
    auto itr1 = dataset1.createIterator();

    cout << "Current sample: " << itr1->getSample() << endl;

    cout << "Next sample: " << itr1->getSample() << endl;

    auto batch1 = itr1->getBatch(3);
    cout << "Next sample batch:" << endl;
    for (auto& a : batch1)
        cout << a << ", ";
    cout << endl;

    batch1 = itr1->getBatch(3);
    cout << "Get another batch:" << endl;
    for (auto& a : batch1)
        cout << a << ", ";
    cout << endl;

    cout << "Dataset of Sample" << endl;
    Container<Sample> dataset2{};
    for (int i = 10; i < 15; i++)
        dataset2.add(Sample(to_string(i)));
    auto itr2 = dataset2.createIterator();

    cout << "Next sample: " << itr2->getSample().getData() << endl;

    auto batch2 = itr2->getBatch(3);
    cout << "Next sample batch:" << endl;
    for (auto& a : batch2)
        cout << a.getData() << ", ";
    cout << endl;

    return 0;
}

运行结果如下:

Dataset of int
Current sample: 1
Next sample: 2
Next sample batch:
3, 4, 5, 
Get another batch:
6, 
Dataset of Sample
Next sample: 10
Next sample batch:
11, 12, 13,

中介者

中介者(Mediator)模式适用于多个类之间紧密耦合,每次增加新功能需要修改大量代码的情景。该模式可以使各组件不再关心其他组件的细节,把复杂的依赖关系交给中介者处理。

比如,我们想实现一个比较复杂的图像生成模型。这个模型根据一段文字描述来产生图像。如果产生的是沙滩的图像,还需要在沙滩上添加人像;如果产生的是街道的图像,则需要添加车辆和人像。我们把图像的生成任务分为两部分:前景和背景。其中,背景部分负责产生沙滩和街道,前景部分负责产生人和车辆。同时,不同的前景类和背景类之间还要存在特殊的依赖关系。这时,我们就可以建立一个中介者来协调各个组件之间的关系。

首先,声明中介者的基本接口:

class Generator;
class Mediator {
public:
    virtual void execute(string img, Generator* gen = nullptr) const = 0;
};

再定义生成模型的基类:

class Generator {
protected:
    Mediator* mediator;
public:
    Generator(Mediator* mediator = nullptr) : mediator(mediator) {}

    void setMediator(Mediator* mediator) {
        this->mediator = mediator;
    }
};

然后就可以继承出生成模型的前景部分和背景部分了:

class ForegroundGenerator : public Generator {
public:
    void generatePerson() {
        cout << "Generated a person in image" << endl;
        this->mediator->execute("", this);
    }
    void generateVehicle() {
        cout << "Generated a vehicle in image" << endl;
        this->mediator->execute("", this);
    }
};

class BackgroundGenerator : public Generator {
public:
    void generateBeach() {
        cout << "Generated beach background in image" << endl;
        this->mediator->execute("person", this);
    }

    void generateStreet() {
        cout << "Generated street background in image" << endl;
        this->mediator->execute("person + vehicle", this);
    }
};

接着给出中介者的具体实现:

class GenerationMediator : public Mediator {
private:
    ForegroundGenerator* fg_gen;
    BackgroundGenerator* bg_gen;
public:
    GenerationMediator(ForegroundGenerator* fg, BackgroundGenerator* bg) : fg_gen(fg), bg_gen(bg) {
        this->fg_gen->setMediator(this);
        this->bg_gen->setMediator(this);
    }

    void execute(string request, Generator* gen = nullptr) const override {
        if (request.find("beach") != string::npos) {
            this->bg_gen->generateBeach();
        }
        if (request.find("street") != string::npos) {
            this->bg_gen->generateStreet();
        }
        if (request.find("person") != string::npos) {
            this->fg_gen->generatePerson();
        }
        if (request.find("vehicle") != string::npos) {
            this->fg_gen->generateVehicle();
        }
    }
};

最后给出main()函数:

int main() {
    string request1 = "I need an image of beach";
    ForegroundGenerator* fg = new ForegroundGenerator();
    BackgroundGenerator* bg = new BackgroundGenerator();
    GenerationMediator* mediator = new GenerationMediator(fg, bg);
    cout << "Generating image based on request: " << request1 << endl;
    mediator->execute(request1);

    string request2 = "I need an image of vehicle";
    cout << "Generating image based on request: " << request2 << endl;
    mediator->execute(request2);

    cout << "You can call Generator to trigger the Mediator:" << endl;
    bg->generateStreet();

    delete fg;
    delete bg;
    delete mediator;

    return 0;
}

我们可以通过中介者调用具体组件,也可以通过直接调用组件的方式来引导中介者执行后续的步骤。

运行结果如下所示:

Generating image based on request: I need an image of beach
Generated beach background in image
Generated a person in image
Generating image based on request: I need an image of vehicle
Generated a vehicle in image
You can call Generator to trigger the Mediator:
Generated street background in image
Generated a person in image
Generated a vehicle in image

备忘录

备忘录(Memento)模式适用于需要记录对象快照来恢复其历史状态的情景。该模式可以用来实现“撤销”某操作的效果。

比如,在前面命令模式当中的修图APP的开发过程中,我们觉得只有一个取消未执行命令的功能还不够,如果有撤销已经执行过命令的功能就更好了。如果所有的命令都是可逆的,我们可以记录所有操作本身的历史记录,然后在需要撤销操作时执行相应操作的逆操作。如果命令的效果不可逆的话,比较容易实现的办法则是直接记录每次操作前数据的整体备份。我们可以用备忘录模式来实现这个功能。

首先,我们定义Memento类来记录图片的一份历史快照:

class Memento {
public:
    virtual string getState() const = 0;
};

class ConcreteMemento : public Memento {
private:
    string state;
public:
    ConcreteMemento(string state) : state(state) {}

    string getState() const override {
        return this->state;
    }
};

然后定义一个Editor类实现对图片的各种操作:

class Editor {
private:
    string image;
public:
    Editor(string img) : image(img) {
        cout << "[Editor] Initial image: " << this->image << endl;
    }

    void resizeToHalf() {
        int length = this->image.size();
        string rst = "";
        for (int i = 0; i < length; i++) {
            if (i % 2 == 0)
                rst.push_back(this->image[i]);
        }
        this->image = rst;
        cout << "[Editor] Image: " << this->image << endl;
    }

    void resizeTo2X() {
        int length = this->image.size();
        string rst = "";
        for (int i = 0; i < length; i++) {
            rst.push_back(this->image[i]);
            rst.push_back(this->image[i]);
        }
        this->image = rst;
        cout << "[Editor] Image: " << this->image << endl;
    }

    void brighten(int var) {
        int length = this->image.size();
        string rst = "";
        for (int i = 0; i < length; i++) {
            if (this->image[i] + var <= '9')
                rst.push_back(this->image[i] + var);
            else
                rst.push_back('9');
        }
        this->image = rst;
        cout << "[Editor] Image: " << this->image << endl;
    }

    Memento* save() {
        return new ConcreteMemento(this->image);
    }

    void restore(Memento* memento) {
        this->image = memento->getState();
        cout << "[Editor] Image is restored to: " << this->image << endl;
    }
};

需要注意的是,我们在Editor里只保留了图片的当前状态,而且是声明为private的,这样可以很好地保护图片数据不被外部行为随意修改。另外,由于图片缩小和调亮(假设像素亮度在0-9之间)操作是不可逆的,所以用直接保存完整图像数据的方式保存历史快照。

接着就可以定义一个App类来保存所有的历史快照,并提供撤销功能的接口:

class App {
private:
    vector<Memento*> hist;
    Editor* editor;
public:
    App(Editor* editor) : editor(editor) {}

    void backup() {
        cout << "[App] Saving current image..." << endl;
        this->hist.push_back(this->editor->save());
    }

    void undo() {
        if (this->hist.empty()) return;
        cout << "[App] Undoing last process..." << endl;
        Memento* mem = this->hist.back();
        this->hist.pop_back();
        this->editor->restore(mem);
    }
};

最后给出main()函数:

int main() {
    string image = "12323454";
    Editor* editor = new Editor(image);
    App* app = new App(editor);

    app->backup();
    editor->resizeToHalf();

    app->backup();
    editor->resizeTo2X();

    app->backup();
    editor->brighten(6);

    app->undo();
    app->undo();
    app->undo();

    return 0;
}

代码运行效果如下:

[Editor] Initial image: 12323454
[App] Saving current image...
[Editor] Image: 1335
[App] Saving current image...
[Editor] Image: 11333355
[App] Saving current image...
[Editor] Image: 77999999
[App] Undoing last process...
[Editor] Image is restored to: 11333355
[App] Undoing last process...
[Editor] Image is restored to: 1335
[App] Undoing last process...
[Editor] Image is restored to: 12323454

观察者

观察者(Observer)模式适用于有很多对象都能以“订阅”的方式同时接收某个类发出的信息,而且随时可以取消或增加订阅的情景。

比如,我们可以回顾一下第二篇文章《以深度学习为例理解22种设计模式(二)结构型模式》中代理模式的例子。我们设计了在分布式计算网络中为每个计算节点延迟返回计算结果的机制。现在,我们需要把主机将合并后的梯度回传给各个计算节点的过程实现出来。在回传的过程中,我们假定在每次回传之前不知道哪些计算节点正在忙于自己的工作,哪些是空闲的。我们需要实现一个接口,把主机中保存的梯度同时传递给所有空闲的从机,而且每台从机也需要能够通知主机是否准备好接收新的梯度数据。这其实十分类似于订阅系统,消息发出方会按照自己的时间表同时向所有正在订阅的用户发送通知,而每个订阅者可以随时通知消息发出方是否愿意接受下一次通知。

首先,定义从机和主机的基类:

class DeviceNode {
public:
    virtual ~DeviceNode() {};
    virtual void pullGrad(int grad) = 0;
};


class MasterNode {
public:
    virtual ~MasterNode() {};
    virtual void attach(DeviceNode* device) = 0;
    virtual void detach(DeviceNode* device) = 0;
    virtual void publishGrad() = 0;
};

然后继承并实现主机的具体实现:

class RealMasterNode : public MasterNode {
private:
    int grad = 1;
    list<DeviceNode*> devices;
public:
    void attach(DeviceNode* device) override {
        cout << "[Master] Add a new device node" << endl;
        this->devices.push_back(device);
    }

    void detach(DeviceNode* device) override {
        cout << "[Master] Remove a device node" << endl;
        this->devices.remove(device);
    }

    void publishGrad() override {
        cout << "[Master] Publishing new grads to device nodes" << endl;
        for (auto& p : this->devices)
            p->pullGrad(this->grad);
    }
};

我们在主机中用一个容器来记录所有当前正在“订阅”中的从机。

我们再给出从机的具体实现:

class RealDeviceNode : public DeviceNode {
private:
    int grad = 0;
    RealMasterNode& master;
public:
    RealDeviceNode(RealMasterNode& master) : master(master) {
        this->imReadyForNewGrad();
    }
    virtual ~RealDeviceNode() {}

    void pullGrad(int grad) override {
        this->grad += grad;
    }

    void imReadyForNewGrad() {
        this->master.attach(this);
    }

    void imNotReadyForNewGrad() {
        this->master.detach(this);
    }

    string getGrad() {
        return to_string(this->grad);
    }
};

最后给出main()函数:

int main() {
    RealMasterNode* master = new RealMasterNode();
    RealDeviceNode* device1 = new RealDeviceNode(*master);
    RealDeviceNode* device2 = new RealDeviceNode(*master);
    RealDeviceNode* device3 = new RealDeviceNode(*master);

    master->publishGrad();

    cout << "Device1 is slow. Not receiving new grads for now" << endl;
    device1->imNotReadyForNewGrad();
    cout << "Device2 is slow. Not receiving new grads for now" << endl;
    device2->imNotReadyForNewGrad();

    master->publishGrad();

    RealDeviceNode* device4 = new RealDeviceNode(*master);
    cout << "Device4 is slow. Not receiving new grads for now" << endl;
    device4->imNotReadyForNewGrad();
    cout << "Device1 is ready to receive new grads" << endl;
    device1->imReadyForNewGrad();

    master->publishGrad();

    cout << "Device1: " << device1->getGrad() << endl;
    cout << "Device2: " << device2->getGrad() << endl;
    cout << "Device3: " << device3->getGrad() << endl;
    cout << "Device4: " << device4->getGrad() << endl;

    delete master;
    delete device1;
    delete device2;
    delete device3;
    delete device4;
    return 0;
}

运行结果如下:

[Master] Add a new device node
[Master] Add a new device node
[Master] Add a new device node
[Master] Publishing new grads to device nodes
Device1 is slow. Not receiving new grads for now
[Master] Remove a device node
Device2 is slow. Not receiving new grads for now
[Master] Remove a device node
[Master] Publishing new grads to device nodes
[Master] Add a new device node
Device4 is slow. Not receiving new grads for now
[Master] Remove a device node
Device1 is ready to receive new grads
[Master] Add a new device node
[Master] Publishing new grads to device nodes
Device1: 2
Device2: 1
Device3: 3
Device4: 0

状态

状态(State)模式适用于需要通过改变对象内部状态来改变其行为的情景。一般将包含很多if/switch分支的函数改写为对象时都可以用到状态模式。

比如,我们已经搭建好了一个CNN网络,但是具体怎么来用这个网络还没有想好。我们可能会训练整个网络、或者针对新数据执行微调训练、或者用NAS方法搜索更好的超参。这种同一个对象的不同行为可以用许多if/switch分支来实现,也可以用状态机的思路来实现。

首先,我们定义用来表示某种状态的State基类:

class Network;

class State {
protected:
    Network* network;
public:
    virtual ~State() {}

    void setNetwork(Network* net) {
        this->network = net;
    }

    virtual void train() = 0;
    virtual string name() = 0;
};

再定义表示CNN网络的类:

class Network {
private:
    State* state;
public:
    Network(State* state) {
        this->switchState(state);
    }
    ~Network() {
        delete state;
    }

    void switchState(State* state) {
        cout << "[Network] Switch to state: " << state->name() << endl;
        if (this->state != nullptr)
            delete this->state;
        this->state = state;
        this->state->setNetwork(this);
    }

    void train() {
        this->state->train();
    }
};

然后,我们就可以从State继承出几种表示具体状态的子类了:

class FullTrainState : public State {
public:
    void train() override {
        cout << "Training all the layers in network" << endl;
    }
    string name() override {
        return "Train-From-Scratch";
    }
};


class FineTuneState : public State {
public:
    void train() override {
        cout << "Training only the last layer in network" << endl;
    }
    string name() override {
        return "Fine-Tuning";
    }
};


class HyperParamSearchState : public State {
public:
    void train() override {
        cout << "Searching the optim hyper-params in network" << endl;
        this->network->switchState(new FullTrainState());
    }
    string name() override {
        return "Hyper-Param-Search";
    }
};

最后,给出main()函数:

int main() {
    FullTrainState* full_train = new FullTrainState();
    Network* network = new Network(full_train);
    network->train();

    HyperParamSearchState* hyper_search = new HyperParamSearchState();
    network->switchState(hyper_search);
    network->train();

    FineTuneState* fine_tune = new FineTuneState();
    network->train();
    network->switchState(fine_tune);
    network->train();

    delete full_train;
    delete hyper_search;
    delete fine_tune;
    delete network;
    return 0;
}

通过状态模式,我们可以在运行时切换对象的内部状态,从而改变其行为。代码的运行结果如下:

[Network] Switch to state: Train-From-Scratch
Training all the layers in network
[Network] Switch to state: Hyper-Param-Search
Searching the optim hyper-params in network
[Network] Switch to state: Train-From-Scratch
Training all the layers in network
[Network] Switch to state: Fine-Tuning
Training only the last layer in network

策略

策略(Strategy)模式适用于需要把业务逻辑与具体算法隔离、或者在运行时切换算法的情景。与状态模式相比,策略模式把算法行为本身抽象成对象,于是把对象数据和行为的隔离执行得更彻底。

比如,在深度学习工具中,用一个对象保存模型参数,用另一个对象执行模型参数的更新(即优化器)的思路就符合策略模式。

我们首先定义优化器的基类:

class Optimizer {
public:
    virtual ~Optimizer() {}
    virtual void step(vector<float>& params) const = 0;
};

再定义一个用来保存参数数据的模型类:

class Model {
private:
    Optimizer* optim;
    vector<float> params;
public:
    Model(Optimizer* optimizer) : optim(optimizer) {}
    ~Model() {
        delete this->optim;
    }

    void initParameters(std::initializer_list<float> data) {
        for (float a : data)
            this->params.push_back(a);
    }

    void printParameters() {
        cout << "All parameters in the model:" << endl;
        for (float a : params)
            cout << a << ", ";
        cout << endl;
    }

    void setOptimizationMethod(Optimizer* optimizer) {
        delete this->optim;
        this->optim = optimizer;
    }

    void train() {
        this->optim->step(this->params);
    }
};

然后,在优化器基类的基础上继承出两种优化算法(梯度下降和梯度上升)。我们这里只是用简单的数值加法和减法来表示不同的优化算法:

class GradienDescend : public Optimizer {
private:
    float lr;
public:
    GradienDescend(float lr = 0.l) : Optimizer(), lr(lr) {}
    void step(vector<float>& params) const override {
        cout << "Train the model with gradient descend method" << endl;
        for (float& a : params)
            a -= this->lr;
    }
};

class GradienAscend : public Optimizer {
private:
    float lr;
public:
    GradienAscend(float lr = 0.l) : Optimizer(), lr(lr) {}
    void step(vector<float>& params) const override {
        cout << "Train the model with gradient ascend method" << endl;
        for (float& a : params)
            a += this->lr;
    }
};

最后是main()函数:

int main()
{
    GradienDescend* optim1 = new GradienDescend(0.1);
    Model* model = new Model(optim1);
    model->initParameters({1, 2, 3, 4, 5});
    model->printParameters();

    model->train();
    model->printParameters();

    GradienAscend* optim2 = new GradienAscend(0.5);
    model->setOptimizationMethod(optim2);
    model->train();
    model->printParameters();

    return 0;
}

运行结果如下:

All parameters in the model:
1, 2, 3, 4, 5, 
Train the model with gradient descend method
All parameters in the model:
0.9, 1.9, 2.9, 3.9, 4.9, 
Train the model with gradient ascend method
All parameters in the model:
1.4, 2.4, 3.4, 4.4, 5.4,

模板方法

模板方法(Template Method)模式适用于需要把多个针对不同对象的算法过程合并在一个类当中,而且这些过程的大部分步骤都相同的情景。

比如,我们想设计一个通用的模型初始化接口。在这个接口中,有些操作是可以定义默认行为的,例如从数据集文件产生DataLoader(用来在迭代训练过程中给出样本batch)以及初始化基础CNN网络的过程应该不会有太多变化。同时,也有些操作是需要针对具体数据和问题决定的,例如原始数据集的读取、CNN最后的分类层的定义、以及误差函数的定义。这时,我们可以用模板方法把相同的行为写进基类,然后针对具体问题用继承子类的方式为灵活性较强的行为提供实现。

首先,定义包含默认行为的基类:

class AbstractModel {
protected:
    virtual void initDataset() const = 0;
    void initDataLoader() const {
        cout << "[AbstractModel] Created a DataLoader with batch size of 64 from Dataset" << endl;
    }
    void initBaseNetwork() const {
        cout << "[AbstractModel] Initialized the base network with pretrained weights" << endl;
    }
    virtual void initFinalLayer() const = 0;
    virtual void initLossFunction() const = 0;
public:
    void init() const {
        this->initDataset();
        this->initDataLoader();
        this->initBaseNetwork();
        this->initFinalLayer();
        this->initLossFunction();
    }
};

然后就可以针对不同问题给出其余接口的实现:

class ClassificationModel : public AbstractModel {
protected:
    void initDataset() const override {
        cout << "[Classification] Loaded ImageNet dataset" << endl;
    }
    void initFinalLayer() const override {
        cout << "[Classification] Let's use a linear layer" << endl;
    }
    void initLossFunction() const override {
        cout << "[Classification] Let's use cross-entropy" << endl;
    }
};

class DetectionModel : public AbstractModel {
protected:
    void initDataset() const override {
        cout << "[Detection] Loaded COCO dataset" << endl;
    }
    void initFinalLayer() const override {
        cout << "[Detection] Let's add several layers and train them" << endl;
    }
    void initLossFunction() const override {
        cout << "[Detection] Let's use RetinaLoss" << endl;
    }
};

剩下的代码如下:

void init_api(AbstractModel* model) {
    model->init();
}


int main() {
    cout << "I want to initialize a classification model" << endl;
    ClassificationModel* model1 = new ClassificationModel();
    init_api(model1);

    cout << "\nI want to initialize a detection model" << endl;
    DetectionModel* model2 = new DetectionModel();
    init_api(model2);

    delete model1;
    delete model2;
    return 0;
}

运行结果如下所示:

I want to initialize a classification model
[Classification] Loaded ImageNet dataset
[AbstractModel] Created a DataLoader with batch size of 64 from Dataset
[AbstractModel] Initialized the base network with pretrained weights
[Classification] Let's use a linear layer
[Classification] Let's use cross-entropy

I want to initialize a detection model
[Detection] Loaded COCO dataset
[AbstractModel] Created a DataLoader with batch size of 64 from Dataset
[AbstractModel] Initialized the base network with pretrained weights
[Detection] Let's add several layers and train them
[Detection] Let's use RetinaLoss

模板方法模式与策略模式的区别在于,模板方法模式是静态的,即通过继承的方式在编译时确定所采取的算法步骤;而策略模式是动态的,即在运行时确定所采取的算法步骤(且允许在运行时切换算法)。

访问者

访问者(Visitor)模式适用于需要用某种操作处理一个对象中所有元素的情景,尤其是需要将操作的实现与对象自身分离,以便在不影响对象结构的前提下改变这种操作的行为时。与其他可以在运行时切换操作行为的模式相比(比如命令模式或策略模式),访问者模式的特点是只有在运行期间访问者与被访问者的类型同时满足条件时才会执行操作。

比如,我们已经实现了一个CNN模型,在模型中包含卷积层和全连通层两种基本结构。在模型调参的过程中,除了模型精度之外,我们还想了解有关该模型的其他指标,比如计算量(FLOPs),或者需要在训练过程中把各层的参数可视化等。但是由于卷积层和全连通层本身结构不同,所以它们的FLOPs计算和可视化方法也不一样。这时,我们可以用访问者模式来实现这些功能。

我们首先定义声明访问者的基本接口:

class ConvLayer;
class LinearLayer;

class Visitor {
public:
    virtual void visitConvLayer(const ConvLayer* layer) const = 0;
    virtual void visitLinearLayer(const LinearLayer* layer) const = 0;
};

然后给出卷积层和全连通层的定义:

class Layer {
public:
    virtual ~Layer() {}
    virtual void accept(Visitor* visitor) const = 0;
};


class ConvLayer : public Layer {
public:
    void accept(Visitor* visitor) const override {
        visitor->visitConvLayer(this);
    }

    string getParam() const {
        return "[x,y,z,w]";
    }
};

class LinearLayer : public Layer {
public:
    void accept(Visitor* visitor) const override {
        visitor->visitLinearLayer(this);
    }

    string getParam() const {
        return "[x,y]";
    }
};

接着,我们就可以针对两种计算结构实现FLOPs计算和参数可视化的接口了:

class calcFlopsVisitor : public Visitor {
public:
    void visitConvLayer(const ConvLayer* layer) const override {
        cout << "Calculating the FLOPS of convolution layer: " << layer->getParam() << endl;
    }
    void visitLinearLayer(const LinearLayer* layer) const override {
        cout << "Calculating the FLOPS of linear layer: " << layer->getParam() << endl;
    }
};

class visParamVisitor : public Visitor {
public:
    void visitConvLayer(const ConvLayer* layer) const override {
        cout << "Visualizing the kernels of convolution layer: " << layer->getParam() << endl;
    }
    void visitLinearLayer(const LinearLayer* layer) const override {
        cout << "Visualizing the weight of linear layer: " << layer->getParam() << endl;
    }
};

我们再把各个层封装到模型中:

class Network {
private:
    vector<Layer*> layers;
public:
    virtual ~Network() {
        for (Layer* p : this->layers)
            delete p;
    }
    void addLayer(Layer* layer) {
        this->layers.push_back(layer);
    }

    void visitorHelper(Visitor* visitor) {
        for (Layer* p : this->layers)
            p->accept(visitor);
    }
};

最后给出main()函数:

int main() {
    Network* network = new Network();
    network->addLayer(new ConvLayer());
    network->addLayer(new ConvLayer());
    network->addLayer(new LinearLayer());
    network->addLayer(new LinearLayer());

    calcFlopsVisitor* visitor1 = new calcFlopsVisitor();
    network->visitorHelper(visitor1);

    visParamVisitor* visitor2 = new visParamVisitor();
    network->visitorHelper(visitor2);

    delete visitor1;
    delete visitor2;
    delete network;
    return 0;
}

代码运行结果如下:

Calculating the FLOPS of convolution layer: [x,y,z,w]
Calculating the FLOPS of convolution layer: [x,y,z,w]
Calculating the FLOPS of linear layer: [x,y]
Calculating the FLOPS of linear layer: [x,y]
Visualizing the kernels of convolution layer: [x,y,z,w]
Visualizing the kernels of convolution layer: [x,y,z,w]
Visualizing the weight of linear layer: [x,y]
Visualizing the weight of linear layer: [x,y]

以上就是全部22种面向对象编程的设计模式了。

本文的所有代码都在我的GitHub上:github.com/johnhany/design_patterns

把这篇文章分享给你的朋友:
Subscribe
订阅评论
guest
2 评论
最新
最旧 得票最多
Inline Feedbacks
View all comments
trackback
1 月 之前

[…] 请见《以深度学习为例理解22种设计模式(三)行为模式》。 […]

trackback
1 月 之前

[…] 请见《以深度学习为例理解22种设计模式(三)行为模式》。 […]