本篇post主要记录本人在阅读bert官方源码的run_squad.py文件时的一些理解与思考。
概述
run_squad.py文件是google官方给出的基于bert在公开数据集SQuAD2.0上做阅读理解的demo。国内有好多MRC方向的竞赛,如duReader、讯飞杯(CMRC)、法研杯等都是基于SQuAD演变而来,因此run_squad.py可作为MRC的一个切入点。
数据探索
斯坦福大学在2016年推出了SQuAD1.1,并于2018年推出了升级版SQuAD2.0。SQuAD1.1数据集通过众包的方式产生了10W+的问题,且答案为相应文章中的一个片段(抽取式阅读理解)。SQuAD2.0在1.1版本的基础上引入了5W+无法回答的问题,进一步提升了MRC的难度。下面重点分析下SQuAD2.0数据集中的格式与内容。
训练集与验证集都是以json文件的形式给出,简化版的格式与内容如下:
{
	"version": "v2.0",
	"data":[
		{
	    "title": "The_Legend_of_Zelda:_Twilight_Princess",
	    "paragraphs": [
	    	{
	    		"qas":[
	    				{
				        "question": "What year was the Wii version of Legend of Zelda: Twilight Princess released?",
				        "id": "56d1159e17492d1400aab8cf",
				        "answers": [
				        {
			            "text": "2006",
			            "answer_start": 578
				        }],
				        "is_impossible": false
					    },
					    {
				        "plausible_answers": [
				        {
			            "text": "action-adventure",
			            "answer_start": 128
				        }],
				        "question": "What category of game is Legend of Zelda: Australia Twilight?",
				        "id": "5a8d7bf7df8bba001a0f9ab1",
				        "answers": [],
				        "is_impossible": true
					    }
	    		],
	    		"context": "The Legend of Zelda: Twilight Princess (Japanese: \u30bc\u30eb\u30c0\u306e\u4f1d\u8aac \u30c8\u30ef\u30a4\u30e9\u30a4\u30c8\u30d7\u30ea\u30f3\u30bb\u30b9, Hepburn: Zeruda no Densetsu: Towairaito Purinsesu?) is an action-adventure game developed and published by Nintendo for the GameCube and Wii home video game consoles. It is the thirteenth installment in the The Legend of Zelda series. Originally planned for release on the GameCube in November 2005, Twilight Princess was delayed by Nintendo to allow its developers to refine the game, add more content, and port it to the Wii. The Wii version was released alongside the console in North America in November 2006, and in Japan, Europe, and Australia the following month. The GameCube version was released worldwide in December 2006.[b]"
	    	},
	    	{
	    		"qas":[
	    			{
			        "question": "This storyline takes place alternate from what storyline?",
			        "id": "56d1162317492d1400aab8ef",
			        "answers": [
			        {
		            "text": "The Wind Waker",
		            "answer_start": 378
			        }],
			        "is_impossible": false
				    },
				    {
			        "plausible_answers": [
			        {
		            "text": "Hyrule",
		            "answer_start": 67
			        }],
			        "question": "What land does Ocarina serve to protect?",
			        "id": "5a8d800edf8bba001a0f9abb",
			        "answers": [],
			        "is_impossible": true
				    }	
	    		],
	    		"context": "The story focuses on series protagonist Link, who tries to prevent Hyrule from being engulfed by a corrupted parallel dimension known as the Twilight Realm. To do so, he takes the form of both a Hylian and a wolf, and is assisted by a mysterious creature named Midna. The game takes place hundreds of years after Ocarina of Time and Majora's Mask, in an alternate timeline from The Wind Waker."
	    	}
	    ]
		}
	]
}
说明:从外往里看,version属性表示数据集所属版本(v1.1/v2.0),data属性为一个数组,里面存储着每一篇文章(简化版只给出了一篇文章以作说明);而一篇文章有两个属性title(文章标题)和paragraphs(文章的不同段落,简化版给出了一篇文章的两个段落);而每个段落有两个属性context(文章某个段落的内容)和qas(针对段落内容提的多个问题及对应答案的集合);而qas中分为两种问题,一种是可回答的问题,另一种是无法回答的问题。
当问题是可回答的问题时,qas中的对象有如下几个属性:is_impossible(false,表示可回答的问题)、id(问题编号)、question(具体问题)、answers(对应问题的答案,训练集中有且只有一个正确答案,验证集中会有多个相同/不同的答案,评估时按性能值最大的取值,具体见评估脚本);当问题是无法回答的问题时,qas中的对象有如下几个属性:is_impossible(true,表示无法回答的问题)、id(问题编号)、question(具体问题)、answers(空数组)、plausible_answers(似是而非(不正确)的答案)。在可回答问题中的answers与无法回答问题中的plausible_answers内又有两种属性text(答案的文本)与answer_start(答案在对应段落中的起始位置,按字符进行定位的)
至此,SQuAD2.0数据集的格式与内容算是说明完毕了。详细的数据内容分布可参考斯坦福的官方paper。
run_squad.py on colab
经过上述数据集的分析,可先感性的在colab中跑一波代码。
为了方便自己加日志分析,直接fork到自己的仓库先污染一波官方源码。。。开搞
!git clone https://github.com/carlos9310/bert.git
下载预训练的模型并解压
!wget https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip && unzip uncased_L-12_H-768_A-12.zip
下载SQuAD2.0的训练语料与验证语料
!mkdir SQUAD_DIR && cd SQUAD_DIR && wget https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json
训练(说明: 训练过程中若发现本地内存(我本地16G)不够时,及时清理输出日志,即可释放浏览器占用的内存。本人训练时长: Wall time: 6h 49min 35s)
%%time
!python bert/run_squad.py \
  --vocab_file=uncased_L-12_H-768_A-12/vocab.txt \
  --bert_config_file=uncased_L-12_H-768_A-12/bert_config.json \
  --init_checkpoint=uncased_L-12_H-768_A-12/bert_model.ckpt \
  --do_train=True \
  --train_file=SQUAD_DIR/train-v2.0.json \
  --train_batch_size=8 \
  --learning_rate=3e-5 \
  --num_train_epochs=1.0 \
  --max_seq_length=384 \
  --doc_stride=128 \
  --output_dir=/tmp/squad2.0_base/ \
  --version_2_with_negative=True
