PyTorch 自定义操作符
创建于:2024年6月18日 | 最后更新:2024年7月22日 | 最后验证:2024年11月5日
PyTorch 提供了一个庞大的运算符库,这些运算符作用于张量(例如 torch.add
,
torch.sum
等)。然而,您可能希望将一个新的自定义操作引入 PyTorch,
并使其与 torch.compile
、autograd 和 torch.vmap
等子系统一起工作。
为此,您必须通过 Python 的 torch.library 文档 或 C++ 的 TORCH_LIBRARY
API 向 PyTorch 注册自定义操作。
从Python编写自定义操作符
请参阅自定义Python操作符。
如果您希望从Python(而不是C++)编写自定义操作符,您可能希望:
你有一个Python函数,你希望PyTorch将其视为一个不透明的可调用对象,特别是在
torch.compile
和torch.export
方面。你有一些Python绑定到C++/CUDA内核,并希望这些能与PyTorch子系统(如
torch.compile
或torch.autograd
)组合使用
将自定义C++和/或CUDA代码与PyTorch集成
请参阅自定义C++和CUDA运算符。
如果您希望从C++(而不是Python)编写自定义操作符,您可能希望:
你有自定义的C++和/或CUDA代码。
你计划使用此代码与
AOTInductor
进行无Python推理。
自定义操作符手册
有关教程和本页面未涵盖的信息,请参阅 The Custom Operators Manual (我们正在努力将信息迁移到我们的文档站点)。我们建议您 首先阅读上述教程之一,然后将自定义操作符手册作为参考; 它并不适合从头到尾阅读。