教程:在线强化学习用于多跳研究¶
警告:此功能是全新的且极具实验性。与DSPY中几乎所有其他功能不同,它目前处于纯粹的概念验证和开发模式,但我们发布它是为了鼓励社区参与。
对于本教程,您还需要DSPy的Arbor RL服务器。
> pip install arbor-ai
> python -m arbor.cli serve --arbor-config arbor.yaml
在您的目录中创建 arbor.yaml,包含类似这样的计划:
inference:
gpu_ids: '0'
training:
gpu_ids: '1, 2'
将GPU 0分配给推理,GPU 1和2分配给训练。
在 [ ]:
Copied!
import dspy
from dspy.clients.lm_local_arbor import ArborProvider
port = 7453
local_lm_name = "Qwen/Qwen2.5-7B-Instruct"
local_lm = dspy.LM(
model=f"openai/arbor:{local_lm_name}",
provider=ArborProvider(),
temperature=0.7,
api_base=f"http://localhost:{port}/v1/",
api_key="arbor",
)
dspy.configure(lm=local_lm)
openai_lm = dspy.LM(model="openai/gpt-4.1-mini")
导入 dspy
从 dspy.clients.lm_local_arbor 导入 ArborProvider
端口 = 7453
本地语言模型名称 = "Qwen/Qwen2.5-7B-Instruct"
本地语言模型 = dspy.LM(
模型=f"openai/arbor:{local_lm_name}",
提供者=ArborProvider(),
温度=0.7,
API基础地址=f"http://localhost:{port}/v1/",
API密钥="arbor",
)
dspy.configure(lm=local_lm)
openai语言模型 = dspy.LM(模型="openai/gpt-4.1-mini")
安装依赖并下载数据¶
为了进行检索,我们将使用很酷的BM25S库,因为它非常轻量级。你可以用任何你喜欢的组件来替换这个。
> pip install -U bm25s PyStemmer "jax[cpu]"
接下来,我们将下载截至2017年所有5,000,000个维基百科页面的摘要快照(即首段内容)。我们将以此作为我们的检索语料库。
这是500MB压缩文件,因此下载和解压缩可能需要2-3分钟。
from dspy.utils import download
download("https://huggingface.co/dspy/cache/resolve/main/wiki.abstracts.2017.tar.gz")
!tar -xzvf wiki.abstracts.2017.tar.gz
然后让我们为BM25检索建立索引!这需要2-3分钟。
在 [ ]:
Copied!
import ujson
import bm25s
import Stemmer
corpus = []
with open("wiki.abstracts.2017.jsonl") as f:
for line in f:
line = ujson.loads(line)
corpus.append(f"{line['title']} | {' '.join(line['text'])}")
stemmer = Stemmer.Stemmer("english")
corpus_tokens = bm25s.tokenize(corpus, stopwords="en", stemmer=stemmer)
retriever = bm25s.BM25(k1=0.9, b=0.4)
retriever.index(corpus_tokens)
导入 ujson
导入 bm25s
导入 Stemmer
语料库 = []
使用 open("wiki.abstracts.2017.jsonl") 作为 f:
对于 f 中的每一行:
行 = ujson.loads(行)
语料库.append(f"{行['title']} | {' '.join(行['text'])}")
词干提取器 = Stemmer.Stemmer("english")
语料库标记 = bm25s.tokenize(语料库, stopwords="en", stemmer=词干提取器)
检索器 = bm25s.BM25(k1=0.9, b=0.4)
检索器.index(语料库标记)
加载HoVer数据集。¶
让我们为任务加载一个数据集。我们将从HoVer多跳任务中加载示例,其中输入是一个(确实!)复杂的声明,而我们寻求的输出是事实核查该声明所需的维基百科页面集合。
在 [ ]:
Copied!
import random
from dspy.datasets import DataLoader
kwargs = dict(fields=("claim", "supporting_facts", "hpqa_id", "num_hops"), input_keys=("claim",))
hover = DataLoader().from_huggingface(dataset_name="hover-nlp/hover", split="train", trust_remote_code=True, **kwargs)
hpqa_ids = set()
hover = [
dspy.Example(claim=x.claim, titles=list(set([y["key"] for y in x.supporting_facts]))).with_inputs("claim")
for x in hover
if x["num_hops"] == 3 and x["hpqa_id"] not in hpqa_ids and not hpqa_ids.add(x["hpqa_id"])
]
random.Random(0).shuffle(hover)
trainset, devset, testset = hover[:600], hover[600:900], hover[900:]
len(trainset), len(devset), len(testset)
import random
from dspy.datasets import DataLoader
kwargs = dict(fields=("claim", "supporting_facts", "hpqa_id", "num_hops"), input_keys=("claim",))
hover = DataLoader().from_huggingface(dataset_name="hover-nlp/hover", split="train", trust_remote_code=True, **kwargs)
hpqa_ids = set()
hover = [
dspy.Example(claim=x.claim, titles=list(set([y["key"] for y in x.supporting_facts]))).with_inputs("claim")
for x in hover
if x["num_hops"] == 3 and x["hpqa_id"] not in hpqa_ids and not hpqa_ids.add(x["hpqa_id"])
]
random.Random(0).shuffle(hover)
trainset, devset, testset = hover[:600], hover[600:900], hover[900:]
len(trainset), len(devset), len(testset)
现在,让我们定义一个函数来在维基百科中进行搜索。这将使用我们的BM25索引。
在 [ ]:
Copied!
def search(query: str, k: int) -> list[str]:
tokens = bm25s.tokenize(query, stopwords="en", stemmer=stemmer, show_progress=False)
results, scores = retriever.retrieve(tokens, k=k, n_threads=1, show_progress=False)
run = {corpus[doc]: float(score) for doc, score in zip(results[0], scores[0])}
return list(run.keys())
def search(query: str, k: int) -> list[str]:
tokens = bm25s.tokenize(query, stopwords="en", stemmer=stemmer, show_progress=False)
results, scores = retriever.retrieve(tokens, k=k, n_threads=1, show_progress=False)
run = {corpus[doc]: float(score) for doc, score in zip(results[0], scores[0])}
return list(run.keys())
一个用于多跳研究的DSPY程序¶
现在,让我们在DSPy中定义多跳程序。它会非常简单,由generate_query和append_notes模块组成。我们会仔细定义指令,尽管通常这不是必需的。
在 [ ]:
Copied!
instr1 = """
Given a claim and some key facts, generate a follow-up search query to find the next most essential clue towards verifying or refuting the claim. The goal ultimately is to find all documents implicated by the claim.
""".strip()
instr2 = """
Given a claim, some key facts, and new search results, identify any new learnings from the new search results, which will extend the key facts known so far about the whether the claim is true or false. The goal is to ultimately collect all facts that would help us find all documents implicated by the claim.
"""
class ResearchHop(dspy.Module):
def __init__(self, num_docs, num_hops):
self.num_docs, self.num_hops = num_docs, num_hops
self.generate_query = dspy.ChainOfThought(dspy.Signature("claim, key_facts -> followup_search_query", instr1))
self.append_notes = dspy.ChainOfThought(dspy.Signature("claim, key_facts, new_search_results -> new_key_facts", instr2))
def forward(self, claim: str) -> list[str]:
key_facts = []
retrieved_docs = []
for hop_idx in range(self.num_hops):
query = self.generate_query(claim=claim, key_facts=key_facts).followup_search_query if hop_idx else claim
search_results = search(query, k=self.num_docs)
retrieved_docs.extend(search_results)
if hop_idx == self.num_hops - 1:
break
prediction = self.append_notes(claim=claim, key_facts=key_facts, new_search_results=search_results)
key_facts.append(prediction.new_key_facts)
return dspy.Prediction(key_facts=key_facts, retrieved_docs=retrieved_docs)
instr1 = """
给定一个声明和一些关键事实,生成一个后续搜索查询,以找到验证或反驳该声明的下一个最关键线索。最终目标是找到声明所涉及的所有文档。
""".strip()
instr2 = """
给定一个声明、一些关键事实和新的搜索结果,从新的搜索结果中识别出任何新的发现,这些发现将扩展目前已知的关于声明真伪的关键事实。最终目标是收集所有有助于我们找到声明所涉及的所有文档的事实。
"""
class ResearchHop(dspy.Module):
def __init__(self, num_docs, num_hops):
self.num_docs, self.num_hops = num_docs, num_hops
self.generate_query = dspy.ChainOfThought(dspy.Signature("claim, key_facts -> followup_search_query", instr1))
self.append_notes = dspy.ChainOfThought(dspy.Signature("claim, key_facts, new_search_results -> new_key_facts", instr2))
def forward(self, claim: str) -> list[str]:
key_facts = []
retrieved_docs = []
for hop_idx in range(self.num_hops):
query = self.generate_query(claim=claim, key_facts=key_facts).followup_search_query if hop_idx else claim
search_results = search(query, k=self.num_docs)
retrieved_docs.extend(search_results)
if hop_idx == self.num_hops - 1:
break
prediction = self.append_notes(claim=claim, key_facts=key_facts, new_search_results=search_results)
key_facts.append(prediction.new_key_facts)
return dspy.Prediction(key_facts=key_facts, retrieved_docs=retrieved_docs)
定义此任务中的成功指标¶
在 [ ]:
Copied!
def recall(example, pred, trace=None):
gold_titles = example.titles
retrieved_titles = [doc.split(" | ")[0] for doc in pred.retrieved_docs]
return sum(x in retrieved_titles for x in set(gold_titles)) / len(gold_titles)
evaluate = dspy.Evaluate(devset=devset, metric=recall, num_threads=16, display_progress=True, display_table=5)
def recall(example, pred, trace=None):
gold_titles = example.titles
retrieved_titles = [doc.split(" | ")[0] for doc in pred.retrieved_docs]
return sum(x in retrieved_titles for x in set(gold_titles)) / len(gold_titles)
evaluate = dspy.Evaluate(devset=devset, metric=recall, num_threads=16, display_progress=True, display_table=5)
使用dspy.GRPO优化ResearchHop系统¶
在 [ ]:
Copied!
from dspy.teleprompt.grpo import GRPO
program = ResearchHop(num_docs=4, num_hops=2)
program.set_lm(local_lm)
# NOTE: Training on 6 GPUs.
train_kwargs = {
"per_device_train_batch_size": 2,
"gradient_accumulation_steps": 4,
"temperature": 0.7,
"beta": 0.04,
"learning_rate": 2e-5,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": False},
"bf16": True,
"lr_scheduler_type": "constant_with_warmup",
"max_prompt_length": None,
"max_completion_length": None,
"scale_rewards": True,
"max_grad_norm": 0.5,
"lora": True,
}
compiler = GRPO(
metric=recall,
multitask=True,
num_dspy_examples_per_grpo_step=6,
num_samples_per_input=8,
exclude_demos=True,
num_train_steps=500,
num_threads=24,
use_train_as_val=False,
num_steps_for_val=10,
train_kwargs=train_kwargs,
report_train_scores=False,
)
optimized_program = compiler.compile(
student=program,
trainset=trainset,
valset=devset,
)
from dspy.teleprompt.grpo import GRPO
程序 = ResearchHop(文档数量=4, 跳数=2)
程序.设置语言模型(本地语言模型)
# 注意:在6个GPU上进行训练。
训练参数 = {
"每设备训练批次大小": 2,
"梯度累积步数": 4,
"温度": 0.7,
"beta": 0.04,
"学习率": 2e-5,
"梯度检查点": True,
"梯度检查点参数": {"使用可重入": False},
"bf16": True,
"学习率调度器类型": "带热身的常数",
"最大提示长度": None,
"最大完成长度": None,
"缩放奖励": True,
"最大梯度范数": 0.5,
"lora": True,
}
编译器 = GRPO(
指标=召回率,
多任务=True,
每GRPO步的dspy示例数=6,
每输入样本数=8,
排除演示=True,
训练步数=500,
线程数=24,
使用训练集作为验证集=False,
验证步数=10,
训练参数=训练参数,
报告训练分数=False,
)
优化程序 = 编译器.编译(
学生=程序,
训练集=训练集,
验证集=开发集,
)
现在,你可以使用GRPO'ed程序。
在 [ ]:
Copied!
example = devset[0]
optimized_program(**example.inputs())
示例 = 开发集[0]
优化程序(**示例.输入())
在我们初步实验中,训练约18小时后,召回率(开发集)从61.8%提升至66.2%。从成本/质量角度看,这通常比运行提示优化器dspy.MIPROv2或dspy.SIMBA的效果要差,但对于小型语言模型的任意语言模型程序进行在线强化学习来说,这仍然是一个非常扎实的开端。