1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
| package main
import (
"fmt"
"image"
"image/color"
"math"
"os"
"sort"
"gocv.io/x/gocv"
ort "github.com/yalue/onnxruntime_go"
)
// ========== 配置 ==========
const (
ONNX_LIB_PATH = "/usr/local/lib/libonnxruntime.so.1.24.1"
MODEL_PATH = "yolo26n.onnx"
INPUT_SIZE = 640
CONF_THRESH = 0.25
IOU_THRESH = 0.45
NUM_CLASSES = 80
)
// COCO类别名称
var CLASS_NAMES = []string{
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat",
// ... 完整80类请参考附录
}
type Detection struct {
X1, Y1, X2, Y2 float32
Confidence float32
ClassID int
ClassName string
}
func main() {
// 1. 初始化ONNX Runtime
if err := InitONNXRuntime(ONNX_LIB_PATH); err != nil {
fmt.Printf("初始化ONNX Runtime失败: %v\n", err)
os.Exit(1)
}
defer ort.DestroyEnvironment()
// 2. 加载模型
session, err := NewModelSession(MODEL_PATH)
if err != nil {
fmt.Printf("加载模型失败: %v\n", err)
os.Exit(1)
}
defer session.Close()
fmt.Println("✅ 模型加载成功")
// 3. 读取图片
img := gocv.IMRead("test.jpg", gocv.IMReadColor)
if img.Empty() {
fmt.Println("无法读取图片")
os.Exit(1)
}
defer img.Close()
originalW := float32(img.Cols())
originalH := float32(img.Rows())
// 4. 图片预处理
PreprocessImage(img, session.Input)
// 5. 执行推理
if err := session.Session.Run(); err != nil {
fmt.Printf("推理失败: %v\n", err)
os.Exit(1)
}
// 6. 后处理解析结果
detections := PostProcess(session.Output.GetData(), originalW, originalH)
// 7. 绘制结果
DrawDetections(&img, detections)
gocv.IMWrite("result.jpg", img)
// 8. 打印结果
fmt.Printf("\n📊 检测结果: 共发现 %d 个目标\n", len(detections))
for i, det := range detections {
fmt.Printf("%2d. %-15s 置信度: %.3f 位置: [%.0f, %.0f, %.0f, %.0f]\n",
i+1, det.ClassName, det.Confidence, det.X1, det.Y1, det.X2, det.Y2)
}
}
// 图片预处理:BGR -> RGB, 归一化, NCHW格式
func PreprocessImage(img gocv.Mat, input *ort.Tensor[float32]) {
data := input.GetData()
channelSize := INPUT_SIZE * INPUT_SIZE
// 调整大小
resized := gocv.NewMat()
gocv.Resize(img, &resized, image.Pt(INPUT_SIZE, INPUT_SIZE), 0, 0, gocv.InterpolationLinear)
defer resized.Close()
// 转换为RGB并归一化
resized.ConvertTo(&resized, gocv.MatTypeCV32F)
resized.DivideFloat(255.0)
// BGR -> RGB + NCHW
for y := 0; y < INPUT_SIZE; y++ {
for x := 0; x < INPUT_SIZE; x++ {
pixel := resized.GetVecfAt(y, x)
idx := y*INPUT_SIZE + x
data[idx] = pixel[2] // R
data[idx+channelSize] = pixel[1] // G
data[idx+channelSize*2] = pixel[0] // B
}
}
}
// 后处理:解析输出 + NMS
func PostProcess(output []float32, imgW, imgH float32) []Detection {
var detections []Detection
scaleX := imgW / INPUT_SIZE
scaleY := imgH / INPUT_SIZE
// 遍历8400个检测框
for i := 0; i < 8400; i++ {
// 找最大置信度类别
maxConf := float32(0)
classID := 0
for c := 0; c < NUM_CLASSES; c++ {
conf := output[8400*(4+c)+i]
if conf > maxConf {
maxConf = conf
classID = c
}
}
if maxConf < CONF_THRESH {
continue
}
// 解析坐标 (cx, cy, w, h)
cx := output[i] * scaleX
cy := output[8400+i] * scaleY
w := output[8400*2+i] * scaleX
h := output[8400*3+i] * scaleY
detections = append(detections, Detection{
X1: cx - w/2,
Y1: cy - h/2,
X2: cx + w/2,
Y2: cy + h/2,
Confidence: maxConf,
ClassID: classID,
ClassName: CLASS_NAMES[classID],
})
}
// NMS非极大值抑制
return NMS(detections, IOU_THRESH)
}
func NMS(detections []Detection, iouThresh float32) []Detection {
if len(detections) == 0 {
return detections
}
// 按置信度降序排序
sort.Slice(detections, func(i, j int) bool {
return detections[i].Confidence > detections[j].Confidence
})
var keep []Detection
suppressed := make([]bool, len(detections))
for i := 0; i < len(detections); i++ {
if suppressed[i] {
continue
}
keep = append(keep, detections[i])
for j := i + 1; j < len(detections); j++ {
if suppressed[j] {
continue
}
if CalculateIOU(&detections[i], &detections[j]) > iouThresh {
suppressed[j] = true
}
}
}
return keep
}
func CalculateIOU(a, b *Detection) float32 {
x1 := max(a.X1, b.X1)
y1 := max(a.Y1, b.Y1)
x2 := min(a.X2, b.X2)
y2 := min(a.Y2, b.Y2)
if x2 <= x1 || y2 <= y1 {
return 0
}
intersection := (x2 - x1) * (y2 - y1)
areaA := (a.X2 - a.X1) * (a.Y2 - a.Y1)
areaB := (b.X2 - b.X1) * (b.Y2 - b.Y1)
union := areaA + areaB - intersection
return intersection / union
}
func DrawDetections(img *gocv.Mat, detections []Detection) {
colors := []color.RGBA{
{255, 0, 0, 0}, {0, 255, 0, 0}, {0, 0, 255, 0},
{255, 255, 0, 0}, {255, 0, 255, 0}, {0, 255, 255, 0},
}
for _, det := range detections {
c := colors[det.ClassID%len(colors)]
rect := image.Rect(int(det.X1), int(det.Y1), int(det.X2), int(det.Y2))
// 画框
gocv.Rectangle(img, rect, c, 2)
// 画标签背景
label := fmt.Sprintf("%s %.2f", det.ClassName, det.Confidence)
size := gocv.GetTextSize(label, gocv.FontHersheySimplex, 0.5, 1)
gocv.Rectangle(img, image.Rect(
int(det.X1), int(det.Y1)-size.Y-10,
int(det.X1)+size.X, int(det.Y1),
), c, -1)
// 画文字
gocv.PutText(img, label, image.Pt(int(det.X1), int(det.Y1)-5),
gocv.FontHersheySimplex, 0.5, color.RGBA{255, 255, 255, 0}, 1)
}
}
func max(a, b float32) float32 {
if a > b {
return a
}
return b
}
func min(a, b float32) float32 {
if a < b {
return a
}
return b
}
|