# PCBERT: Parent and Child BERT for Chinese Few-shot NER
# 前言
这是一篇发表于 COLING 2022 的文章,本文主要分析其 PCBERT 的 CBERT 源码部分。
# C-BERT 源码部分
在 C-BERT 部分的 BERT 仍然保持 LEBERT 部分,主要的改动部分在于 prompt 部分,也就是对应的 P-BERT 部分。
# [batch_size, 4, 512]
prompt_inputs = it['prompt_input_ids']
prompt_inputs = prompt_inputs.reshape(-1, 512) # [batch_size * 7, 512]
prompt_origin_labels = it['prompt_origin_labels']
prompt_origin_labels = prompt_origin_labels.reshape(-1, 512) # [batch_size * 7, 512]
prompt_attention_mask = prompt_origin_labels.gt(0)
prompt_indexed = []
for i in range(it['input_ids'].shape[0]):
it['prompt_indexes'][i] = it['prompt_indexes'][i] + 2048 * i
prompt_indexed += it['prompt_indexes'][i]
prompt_outputs = self.prompt_model(input_ids=prompt_inputs, attention_mask=prompt_attention_mask)
prompt_hidden_states = prompt_outputs.last_hidden_state # [batch_size * 7, 512, 768]
prompt_hidden_states = prompt_hidden_states * prompt_attention_mask.unsqueeze(-1).float() # [batch_size * 7, 512, 768]
prompt_hidden_states = prompt_hidden_states.reshape(-1, 768) # [batch_size * 7 * 512, 768]
prompt_entity_hs = prompt_hidden_states[prompt_indexed] # [batch_size * max_seq_len * entity_pad_len(4), 768]
prompt_entity_hs = prompt_entity_hs.reshape(it['input_ids'].shape[0], -1, 4, 768) # [batch_size, max_seq_len, entity_pad_len(4), 768]
# prompt_entity_hs_fusion = torch.mean(prompt_entity_hs, dim=2) # [batch_size, max_seq_len, 768]
prompt_entity_hs_fusion = prompt_entity_hs.reshape(it['input_ids'].shape[0], it['input_ids'].shape[1], -1) # [batch_size, max_seq_len, 4 * 768]
上述代码第 16 行是把每个 index 对应的 Label [MASK] 部分取出来,由于每个例子都是等长的,且取出的时候是按序的,因此可以直接 reshape 形成 [batch_size, seq_length, 4, 768] 这样的长度,其中 4 是 entity 的固定长度。
第二个改动部分是 CRF 部分,主要的改动是使用自注意力机制融合再进行叠加:
features, masks = self.__build_features(embeds, masks)
#[N,L,D*4] -> [N,L,D]
prompt_feature = self.dense(prompt)
prompt_feature_weights = self.attention(prompt_feature)
prompt_feature_weights = torch.softmax(prompt_feature_weights, dim=1)
prompt_feature = prompt_feature * prompt_feature_weights
prompt_feature = self.act(prompt_feature)
prompt_feature = self.layerNorm(prompt_feature)
fusion = torch.cat([features, prompt_feature[:, :features.shape[1], :]], dim=-1)
fusion = self.fc1(fusion)
经过一层注意力层(fc+relu+fc)获取注意力权重矩阵,然后得到提示特征。最后通过拼接 + 全连接降维得到最后的融合特征,而这个特征就作为 emission_score 传入 CRF。