4000-520-616
欢迎来到免疫在线!(蚂蚁淘生物旗下平台)  请登录 |  免费注册 |  询价篮
主营:原厂直采,平行进口,授权代理(蚂蚁淘为您服务)
咨询热线电话
4000-520-616
当前位置: 首页 > 新闻动态 >
新闻详情
基于BERT fine-tuning的中文标题分类实战 - 算法网
来自 : ddrv.cn/a/585... 发布时间:2021-03-25

BERT的问世向世人宣告了无监督预训练的语言模型在众多NLP任务中成为“巨人肩膀”的可能性,接踵而出的GPT2、XL-Net则不断将NLP从业者的期望带向了新的高度。得益于这些力作模型的开源,使得我们在了解其论文思想的基础上,可以借力其凭借强大算力预训练的模型从而快速在自己的数据集上开展实验,甚至应用于真实的业务中。

在GitHub上已经存在使用多种语言/框架依照Google最初release的TensorFlow版本的代码进行实现的Pretrained-BERT,并且都提供了较为详细的文档。本文主要展示通过极简的代码调用Pytorch Pretrained-BERT并进行fine-tuning的文本分类任务。

下面的代码是使用pytorch-pretrained-BERT进行文本分类的官方实现,感兴趣的同学可以直接点进去阅读:

https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/examples/run_classifier.py​ github.com

本文所使用的数据是标题及其对应的类别,如“中国的垃圾分类能走多远”对应“社会”类别,共有28个类别,每个类别的训练数据和测试数据各有1000条,数据已经同步至云盘,欢迎下载。链接:

https://pan.baidu.com/s/1r4SI6-IizlCcsyMGL7RU8Q​ pan.baidu.com

提取码: 6awx

import osimport sysimport pickleimport pandas as pdimport numpy as npfrom concurrent.futures import ThreadPoolExecutorimport torchimport picklefrom sklearn.preprocessing import LabelEncoderfrom torch.optim import optimizerfrom torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDatasetfrom torch.nn import CrossEntropyLoss,BCEWithLogitsLossfrom tqdm import tqdm_notebook, trangefrom pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM, BertForSequenceClassificationfrom pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedulefrom sklearn.metrics import precision_recall_curve,classification_reportimport matplotlib.pyplot as plt%matplotlib inline
# pandas读取数据data = pd.read_pickle(\"title_category.pkl\")# 列名重新命名data.columns = [\'text\',\'label\']

因为label为中文格式,为了适应模型的输入需要进行ID化,此处调用sklearn中的label encoder方法快速进行变换。

