创建 Float16 和混合精度模型

将模型转换为使用 float16 而不是 float32 可以减小模型大小(最多一半)并提高某些 GPU 的性能。可能会有一些精度损失,但在许多模型中,新的精度是可以接受的。float16 转换不需要调整数据,这使其比量化更可取。

目录

Float16 转换

按照以下步骤将模型转换为 float16

  1. 安装 onnx 和 onnxconverter-common

    pip install onnx onnxconverter-common

  2. 在 python 中使用 convert_float_to_float16 函数。

     import onnx
     from onnxconverter_common import float16
    
     model = onnx.load("path/to/model.onnx")
     model_fp16 = float16.convert_float_to_float16(model)
     onnx.save(model_fp16, "path/to/model_fp16.onnx")
    

Float16 工具参数

如果转换后的模型无法工作或精度较差,您可能需要设置其他参数。

convert_float_to_float16(model, min_positive_val=1e-7, max_finite_val=1e4, keep_io_types=False,
                         disable_shape_infer=False, op_block_list=None, node_block_list=None)
  • model:要转换的 ONNX 模型。
  • min_positive_val, max_finite_val:常量值将被裁剪到这些边界。0.0naninf-inf 将保持不变。
  • keep_io_types:模型输入/输出是否应保留为 float32。
  • disable_shape_infer:跳过运行 onnx 形状/类型推断。如果形状推断崩溃、模型中已存在形状/类型或不需要类型(类型用于确定不支持/阻止的操作需要在哪里进行强制转换操作)时很有用。
  • op_block_list:保留为 float32 的操作类型列表。默认情况下使用 float16.DEFAULT_OP_BLOCK_LIST 中的列表。此列表包含 ONNX Runtime 中不支持 float16 的操作。
  • node_block_list:保留为 float32 的节点名称列表。

注意:阻止的操作将在其周围插入 float16/float32 之间的强制转换。目前,如果两个阻止的操作彼此相邻,则仍将插入强制转换,从而创建冗余对。ORT 将在运行时优化此对,因此结果将保持全精度。

混合精度

如果 float16 转换效果不佳,您可以将大多数操作转换为 float16,但保留一些操作为 float32。auto_mixed_precision.auto_convert_mixed_precision 工具会找到一组最小的操作来跳过,同时保持一定的精度水平。您需要为模型提供一个示例输入。

由于 ONNX Runtime 的 CPU 版本不支持 float16 操作,并且该工具需要测量精度损失,因此混合精度工具必须在具有 GPU 的设备上运行

from onnxconverter_common import auto_mixed_precision
import onnx

model = onnx.load("path/to/model.onnx")
# Assuming x is the input to the model
feed_dict = {'input': x.numpy()}
model_fp16 = auto_convert_mixed_precision(model, feed_dict, rtol=0.01, atol=0.001, keep_io_types=True)
onnx.save(model_fp16, "path/to/model_fp16.onnx")

混合精度工具参数

auto_convert_mixed_precision(model, feed_dict, validate_fn=None, rtol=None, atol=None, keep_io_types=False)
  • model:要转换的 ONNX 模型。
  • feed_dict:用于在转换期间测量模型精度的测试数据。格式类似于 InferenceSession.run(输入名称到值的映射)
  • validate_fn:一个接受两个 numpy 数组列表(分别是 float32 模型和混合精度模型的输出)的函数,如果结果足够接近,则返回 True,否则返回 False。可以代替或除了 rtolatol 使用。
  • rtol, atol:用于验证的绝对和相对容差。有关更多信息,请参见 numpy.allclose
  • keep_io_types:模型输入/输出是否应保留为 float32。

混合精度工具通过将操作集群转换为 float16 来工作。如果集群失败,它将被分成两半,并独立尝试两个集群。当工具工作时,会打印集群大小的可视化。