预测(预测的时长: Wall time: 13min 58s)
%%time
!python bert/run_squad.py \
  --vocab_file=uncased_L-12_H-768_A-12/vocab.txt \
  --bert_config_file=uncased_L-12_H-768_A-12/bert_config.json \
  --init_checkpoint=uncased_L-12_H-768_A-12/bert_model.ckpt \
  --do_predict=True \
  --predict_file=SQUAD_DIR/dev-v2.0.json \
  --learning_rate=3e-5 \
  --max_seq_length=384 \
  --doc_stride=128 \
  --output_dir=/tmp/squad2.0_base/ \
  --version_2_with_negative=True
基于上述生成的预测结果,生成是否可回答的问题的阀值null_score_diff_threshold,初始值为0.0(这部分内容官方仓库的readme有做说明)
说明: evaluate-v2.0.py脚本文件需先手动上传到SQUAD_DIR下,wget下载不下来
!cd SQUAD_DIR && python evaluate-v2.0.py  dev-v2.0.json /tmp/squad2.0_base/predictions.json --na-prob-file /tmp/squad2.0_base/null_odds.json
输出为
{
  "exact": 71.89421376231786,
  "f1": 75.11333314674677,
  "total": 11873,
  "HasAns_exact": 71.94669365721997,
  "HasAns_f1": 78.39416404374565,
  "HasAns_total": 5928,
  "NoAns_exact": 71.84188393608073,
  "NoAns_f1": 71.84188393608073,
  "NoAns_total": 5945,
  "best_exact": 73.42710351217048,
  "best_exact_thresh": -2.1611831188201904,
  "best_f1": 76.15613215155969,
  "best_f1_thresh": -1.867471694946289
}
基于上述生成的best_f1_thresh,增加此参数,重新预测
%%time
!python bert/run_squad.py \
  --vocab_file=uncased_L-12_H-768_A-12/vocab.txt \
  --bert_config_file=uncased_L-12_H-768_A-12/bert_config.json \
  --init_checkpoint=uncased_L-12_H-768_A-12/bert_model.ckpt \
  --do_predict=True \
  --predict_file=SQUAD_DIR/dev-v2.0.json \
  --learning_rate=3e-5 \
  --max_seq_length=384 \
  --doc_stride=128 \
  --output_dir=/tmp/squad2.0_base/ \
  --version_2_with_negative=True \
  --null_score_diff_threshold=-1.867471694946289
