描述
PaddleX 当前为绕开 ROCm/HIP 上 BF16 算子缺失,在 PaddleOCR-VL 上保留两类 workaround,且这两类 workaround 来自不同的根因,需要 Paddle 框架修复后整体移除:
Workaround 1:_keep_in_fp32_modules = ["visual", "mlp_AR"]
paddlex/inference/models/doc_vlm/modeling/paddleocr_vl/_paddleocr_vl.py:70,把 SigLIP 视觉编码器与多模态 projector 强制保持 FP32。注释:「Keep visual encoder in fp32 for ROCm stability (MIOpen bf16 conv has bugs)」。
与已有 issue / PR 的关系:
实际跑 PaddleOCR-VL-1.5 端到端 BF16 推理发现:单靠 #78587 注册 BF16 conv 内核还不够——SigLIP 视觉编码器的 BF16 layer_norm 与 BF16 softmax 也会崩。也就是说在没有补齐 layer_norm / softmax 之前,无法把 _keep_in_fp32_modules 安全地改成 None。
Workaround 2:runner.py 中 4 处 paddle.is_compiled_with_rocm() 的 delete_pass
paddlex/inference/models/runners/paddle_static/runner.py 行 406-408、462-464、496-498、505-507,每处都是:
if paddle.is_compiled_with_rocm():
config.delete_pass("conv2d_add_act_fuse_pass")
config.delete_pass("conv2d_add_fuse_pass")
与现有 issue / PR 的关系:#5076 / #5077 / #5081 都没有覆盖 runner.py 这部分。两个 PIR pass 把 conv2d + add[+ act] 改写成 fused_conv2d_add_act —— 这是与 conv2d / conv3d 不同的算子,#78587 注册的 BF16 conv 内核不影响 fused_conv2d_add_act,所以这 4 处 delete_pass 在 #78587 合入后依然必要。
影响
- 视觉塔 FP32 → 显存近乎翻倍,限制可部署模型规模与批大小。
- 静态图 4 处
delete_pass → CUDA 与 ROCm 推理路径需要分叉维护。
期望
Paddle 框架补齐剩余的 HIP BF16 缺口(PaddlePaddle/Paddle#78710 / PR PaddlePaddle/Paddle#78711:BF16 layer_norm 注册、BF16 softmax 走矩阵 kernel、conv2d_add[_act]_fuse_pass 在 HIP 上不再注册)后,PaddleX 应:
验证
MI300X (gfx942) / ROCm 7.2 / PaddleOCR-VL-1.5:移除上述 workaround 后端到端 BF16 推理输出与 FP32-fallback 路径语义一致;GPU kernel 总耗时由 4 415.7 ms 降到 3 915.5 ms(1.13×),FP32 GEMM 调用由 18 756 → 1 316。完整 rocprofv3 数据见 PaddlePaddle/Paddle#78711 PR 描述里附的 BF16 benchmark。
修复 PR
#5096(依赖 PaddlePaddle/Paddle#78711 与 PaddlePaddle/Paddle#78587 都合入并发版后才能合)。
描述
PaddleX 当前为绕开 ROCm/HIP 上 BF16 算子缺失,在 PaddleOCR-VL 上保留两类 workaround,且这两类 workaround 来自不同的根因,需要 Paddle 框架修复后整体移除:
Workaround 1:
_keep_in_fp32_modules = ["visual", "mlp_AR"]paddlex/inference/models/doc_vlm/modeling/paddleocr_vl/_paddleocr_vl.py:70,把 SigLIP 视觉编码器与多模态 projector 强制保持 FP32。注释:「Keep visual encoder in fp32 for ROCm stability (MIOpen bf16 conv has bugs)」。Workaround 2:
runner.py中 4 处paddle.is_compiled_with_rocm()的delete_passpaddlex/inference/models/runners/paddle_static/runner.py行 406-408、462-464、496-498、505-507,每处都是:影响
delete_pass→ CUDA 与 ROCm 推理路径需要分叉维护。期望
Paddle 框架补齐剩余的 HIP BF16 缺口(PaddlePaddle/Paddle#78710 / PR PaddlePaddle/Paddle#78711:BF16
layer_norm注册、BF16 softmax 走矩阵 kernel、conv2d_add[_act]_fuse_pass在 HIP 上不再注册)后,PaddleX 应:_keep_in_fp32_modules(与 fix(doc_vlm): remove ROCm BF16 _keep_in_fp32_modules workaround in PaddleOCR-VL #5077 重叠,若 fix(doc_vlm): remove ROCm BF16 _keep_in_fp32_modules workaround in PaddleOCR-VL #5077 先合则只剩本 issue 的 Workaround 2);delete_pass调用(PaddleOCR-VL: Remove ROCm BF16 _keep_in_fp32_modules workaround #5076 / fix(doc_vlm): remove ROCm BF16 _keep_in_fp32_modules workaround in PaddleOCR-VL #5077 / [ROCm] Remove BF16 workaround now that Paddle framework supports HIP BF16 #5081 均未覆盖)。验证
MI300X (gfx942) / ROCm 7.2 / PaddleOCR-VL-1.5:移除上述 workaround 后端到端 BF16 推理输出与 FP32-fallback 路径语义一致;GPU kernel 总耗时由 4 415.7 ms 降到 3 915.5 ms(1.13×),FP32 GEMM 调用由 18 756 → 1 316。完整 rocprofv3 数据见 PaddlePaddle/Paddle#78711 PR 描述里附的 BF16 benchmark。
修复 PR
#5096(依赖 PaddlePaddle/Paddle#78711 与 PaddlePaddle/Paddle#78587 都合入并发版后才能合)。