Edge probing BERT and other language models

Recently, there has been a lot of research in the field of “BERTology”: investigating what aspects of language BERT does and doesn’t understand, using probing techniques. In this post, I will describe the “edge probing” technique developed by Tenney et al., in two papers titled “What do you learn from context? Probing for sentence structure in contextualized word representations” and “BERT Rediscovers the Classical NLP Pipeline“. On my first read through these papers, the method didn’t seem very complicated, and the authors only spend a paragraph explaining how it works. But upon closer look, the method is actually nontrivial, and took me some time to understand. The details are there, but hidden in an appendix in the first of the two papers.

The setup for edge probing is you have some classification task that takes a sentence, and a span of consecutive tokens within the sentence, and produces an N-way classification. This can be generalized to two spans in the sentence, but we’ll focus on the single-span case. The tasks cover a range of syntactic and semantic functions, varying from part-of-speech tagging, dependency parsing, coreference resolution, etc.

tenney-examples

Above: Examples of tasks where edge probing may be used. Table taken from Tenney (2019a).

Let’s go through the steps of how edge probing works. Suppose we want to probe for which parts of BERT-BASE contain the most information about named entity recognition (NER). In this NER setup, we’re given the named entity spans and only need to classify which type of entity it is (e.g: person, organization, location, etc). The first step is to feed the sentence through BERT, giving 12 layers, each layer being a 768-dimensional vector.

edge_probe

Above: Edge probing example for NER.

The probing model has several stages:

  1. Mix: learn a task-specific linear combination of the layers, giving a single 768-dimensional vector for each span token. The weights that are learned indicate how much useful information for the task is contained in each layer.
  2. Projection: learn a linear mapping from 768 down to 256 dimensions.
  3. Self-attention pooling: learn a function to generate a scalar weight for each span vector. Then, we normalize the weights to sum up to 1, and take a weighted average. The purpose of this is to collapse the variable-length sequence of span vectors into a single fixed-length 256-dimensional vector.
  4. Feedforward NN: learn a multi-layer perceptron classifier with 2 hidden layers.

For the two-span case, they use two separate sets of weights for the mix, projection, and self-attention steps. Then, the feedforward neural network takes a concatenated 512-dimensional vector instead of a 256-dimensional vector.

The edge probing model needs to be trained on a dataset of classification instances. The probe has weights that are initialized randomly and trained using gradient descent, but the BERT weights are kept constant while the probe is being trained.

This setup was more sophisticated than it looked at first glance, and it turns out the probe itself is quite powerful. In one of the experiments, the authors found that even with randomized input, the probe was able to get over 70% of the full model’s performance!

Edge probing is not the only way of probing deep language models. Other probing experiments (Liu et al., 2019) used a simple linear classifier: this is better for measuring what information can easily be recovered from representations.

References

  1. Tenney, Ian, et al. “What do you learn from context? probing for sentence structure in contextualized word representations.” International Conference on Learning Representations. 2019a.
  2. Tenney, Ian, Dipanjan Das, and Ellie Pavlick. “BERT Rediscovers the Classical NLP Pipeline.” Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics. 2019b.
  3. Liu, Nelson F., et al. “Linguistic Knowledge and Transferability of Contextual Representations.” Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers). 2019.