from __future__ import annotations
import json
import os
from typing import Any
import torch
from torch import Tensor, nn
[文档]
class Pooling(nn.Module):
"""
Performs pooling (max or mean) on the token embeddings.
Using pooling, it generates from a variable sized sentence a fixed sized sentence embedding. This layer also allows
to use the CLS token if it is returned by the underlying word embedding model. You can concatenate multiple poolings
together.
Args:
word_embedding_dimension: Dimensions for the word embeddings
pooling_mode: Either "cls", "lasttoken", "max", "mean",
"mean_sqrt_len_tokens", or "weightedmean". If set,
overwrites the other pooling_mode_* settings
pooling_mode_cls_token: Use the first token (CLS token) as text
representations
pooling_mode_max_tokens: Use max in each dimension over all
tokens.
pooling_mode_mean_tokens: Perform mean-pooling
pooling_mode_mean_sqrt_len_tokens: Perform mean-pooling, but
divide by sqrt(input_length).
pooling_mode_weightedmean_tokens: Perform (position) weighted
mean pooling. See `SGPT: GPT Sentence Embeddings for
Semantic Search <https://arxiv.org/abs/2202.08904>`_.
pooling_mode_lasttoken: Perform last token pooling. See `SGPT:
GPT Sentence Embeddings for Semantic Search
<https://arxiv.org/abs/2202.08904>`_ and `Text and Code
Embeddings by Contrastive Pre-Training
<https://arxiv.org/abs/2201.10005>`_.
"""
POOLING_MODES = (
"cls",
"lasttoken",
"max",
"mean",
"mean_sqrt_len_tokens",
"weightedmean",
)
def __init__(
self,
word_embedding_dimension: int,
pooling_mode: str = None,
pooling_mode_cls_token: bool = False,
pooling_mode_max_tokens: bool = False,
pooling_mode_mean_tokens: bool = True,
pooling_mode_mean_sqrt_len_tokens: bool = False,
pooling_mode_weightedmean_tokens: bool = False,
pooling_mode_lasttoken: bool = False,
include_prompt=True,
) -> None:
super().__init__()
self.config_keys = [
"word_embedding_dimension",
"pooling_mode_cls_token",
"pooling_mode_mean_tokens",
"pooling_mode_max_tokens",
"pooling_mode_mean_sqrt_len_tokens",
"pooling_mode_weightedmean_tokens",
"pooling_mode_lasttoken",
"include_prompt",
]
if pooling_mode is not None: # Set pooling mode by string
pooling_mode = pooling_mode.lower()
if pooling_mode not in self.POOLING_MODES:
raise ValueError(
f"Set invalid pooling mode: {pooling_mode}. Valid pooling modes are: {self.POOLING_MODES}."
)
pooling_mode_cls_token = pooling_mode == "cls"
pooling_mode_max_tokens = pooling_mode == "max"
pooling_mode_mean_tokens = pooling_mode == "mean"
pooling_mode_mean_sqrt_len_tokens = pooling_mode == "mean_sqrt_len_tokens"
pooling_mode_weightedmean_tokens = pooling_mode == "weightedmean"
pooling_mode_lasttoken = pooling_mode == "lasttoken"
self.word_embedding_dimension = word_embedding_dimension
self.pooling_mode_cls_token = pooling_mode_cls_token
self.pooling_mode_mean_tokens = pooling_mode_mean_tokens
self.pooling_mode_max_tokens = pooling_mode_max_tokens
self.pooling_mode_mean_sqrt_len_tokens = pooling_mode_mean_sqrt_len_tokens
self.pooling_mode_weightedmean_tokens = pooling_mode_weightedmean_tokens
self.pooling_mode_lasttoken = pooling_mode_lasttoken
self.include_prompt = include_prompt
pooling_mode_multiplier = sum(
[
pooling_mode_cls_token,
pooling_mode_max_tokens,
pooling_mode_mean_tokens,
pooling_mode_mean_sqrt_len_tokens,
pooling_mode_weightedmean_tokens,
pooling_mode_lasttoken,
]
)
self.pooling_output_dimension = pooling_mode_multiplier * word_embedding_dimension
def __repr__(self) -> str:
return f"Pooling({self.get_config_dict()})"
def get_pooling_mode_str(self) -> str:
"""Returns the pooling mode as string"""
modes = []
if self.pooling_mode_cls_token:
modes.append("cls")
if self.pooling_mode_mean_tokens:
modes.append("mean")
if self.pooling_mode_max_tokens:
modes.append("max")
if self.pooling_mode_mean_sqrt_len_tokens:
modes.append("mean_sqrt_len_tokens")
if self.pooling_mode_weightedmean_tokens:
modes.append("weightedmean")
if self.pooling_mode_lasttoken:
modes.append("lasttoken")
return "+".join(modes)
def forward(self, features: dict[str, Tensor]) -> dict[str, Tensor]:
token_embeddings = features["token_embeddings"]
attention_mask = (
features["attention_mask"]
if "attention_mask" in features
else torch.ones(token_embeddings.shape[:-1], device=token_embeddings.device, dtype=torch.int64)
)
if not self.include_prompt and "prompt_length" in features:
attention_mask[:, : features["prompt_length"]] = 0
## Pooling strategy
output_vectors = []
if self.pooling_mode_cls_token:
cls_token = features.get("cls_token_embeddings", token_embeddings[:, 0]) # Take first token by default
output_vectors.append(cls_token)
if self.pooling_mode_max_tokens:
input_mask_expanded = (
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).to(token_embeddings.dtype)
)
token_embeddings[input_mask_expanded == 0] = -1e9 # Set padding tokens to large negative value
max_over_time = torch.max(token_embeddings, 1)[0]
output_vectors.append(max_over_time)
if self.pooling_mode_mean_tokens or self.pooling_mode_mean_sqrt_len_tokens:
input_mask_expanded = (
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).to(token_embeddings.dtype)
)
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
# If tokens are weighted (by WordWeights layer), feature 'token_weights_sum' will be present
if "token_weights_sum" in features:
sum_mask = features["token_weights_sum"].unsqueeze(-1).expand(sum_embeddings.size())
else:
sum_mask = input_mask_expanded.sum(1)
sum_mask = torch.clamp(sum_mask, min=1e-9)
if self.pooling_mode_mean_tokens:
output_vectors.append(sum_embeddings / sum_mask)
if self.pooling_mode_mean_sqrt_len_tokens:
output_vectors.append(sum_embeddings / torch.sqrt(sum_mask))
if self.pooling_mode_weightedmean_tokens:
input_mask_expanded = (
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).to(token_embeddings.dtype)
)
# token_embeddings shape: bs, seq, hidden_dim
weights = (
torch.arange(start=1, end=token_embeddings.shape[1] + 1)
.unsqueeze(0)
.unsqueeze(-1)
.expand(token_embeddings.size())
.to(token_embeddings.dtype)
.to(token_embeddings.device)
)
assert weights.shape == token_embeddings.shape == input_mask_expanded.shape
input_mask_expanded = input_mask_expanded * weights
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
# If tokens are weighted (by WordWeights layer), feature 'token_weights_sum' will be present
if "token_weights_sum" in features:
sum_mask = features["token_weights_sum"].unsqueeze(-1).expand(sum_embeddings.size())
else:
sum_mask = input_mask_expanded.sum(1)
sum_mask = torch.clamp(sum_mask, min=1e-9)
output_vectors.append(sum_embeddings / sum_mask)
if self.pooling_mode_lasttoken:
bs, seq_len, hidden_dim = token_embeddings.shape
# attention_mask shape: (bs, seq_len)
# Get shape [bs] indices of the last token (i.e. the last token for each batch item)
# Use flip and max() to get the last index of 1 in the attention mask
if torch.jit.is_tracing():
# Avoid tracing the argmax with int64 input that can not be handled by ONNX Runtime: https://github.com/microsoft/onnxruntime/issues/10068
attention_mask = attention_mask.to(torch.int32)
values, indices = attention_mask.flip(1).max(1)
indices = torch.where(values == 0, seq_len - 1, indices)
gather_indices = seq_len - indices - 1
# Turn indices from shape [bs] --> [bs, 1, hidden_dim]
gather_indices = gather_indices.unsqueeze(-1).repeat(1, hidden_dim)
gather_indices = gather_indices.unsqueeze(1)
assert gather_indices.shape == (bs, 1, hidden_dim)
# Gather along the 1st dim (seq_len) (bs, seq_len, hidden_dim -> bs, hidden_dim)
# Actually no need for the attention mask as we gather the last token where attn_mask = 1
# but as we set some indices (which shouldn't be attended to) to 0 with clamp, we
# use the attention mask to ignore them again
input_mask_expanded = (
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).to(token_embeddings.dtype)
)
embedding = torch.gather(token_embeddings * input_mask_expanded, 1, gather_indices).squeeze(dim=1)
output_vectors.append(embedding)
output_vector = torch.cat(output_vectors, 1)
features["sentence_embedding"] = output_vector
return features
def get_sentence_embedding_dimension(self) -> int:
return self.pooling_output_dimension
def get_config_dict(self) -> dict[str, Any]:
return {key: self.__dict__[key] for key in self.config_keys}
def save(self, output_path) -> None:
with open(os.path.join(output_path, "config.json"), "w") as fOut:
json.dump(self.get_config_dict(), fOut, indent=2)
@staticmethod
def load(input_path) -> Pooling:
with open(os.path.join(input_path, "config.json")) as fIn:
config = json.load(fIn)
return Pooling(**config)