The embeddings generated by transformer encoders are rich with contextual information. Effectively utilizing the embeddings requires advanced pooling techniques to distill the embedded information, then feed it into classification or regression layers.

In this blog post, we’ll dive deep into the world of embedding pooling techniques for transformer encoders. This often underappreciated aspect is crucial for building very strong classifiers and regressors.

Table of Contents

  • Pooling
  • What Different Layers of a Transformer Learn
  • Pooling Techniques
    • CLS Pooling
    • Mean Pooling
    • Max Pooling
    • Mean-Max Pooling
    • Concatenation Pooling
    • LSTM Pooling
    • Weighted Layer Pooling
    • Attention Pooling
  • Build Your Pooling Strategy

Pooling

In this context, Pooling refers to the methods used to summarize the embeddings from the transformer. Transformer encoders typically output a sequence of vectors, one for each token in the input. However, many classification and regression tasks require a single, fixed-size representation of the entire input. This is where pooling comes in – it’s the bridge between the rich, contextual embeddings of the transformer and the final prediction layer.


What Different Layers of a Transformer Learn

Encoder-only transformers are a structure of multi-layer encoder, each leayer learn different levels of representations. They capture a rich hierarchy of linguistic information, with surface features in lower layers, syntactic features in middle layers, and semantic features in higher layers.

Research has been conducted on what different layers can learn about language, such as: What Does BERT Learn about the Structure of Language?

The paper revealed several interesting insights, as shown in the screenshot below:

  • Lower layers achieve higher scores on tasks requiring surface features, while failing on tasks require capturing semantic features.
  • Middle layers achieve higher scores on tasks requiring syntactic features, and perform reasonably well in other tasks, making them a good option for pooling in most cases.
  • Higher layers shows its best performance on tasks requiring semtantic features, and struggle with simple tasks that require basic surface features.
  • The use of the last layer’s embeddings only, might restrict the power of the pre-trained representation. Screenshot 2024-07-27 at 11 58 11 AM

Pooling Techniques

Let’s load roberta-base model using HuggingFace Transformers library

import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer

# load your text, should be a list of strings
text = ['this is a string'] * 16

# load model and tokenizer
model = AutoModel.from_pretrained('roberta-base')
tokenizer = AutoTokenizer.from_pretrained('roberta-base')

# encode text
features = tokenizer.batch_encode_plus(
    text,
    add_special_tokens=True,
    padding='max_length',
    max_length=256,
    truncation=True,
    return_tensors='pt',
    return_attention_mask=True
)

# getting the outputs of the model
with torch.no_grad():
    outputs = model(features['input_ids'], features['attention_mask'])

CLS Pooling

Note: CLS is the beginning of text token for lots of encoder-only transformers.

CLS pooling is one of the simplest forms of pooling and is the default approach in the HuggingFace Transformers library. This method involves using the embeddings of the first token of the sequence, the [CLS] token, and ignoring the embeddings of the remaining tokens. The motivation behind this approach is the belief that the [CLS] token embeddings capture contextual information about the entire sequence.

# cls pooling
pooler_output = outputs[1]

# regression head
logits = nn.Linear(config.hidden_size, 1)(pooler_output) 

Mean Pooling

Mean Pooling is another common technique for summarizing the embeddings generated by transformer encoders. Unlike CLS pooling, which focuses on a single token, Mean Pooling takes into account all tokens in the sequence. This method involves calculating the average of all token embeddings across sequence length dimension, resulting in a single, fixed-size vector that represents the entire input sequence

# getting all of the tokens embeddings of the last layer
last_hidden_state = outputs[0]

def mean_pooling(last_hidden_state, attention_mask):
  # using attention mask to average only tokens embeddings with attention to it, basically ignoring padding tokens
  # then averaging the embeddings of the relevant sequence tokens
  input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
  sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)

  sum_mask = input_mask_expanded.sum(1)
  sum_mask = torch.clamp(sum_mask, min=1e-9)

  return sum_embeddings / sum_mask