le = LabelEncoder()le.fit(data.label.tolist())data[\'label\'] = le.transform(data.label.tolist())

\"《基于BERT
\"《基于BERT

训练数据准备

本文需要使用的预训练bert模型为使用中文维基语料训练的字符级别的模型,在Google提供的模型列表中对应的名称为 bert-base-chinese ,使用更多语言语料训练的模型名称可以参见下方链接:https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/pytorch_pretrained_bert/modeling.py。

另外,首次执行下面的代码时因为本地没有cache,因此会自动启动下载,实践证明下载速度还是很快的。需要注意的是,do_lower_case参数需要手动显式的设置为False。

# 分词工具bert_tokenizer = BertTokenizer.from_pretrained(\'bert-base-chinese\', do_lower_case=False)# 封装类class DataPrecessForSingleSentence(object): \"\"\" 对文本进行处理 \"\"\" def __init__(self, bert_tokenizer, max_workers=10): \"\"\" bert_tokenizer :分词器 dataset :包含列名为\'text\'与\'label\'的pandas dataframe \"\"\" self.bert_tokenizer = bert_tokenizer # 创建多线程池 self.pool = ThreadPoolExecutor(max_workers=max_workers) # 获取文本与标签 def get_input(self, dataset, max_seq_len=30): \"\"\" 通过多线程(因为notebook中多进程使用存在一些问题)的方式对输入文本进行分词、ID化、截断、填充等流程得到最终的可用于模型输入的序列。 入参: dataset : pandas的dataframe格式,包含两列,第一列为文本,第二列为标签。标签取值为{0,1},其中0表示负样本,1代表正样本。 max_seq_len : 目标序列长度,该值需要预先对文本长度进行分别得到,可以设置为小于等于512(BERT的最长文本序列长度为512)的整数。 出参: seq : 在入参seq的头尾分别拼接了\'CLS\'与\'SEP\'符号,如果长度仍小于max_seq_len,则使用0在尾部进行了填充。 seq_mask : 只包含0、1且长度等于seq的序列,用于表征seq中的符号是否是有意义的,如果seq序列对应位上为填充符号, 那么取值为1,否则为0。 seq_segment : shape等于seq,因为是单句,所以取值都为0。 labels : 标签取值为{0,1},其中0表示负样本,1代表正样本。 \"\"\" sentences = dataset.iloc[:, 0].tolist() labels = dataset.iloc[:, 1].tolist() # 切词 tokens_seq = list( self.pool.map(self.bert_tokenizer.tokenize, sentences)) # 获取定长序列及其mask result = list( self.pool.map(self.trunate_and_pad, tokens_seq, [max_seq_len] * len(tokens_seq))) seqs = [i[0] for i in result] seq_masks = [i[1] for i in result] seq_segments = [i[2] for i in result] return seqs, seq_masks, seq_segments, labels def trunate_and_pad(self, seq, max_seq_len): \"\"\" 1. 因为本类处理的是单句序列,按照BERT中的序列处理方式,需要在输入序列头尾分别拼接特殊字符\'CLS\'与\'SEP\', 因此不包含两个特殊字符的序列长度应该小于等于max_seq_len-2,如果序列长度大于该值需要那么进行截断。 2. 对输入的序列 最终形成[\'CLS\',seq,\'SEP\']的序列,该序列的长度如果小于max_seq_len,那么使用0进行填充。 入参: seq : 输入序列,在本处其为单个句子。 max_seq_len : 拼接\'CLS\'与\'SEP\'这两个特殊字符后的序列长度 出参: seq : 在入参seq的头尾分别拼接了\'CLS\'与\'SEP\'符号,如果长度仍小于max_seq_len,则使用0在尾部进行了填充。 seq_mask : 只包含0、1且长度等于seq的序列,用于表征seq中的符号是否是有意义的,如果seq序列对应位上为填充符号, 那么取值为1,否则为0。 seq_segment : shape等于seq,因为是单句,所以取值都为0。 \"\"\" # 对超长序列进行截断 if len(seq) (max_seq_len - 2): seq = seq[0:(max_seq_len - 2)] # 分别在首尾拼接特殊符号 seq = [\'[CLS]\'] + seq + [\'[SEP]\'] # ID化 seq = self.bert_tokenizer.convert_tokens_to_ids(seq) # 根据max_seq_len与seq的长度产生填充序列 padding = [0] * (max_seq_len - len(seq)) # 创建seq_mask seq_mask = [1] * len(seq) + padding # 创建seq_segment seq_segment = [0] * len(seq) + padding # 对seq拼接填充序列 seq += padding assert len(seq) == max_seq_len assert len(seq_mask) == max_seq_len assert len(seq_segment) == max_seq_len return seq, seq_mask, seq_segment

DataPrecessForSingleSentence是一个用于将pandas Dataframe转化为模型输入的类,每个函数的入参和出参已经写得比较清晰翔实了。处理流程大致如下:

通过多线程的方式进行调用tokenize进行切词(字符级别)对于切词产生的序列如果长度大于设置的max_seq_len-2时需要进行截断。BERT中使用的max_seq_len是512,因此最长不可以超过512个字符。另外,本处需要减2的原因在于还需要在原始序列上拼接两个特殊符号,因此需要预留两个字符的“槽位”。在首、尾分别拼接\'[CLS] 及\'[SEP] ,如果序列长度不足max_seq_len,使用0进行填充。产生相应的mask序列和segment序列,其中mask序列使用0、1值标注对应位上是否为填充符号,如果是那么取值为0,负责为1,如果序列长度不足max_seq_len,使用0进行填充。segment序列则用于表示序列是否为同一个输入源,在本例中取值全部为0,如果序列长度不足max_seq_len,使用0进行填充。对于填充后的序列进行ID化,调用的是convert_tokens_to_ids方法,最终返回seq,seq_mask 与seq_segment序列。
# 类初始化processor = DataPrecessForSingleSentence(bert_tokenizer= bert_tokenizer)# 产生输入ju 数据seqs, seq_masks, seq_segments, labels = processor.get_input( dataset=data, max_seq_len=30)

本文设定的max_seq_len为30,因为通过统计标题的长度可以得知30已经是其85百分位数,基本已经涵盖了绝大部分样本。

加载预训练的bert模型
# 加载预训练的bert模型model = BertForSequenceClassification.from_pretrained( \'bert-base-chinese\', num_labels=28)

同样,首次执行会自动启动下载,在本例中因为有28个类别,因此num_labels参数需要设置为28。

数据格式化

数据格式化指的是将list格式的数据转化为torch的tensor格式。

# 转换为torch tensort_seqs = torch.tensor(seqs, dtype=torch.long)t_seq_masks = torch.tensor(seq_masks, dtype = torch.long)t_seq_segments = torch.tensor(seq_segments, dtype = torch.long)t_labels = torch.tensor(labels, dtype = torch.long)train_data = TensorDataset(t_seqs, t_seq_masks, t_seq_segments, t_labels)train_sampler = RandomSampler(train_data)train_dataloder = DataLoader(dataset= train_data, sampler= train_sampler,batch_size = 256)

使用了TensorDataset、RandomSampler、DataLoader对输入数据进行了封装,相较于自己编写generator代码量简短很多,此处设置的batch size为256。

# 将模型转换为trin modemodel.train()BertForSequenceClassification( (bert): BertModel( (embeddings): BertEmbeddings( (word_embeddings): Embedding(21128, 768, padding_idx=0) (position_embeddings): Embedding(512, 768) (token_type_embeddings): Embedding(2, 768) (LayerNorm): BertLayerNorm() (dropout): Dropout(p=0.1) (encoder): BertEncoder( (layer): ModuleList( (0): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): BertLayerNorm() (dropout): Dropout(p=0.1) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): BertLayerNorm() (dropout): Dropout(p=0.1) (1): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): BertLayerNorm() (dropout): Dropout(p=0.1) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): BertLayerNorm() (dropout): Dropout(p=0.1) (2): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): BertLayerNorm() (dropout): Dropout(p=0.1) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): BertLayerNorm() (dropout): Dropout(p=0.1) (3): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): BertLayerNorm() (dropout): Dropout(p=0.1) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): BertLayerNorm() (dropout): Dropout(p=0.1) (4): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): BertLayerNorm() (dropout): Dropout(p=0.1) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): BertLayerNorm() (dropout): Dropout(p=0.1) (5): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): BertLayerNorm() (dropout): Dropout(p=0.1) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): BertLayerNorm() (dropout): Dropout(p=0.1) (6): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): BertLayerNorm() (dropout): Dropout(p=0.1) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): BertLayerNorm() (dropout): Dropout(p=0.1) (7): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): BertLayerNorm() (dropout): Dropout(p=0.1) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): BertLayerNorm() (dropout): Dropout(p=0.1) (8): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): BertLayerNorm() (dropout): Dropout(p=0.1) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): BertLayerNorm() (dropout): Dropout(p=0.1) (9): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): BertLayerNorm() (dropout): Dropout(p=0.1) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): BertLayerNorm() (dropout): Dropout(p=0.1) (10): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): BertLayerNorm() (dropout): Dropout(p=0.1) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): BertLayerNorm() (dropout): Dropout(p=0.1) (11): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): BertLayerNorm() (dropout): Dropout(p=0.1) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): BertLayerNorm() (dropout): Dropout(p=0.1) (pooler): BertPooler( (dense): Linear(in_features=768, out_features=768, bias=True) (activation): Tanh() (dropout): Dropout(p=0.1) (classifier): Linear(in_features=768, out_features=28, bias=True)

从打印出的网络结构可以看出,classifier层的out_features已经设置为了上文的提到的28。另外,我们可以关注一下BertPooler层,如果对于前面步骤中在序列头部拼接[CLS]有疑问的话,通过阅读BertPooler的代码可以明晰该字符的用处。

# link : https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/pytorch_pretrained_bert/modeling.pyclass BertPooler(nn.Module): def __init__(self, config): super(BertPooler, self).__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.activation = nn.Tanh() def forward(self, hidden_states): # We \"pool\" the model by simply taking the hidden state corresponding # to the first token. first_token_tensor = hidden_states[:, 0] pooled_output = self.dense(first_token_tensor) pooled_output = self.activation(pooled_output) return pooled_output

上面的代码是BertPooler的实现,可以看出在forward方法中hidden_states[:, 0]只取了第一个字符对应的hidden unit,因此凭借双向Encoder的表征能力,\'[CLS] 符号融合了整个序列的表征信息,因此可以用于以一种低维的方式对整个序列进行表征。

# 待优化的参数param_optimizer = list(model.named_parameters())no_decay = [\'bias\', \'LayerNorm.bias\', \'LayerNorm.weight\']optimizer_grouped_parameters = [ \'params\': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], \'weight_decay\': 0.01 \'params\': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], \'weight_decay\':optimizer = BertAdam(optimizer_grouped_parameters, lr=2e-05, warmup= 0.1 , t_total= 2000)device = \'cpu\'

我记得当时在看《动手学深度学习》一书(3.12节)时,李沐提到权重衰减等价于L2正则化。在bert官方的代码中对于bias项、LayerNorm.bias、LayerNorm.weight项免于正则化。

fine-tuning
# 存储每一个batch的lossloss_collect = []for i in trange(10, desc=\'Epoch\'): for step, batch_data in enumerate( tqdm_notebook(train_dataloder, desc=\'Iteration\')): batch_data = tuple(t.to(device) for t in batch_data) batch_seqs, batch_seq_masks, batch_seq_segments, batch_labels = batch_data # 对标签进行onehot编码 one_hot = torch.zeros(batch_labels.size(0), 28).long() one_hot_batch_labels = one_hot.scatter_( dim=1, index=torch.unsqueeze(batch_labels, dim=1), src=torch.ones(batch_labels.size(0), 28).long()) logits = model( batch_seqs, batch_seq_masks, batch_seq_segments, labels=None) logits = logits.softmax(dim=1) loss_function = CrossEntropyLoss() loss = loss_function(logits, batch_labels) loss.backward() loss_collect.append(loss.item()) print(\"\\r%f\" % loss, end=\'\') optimizer.step() optimizer.zero_grad()

总共进行了10个epoch的训练,将各个batch的loss写入了loss_collect,下面对loss_collect进行可视化。

loss可视化
plt.figure(figsize=(12,8))plt.plot(range(len(loss_collect)), loss_collect,\'g.\')plt.grid(True)plt.show()

\"《基于BERT
\"《基于BERT

从上图可以看出,loss在前200个batch下降速度明显,随后下降速度逐渐变缓,但从整体趋势以及纵轴的loss绝对值可以看出,loss距离收敛还存在一定空间,如果增大训练样本量及迭代次数,loss依然可以继续减小。

模型持久化
torch.save(model,open(\"fine_tuned_chinese_bert.bin\",\"wb\"))
加载测试数据
test_data = pd.read_pickle(\"title_category_valid.pkl\")test_data.columns = [\'text\',\'label\']# 标签ID化test_data[\'label\'] = le.transform(test_data.label.tolist())# 转换为tensortest_seqs, test_seq_masks, test_seq_segments, test_labels = processor.get_input( dataset=test_data, max_seq_len=30)test_seqs = torch.tensor(test_seqs, dtype=torch.long)test_seq_masks = torch.tensor(test_seq_masks, dtype = torch.long)test_seq_segments = torch.tensor(test_seq_segments, dtype = torch.long)test_labels = torch.tensor(test_labels, dtype = torch.long)test_data = TensorDataset(test_seqs, test_seq_masks, test_seq_segments, test_labels)test_dataloder = DataLoader(dataset= train_data, batch_size = 256)# 用于存储预测标签与真实标签true_labels = []pred_labels = []model.eval()with torch.no_grad(): for batch_data in tqdm_notebook(test_dataloder, desc = \'TEST\'): batch_data = tuple(t.to(device) for t in batch_data) batch_seqs, batch_seq_masks, batch_seq_segments, batch_labels = batch_data  logits = model( batch_seqs, batch_seq_masks, batch_seq_segments, labels=None) logits = logits.softmax(dim=1).argmax(dim = 1) pred_labels.append(logits.detach().numpy()) true_labels.append(batch_labels.detach().numpy())# 查看各个类别的准召print(classification_report(np.concatenate(true_labels), np.concatenate(pred_labels))) 
 precision recall f1-score support 0 0.93 0.95 0.94 1000 1 0.88 0.90 0.89 1000 2 0.91 0.92 0.91 1000 3 0.88 0.95 0.92 1000 4 0.88 0.92 0.90 1000 5 0.91 0.91 0.91 1000 6 0.85 0.84 0.84 1000 7 0.93 0.97 0.95 1000 8 0.88 0.94 0.91 1000 9 0.77 0.86 0.81 1000 10 0.97 0.94 0.96 1000 11 0.85 0.90 0.88 1000 12 0.91 0.97 0.94 1000 13 0.75 0.86 0.80 1000 14 0.84 0.90 0.87 1000 15 0.77 0.87 0.82 1000 16 0.91 0.95 0.93 1000 17 0.96 0.95 0.95 1000 18 0.91 0.93 0.92 1000 19 0.92 0.94 0.93 1000 20 0.94 0.93 0.93 1000 21 0.80 0.80 0.80 1000 22 0.93 0.97 0.95 1000 23 0.82 0.86 0.84 1000 24 0.00 0.00 0.00 1000 25 0.92 0.93 0.93 1000 26 0.89 0.90 0.89 1000 27 0.89 0.89 0.89 1000 micro avg 0.88 0.88 0.88 28000 macro avg 0.85 0.88 0.86 28000weighted avg 0.85 0.88 0.86 28000

可以看出,整体的准召还是比较理想的,不过因为训练和测试都是使用的平衡数据集,因此在真实分布上的准召与该数据集存在一定差异。

本文主要是对run_classifier.py的代码进行了简化,然后在中文数据集上进行了fine-tuning。具体的数据集和代码在文中进行了提供和展示,欢迎交流!

本文链接: http://npfine.immuno-online.com/view-728497.html

发布于 : 2021-03-25 阅读(0)