ROCm 执行提供程序
ROCm 执行提供程序支持在启用 AMD ROCm 的 GPU 上进行硬件加速计算。
目录
安装
注意 请务必安装此处指定的正确版本的 Pytorch PyTorch 版本。
对于 Nightly PyTorch 构建版本,请参阅 Pytorch 主页 并选择 ROCm 作为计算平台。
预构建的带有 ROCm EP 的 ONNX Runtime 二进制文件已发布,适用于大多数语言绑定。请参考 安装 ORT。
要求
ONNX Runtime | ROCm |
---|---|
主分支 | 6.0 |
1.17 | 6.0 5.7 |
1.16 | 5.6 5.5 5.4.2 |
1.15 | 5.4.2 5.4 5.3.2 |
1.14 | 5.4 5.3.2 |
1.13 | 5.4 5.3.2 |
1.12 | 5.2.3 5.2 |
构建
有关构建说明,请参阅构建页面。
配置选项
ROCm 执行提供程序支持以下配置选项。
device_id
设备 ID。
默认值:0
tunable_op_enable
设置为使用 TunableOp。
默认值:false
tunable_op_tuning_enable
设置 TunableOp 尝试进行在线调优。
默认值:false
user_compute_stream
定义推理运行的计算流。它隐式设置 has_user_compute_stream
选项。它不能通过 UpdateROCMProviderOptions
设置。这不能与外部分配器结合使用。
Python 用法示例
providers = [("ROCMExecutionProvider", {"device_id": torch.cuda.current_device(),
"user_compute_stream": str(torch.cuda.current_stream().cuda_stream)})]
sess_options = ort.SessionOptions()
sess = ort.InferenceSession("my_model.onnx", sess_options=sess_options, providers=providers)
为了利用用户计算流,建议使用 I/O 绑定 将输入和输出绑定到设备中的张量。
do_copy_in_default_stream
是否在默认流中进行复制或使用单独的流。 推荐设置为 true。 如果为 false,则可能存在竞争条件,并且可能获得更好的性能。
默认值:true
gpu_mem_limit
设备内存 arena 的大小限制(以字节为单位)。此大小限制仅适用于执行提供程序的 arena。总设备内存使用量可能更高。 s:C++ size_t 类型的最大值(实际上是无限的)
注意: 将被 default_memory_arena_cfg
的内容覆盖(如果指定)
arena_extend_strategy
扩展设备内存 arena 的策略。
值 | 描述 |
---|---|
kNextPowerOfTwo (0) | 后续扩展以更大的量扩展(乘以 2 的幂) |
kSameAsRequested (1) | 按请求量扩展 |
默认值:kNextPowerOfTwo
注意: 将被 default_memory_arena_cfg
的内容覆盖(如果指定)
gpu_external_[alloc|free|empty_cache]
gpu_external_* 用于传递外部分配器。 Python 用法示例
from onnxruntime.training.ortmodule.torch_cpp_extensions import torch_gpu_allocator
provider_option_map["gpu_external_alloc"] = str(torch_gpu_allocator.gpu_caching_allocator_raw_alloc_address())
provider_option_map["gpu_external_free"] = str(torch_gpu_allocator.gpu_caching_allocator_raw_delete_address())
provider_option_map["gpu_external_empty_cache"] = str(torch_gpu_allocator.gpu_caching_allocator_empty_cache_address())
默认值:0
用法
C/C++
Ort::Env env = Ort::Env{ORT_LOGGING_LEVEL_ERROR, "Default"};
Ort::SessionOptions so;
int device_id = 0;
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_ROCm(so, device_id));
C API 详细信息在此。
Python
Python API 详细信息在此。
示例
Python
import onnxruntime as ort
model_path = '<path to model>'
providers = [
'ROCMExecutionProvider',
'CPUExecutionProvider',
]
session = ort.InferenceSession(model_path, providers=providers)