sentence_transformers.models.Pooling 源代码

from __future__ import annotations

import json
import os
from typing import Any

import torch
from torch import Tensor, nn


[文档] class Pooling(nn.Module): """ Performs pooling (max or mean) on the token embeddings. Using pooling, it generates from a variable sized sentence a fixed sized sentence embedding. This layer also allows to use the CLS token if it is returned by the underlying word embedding model. You can concatenate multiple poolings together. Args: word_embedding_dimension: Dimensions for the word embeddings pooling_mode: Either "cls", "lasttoken", "max", "mean", "mean_sqrt_len_tokens", or "weightedmean". If set, overwrites the other pooling_mode_* settings pooling_mode_cls_token: Use the first token (CLS token) as text representations pooling_mode_max_tokens: Use max in each dimension over all tokens. pooling_mode_mean_tokens: Perform mean-pooling pooling_mode_mean_sqrt_len_tokens: Perform mean-pooling, but divide by sqrt(input_length). pooling_mode_weightedmean_tokens: Perform (position) weighted mean pooling. See `SGPT: GPT Sentence Embeddings for Semantic Search <https://arxiv.org/abs/2202.08904>`_. pooling_mode_lasttoken: Perform last token pooling. See `SGPT: GPT Sentence Embeddings for Semantic Search <https://arxiv.org/abs/2202.08904>`_ and `Text and Code Embeddings by Contrastive Pre-Training <https://arxiv.org/abs/2201.10005>`_. """ POOLING_MODES = ( "cls", "lasttoken", "max", "mean", "mean_sqrt_len_tokens", "weightedmean", ) def __init__( self, word_embedding_dimension: int, pooling_mode: str = None, pooling_mode_cls_token: bool = False, pooling_mode_max_tokens: bool = False, pooling_mode_mean_tokens: bool = True, pooling_mode_mean_sqrt_len_tokens: bool = False, pooling_mode_weightedmean_tokens: bool = False, pooling_mode_lasttoken: bool = False, include_prompt=True, ) -> None: super().__init__() self.config_keys = [ "word_embedding_dimension", "pooling_mode_cls_token", "pooling_mode_mean_tokens", "pooling_mode_max_tokens", "pooling_mode_mean_sqrt_len_tokens", "pooling_mode_weightedmean_tokens", "pooling_mode_lasttoken", "include_prompt", ] if pooling_mode is not None: # Set pooling mode by string pooling_mode = pooling_mode.lower() if pooling_mode not in self.POOLING_MODES: raise ValueError( f"Set invalid pooling mode: {pooling_mode}. Valid pooling modes are: {self.POOLING_MODES}." ) pooling_mode_cls_token = pooling_mode == "cls" pooling_mode_max_tokens = pooling_mode == "max" pooling_mode_mean_tokens = pooling_mode == "mean" pooling_mode_mean_sqrt_len_tokens = pooling_mode == "mean_sqrt_len_tokens" pooling_mode_weightedmean_tokens = pooling_mode == "weightedmean" pooling_mode_lasttoken = pooling_mode == "lasttoken" self.word_embedding_dimension = word_embedding_dimension self.pooling_mode_cls_token = pooling_mode_cls_token self.pooling_mode_mean_tokens = pooling_mode_mean_tokens self.pooling_mode_max_tokens = pooling_mode_max_tokens self.pooling_mode_mean_sqrt_len_tokens = pooling_mode_mean_sqrt_len_tokens self.pooling_mode_weightedmean_tokens = pooling_mode_weightedmean_tokens self.pooling_mode_lasttoken = pooling_mode_lasttoken self.include_prompt = include_prompt pooling_mode_multiplier = sum( [ pooling_mode_cls_token, pooling_mode_max_tokens, pooling_mode_mean_tokens, pooling_mode_mean_sqrt_len_tokens, pooling_mode_weightedmean_tokens, pooling_mode_lasttoken, ] ) self.pooling_output_dimension = pooling_mode_multiplier * word_embedding_dimension def __repr__(self) -> str: return f"Pooling({self.get_config_dict()})" def get_pooling_mode_str(self) -> str: """Returns the pooling mode as string""" modes = [] if self.pooling_mode_cls_token: modes.append("cls") if self.pooling_mode_mean_tokens: modes.append("mean") if self.pooling_mode_max_tokens: modes.append("max") if self.pooling_mode_mean_sqrt_len_tokens: modes.append("mean_sqrt_len_tokens") if self.pooling_mode_weightedmean_tokens: modes.append("weightedmean") if self.pooling_mode_lasttoken: modes.append("lasttoken") return "+".join(modes) def forward(self, features: dict[str, Tensor]) -> dict[str, Tensor]: token_embeddings = features["token_embeddings"] attention_mask = ( features["attention_mask"] if "attention_mask" in features else torch.ones(token_embeddings.shape[:-1], device=token_embeddings.device, dtype=torch.int64) ) if not self.include_prompt and "prompt_length" in features: attention_mask[:, : features["prompt_length"]] = 0 ## Pooling strategy output_vectors = [] if self.pooling_mode_cls_token: cls_token = features.get("cls_token_embeddings", token_embeddings[:, 0]) # Take first token by default output_vectors.append(cls_token) if self.pooling_mode_max_tokens: input_mask_expanded = ( attention_mask.unsqueeze(-1).expand(token_embeddings.size()).to(token_embeddings.dtype) ) token_embeddings[input_mask_expanded == 0] = -1e9 # Set padding tokens to large negative value max_over_time = torch.max(token_embeddings, 1)[0] output_vectors.append(max_over_time) if self.pooling_mode_mean_tokens or self.pooling_mode_mean_sqrt_len_tokens: input_mask_expanded = ( attention_mask.unsqueeze(-1).expand(token_embeddings.size()).to(token_embeddings.dtype) ) sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) # If tokens are weighted (by WordWeights layer), feature 'token_weights_sum' will be present if "token_weights_sum" in features: sum_mask = features["token_weights_sum"].unsqueeze(-1).expand(sum_embeddings.size()) else: sum_mask = input_mask_expanded.sum(1) sum_mask = torch.clamp(sum_mask, min=1e-9) if self.pooling_mode_mean_tokens: output_vectors.append(sum_embeddings / sum_mask) if self.pooling_mode_mean_sqrt_len_tokens: output_vectors.append(sum_embeddings / torch.sqrt(sum_mask)) if self.pooling_mode_weightedmean_tokens: input_mask_expanded = ( attention_mask.unsqueeze(-1).expand(token_embeddings.size()).to(token_embeddings.dtype) ) # token_embeddings shape: bs, seq, hidden_dim weights = ( torch.arange(start=1, end=token_embeddings.shape[1] + 1) .unsqueeze(0) .unsqueeze(-1) .expand(token_embeddings.size()) .to(token_embeddings.dtype) .to(token_embeddings.device) ) assert weights.shape == token_embeddings.shape == input_mask_expanded.shape input_mask_expanded = input_mask_expanded * weights sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) # If tokens are weighted (by WordWeights layer), feature 'token_weights_sum' will be present if "token_weights_sum" in features: sum_mask = features["token_weights_sum"].unsqueeze(-1).expand(sum_embeddings.size()) else: sum_mask = input_mask_expanded.sum(1) sum_mask = torch.clamp(sum_mask, min=1e-9) output_vectors.append(sum_embeddings / sum_mask) if self.pooling_mode_lasttoken: bs, seq_len, hidden_dim = token_embeddings.shape # attention_mask shape: (bs, seq_len) # Get shape [bs] indices of the last token (i.e. the last token for each batch item) # Use flip and max() to get the last index of 1 in the attention mask if torch.jit.is_tracing(): # Avoid tracing the argmax with int64 input that can not be handled by ONNX Runtime: https://github.com/microsoft/onnxruntime/issues/10068 attention_mask = attention_mask.to(torch.int32) values, indices = attention_mask.flip(1).max(1) indices = torch.where(values == 0, seq_len - 1, indices) gather_indices = seq_len - indices - 1 # Turn indices from shape [bs] --> [bs, 1, hidden_dim] gather_indices = gather_indices.unsqueeze(-1).repeat(1, hidden_dim) gather_indices = gather_indices.unsqueeze(1) assert gather_indices.shape == (bs, 1, hidden_dim) # Gather along the 1st dim (seq_len) (bs, seq_len, hidden_dim -> bs, hidden_dim) # Actually no need for the attention mask as we gather the last token where attn_mask = 1 # but as we set some indices (which shouldn't be attended to) to 0 with clamp, we # use the attention mask to ignore them again input_mask_expanded = ( attention_mask.unsqueeze(-1).expand(token_embeddings.size()).to(token_embeddings.dtype) ) embedding = torch.gather(token_embeddings * input_mask_expanded, 1, gather_indices).squeeze(dim=1) output_vectors.append(embedding) output_vector = torch.cat(output_vectors, 1) features["sentence_embedding"] = output_vector return features def get_sentence_embedding_dimension(self) -> int: return self.pooling_output_dimension def get_config_dict(self) -> dict[str, Any]: return {key: self.__dict__[key] for key in self.config_keys} def save(self, output_path) -> None: with open(os.path.join(output_path, "config.json"), "w") as fOut: json.dump(self.get_config_dict(), fOut, indent=2) @staticmethod def load(input_path) -> Pooling: with open(os.path.join(input_path, "config.json")) as fIn: config = json.load(fIn) return Pooling(**config)