对重新预测后的结果再进行评估
!cd SQUAD_DIR && python evaluate-v2.0.py  dev-v2.0.json /tmp/squad2.0_base/predictions.json
输出结果
{
  "exact": 73.39341362755833,
  "f1": 76.15613215155983,
  "total": 11873,
  "HasAns_exact": 65.97503373819163,
  "HasAns_f1": 71.50839356198873,
  "HasAns_total": 5928,
  "NoAns_exact": 80.7905803195963,
  "NoAns_f1": 80.7905803195963,
  "NoAns_total": 5945
}
调整阀值后,预测结果有所提升。
至此,SQuAD2.0从训练到验证的流程算是跑完了。至于测试,需到官网上进行提交代码、模型,由官方亲自测试,最终会将结果公布到排行榜(Leaderboard)上。
思路梳理
下面大概梳理下run_squad.py中的处理思路与方法。
由于采用了预训练(pretrained)+微调(fine-tuning)的范式做MRC任务,因此在模型结构上显得简单明了,直接在bert的输出后加一层全连接层即可完成MRC模型的搭建(本质上是一种分类模型)。而run_squad.py工作的重心主要在数据预处理上。
首先,读取json格式的数据集文件,将文件中的内容转成SquadExample对象组成的list,其中每个SquadExample对象有如下几个特征:问题id、问题内容question、段落内容context、答案文本text(无法回答时为空字符串)、答案的起始位置start_position、答案的终止位置end_position(词级别的,已将给定的字符级别的位置转到词级别,这里的词是以空格进行split的)和问题是否可回答的标志is_impossible。
此时,只是将原有的层级结构(非线性)的json格式的数据集转化成水平结构(线性)的list形式,其中每个元素对应一个样本,每个样本有如下特征(qas_id,question,doc_tokens,answer,start_position,end_position,is_impossible)。
具体代码如下:
def read_squad_examples(input_file, is_training):
  """Read a SQuAD json file into a list of SquadExample."""
  with tf.gfile.Open(input_file, "r") as reader:
    input_data = json.load(reader)["data"]
  def is_whitespace(c):
    if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
      return True
    return False
  examples = []
  for entry in input_data:
    for paragraph in entry["paragraphs"]:
      paragraph_text = paragraph["context"]
      doc_tokens = []
      char_to_word_offset = []
      prev_is_whitespace = True
      for c in paragraph_text:
        if is_whitespace(c):
          prev_is_whitespace = True
        else:
          if prev_is_whitespace:
            doc_tokens.append(c)
          else:
            doc_tokens[-1] += c
          prev_is_whitespace = False
        char_to_word_offset.append(len(doc_tokens) - 1)
      for qa in paragraph["qas"]:
        qas_id = qa["id"]
        question_text = qa["question"]
        start_position = None
        end_position = None
        orig_answer_text = None
        is_impossible = False
        if is_training:
          if FLAGS.version_2_with_negative:
            is_impossible = qa["is_impossible"]
          if (len(qa["answers"]) != 1) and (not is_impossible):
            raise ValueError(
                "For training, each question should have exactly 1 answer.")
          if not is_impossible:
            answer = qa["answers"][0]
            orig_answer_text = answer["text"]
            answer_offset = answer["answer_start"]
            answer_length = len(orig_answer_text)
            start_position = char_to_word_offset[answer_offset]
            end_position = char_to_word_offset[answer_offset + answer_length -
                                               1]
            # Only add answers where the text can be exactly recovered from the
            # document. If this CAN'T happen it's likely due to weird Unicode
            # stuff so we will just skip the example.
            #
            # Note that this means for training mode, every example is NOT
            # guaranteed to be preserved.
            actual_text = " ".join(
                doc_tokens[start_position:(end_position + 1)])
            cleaned_answer_text = " ".join(
                tokenization.whitespace_tokenize(orig_answer_text))
            if actual_text.find(cleaned_answer_text) == -1:
              tf.logging.warning("Could not find answer: '%s' vs. '%s'",
                                 actual_text, cleaned_answer_text)
              continue
          else:
            start_position = -1
            end_position = -1
            orig_answer_text = ""
        example = SquadExample(
            qas_id=qas_id,
            question_text=question_text,
            doc_tokens=doc_tokens,
            orig_answer_text=orig_answer_text,
            start_position=start_position,
            end_position=end_position,
            is_impossible=is_impossible)
        examples.append(example)
  return examples
