ROCm 执行提供程序

ROCm 执行提供程序支持在启用 AMD ROCm 的 GPU 上进行硬件加速计算。

目录

安装

注意 请务必安装此处指定的正确版本的 Pytorch PyTorch 版本

对于 Nightly PyTorch 构建版本,请参阅 Pytorch 主页 并选择 ROCm 作为计算平台。

预构建的带有 ROCm EP 的 ONNX Runtime 二进制文件已发布,适用于大多数语言绑定。请参考 安装 ORT

要求

ONNX Runtime ROCm
主分支 6.0
1.17 6.0
5.7
1.16 5.6
5.5
5.4.2
1.15 5.4.2
5.4
5.3.2
1.14 5.4
5.3.2
1.13 5.4
5.3.2
1.12 5.2.3
5.2

构建

有关构建说明,请参阅构建页面

配置选项

ROCm 执行提供程序支持以下配置选项。

device_id

设备 ID。

默认值:0

tunable_op_enable

设置为使用 TunableOp。

默认值:false

tunable_op_tuning_enable

设置 TunableOp 尝试进行在线调优。

默认值:false

user_compute_stream

定义推理运行的计算流。它隐式设置 has_user_compute_stream 选项。它不能通过 UpdateROCMProviderOptions 设置。这不能与外部分配器结合使用。

Python 用法示例

providers = [("ROCMExecutionProvider", {"device_id": torch.cuda.current_device(),
                                        "user_compute_stream": str(torch.cuda.current_stream().cuda_stream)})]
sess_options = ort.SessionOptions()
sess = ort.InferenceSession("my_model.onnx", sess_options=sess_options, providers=providers)

为了利用用户计算流,建议使用 I/O 绑定 将输入和输出绑定到设备中的张量。

do_copy_in_default_stream

是否在默认流中进行复制或使用单独的流。 推荐设置为 true。 如果为 false,则可能存在竞争条件,并且可能获得更好的性能。

默认值:true

gpu_mem_limit

设备内存 arena 的大小限制(以字节为单位)。此大小限制仅适用于执行提供程序的 arena。总设备内存使用量可能更高。 s:C++ size_t 类型的最大值(实际上是无限的)

注意: 将被 default_memory_arena_cfg 的内容覆盖(如果指定)

arena_extend_strategy

扩展设备内存 arena 的策略。

描述
kNextPowerOfTwo (0) 后续扩展以更大的量扩展(乘以 2 的幂)
kSameAsRequested (1) 按请求量扩展

默认值:kNextPowerOfTwo

注意: 将被 default_memory_arena_cfg 的内容覆盖(如果指定)

gpu_external_[alloc|free|empty_cache]

gpu_external_* 用于传递外部分配器。 Python 用法示例

from onnxruntime.training.ortmodule.torch_cpp_extensions import torch_gpu_allocator

provider_option_map["gpu_external_alloc"] = str(torch_gpu_allocator.gpu_caching_allocator_raw_alloc_address())
provider_option_map["gpu_external_free"] = str(torch_gpu_allocator.gpu_caching_allocator_raw_delete_address())
provider_option_map["gpu_external_empty_cache"] = str(torch_gpu_allocator.gpu_caching_allocator_empty_cache_address())

默认值:0

用法

C/C++

Ort::Env env = Ort::Env{ORT_LOGGING_LEVEL_ERROR, "Default"};
Ort::SessionOptions so;
int device_id = 0;
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_ROCm(so, device_id));

C API 详细信息在此

Python

Python API 详细信息在此

示例

Python

import onnxruntime as ort

model_path = '<path to model>'

providers = [
    'ROCMExecutionProvider',
    'CPUExecutionProvider',
]

session = ort.InferenceSession(model_path, providers=providers)