mean_embeddings = mean_pooling(last_hidden_state, features['attention_mask'])

# regression head
logits = nn.Linear(config.hidden_size, 1)(mean_embeddings)

Max Pooling

In contrast to Mean Pooling, which averages the embeddings of all tokens, Max Pooling selects the maximum value from each dimension across all token embeddings. This method is particularly effective in highlighting the most salient features of the input sequence.

last_hidden_state = outputs[0]

def max_pooling(last_hidden_state, attention_mask):
  input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
  last_hidden_state[input_mask_expanded == 0] = -1e9  # Set padding tokens to large negative value
  return torch.max(last_hidden_state, 1)[0]

max_embeddings = max_pooling(last_hidden_state, features['attention_mask'])

# regression head
logits = nn.Linear(config.hidden_size, 1)(max_embeddings) 

Mean-Max Pooling

As you might have guessed, it’s a concatenatation of both Mean and Max Pooling.

last_hidden_state = outputs[0]
mean_pooling_embeddings = mean_pooling(last_hidden_state, features['attention_mask'])
max_pooling_embeddings = max_pooling(last_hidden_state, features['attention_mask'])
mean_max_embeddings = torch.cat((mean_pooling_embeddings, max_pooling_embeddings), 1)

# as we have concatenated both embeddings, we have to set twice the hidden size
logits = nn.Linear(config.hidden_size*2, 1)(mean_max_embeddings) 

Concatenation Pooling

Concatenate Pooling is a technique where the outputs from different layers of the transformer are concatenated into a single vector. This approach aims to leverage the rich, hierarchical information captured by multiple layers of the model.

Different papers experimented with various concatenation strategies and concatenating 2-4 layers is generally a good starting point.

from transformers import AutoConfig
config = AutoConfig.from_pretrained('roberta-base')
# explicitly update configs to output all of the embeddings produced by all layers
config.update({'output_hidden_states':True})
model = AutoModel.from_pretrained('roberta-base', config=config)

# we already loaded tokenizer and encoded text before

with torch.no_grad():
    outputs = model(features['input_ids'], features['attention_mask'])

# stack the outputs of all layers
all_hidden_states = torch.stack(outputs[2])

# concatenate the embeddings of last 4 layers
concatenate_pooling = torch.cat(
    (all_hidden_states[-1], all_hidden_states[-2], all_hidden_states[-3], all_hidden_states[-4]),-1
)

# we can performa CLS pooling on them
concatenate_pooling = concatenate_pooling[:, 0]

# or we can do mean pooling on each layer then concatenate them
# layer_1 = mean_pooling(all_hidden_states[-1], features['attention_mask'])
# layer_2 = mean_pooling(all_hidden_states[-2], features['attention_mask'])
# layer_3 = mean_pooling(all_hidden_states[-3], features['attention_mask'])
# layer_4 = mean_pooling(all_hidden_states[-4], features['attention_mask'])
# concatenate_pooling = torch.cat(
#    (layer_1, layer_2, layer_3, layer_4), -1
# )

logits = nn.Linear(config.hidden_size*4, 1)(concatenate_pooling) # regression head

LSTM Pooling

LSTM Pooling is a technique that leverages the power of Long Short-Term Memory (LSTM) networks to pool embeddings from transformer models. Instead of using simple statistical methods like mean or max pooling, LSTM Pooling uses a sequence-to-sequence model to capture the sequential dependencies and contextual information from the transformer embeddings.

# stack all hidden states
all_hidden_states = torch.stack(outputs[2])

class LSTMPooling(nn.Module):
    def __init__(self, num_layers, hidden_size, hiddendim_lstm):
        super(LSTMPooling, self).__init__()
        self.num_hidden_layers = num_layers
        self.hidden_size = hidden_size
        self.hiddendim_lstm = hiddendim_lstm
        self.lstm = nn.LSTM(self.hidden_size, self.hiddendim_lstm, batch_first=True)
        self.dropout = nn.Dropout(0.1)
    
    def forward(self, all_hidden_states):
        ## forward
        hidden_states = torch.stack([all_hidden_states[layer_i][:, 0].squeeze()
                                     for layer_i in range(1, self.num_hidden_layers+1)], dim=-1)
        hidden_states = hidden_states.view(-1, self.num_hidden_layers, self.hidden_size)
        out, _ = self.lstm(hidden_states, None)
        out = self.dropout(out[:, -1, :])
        return out

