ROCm 执行提供程序
ROCm 执行提供程序能够在支持 AMD ROCm 的 GPU 上实现硬件加速计算。
目录
安装
注意 请务必安装此处指定的正确版本 PyTorch PyTorch 版本。
对于 Nightly PyTorch 构建,请参阅 PyTorch 主页 并选择 ROCm 作为计算平台。
大多数语言绑定都提供了带有 ROCm EP 的 ONNX Runtime 预构建二进制文件。请参考 安装 ORT。
要求
ONNX Runtime | ROCm |
---|---|
main | 6.0 |
1.17 | 6.0 5.7 |
1.16 | 5.6 5.5 5.4.2 |
1.15 | 5.4.2 5.4 5.3.2 |
1.14 | 5.4 5.3.2 |
1.13 | 5.4 5.3.2 |
1.12 | 5.2.3 5.2 |
构建
有关构建说明,请参阅 构建页面。
配置选项
ROCm 执行提供程序支持以下配置选项。
device_id
设备 ID。
默认值: 0
tunable_op_enable
设置为使用 TunableOp。
默认值: false
tunable_op_tuning_enable
设置 TunableOp 尝试进行在线调优。
默认值: false
user_compute_stream
定义用于推理运行的计算流。它隐式设置了 has_user_compute_stream
选项。无法通过 UpdateROCMProviderOptions
进行设置。不能与外部分配器结合使用。
Python 用法示例
providers = [("ROCMExecutionProvider", {"device_id": torch.cuda.current_device(),
"user_compute_stream": str(torch.cuda.current_stream().cuda_stream)})]
sess_options = ort.SessionOptions()
sess = ort.InferenceSession("my_model.onnx", sess_options=sess_options, providers=providers)
为了利用用户计算流,建议使用 I/O 绑定 将输入和输出绑定到设备中的张量。
do_copy_in_default_stream
是否在默认流中进行复制或使用单独的流。建议设置为 true。如果设置为 false,可能存在竞态条件,但也可能获得更好的性能。
默认值: true
gpu_mem_limit
设备内存竞技场(arena)的大小限制,单位为字节。此大小限制仅适用于执行提供程序的竞技场。总设备内存使用量可能更高。s: C++ size_t 类型的最大值(实际上是无限制的)
注意: 将被 default_memory_arena_cfg
的内容覆盖(如果指定)
arena_extend_strategy
扩展设备内存竞技场的策略。
值 | 描述 |
---|---|
kNextPowerOfTwo (0) | 后续扩展以更大的量进行(乘以 2 的幂) |
kSameAsRequested (1) | 按请求的量扩展 |
默认值: kNextPowerOfTwo
注意: 将被 default_memory_arena_cfg
的内容覆盖(如果指定)
gpu_external_[alloc|free|empty_cache]
gpu_external_* 用于传递外部分配器。Python 用法示例
from onnxruntime.training.ortmodule.torch_cpp_extensions import torch_gpu_allocator
provider_option_map["gpu_external_alloc"] = str(torch_gpu_allocator.gpu_caching_allocator_raw_alloc_address())
provider_option_map["gpu_external_free"] = str(torch_gpu_allocator.gpu_caching_allocator_raw_delete_address())
provider_option_map["gpu_external_empty_cache"] = str(torch_gpu_allocator.gpu_caching_allocator_empty_cache_address())
默认值: 0
用法
C/C++
Ort::Env env = Ort::Env{ORT_LOGGING_LEVEL_ERROR, "Default"};
Ort::SessionOptions so;
int device_id = 0;
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_ROCm(so, device_id));
C API 详情请参阅 此处。
Python
Python API 详情请参阅 此处。
示例
Python
import onnxruntime as ort
model_path = '<path to model>'
providers = [
'ROCMExecutionProvider',
'CPUExecutionProvider',
]
session = ort.InferenceSession(model_path, providers=providers)