# PCBERT: Parent and Child BERT for Chinese Few-shot NER

# 前言

这是一篇发表于 COLING 2022 的文章,本文主要分析其 PCBERT 的 CBERT 源码部分。

# C-BERT 源码部分

  1. 在 C-BERT 部分的 BERT 仍然保持 LEBERT 部分,主要的改动部分在于 prompt 部分,也就是对应的 P-BERT 部分。

    # [batch_size, 4, 512]
    prompt_inputs = it['prompt_input_ids']
    prompt_inputs = prompt_inputs.reshape(-1, 512) # [batch_size * 7, 512]
    prompt_origin_labels = it['prompt_origin_labels']
    prompt_origin_labels = prompt_origin_labels.reshape(-1, 512) # [batch_size * 7, 512]
    prompt_attention_mask = prompt_origin_labels.gt(0)
    prompt_indexed = []
    for i in range(it['input_ids'].shape[0]):
        it['prompt_indexes'][i] = it['prompt_indexes'][i] + 2048 * i
        prompt_indexed += it['prompt_indexes'][i]
    prompt_outputs = self.prompt_model(input_ids=prompt_inputs, attention_mask=prompt_attention_mask)
    prompt_hidden_states = prompt_outputs.last_hidden_state # [batch_size * 7, 512, 768]
    prompt_hidden_states = prompt_hidden_states * prompt_attention_mask.unsqueeze(-1).float() # [batch_size * 7, 512, 768]
    prompt_hidden_states = prompt_hidden_states.reshape(-1, 768) # [batch_size * 7 * 512, 768]
    prompt_entity_hs = prompt_hidden_states[prompt_indexed] # [batch_size * max_seq_len * entity_pad_len(4), 768]
    prompt_entity_hs = prompt_entity_hs.reshape(it['input_ids'].shape[0], -1, 4, 768) # [batch_size, max_seq_len, entity_pad_len(4), 768]
    # prompt_entity_hs_fusion = torch.mean(prompt_entity_hs, dim=2) # [batch_size, max_seq_len, 768]
    prompt_entity_hs_fusion = prompt_entity_hs.reshape(it['input_ids'].shape[0], it['input_ids'].shape[1], -1) # [batch_size, max_seq_len, 4 * 768]

    上述代码第 16 行是把每个 index 对应的 Label [MASK] 部分取出来,由于每个例子都是等长的,且取出的时候是按序的,因此可以直接 reshape 形成 [batch_size, seq_length, 4, 768] 这样的长度,其中 4 是 entity 的固定长度。

  2. 第二个改动部分是 CRF 部分,主要的改动是使用自注意力机制融合再进行叠加:

    features, masks = self.__build_features(embeds, masks)
    #[N,L,D*4] -> [N,L,D] 
    prompt_feature = self.dense(prompt)
    prompt_feature_weights = self.attention(prompt_feature)
    prompt_feature_weights = torch.softmax(prompt_feature_weights, dim=1)
    prompt_feature = prompt_feature * prompt_feature_weights
    prompt_feature = self.act(prompt_feature)
    prompt_feature = self.layerNorm(prompt_feature)
    fusion = torch.cat([features, prompt_feature[:, :features.shape[1], :]], dim=-1)
    fusion = self.fc1(fusion)

    经过一层注意力层(fc+relu+fc)获取注意力权重矩阵,然后得到提示特征。最后通过拼接 + 全连接降维得到最后的融合特征,而这个特征就作为 emission_score 传入 CRF。

更新于 阅读次数