hiddendim_lstm = 256
pooler = LSTMPooling(config.num_hidden_layers, config.hidden_size, hiddendim_lstm)
lstm_pooling_embeddings = pooler(all_hidden_states)

# regression head
logits = nn.Linear(hiddendim_lstm, 1)(lstm_pooling_embeddings) 

Weighted Layer Pooling

Weighted Layer Pooling is a technique that combines the representations from different layers of a transformer model to produce a single, more informative embedding. This method leverages the fact that the most transferable contextualized representations of input text often occur in the middle layers, while the top layers specialize more for language modeling tasks.

# stack all hidden states
all_hidden_states = torch.stack(outputs[2])

class WeightedLayerPooling(nn.Module):
    def __init__(self, num_hidden_layers, layer_start: int = 4, layer_weights = None):
        super(WeightedLayerPooling, self).__init__()
        self.layer_start = layer_start
        self.num_hidden_layers = num_hidden_layers
        self.layer_weights = layer_weights if layer_weights is not None \
            else nn.Parameter(
                torch.tensor([1] * (num_hidden_layers+1 - layer_start), dtype=torch.float)
            )

    def forward(self, all_hidden_states):
        all_layer_embedding = all_hidden_states[self.layer_start:, :, :, :]
        weight_factor = self.layer_weights.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).expand(all_layer_embedding.size())
        weighted_average = (weight_factor*all_layer_embedding).sum(dim=0) / self.layer_weights.sum()
        return weighted_average

# this means from layer 9 to the final layer are combined
layer_start = 9
pooler = WeightedLayerPooling(
    config.num_hidden_layers, 
    layer_start=layer_start, layer_weights=None
)
weighted_pooling_embeddings = pooler(all_hidden_states)
weighted_pooling_embeddings = weighted_pooling_embeddings[:, 0]

# regression head
logits = nn.Linear(config.hidden_size, 1)(weighted_pooling_embeddings)

Attention Pooling

Attention Pooling is a technique that uses an attention mechanism to dynamically weight the importance of each token in the sequence when generating a single, fixed-size representation. This method allows the model to focus on the most relevant parts of the input, making it particularly effective for tasks where different parts of the input have varying degrees of importance.

This technique showed great performance in several Kaggle NLP competition.

last_hidden_state = outputs[1]

# defining attention pooling
attention_pooling = nn.Sequential(
    nn.Linear(768, 512),            
    nn.Tanh(),                       
    nn.Linear(512, 1),
    nn.Softmax(dim=1)
)

weights = attention(last_hidden_state)
context_vector = torch.sum(weights * last_hidden_state, dim=1)

logits = nn.Linear(config.hidden_size, 1)(context_vector)

Build Your Pooling Strategy

Creating an effective pooling strategy often involves combining different techniques or layers to maximize performance for your specific task. Here are some ways you can build and customize your pooling strategy:

  • Combine CLS Pooling and Mean Pooling: Perform CLS Pooling to capture the contextual information from the first token and Mean Pooling to aggregate information from all tokens. Concatenating the outputs of both techniques can provide a richer representation.
  • Concatenate Different Layers: Experiment with concatenating outputs from different layers of the transformer. For instance, concatenate the last four layers, as suggested by some studies, or try different combinations like the middle layers that capture syntactic information and higher layers for semantic features.
  • Experiment and Iterate: Don’t hesitate to experiment with different pooling strategies and iterate based on performance.

<
Previous Post
Fine-tuning Llama3 Models with LoRA on Custom Data
>
Next Post
Research Digger: Streamlining Academic Research with AI