设计模式

1 minute read

Published:

工厂模式

设计模式的工厂模式是一种创建型设计模式,用于处理对象的创建过程,将对象的创建和使用分离。这种模式对于管理对象的创建逻辑非常有用,尤其是在需要灵活地生成不同类型的对象时。

工厂模式的关键概念:

  1. 封装性:工厂模式隐藏了对象创建的具体细节,调用者不需要知道对象是如何被创建和表示的。

  2. 扩展性:当需要添加新的对象类型时,只需添加相应的具体产品类和工厂类,而无需修改现有代码,符合开闭原则(对扩展开放,对修改封闭)。

  3. 解耦:对象的创建逻辑和使用逻辑分离,降低了系统各部分之间的耦合度。

工厂模式的分类:

  1. 简单工厂模式:通过一个单一的工厂类来创建不同的对象。简单工厂模式不是真正的设计模式,因为它没有将创建逻辑封装起来。

  2. 工厂方法模式:定义了一个用于创建对象的接口,让子类决定实例化哪一个类。这把对象的创建推迟到子类。

  3. 抽象工厂模式:创建相关或依赖对象的家族,而不需明确指定具体类。当需要支持多个产品族时使用。

工厂模式的结构:

  • 产品(Product):定义了产品的接口,产品是工厂模式创建的对象。
  • 具体产品(Concrete Product):实现了产品接口的具体类。
  • 工厂(Factory):定义了创建产品的接口,声明了创建产品的方法。
  • 具体工厂(Concrete Factory):实现了工厂接口,生成具体产品对象。

工厂模式的使用场景:

  • 当创建逻辑复杂时,使用工厂模式可以将创建逻辑封装在工厂内部。
  • 当一个类不知道它所必须创建的对象的类时。
  • 当需要通过不同的工厂生成不同的产品族时。

优点:

  • 低耦合性:客户端不需要知道具体的类是如何实现的,只需要知道工厂的接口。
  • 代码可维护性:新增产品时,只需新增相应的类和工厂类,无需修改现有代码。
  • 可扩展性:容易扩展,增加新的产品和工厂类不会影响现有系统。

缺点:

  • 增加系统的复杂度:每增加一个产品类别,都需要增加一个具体类和产品类,可能会导致系统中类的数量急剧增加。
  • 增加系统的抽象性:在添加新的产品类时,需要对工厂的接口进行扩展,这可能需要修改抽象工厂类及其所有子类。

我们下面看一下具体的简单工厂模式的例子:

import collections
from itertools import repeat
from typing import List, Dict, Any

def _ntuple(n, name="parse"):
    def parse(x):
        if isinstance(x, collections.abc.Iterable):
            return tuple(x)
        return tuple(repeat(x, n))

    parse.__name__ = name
    return parse


_single = _ntuple(1, "_single")
_pair = _ntuple(2, "_pair")
_triple = _ntuple(3, "_triple")
_quadruple = _ntuple(4, "_quadruple")

if __name__ == '__main__':
    print(f'{_pair(3)}')
    print(f'{_pair((3,3))}')

Output:

(3, 3)
(3, 3)

我们通过_ntuple()来创建输出不同长度tuple的函数,比如我们创建_pair,即不管输入是1个字符,还是2个字符,返回的都是2个字符。

这个的一个非常常用的场景就是Pytorch中,我们给kernel_size、padding等传入一个数字时,pytorch也可以解析为一个tuple。

CLASS torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', device=None, dtype=None)

Parameters
in_channels (int)  Number of channels in the input image
out_channels (int)  Number of channels produced by the convolution
kernel_size (int or tuple)  Size of the convolving kernel
stride (int or tuple, optional)  Stride of the convolution. Default: 1
padding (int, tuple or str, optional)  Padding added to all four sides of the input. Default: 0
padding_mode (str, optional)  'zeros', 'reflect', 'replicate' or 'circular'. Default: 'zeros'
dilation (int or tuple, optional)  Spacing between kernel elements. Default: 1
groups (int, optional)  Number of blocked connections from input channels to output channels. Default: 1
bias (bool, optional)  If True, adds a learnable bias to the output. Default: True

可以看到 kernel_size可以为int,也可以为tuple。通过上面例子的工厂函数,就可以实现这样的需求。

我们上面列出的这个_ntuple()函数,正是pytorch源码中的函数。