dspy.Refine
dspy.Refine(module: Module, N: int, reward_fn: Callable[[dict, Prediction], float], threshold: float, fail_count: int | None = None)
基础类: Module
通过使用不同的rollout ID在temperature=1.0下运行模块最多N次来优化模块,并返回最佳预测结果。
本模块通过使用不同的展开标识符多次运行提供的模块,并选择第一个超过指定阈值的预测或具有最高奖励的预测。 如果没有预测满足阈值,它会自动生成反馈以改进未来的预测。
参数:
| 名称 | 类型 | 描述 | 默认值 |
|---|---|---|---|
module
|
Module
|
要优化的模块。 |
必填 |
N
|
int
|
模块运行的次数。必须 |
必填 |
reward_fn
|
Callable
|
奖励函数。 |
必填 |
threshold
|
float
|
奖励函数的阈值。 |
必填 |
fail_count
|
Optional[int]
|
模块在抛出错误前可以失败的次数 |
None
|
示例
import dspy
dspy.settings.configure(lm=dspy.LM("openai/gpt-4o-mini"))
# Define a QA module with chain of thought
qa = dspy.ChainOfThought("question -> answer")
# Define a reward function that checks for one-word answers
def one_word_answer(args, pred):
return 1.0 if len(pred.answer.split()) == 1 else 0.0
# Create a refined module that tries up to 3 times
best_of_3 = dspy.Refine(module=qa, N=3, reward_fn=one_word_answer, threshold=1.0)
# Use the refined module
result = best_of_3(question="What is the capital of Belgium?").answer
# Returns: Brussels
Source code in dspy/predict/refine.py
函数
__call__(*args, **kwargs) -> Prediction
Source code in dspy/primitives/module.py
acall(*args, **kwargs) -> Prediction
async
Source code in dspy/primitives/module.py
batch(examples: list[Example], num_threads: int | None = None, max_errors: int | None = None, return_failed_examples: bool = False, provide_traceback: bool | None = None, disable_progress_bar: bool = False) -> list[Example] | tuple[list[Example], list[Example], list[Exception]]
使用Parallel模块并行处理dspy.Example实例列表。
参数:
| 名称 | 类型 | 描述 | 默认值 |
|---|---|---|---|
examples
|
list[Example]
|
要处理的dspy.Example实例列表。 |
必填 |
num_threads
|
int | None
|
用于并行处理的线程数量。 |
None
|
max_errors
|
int | None
|
Maximum number of errors allowed before stopping execution.
If |
None
|
return_failed_examples
|
bool
|
是否返回失败的示例和异常。 |
False
|
provide_traceback
|
bool | None
|
是否在错误日志中包含回溯信息。 |
None
|
disable_progress_bar
|
bool
|
是否显示进度条。 |
False
|
返回:
| 类型 | 描述 |
|---|---|
list[Example] | tuple[list[Example], list[Example], list[Exception]]
|
结果列表,以及可选的失败示例和异常。 |
Source code in dspy/primitives/module.py
deepcopy()
深拷贝模块。
这是对默认Python深拷贝的一个调整,仅对self.parameters()进行深拷贝,而对于其他属性,我们只进行浅拷贝。
Source code in dspy/primitives/base_module.py
dump_state(json_mode=True)
forward(**kwargs)
Source code in dspy/predict/refine.py
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 | |
get_lm()
Source code in dspy/primitives/module.py
inspect_history(n: int = 1)
load(path)
加载已保存的模块。如果您想加载整个程序,而不仅仅是现有程序的状态,您可能还想查看 dspy.load。
参数:
| 名称 | 类型 | 描述 | 默认值 |
|---|---|---|---|
path
|
str
|
保存状态文件的路径,应为 .json 或 .pkl 文件 |
必填 |
Source code in dspy/primitives/base_module.py
load_state(state)
map_named_predictors(func)
named_parameters()
与PyTorch不同,它也能处理(非递归的)参数列表。
Source code in dspy/primitives/base_module.py
named_predictors()
named_sub_modules(type_=None, skip_compiled=False) -> Generator[tuple[str, BaseModule], None, None]
查找模块中的所有子模块及其名称。
假设 self.children[4]['key'].sub_module 是一个子模块。那么名称将是
children[4]['key'].sub_module。但如果该子模块可以通过不同路径访问,则只会返回其中一个路径。
Source code in dspy/primitives/base_module.py
parameters()
predictors()
reset_copy()
save(path, save_program=False, modules_to_serialize=None)
保存模块。
将模块保存到目录或文件。有两种模式:
- save_program=False: 仅将模块的状态保存为json或pickle文件,具体取决于文件扩展名的值。
- save_program=True: 通过cloudpickle将整个模块保存到目录,其中包含模型的状态和架构。
如果 save_program=True 并且提供了 modules_to_serialize,它将使用 cloudpickle 的 register_pickle_by_value 注册这些模块进行序列化。
这会使 cloudpickle 按值而不是按引用序列化模块,确保模块与保存的程序一起完全保留。当您有需要与程序一起序列化的自定义模块时,这非常有用。
如果为 None,则不会注册任何模块进行序列化。
我们还会保存依赖版本,以便加载的模型可以检查是否存在关键依赖或dspy版本的版本不匹配问题。
参数:
| 名称 | 类型 | 描述 | 默认值 |
|---|---|---|---|
path
|
str
|
保存状态文件的路径,当 |
必填 |
save_program
|
bool
|
如果为True,则通过cloudpickle将整个模块保存到目录,否则仅保存状态。 |
False
|
modules_to_serialize
|
list
|
一个模块列表,用于通过 cloudpickle 的 |
None
|
Source code in dspy/primitives/base_module.py
163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 | |
:::