Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

YoloV9-rd Ckpt to Onnx #212

Rhroan started this conversation in Show and tell
Discussion options

#!/usr/bin/env python3
"""
2025年07月09日 = Test completed
export_yolov9_full_onnx.py

Export a self-contained ONNX model (yolov9_full.onnx),
including full post-processing (decode + NMS), so that it can run without PyTorch during inference.
"""

import torch
import torch.nn as nn
from hydra import initialize, compose
from yolo.model.yolo import create_model
from yolo import create_converter, PostProcess

NUM_CLASSES = 1 # Number of object classes used in training. Set to 1 for single-class; update accordingly for multi-class.
CKPT_PATH = "v9-t-epoch=40-map-map=0.8745.ckpt" # Path to the PyTorch checkpoint file.
ModeName = "v9-t" # Model name as defined in the Hydra config. Must match the one used during training.
ExportOnnxFile = "yolov9t_full.onnx" # Output ONNX file name.

class FullYoloExport(nn.Module):
def init(self, model, cfg):
super().init()
self.model = model.eval()
self.converter = create_converter(
cfg.model.name, model, cfg.model.anchor, cfg.image_size, device="cpu"
)
self.nms_cfg = cfg.task.nms

def forward(self, x):
 preds = self.model(x)
 heads = preds["Main"] if isinstance(preds, dict) else preds
 dets = PostProcess(self.converter, self.nms_cfg)({"Main": heads}, None)
 return torch.cat(dets, dim=0) # Output shape: [N, 6]

def export_full_onnx():
with initialize(config_path="../../yolo/config", version_base=None):
cfg = compose(
config_name="config.yaml",
overrides=[
"task=inference",
f"model={ModeName}",
f"dataset.class_num={NUM_CLASSES}",
],
)
model = create_model(cfg.model, class_num=NUM_CLASSES)
ckpt = torch.load(CKPT_PATH, map_location="cpu")
sd = ckpt.get("state_dict", ckpt)
model.load_state_dict(sd, strict=True)

wrapper = FullYoloExport(model, cfg).eval()
dummy = torch.randn(1, 3, cfg.image_size[0], cfg.image_size[1])
torch.onnx.export(
 wrapper,
 dummy,
 ExportOnnxFile,
 opset_version=16,
 input_names=["input"],
 output_names=["output"],
 dynamic_axes={
 "input": {0: "batch"}, # Use integer 0, not string "0"
 }
)
print("✅ Exported yolov9_full.onnx")

if name == "main":
export_full_onnx()

You must be logged in to vote

Replies: 0 comments

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
1 participant

AltStyle によって変換されたページ (->オリジナル) /