sentence_transformers.models.Normalize 源代码

from __future__ import annotations

import torch.nn.functional as F
from torch import Tensor, nn


[文档] class Normalize(nn.Module): """This layer normalizes embeddings to unit length""" def __init__(self) -> None: super().__init__() def forward(self, features: dict[str, Tensor]) -> dict[str, Tensor]: features.update({"sentence_embedding": F.normalize(features["sentence_embedding"], p=2, dim=1)}) return features def save(self, output_path) -> None: pass @staticmethod def load(input_path) -> Normalize: return Normalize()