导出带有特征图shape的onnx
Published:
import onnx
import onnxsim
from onnx import shape_inference
onnx_path = "restormer.onnx"
torch.onnx.export(
model_restoration, # 模型
input_, # 输入
onnx_path, # 导出路径
opset_version=11, # ONNX Opset版本
input_names=["input"], # 输入名称
output_names=["output"], # 输出名称
dynamic_axes={ # 允许动态输入大小
"input": {2: "height", 3: "width"},
"output": {2: "height", 3: "width"},
}
)
print(f"Model exported to {onnx_path}")
# 3. 简化 ONNX 模型
simplified_onnx_path = "restormer_simplified.onnx"
model = onnx.load(onnx_path)
model_simplified, check = onnxsim.simplify(model)
if check:
onnx.save(model_simplified, simplified_onnx_path)
print(f"Simplified ONNX model saved to {simplified_onnx_path}")
else:
print("Simplification failed!")
# 4. 使用 ONNX Shape Inference 推导 Feature Map 的 Shape
inferred_model = shape_inference.infer_shapes(onnx.load(simplified_onnx_path))
onnx.save(inferred_model, "restormer_inferred.onnx")
print("Shape inference completed and model saved as 'restormer_inferred.onnx'")
# 5. 打印推导出的中间层 Feature Map 的 Shape
for node in inferred_model.graph.value_info:
name = node.name
shape = [
dim.dim_value if dim.dim_value != 0 else "dynamic"
for dim in node.type.tensor_type.shape.dim
]
print(f"Feature Map Name: {name}, Shape: {shape}")