15.python设计模式【函数工厂模式】
1.知识讲解
- 内容:定义一个字典,在python中一切皆对象,将所有的函数进行封装,然后定一个分发函数进行分发,将原来if…else全部干掉。
- 角色:
- 函数(function)
- 函数工厂(function factory)
- 客户端 (client)
- 举个例子:
需求:封装一个函数,能够同时进行加减乘除运算。
加减乘除函数:
# 定义一个计算器的相关功能
def plus(a, b):return a + bdef substact(a, b):return a - bdef multiply(a, b):return a * bdef divide(a, b):return a / b
定义封装函数:
# 定义一个计算函数
def cal(a, b, how):if how == 1:return plus(a, b)elif how == 2:return substact(a, b)elif how == 3:return multiply(a, b)else:return None
从上面这个封装函数来看,太多了if…else…很冗余
于是定义一个函数工厂,将所有函数进行封装,然后根据函数名进行调用
# 定义函数工厂
# 在python里面一切皆是对象
# 定义了一个字典,key是函数名称,value是函数对象
func_map = {"plus": plus,"substract": substact,"multiply": multiply,"divide": divide
}
# 函数工厂模式就是一种对函数进行动态分发的模式
def cal(a,b,how):if how in func_map.keys():return func_map[how](a,b)else:return None
- 优点:
- 对函数进行动态分发,减少了函数的冗余代码。
2.实战
2.1 demo1
需求:这个是我在写深度学习项目的时候遇到的一个设计模式,当初不明白,现在明白了这个设计模式。自然语言处理中,有一次有一个实验,需要同时验证Bert,roberta,gpt,Xnet等预训练模型的相关功能的性能,他们大致分以下几个模块
- config
- tokenizer
- 掩码模型:Bert,roberta,gpt使用的是mlm掩码模型,而Xnet使用的是plm掩码模型
- 自带的分类模型:sequence_classifier ,但是GPT没有
因为他们每个的这四个部分的功能实现都不相同,但是在实验过程中都需要用到,因此就用到了函数工厂模式。
from torch import nn
from transformers import BertConfig, BertTokenizer, BertForSequenceClassification, BertForMaskedLM, RobertaConfig, \RobertaTokenizer, RobertaForSequenceClassification, RobertaForMaskedLM, XLMRobertaConfig, XLMRobertaTokenizer, \XLMRobertaForSequenceClassification, XLMRobertaForMaskedLM, XLNetConfig, XLNetTokenizer, \XLNetForSequenceClassification, XLNetLMHeadModel, AlbertConfig, AlbertTokenizer, AlbertForSequenceClassification, \AlbertForMaskedLM, GPT2Config, GPT2Tokenizer, GPT2LMHeadModel, AutoTokenizer# 定义一个函数工厂,将所有的函数全部用一个字典封装好,到时候用到那个预训练模型,则就根据预训练模型的名称调用对应的函数。
MODEL_CLASSES = {'bert': {'config': BertConfig,'tokenizer': BertTokenizer,"sequence_classifier": BertForSequenceClassification,"mlm":BertForMaskedLM},'roberta': {'config': RobertaConfig,'tokenizer': RobertaTokenizer,"sequence_classifier": RobertaForSequenceClassification,"mlm": RobertaForMaskedLM},'xlm-roberta': {'config': XLMRobertaConfig,'tokenizer': XLMRobertaTokenizer,"sequence_classifier": XLMRobertaForSequenceClassification,"mlm": XLMRobertaForMaskedLM},'xlnet': {'config': XLNetConfig,'tokenizer': XLNetTokenizer,"sequence_classifier": XLNetForSequenceClassification,"plm": XLNetLMHeadModel},'albert': {'config': AlbertConfig,'tokenizer': AlbertTokenizer,"sequence_classifier": AlbertForSequenceClassification,"mlm": AlbertForMaskedLM},'gpt2': {'config': GPT2Config,'tokenizer': GPT2Tokenizer,"mlm": GPT2LMHeadModel},
}class TransformerModelWrapper(nn.Module):# 基于Transformer的语言模型的包装器。'''WrapperConfig封装了:model_type为Bert,roberta,gpt,Xnet,wrapper_type为mlm和plm两种类型'''def __init__(self, config: WrapperConfig):super(TransformerModelWrapper, self).__init__()self.config = configconfig_class = MODEL_CLASSES[self.config.model_type]['config']tokenizer_class = MODEL_CLASSES[self.config.model_type]['tokenizer']model_class = MODEL_CLASSES[self.config.model_type][self.config.wrapper_type]