-
Notifications
You must be signed in to change notification settings - Fork 264
YoloV9-rd Ckpt to Onnx #212
-
#!/usr/bin/env python3
"""
2025年07月09日 = Test completed
export_yolov9_full_onnx.py
Export a self-contained ONNX model (yolov9_full.onnx),
including full post-processing (decode + NMS), so that it can run without PyTorch during inference.
"""
import torch
import torch.nn as nn
from hydra import initialize, compose
from yolo.model.yolo import create_model
from yolo import create_converter, PostProcess
NUM_CLASSES = 1 # Number of object classes used in training. Set to 1 for single-class; update accordingly for multi-class.
CKPT_PATH = "v9-t-epoch=40-map-map=0.8745.ckpt" # Path to the PyTorch checkpoint file.
ModeName = "v9-t" # Model name as defined in the Hydra config. Must match the one used during training.
ExportOnnxFile = "yolov9t_full.onnx" # Output ONNX file name.
class FullYoloExport(nn.Module):
def init(self, model, cfg):
super().init()
self.model = model.eval()
self.converter = create_converter(
cfg.model.name, model, cfg.model.anchor, cfg.image_size, device="cpu"
)
self.nms_cfg = cfg.task.nms
def forward(self, x):
preds = self.model(x)
heads = preds["Main"] if isinstance(preds, dict) else preds
dets = PostProcess(self.converter, self.nms_cfg)({"Main": heads}, None)
return torch.cat(dets, dim=0) # Output shape: [N, 6]
def export_full_onnx():
with initialize(config_path="../../yolo/config", version_base=None):
cfg = compose(
config_name="config.yaml",
overrides=[
"task=inference",
f"model={ModeName}",
f"dataset.class_num={NUM_CLASSES}",
],
)
model = create_model(cfg.model, class_num=NUM_CLASSES)
ckpt = torch.load(CKPT_PATH, map_location="cpu")
sd = ckpt.get("state_dict", ckpt)
model.load_state_dict(sd, strict=True)
wrapper = FullYoloExport(model, cfg).eval()
dummy = torch.randn(1, 3, cfg.image_size[0], cfg.image_size[1])
torch.onnx.export(
wrapper,
dummy,
ExportOnnxFile,
opset_version=16,
input_names=["input"],
output_names=["output"],
dynamic_axes={
"input": {0: "batch"}, # Use integer 0, not string "0"
}
)
print("✅ Exported yolov9_full.onnx")
if name == "main":
export_full_onnx()
Beta Was this translation helpful? Give feedback.