接着将上述生成的由SquadExample对象组成的list转换成符合bert输入格式的特征,并将转换后的每个特征的部分属性依此写入到.tf_record格式的文件中。
具体地,先将上述生成的SquadExample对象中词级别的question、doc_tokens内容进一步转化为WordPiece级别的token,同时将词级别的start_position、end_position也转换成WordPiece级别对应的位置tok_start_position、tok_end_position。
然后将WordPiece后的token拼接成[CLS] question_tokens [SEP] context_tokens [SEP]的形式。但是考虑到bert对输入长度有限制,即拼接后的总tokens数目不能超过max_seq_length。
为了能建模更长的段落内容context_tokens,采取了一种滑动窗口的策略。先确定当前段落支持的最大tokens数,不超过的段落正常处理,超过的段落以滑动窗口的形式分割成多个子段落。确定好分割后的段落doc_spans,再依次进行拼接成如下形式:[CLS] question_tokens [SEP] doc_span_tokens [SEP]。基于拼接后的WordPiece级别的tokens生成input_ids、input_mask、segment_ids、start_position、end_position等features(训练集中对于无法回答的问题,将start_position与end_position设为0,即对应[CLS]的位置),然后将上述特征转换成InputFeatures对象的属性,最后写入到.tf_record格式的文件中。
具体代码如下:
def convert_examples_to_features(examples, tokenizer, max_seq_length,
                                 doc_stride, max_query_length, is_training,
                                 output_fn):
  """Loads a data file into a list of `InputBatch`s."""
  unique_id = 1000000000
  for (example_index, example) in enumerate(examples):
    query_tokens = tokenizer.tokenize(example.question_text)
    if len(query_tokens) > max_query_length:
      query_tokens = query_tokens[0:max_query_length]
    tok_to_orig_index = []
    orig_to_tok_index = []
    all_doc_tokens = []
    for (i, token) in enumerate(example.doc_tokens):
      orig_to_tok_index.append(len(all_doc_tokens))
      sub_tokens = tokenizer.tokenize(token)
      for sub_token in sub_tokens:
        tok_to_orig_index.append(i)
        all_doc_tokens.append(sub_token)
    tok_start_position = None
    tok_end_position = None
    if is_training and example.is_impossible:
      tok_start_position = -1
      tok_end_position = -1
    if is_training and not example.is_impossible:
      tok_start_position = orig_to_tok_index[example.start_position]
      if example.end_position < len(example.doc_tokens) - 1:
        tok_end_position = orig_to_tok_index[example.end_position + 1] - 1
      else:
        tok_end_position = len(all_doc_tokens) - 1
      (tok_start_position, tok_end_position) = _improve_answer_span(
          all_doc_tokens, tok_start_position, tok_end_position, tokenizer,
          example.orig_answer_text)
    # The -3 accounts for [CLS], [SEP] and [SEP]
    max_tokens_for_doc = max_seq_length - len(query_tokens) - 3
    # We can have documents that are longer than the maximum sequence length.
    # To deal with this we do a sliding window approach, where we take chunks
    # of the up to our max length with a stride of `doc_stride`.
    _DocSpan = collections.namedtuple(  # pylint: disable=invalid-name
        "DocSpan", ["start", "length"])
    doc_spans = []
    start_offset = 0
    while start_offset < len(all_doc_tokens):
      length = len(all_doc_tokens) - start_offset
      if length > max_tokens_for_doc:
        length = max_tokens_for_doc
      doc_spans.append(_DocSpan(start=start_offset, length=length))
      if start_offset + length == len(all_doc_tokens):
        break
      start_offset += min(length, doc_stride)
    for (doc_span_index, doc_span) in enumerate(doc_spans):
      tokens = []
      token_to_orig_map = {}
      token_is_max_context = {}
      segment_ids = []
      tokens.append("[CLS]")
      segment_ids.append(0)
      for token in query_tokens:
        tokens.append(token)
        segment_ids.append(0)
      tokens.append("[SEP]")
      segment_ids.append(0)
      for i in range(doc_span.length):
        split_token_index = doc_span.start + i
        token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]
        is_max_context = _check_is_max_context(doc_spans, doc_span_index,
                                               split_token_index)
        token_is_max_context[len(tokens)] = is_max_context
        tokens.append(all_doc_tokens[split_token_index])
        segment_ids.append(1)
      tokens.append("[SEP]")
      segment_ids.append(1)
      input_ids = tokenizer.convert_tokens_to_ids(tokens)
      # The mask has 1 for real tokens and 0 for padding tokens. Only real
      # tokens are attended to.
      input_mask = [1] * len(input_ids)
      # Zero-pad up to the sequence length.
      while len(input_ids) < max_seq_length:
        input_ids.append(0)
        input_mask.append(0)
        segment_ids.append(0)
      assert len(input_ids) == max_seq_length
      assert len(input_mask) == max_seq_length
      assert len(segment_ids) == max_seq_length
      start_position = None
      end_position = None
      if is_training and not example.is_impossible:
        # For training, if our document chunk does not contain an annotation
        # we throw it out, since there is nothing to predict.
        doc_start = doc_span.start
        doc_end = doc_span.start + doc_span.length - 1
        out_of_span = False
        if not (tok_start_position >= doc_start and
                tok_end_position <= doc_end):
          out_of_span = True
        if out_of_span:
          start_position = 0
          end_position = 0
        else:
          doc_offset = len(query_tokens) + 2
          start_position = tok_start_position - doc_start + doc_offset
          end_position = tok_end_position - doc_start + doc_offset
      if is_training and example.is_impossible:
        start_position = 0
        end_position = 0
      if example_index < 20:
        tf.logging.info("*** Example ***")
        tf.logging.info("unique_id: %s" % (unique_id))
        tf.logging.info("example_index: %s" % (example_index))
        tf.logging.info("doc_span_index: %s" % (doc_span_index))
        tf.logging.info("tokens: %s" % " ".join(
            [tokenization.printable_text(x) for x in tokens]))
        tf.logging.info("token_to_orig_map: %s" % " ".join(
            ["%d:%d" % (x, y) for (x, y) in six.iteritems(token_to_orig_map)]))
        tf.logging.info("token_is_max_context: %s" % " ".join([
            "%d:%s" % (x, y) for (x, y) in six.iteritems(token_is_max_context)
        ]))
        tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
        tf.logging.info(
            "input_mask: %s" % " ".join([str(x) for x in input_mask]))
        tf.logging.info(
            "segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
        if is_training and example.is_impossible:
          tf.logging.info("impossible example")
        if is_training and not example.is_impossible:
          answer_text = " ".join(tokens[start_position:(end_position + 1)])
          tf.logging.info("start_position: %d" % (start_position))
          tf.logging.info("end_position: %d" % (end_position))
          tf.logging.info(
              "answer: %s" % (tokenization.printable_text(answer_text)))
      feature = InputFeatures(
          unique_id=unique_id,
          example_index=example_index,
          doc_span_index=doc_span_index,
          tokens=tokens,
          token_to_orig_map=token_to_orig_map,
          token_is_max_context=token_is_max_context,
          input_ids=input_ids,
          input_mask=input_mask,
          segment_ids=segment_ids,
          start_position=start_position,
          end_position=end_position,
          is_impossible=example.is_impossible)
      # Run callback
      output_fn(feature)
      unique_id += 1
上述代码中的_check_is_max_context函数记录了长段落被划分为多个子段落后,每个子段落中的每个token是否处在最大的上下文中,这一信息在预测的时候可用来辅助判读正确答案。
接着从上述生成的.tf_record格式的文件中加载固定批次大小的dataset输入到bert模型中,具体加载数据的代码如下:
def input_fn_builder(input_file, seq_length, is_training, drop_remainder):
  """Creates an `input_fn` closure to be passed to TPUEstimator."""
  name_to_features = {
      "unique_ids": tf.FixedLenFeature([], tf.int64),
      "input_ids": tf.FixedLenFeature([seq_length], tf.int64),
      "input_mask": tf.FixedLenFeature([seq_length], tf.int64),
      "segment_ids": tf.FixedLenFeature([seq_length], tf.int64),
  }
  if is_training:
    name_to_features["start_positions"] = tf.FixedLenFeature([], tf.int64)
    name_to_features["end_positions"] = tf.FixedLenFeature([], tf.int64)
  def _decode_record(record, name_to_features):
    """Decodes a record to a TensorFlow example."""
    example = tf.parse_single_example(record, name_to_features)
    # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
    # So cast all int64 to int32.
    for name in list(example.keys()):
      t = example[name]
      if t.dtype == tf.int64:
        t = tf.to_int32(t)
      example[name] = t
    return example
  def input_fn(params):
    """The actual input function."""
    batch_size = params["batch_size"]
    # For training, we want a lot of parallel reading and shuffling.
    # For eval, we want no shuffling and parallel reading doesn't matter.
    d = tf.data.TFRecordDataset(input_file)
    if is_training:
      d = d.repeat()
      d = d.shuffle(buffer_size=100)
    d = d.apply(
        tf.contrib.data.map_and_batch(
            lambda record: _decode_record(record, name_to_features),
            batch_size=batch_size,
            drop_remainder=drop_remainder))
    return d
  return input_fn
然后将上述加载的数据(targets:unique_ids、input_ids、input_mask、segment_ids;labels:start_positions、end_positions)作为输入喂给扩展后的bert模型(在预训练bert模型的输出后,增加一层全连接层),输出的是每个批次中,每一个token后的文本序列中的每一个token作为答案的开始或结束的几率。 具体代码如下:
def create_model(bert_config, is_training, input_ids, input_mask, segment_ids,
                 use_one_hot_embeddings):
  """Creates a classification model."""
  model = modeling.BertModel(
      config=bert_config,
      is_training=is_training,
      input_ids=input_ids,
      input_mask=input_mask,
      token_type_ids=segment_ids,
      use_one_hot_embeddings=use_one_hot_embeddings)
  final_hidden = model.get_sequence_output()
  final_hidden_shape = modeling.get_shape_list(final_hidden, expected_rank=3)
  batch_size = final_hidden_shape[0]
  seq_length = final_hidden_shape[1]
  hidden_size = final_hidden_shape[2]
  output_weights = tf.get_variable(
      "cls/squad/output_weights", [2, hidden_size],
      initializer=tf.truncated_normal_initializer(stddev=0.02))
  output_bias = tf.get_variable(
      "cls/squad/output_bias", [2], initializer=tf.zeros_initializer())
  final_hidden_matrix = tf.reshape(final_hidden,
                                   [batch_size * seq_length, hidden_size])
  logits = tf.matmul(final_hidden_matrix, output_weights, transpose_b=True)
  logits = tf.nn.bias_add(logits, output_bias)
  logits = tf.reshape(logits, [batch_size, seq_length, 2])
  logits = tf.transpose(logits, [2, 0, 1])
  unstacked_logits = tf.unstack(logits, axis=0)
  (start_logits, end_logits) = (unstacked_logits[0], unstacked_logits[1])
  return (start_logits, end_logits)
其中output_weights、output_bias为新增的全连接层的参数,训练时随机初始化一组值,预测时直接通过张量名加载训练后的值。unstacked_logits的形状为[2,batch_size,seq_length],start_logits的形状为[batch_size,seq_length],表示每个批次中每个序列中的token作为答案起始的几率,end_logits的形状为[batch_size,seq_length],表示每个批次中每个序列中的token作为答案结束的几率。
基于上述模型的预测输出start_logits与end_logits,和真实的start_positions与end_positions,分别计算交叉熵的损失(分类模型典型的损失函数),然后利用已知优化器将损失不断减小,最终经过一定训练步数后确定最终的全连接层的参数。预测时,直接加载确定后的参数值,并由预测样本输出对应的预测结果。具体代码如下:
def model_fn_builder(bert_config, init_checkpoint, learning_rate,
                     num_train_steps, num_warmup_steps, use_tpu,
                     use_one_hot_embeddings):
  """Returns `model_fn` closure for TPUEstimator."""
  def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
    """The `model_fn` for TPUEstimator."""
    tf.logging.info("*** Features ***")
    for name in sorted(features.keys()):
      tf.logging.info("  name = %s, shape = %s" % (name, features[name].shape))
    unique_ids = features["unique_ids"]
    input_ids = features["input_ids"]
    input_mask = features["input_mask"]
    segment_ids = features["segment_ids"]
    is_training = (mode == tf.estimator.ModeKeys.TRAIN)
    (start_logits, end_logits) = create_model(
        bert_config=bert_config,
        is_training=is_training,
        input_ids=input_ids,
        input_mask=input_mask,
        segment_ids=segment_ids,
        use_one_hot_embeddings=use_one_hot_embeddings)
    tvars = tf.trainable_variables()
    initialized_variable_names = {}
    scaffold_fn = None
    if init_checkpoint:
      (assignment_map, initialized_variable_names
      ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
      if use_tpu:
        def tpu_scaffold():
          tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
          return tf.train.Scaffold()
        scaffold_fn = tpu_scaffold
      else:
        tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
    tf.logging.info("**** Trainable Variables ****")
    for var in tvars:
      init_string = ""
      if var.name in initialized_variable_names:
        init_string = ", *INIT_FROM_CKPT*"
      tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                      init_string)
    output_spec = None
    if mode == tf.estimator.ModeKeys.TRAIN:
      seq_length = modeling.get_shape_list(input_ids)[1]
      def compute_loss(logits, positions):
        one_hot_positions = tf.one_hot(
            positions, depth=seq_length, dtype=tf.float32)
        log_probs = tf.nn.log_softmax(logits, axis=-1)
        loss = -tf.reduce_mean(
            tf.reduce_sum(one_hot_positions * log_probs, axis=-1))
        return loss
      start_positions = features["start_positions"]
      end_positions = features["end_positions"]
      start_loss = compute_loss(start_logits, start_positions)
      end_loss = compute_loss(end_logits, end_positions)
      total_loss = (start_loss + end_loss) / 2.0
      train_op = optimization.create_optimizer(
          total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu)
      output_spec = tf.contrib.tpu.TPUEstimatorSpec(
          mode=mode,
          loss=total_loss,
          train_op=train_op,
          scaffold_fn=scaffold_fn)
    elif mode == tf.estimator.ModeKeys.PREDICT:
      predictions = {
          "unique_ids": unique_ids,
          "start_logits": start_logits,
          "end_logits": end_logits,
      }
      output_spec = tf.contrib.tpu.TPUEstimatorSpec(
          mode=mode, predictions=predictions, scaffold_fn=scaffold_fn)
    else:
      raise ValueError(
          "Only TRAIN and PREDICT modes are supported: %s" % (mode))
    return output_spec
  return model_fn
值得说明的是,start_logits与end_logits都是Wordpiece级别的token对应的位置的几率,真实的位置是基于字符级别的且答案文本是词级别的并区分大小写。因此还需进一步将预测后的值还原到词级别中去。 还原时主要利用之前的字符级别到词级别再到Wordpiece级别的转换过程中记录的相关映射关系进行的。相应代码为:
补充: 对于无法回答的问题,也是基于上述预测生成的start_logits与end_logits。在处理时,记录第一个位置的start_logit与end_logit的几率,即[CLS]对应的位置(在数据预处理convert_examples_to_features中,已将无法回答的问题对应的start_position与end_position都设为0)。当第一个位置的两个几率和与span中最有可能的两个几率和的差值大于null_score_diff_threshold(可调节,默认为0),该问题无法回答。否则,该问题可回答,答案为best_non_null_entry.text。
def write_predictions(all_examples, all_features, all_results, n_best_size,
                      max_answer_length, do_lower_case, output_prediction_file,
                      output_nbest_file, output_null_log_odds_file):
  """Write final predictions to the json file and log-odds of null if needed."""
  tf.logging.info("Writing predictions to: %s" % (output_prediction_file))
  tf.logging.info("Writing nbest to: %s" % (output_nbest_file))
  example_index_to_features = collections.defaultdict(list)
  for feature in all_features:
    example_index_to_features[feature.example_index].append(feature)
  unique_id_to_result = {}
  for result in all_results:
    unique_id_to_result[result.unique_id] = result
  _PrelimPrediction = collections.namedtuple(  # pylint: disable=invalid-name
      "PrelimPrediction",
      ["feature_index", "start_index", "end_index", "start_logit", "end_logit"])
  all_predictions = collections.OrderedDict()
  all_nbest_json = collections.OrderedDict()
  scores_diff_json = collections.OrderedDict()
  for (example_index, example) in enumerate(all_examples):
    features = example_index_to_features[example_index]
    prelim_predictions = []
    # keep track of the minimum score of null start+end of position 0
    score_null = 1000000  # large and positive
    min_null_feature_index = 0  # the paragraph slice with min mull score
    null_start_logit = 0  # the start logit at the slice with min null score
    null_end_logit = 0  # the end logit at the slice with min null score
    for (feature_index, feature) in enumerate(features):
      result = unique_id_to_result[feature.unique_id]
      start_indexes = _get_best_indexes(result.start_logits, n_best_size)
      end_indexes = _get_best_indexes(result.end_logits, n_best_size)
      # if we could have irrelevant answers, get the min score of irrelevant
      if FLAGS.version_2_with_negative:
        feature_null_score = result.start_logits[0] + result.end_logits[0]
        if feature_null_score < score_null:
          score_null = feature_null_score
          min_null_feature_index = feature_index
          null_start_logit = result.start_logits[0]
          null_end_logit = result.end_logits[0]
      for start_index in start_indexes:
        for end_index in end_indexes:
          # We could hypothetically create invalid predictions, e.g., predict
          # that the start of the span is in the question. We throw out all
          # invalid predictions.
          if start_index >= len(feature.tokens):
            continue
          if end_index >= len(feature.tokens):
            continue
          if start_index not in feature.token_to_orig_map:
            continue
          if end_index not in feature.token_to_orig_map:
            continue
          if not feature.token_is_max_context.get(start_index, False):
            continue
          if end_index < start_index:
            continue
          length = end_index - start_index + 1
          if length > max_answer_length:
            continue
          prelim_predictions.append(
              _PrelimPrediction(
                  feature_index=feature_index,
                  start_index=start_index,
                  end_index=end_index,
                  start_logit=result.start_logits[start_index],
                  end_logit=result.end_logits[end_index]))
    if FLAGS.version_2_with_negative:
      prelim_predictions.append(
          _PrelimPrediction(
              feature_index=min_null_feature_index,
              start_index=0,
              end_index=0,
              start_logit=null_start_logit,
              end_logit=null_end_logit))
    prelim_predictions = sorted(
        prelim_predictions,
        key=lambda x: (x.start_logit + x.end_logit),
        reverse=True)
    _NbestPrediction = collections.namedtuple(  # pylint: disable=invalid-name
        "NbestPrediction", ["text", "start_logit", "end_logit"])
    seen_predictions = {}
    nbest = []
    for pred in prelim_predictions:
      if len(nbest) >= n_best_size:
        break
      feature = features[pred.feature_index]
      if pred.start_index > 0:  # this is a non-null prediction
        tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)]
        orig_doc_start = feature.token_to_orig_map[pred.start_index]
        orig_doc_end = feature.token_to_orig_map[pred.end_index]
        orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]
        tok_text = " ".join(tok_tokens)
        # De-tokenize WordPieces that have been split off.
        tok_text = tok_text.replace(" ##", "")
        tok_text = tok_text.replace("##", "")
        # Clean whitespace
        tok_text = tok_text.strip()
        tok_text = " ".join(tok_text.split())
        orig_text = " ".join(orig_tokens)
        final_text = get_final_text(tok_text, orig_text, do_lower_case)
        if final_text in seen_predictions:
          continue
        seen_predictions[final_text] = True
      else:
        final_text = ""
        seen_predictions[final_text] = True
      nbest.append(
          _NbestPrediction(
              text=final_text,
              start_logit=pred.start_logit,
              end_logit=pred.end_logit))
    # if we didn't inlude the empty option in the n-best, inlcude it
    if FLAGS.version_2_with_negative:
      if "" not in seen_predictions:
        nbest.append(
            _NbestPrediction(
                text="", start_logit=null_start_logit,
                end_logit=null_end_logit))
    # In very rare edge cases we could have no valid predictions. So we
    # just create a nonce prediction in this case to avoid failure.
    if not nbest:
      nbest.append(
          _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
    assert len(nbest) >= 1
    total_scores = []
    best_non_null_entry = None
    for entry in nbest:
      total_scores.append(entry.start_logit + entry.end_logit)
      if not best_non_null_entry:
        if entry.text:
          best_non_null_entry = entry
    probs = _compute_softmax(total_scores)
    nbest_json = []
    for (i, entry) in enumerate(nbest):
      output = collections.OrderedDict()
      output["text"] = entry.text
      output["probability"] = probs[i]
      output["start_logit"] = entry.start_logit
      output["end_logit"] = entry.end_logit
      nbest_json.append(output)
    assert len(nbest_json) >= 1
    if not FLAGS.version_2_with_negative:
      all_predictions[example.qas_id] = nbest_json[0]["text"]
    else:
      # predict "" iff the null score - the score of best non-null > threshold
      score_diff = score_null - best_non_null_entry.start_logit - (
          best_non_null_entry.end_logit)
      scores_diff_json[example.qas_id] = score_diff
      if score_diff > FLAGS.null_score_diff_threshold:
        all_predictions[example.qas_id] = ""
      else:
        all_predictions[example.qas_id] = best_non_null_entry.text
    all_nbest_json[example.qas_id] = nbest_json
  with tf.gfile.GFile(output_prediction_file, "w") as writer:
    writer.write(json.dumps(all_predictions, indent=4) + "\n")
  with tf.gfile.GFile(output_nbest_file, "w") as writer:
    writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
  if FLAGS.version_2_with_negative:
    with tf.gfile.GFile(output_null_log_odds_file, "w") as writer:
      writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
预测完成后会生成predictions.json、nbest_predictions.json与null_odds.json共3个json文件。其中predictions.json记录的是每个问题Id对应的最佳答案,无法回答的问题其答案为空字符串,形如如下格式:
{
        "56ddde6b9a695914005b9628": "France",
        "56ddde6b9a695914005b9629": "10th and 11th centuries",
        "56ddde6b9a695914005b962a": "",
        "56ddde6b9a695914005b962b": "Rollo",
        "56ddde6b9a695914005b962c": "10th century",
        "5ad39d53604f3c001a3fe8d1": "",
        "5ad39d53604f3c001a3fe8d2": "",
        "5ad39d53604f3c001a3fe8d3": "",
        "5ad39d53604f3c001a3fe8d4": "",
        "56dddf4066d3e219004dad5f": "William the Conqueror"
}
nbest_predictions.json记录的是每个问题Id对应的n个最佳答案,且n个可能的答案的概率和为1,形如如下形式
{
    "56ddde6b9a695914005b9628": [
        {
            "text": "France",
            "probability": 0.9929679941964172,
            "start_logit": 7.072847366333008,
            "end_logit": 7.759739398956299
        },
        {
            "text": "a region in France",
            "probability": 0.0025574380090684365,
            "start_logit": 1.1111549139022827,
            "end_logit": 7.759739398956299
        },
        {
            "text": "in France",
            "probability": 0.0017225365088163532,
            "start_logit": 0.7159468531608582,
            "end_logit": 7.759739398956299
        }
    ]
}
null_odds.json记录的是每个问题Id无法回到的可能性,如果Id对应的值大于某个阀值(可调整),则该问题无法回答。形如如下形式:
{
        "56ddde6b9a695914005b9628": -10.035285472869873,
        "56ddde6b9a695914005b9629": -9.032696723937988,
        "56ddde6b9a695914005b962a": 0.1835465431213379,
        "56ddde6b9a695914005b962b": -5.336197853088379,
        "56ddde6b9a695914005b962c": -1.4510955810546875,
        "5ad39d53604f3c001a3fe8d1": 6.505279779434204,
        "5ad39d53604f3c001a3fe8d2": 4.115971803665161,
        "5ad39d53604f3c001a3fe8d3": 2.6883482933044434,
        "5ad39d53604f3c001a3fe8d4": 7.815060138702393,
        "56dddf4066d3e219004dad5f": -6.004664421081543
}
至此,run_squad.py文件在SQuAD 2.0数据集上的训练、预测的思路流程算是大概梳理完了。
扩展
上述的SQuAD 2.0属于抽取式的MRC。在SQuAD 2.0之后,斯坦福又进一步提出基于多轮对话的MRC数据集CoQA。CoQA包含8K组对话,其中有12W+带有答案的问题。每组对话的平均长度为15轮,每轮对话由一个问题Q和一个回答A组成,同时回答A对应的片段R也被提取出来。该数据集有如下几个特点
- 
    体现人类对话中问题的本质。表现在问题间存在关联性,当前问题依赖之前问题的答案 
- 
    答案的自然性。表现在有的问题在段落中无法抽取,需进行推理生成 
- 
    问答系统的鲁棒性。表现在问题的多样性上 
参考
- https://github.com/google-research/bert
- SQuAD2.0-The Stanford Question Answering Dataset
- SQuAD: 100,000+ Questions for Machine Comprehension of Text
- Know What You Don’t Know: Unanswerable Questions for SQuAD
- Understand Cross Entropy Loss in Minutes
- CoQA-A Conversational Question Answering Challenge
- 斯坦福发布全新对话问答数据集,可评估机器参与问答式对话的能力