采样器¶
批量采样器¶
- class sentence_transformers.training_args.BatchSamplers(value)[源代码][源代码]¶
存储批量采样器可接受的字符串标识符。
批量采样器负责在训练期间确定样本如何分组为批次。有效选项包括:
BatchSamplers.BATCH_SAMPLER
: [默认] 使用 :class:~sentence_transformers.sampler.DefaultBatchSampler
,这是默认的 PyTorch 批采样器。BatchSamplers.NO_DUPLICATES
: 使用 :class:~sentence_transformers.sampler.NoDuplicatesBatchSampler
,确保批次中没有重复样本。推荐用于使用批次内负样本的损失函数,例如:BatchSamplers.GROUP_BY_LABEL
: 使用 :class:~sentence_transformers.sampler.GroupByLabelBatchSampler
,确保每个批次包含来自同一标签的2个以上样本。推荐用于需要同一标签多个样本的损失函数,例如:
如果你想使用自定义的批量采样器,你可以创建一个新的 Trainer 类,该类继承自 :class:
~sentence_transformers.trainer.SentenceTransformerTrainer
并重写 :meth:~sentence_transformers.trainer.SentenceTransformerTrainer.get_batch_sampler
方法。该方法必须返回一个支持__iter__
和__len__
方法的类实例。前者应为每个批次生成一个索引列表,后者应返回批次的数量。- 用法:
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments from sentence_transformers.training_args import BatchSamplers from sentence_transformers.losses import MultipleNegativesRankingLoss from datasets import Dataset model = SentenceTransformer("microsoft/mpnet-base") train_dataset = Dataset.from_dict({ "anchor": ["It's nice weather outside today.", "He drove to work."], "positive": ["It's so sunny.", "He took the car to the office."], }) loss = MultipleNegativesRankingLoss(model) args = SentenceTransformerTrainingArguments( output_dir="checkpoints", batch_sampler=BatchSamplers.NO_DUPLICATES, ) trainer = SentenceTransformerTrainer( model=model, args=args, train_dataset=train_dataset, loss=loss, ) trainer.train()
- class sentence_transformers.sampler.DefaultBatchSampler(*args, **kwargs)[源代码][源代码]¶
此采样器是 SentenceTransformer 库中使用的默认批次采样器。它等同于 PyTorch 的 BatchSampler。
- 参数:
sampler (Sampler or Iterable) -- 用于从数据集中采样元素的采样器,例如 SubsetRandomSampler。
batch_size (int) -- 每个批次的样本数量。
drop_last (bool) -- 如果为 True,当数据集大小不能被批次大小整除时,丢弃最后一个不完整的批次。
- class sentence_transformers.sampler.NoDuplicatesBatchSampler(dataset: Dataset, batch_size: int, drop_last: bool, valid_label_columns: list[str] = [], generator: torch.Generator = None, seed: int = 0)[源代码][源代码]¶
此采样器创建批次,使得每个批次中的样本在各列中的值都是唯一的。这在损失函数考虑批次中的其他样本作为批次内负样本时非常有用,并且您希望确保负样本不是锚点/正样本的重复。
- 推荐用于:
- 参数:
dataset (Dataset) -- 要从中采样的数据集。
batch_size (int) -- 每个批次的样本数量。
drop_last (bool) -- 如果为 True,当数据集大小不能被批次大小整除时,丢弃最后一个不完整的批次。
valid_label_columns (List[str]) -- 要检查标签的列名列表。数据集中找到的第一个
valid_label_columns
中的列名将被用作标签列。generator (torch.Generator, optional) -- 用于打乱索引的可选随机数生成器。
seed (int, optional) -- 随机数生成器的种子,以确保可重复性。
- class sentence_transformers.sampler.GroupByLabelBatchSampler(dataset: Dataset, batch_size: int, drop_last: bool, valid_label_columns: list[str] = None, generator: torch.Generator = None, seed: int = 0)[源代码][源代码]¶
此采样器按样本的标签对其进行分组,并旨在创建批次,使得每个批次中的样本标签尽可能同质化。此采样器旨在与
Batch...TripletLoss
类一起使用,这些类要求每个批次至少包含每个标签类的2个示例。- 推荐用于:
- 参数:
dataset (Dataset) -- 要从中采样的数据集。
batch_size (int) -- 每个批次的样本数量。必须是2的倍数。
drop_last (bool) -- 如果为 True,当数据集大小不能被批次大小整除时,丢弃最后一个不完整的批次。
valid_label_columns (List[str]) -- 要检查标签的列名列表。数据集中找到的第一个
valid_label_columns
中的列名将被用作标签列。generator (torch.Generator, optional) -- 用于打乱索引的可选随机数生成器。
seed (int, optional) -- 随机数生成器的种子,以确保可重复性。
MultiDatasetBatchSamplers¶
- class sentence_transformers.training_args.MultiDatasetBatchSamplers(value)[源代码][源代码]¶
存储多数据集批采样器可接受的字符串标识符。
多数据集批采样器负责在训练期间确定从多个数据集中按什么顺序采样批次。有效选项包括:
MultiDatasetBatchSamplers.ROUND_ROBIN
: 使用 :class:~sentence_transformers.sampler.RoundRobinBatchSampler
,该采样器从每个数据集中按轮询方式采样,直到其中一个数据集耗尽。使用这种策略,可能不会使用每个数据集中的所有样本,但每个数据集都被均匀采样。MultiDatasetBatchSamplers.PROPORTIONAL
: [默认] 使用 :class:~sentence_transformers.sampler.ProportionalBatchSampler
,该采样器按数据集的大小比例从每个数据集中采样。使用此策略,每个数据集的所有样本都会被使用,并且较大的数据集会被更频繁地采样。
- 用法:
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments from sentence_transformers.training_args import MultiDatasetBatchSamplers from sentence_transformers.losses import CoSENTLoss from datasets import Dataset, DatasetDict model = SentenceTransformer("microsoft/mpnet-base") train_general = Dataset.from_dict({ "sentence_A": ["It's nice weather outside today.", "He drove to work."], "sentence_B": ["It's so sunny.", "He took the car to the bank."], "score": [0.9, 0.4], }) train_medical = Dataset.from_dict({ "sentence_A": ["The patient has a fever.", "The doctor prescribed medication.", "The patient is sweating."], "sentence_B": ["The patient feels hot.", "The medication was given to the patient.", "The patient is perspiring."], "score": [0.8, 0.6, 0.7], }) train_legal = Dataset.from_dict({ "sentence_A": ["This contract is legally binding.", "The parties agree to the terms and conditions."], "sentence_B": ["Both parties acknowledge their obligations.", "By signing this agreement, the parties enter into a legal relationship."], "score": [0.7, 0.8], }) train_dataset = DatasetDict({ "general": train_general, "medical": train_medical, "legal": train_legal, }) loss = CoSENTLoss(model) args = SentenceTransformerTrainingArguments( output_dir="checkpoints", multi_dataset_batch_sampler=MultiDatasetBatchSamplers.PROPORTIONAL, ) trainer = SentenceTransformerTrainer( model=model, args=args, train_dataset=train_dataset, loss=loss, ) trainer.train()
- class sentence_transformers.sampler.RoundRobinBatchSampler(dataset: ConcatDataset, batch_samplers: list[BatchSampler], generator: Generator | None = None, seed: int | None = None)[源代码][源代码]¶
批量采样器,以轮询方式从多个批量采样器中生成批次,直到其中一个耗尽。使用此采样器,不太可能使用每个数据集中的所有样本,但我们确实确保每个数据集被均匀采样。
- 参数:
dataset (ConcatDataset) -- 多个数据集的连接。
batch_samplers (List[BatchSampler]) -- 批量采样器列表,每个数据集对应 ConcatDataset 中的一个。
generator (torch.Generator, optional) -- 一个用于可重复采样的生成器。默认为 None。
seed (int, optional) -- 生成器的种子。默认为 None。
- class sentence_transformers.sampler.ProportionalBatchSampler(dataset: ConcatDataset, batch_samplers: list[BatchSampler], generator: Generator, seed: int)[源代码][源代码]¶
批量采样器,按数据集大小比例从每个数据集中采样,直到所有数据集同时耗尽。使用此采样器,每个数据集的所有样本都会被使用,并且较大的数据集会被更频繁地采样。
- 参数:
dataset (ConcatDataset) -- 多个数据集的连接。
batch_samplers (List[BatchSampler]) -- 批量采样器列表,每个数据集对应 ConcatDataset 中的一个。
generator (torch.Generator, optional) -- 一个用于可重复采样的生成器。默认为 None。
seed (int, optional) -- 生成器的种子。默认为 None。