导出带有特征图shape的onnx

less than 1 minute read

Published:

import onnx
import onnxsim
from onnx import shape_inference

onnx_path = "restormer.onnx"
torch.onnx.export(
    model_restoration,        # 模型
    input_,                   # 输入
    onnx_path,                # 导出路径
    opset_version=11,         # ONNX Opset版本
    input_names=["input"],    # 输入名称
    output_names=["output"],  # 输出名称
    dynamic_axes={            # 允许动态输入大小
        "input": {2: "height", 3: "width"},
        "output": {2: "height", 3: "width"},
    }
)
print(f"Model exported to {onnx_path}")

# 3. 简化 ONNX 模型
simplified_onnx_path = "restormer_simplified.onnx"
model = onnx.load(onnx_path)
model_simplified, check = onnxsim.simplify(model)
if check:
    onnx.save(model_simplified, simplified_onnx_path)
    print(f"Simplified ONNX model saved to {simplified_onnx_path}")
else:
    print("Simplification failed!")

# 4. 使用 ONNX Shape Inference 推导 Feature Map 的 Shape
inferred_model = shape_inference.infer_shapes(onnx.load(simplified_onnx_path))
onnx.save(inferred_model, "restormer_inferred.onnx")
print("Shape inference completed and model saved as 'restormer_inferred.onnx'")

# 5. 打印推导出的中间层 Feature Map 的 Shape
for node in inferred_model.graph.value_info:
    name = node.name
    shape = [
        dim.dim_value if dim.dim_value != 0 else "dynamic"
        for dim in node.type.tensor_type.shape.dim
    ]
    print(f"Feature Map Name: {name}, Shape: {shape}")