Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit e1f9597

Browse files
committed
tracker work paralell with detector, work var.
chsnge req in ultralytics lib ( because need for tracker spec ver of opencv )
1 parent 6572e3f commit e1f9597

File tree

1 file changed

+275
-0
lines changed

1 file changed

+275
-0
lines changed

‎custom.py

Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
import re
4+
import sys
5+
import copy
6+
import time
7+
import argparse
8+
9+
import cv2 as cv
10+
import numpy as np
11+
from ultralytics import YOLO
12+
13+
14+
def get_args():
15+
parser = argparse.ArgumentParser()
16+
17+
parser.add_argument("--device", default="sample_movie/bird.mp4")
18+
parser.add_argument("--width", help='cap width', type=int, default=960)
19+
parser.add_argument("--height", help='cap height', type=int, default=540)
20+
21+
parser.add_argument('--use_mil', action='store_true')
22+
parser.add_argument('--use_goturn', action='store_true')
23+
parser.add_argument('--use_dasiamrpn', action='store_true')
24+
parser.add_argument('--use_csrt', action='store_true')
25+
parser.add_argument('--use_kcf', action='store_true')
26+
parser.add_argument('--use_boosting', action='store_true')
27+
parser.add_argument('--use_mosse', action='store_true')
28+
parser.add_argument('--use_medianflow', action='store_true')
29+
parser.add_argument('--use_tld', action='store_true')
30+
parser.add_argument('--use_nano', action='store_true')
31+
parser.add_argument('--use_vit', action='store_true')
32+
33+
args = parser.parse_args()
34+
35+
return args
36+
37+
38+
def isint(s):
39+
p = '[-+]?\d+'
40+
return True if re.fullmatch(p, s) else False
41+
42+
43+
def detect_objects(frame, model):
44+
"""
45+
Object detection using YOLOv8.
46+
"""
47+
# Perform detection
48+
results = model(frame)
49+
50+
# Extract bounding boxes
51+
bboxes = []
52+
for result in results:
53+
for box in result.boxes:
54+
x1, y1, x2, y2 = box.xyxy[0] # Get the bounding box coordinates
55+
bboxes.append((int(x1), int(y1), int(x2 - x1), int(y2 - y1)))
56+
57+
return bboxes
58+
59+
60+
def initialize_tracker_list(window_name, image, tracker_algorithm_list, detected_bboxes):
61+
tracker_list = []
62+
63+
# Tracker list generation
64+
for tracker_algorithm in tracker_algorithm_list:
65+
for bbox in detected_bboxes:
66+
tracker = None
67+
if tracker_algorithm == 'MIL':
68+
tracker = cv.TrackerMIL_create()
69+
if tracker_algorithm == 'GOTURN':
70+
params = cv.TrackerGOTURN_Params()
71+
params.modelTxt = "model/GOTURN/goturn.prototxt"
72+
params.modelBin = "model/GOTURN/goturn.caffemodel"
73+
tracker = cv.TrackerGOTURN_create(params)
74+
if tracker_algorithm == 'DaSiamRPN':
75+
params = cv.TrackerDaSiamRPN_Params()
76+
params.model = "model/DaSiamRPN/dasiamrpn_model.onnx"
77+
params.kernel_r1 = "model/DaSiamRPN/dasiamrpn_kernel_r1.onnx"
78+
params.kernel_cls1 = "model/DaSiamRPN/dasiamrpn_kernel_cls1.onnx"
79+
tracker = cv.TrackerDaSiamRPN_create(params)
80+
if tracker_algorithm == 'Nano':
81+
params = cv.TrackerNano_Params()
82+
params.backbone = "model/nanotrackv2/nanotrack_backbone_sim.onnx"
83+
params.neckhead = "model/nanotrackv2/nanotrack_head_sim.onnx"
84+
tracker = cv.TrackerNano_create(params)
85+
if tracker_algorithm == 'Vit':
86+
params = cv.TrackerVit_Params()
87+
params.net = "model/vit/object_tracking_vittrack_2023sep.onnx"
88+
tracker = cv.TrackerVit_create(params)
89+
if tracker_algorithm == 'CSRT':
90+
tracker = cv.TrackerCSRT_create()
91+
if tracker_algorithm == 'KCF':
92+
tracker = cv.TrackerKCF_create()
93+
if tracker_algorithm == 'Boosting':
94+
tracker = cv.legacy_TrackerBoosting.create()
95+
if tracker_algorithm == 'MOSSE':
96+
tracker = cv.legacy_TrackerMOSSE.create()
97+
if tracker_algorithm == 'MedianFlow':
98+
tracker = cv.legacy_TrackerMedianFlow.create()
99+
if tracker_algorithm == 'TLD':
100+
tracker = cv.legacy_TrackerTLD.create()
101+
102+
if tracker is not None:
103+
tracker.init(image, bbox)
104+
tracker_list.append(tracker)
105+
106+
return tracker_list
107+
108+
109+
def main():
110+
color_list = [
111+
[255, 0, 0], # blue
112+
[255, 255, 0], # aqua
113+
[0, 255, 0], # lime
114+
[128, 0, 128], # purple
115+
[0, 0, 255], # red
116+
[255, 0, 255], # fuchsia
117+
[0, 128, 0], # green
118+
[128, 128, 0], # teal
119+
[0, 0, 128], # maroon
120+
[0, 128, 128], # olive
121+
[0, 255, 255], # yellow
122+
]
123+
124+
# Parse arguments ########################################################
125+
args = get_args()
126+
127+
cap_device = args.device
128+
cap_width = args.width
129+
cap_height = args.height
130+
131+
use_mil = args.use_mil
132+
use_goturn = args.use_goturn
133+
use_dasiamrpn = args.use_dasiamrpn
134+
use_csrt = args.use_csrt
135+
use_kcf = args.use_kcf
136+
use_boosting = args.use_boosting
137+
use_mosse = args.use_mosse
138+
use_medianflow = args.use_medianflow
139+
use_tld = args.use_tld
140+
use_nano = args.use_nano
141+
use_vit = args.use_vit
142+
143+
# Tracker algorithm selection ############################################
144+
tracker_algorithm_list = []
145+
if use_mil:
146+
tracker_algorithm_list.append('MIL')
147+
if use_goturn:
148+
tracker_algorithm_list.append('GOTURN')
149+
if use_dasiamrpn:
150+
tracker_algorithm_list.append('DaSiamRPN')
151+
if use_csrt:
152+
tracker_algorithm_list.append('CSRT')
153+
if use_kcf:
154+
tracker_algorithm_list.append('KCF')
155+
if use_boosting:
156+
tracker_algorithm_list.append('Boosting')
157+
if use_mosse:
158+
tracker_algorithm_list.append('MOSSE')
159+
if use_medianflow:
160+
tracker_algorithm_list.append('MedianFlow')
161+
if use_tld:
162+
tracker_algorithm_list.append('TLD')
163+
if use_nano:
164+
tracker_algorithm_list.append('Nano')
165+
if use_vit:
166+
tracker_algorithm_list.append('Vit')
167+
168+
if len(tracker_algorithm_list) == 0:
169+
tracker_algorithm_list.append('DaSiamRPN')
170+
print(tracker_algorithm_list)
171+
172+
# Camera setup ###########################################################
173+
if isint(cap_device):
174+
cap_device = int(cap_device)
175+
cap = cv.VideoCapture(cap_device)
176+
cap.set(cv.CAP_PROP_FRAME_WIDTH, cap_width)
177+
cap.set(cv.CAP_PROP_FRAME_HEIGHT, cap_height)
178+
179+
# Load YOLOv8 model ######################################################
180+
model = YOLO(r"D:\pycharm_projects\yolov8\runs\detect\drone_v9_300ep_32bath\weights\best.pt", task='detect') # Ensure you have the correct path to your YOLOv8 model
181+
182+
# Tracker initialization #################################################
183+
window_name = 'Tracker Demo'
184+
cv.namedWindow(window_name)
185+
186+
tracker_list = []
187+
detected_bboxes = []
188+
189+
while cap.isOpened():
190+
ret, image = cap.read()
191+
if not ret:
192+
break
193+
debug_image = copy.deepcopy(image)
194+
195+
# If no tracker is initialized, run detection until an object is found
196+
if not tracker_list:
197+
detected_bboxes = detect_objects(image, model)
198+
if detected_bboxes:
199+
tracker_list = initialize_tracker_list(window_name, image, tracker_algorithm_list, detected_bboxes)
200+
201+
elapsed_time_list = []
202+
tracker_scores = [] # Initialize a list to store tracker scores
203+
204+
for index, tracker in enumerate(tracker_list):
205+
# Update tracking
206+
start_time = time.time()
207+
ok, bbox = tracker.update(image)
208+
try:
209+
tracker_score = tracker.getTrackingScore()
210+
except:
211+
tracker_score = '-'
212+
213+
elapsed_time_list.append(time.time() - start_time)
214+
tracker_scores.append(tracker_score) # Append the score to the list
215+
216+
if ok:
217+
# Draw bounding box after tracking
218+
new_bbox = [
219+
int(bbox[0]),
220+
int(bbox[1]),
221+
int(bbox[2]),
222+
int(bbox[3])
223+
]
224+
cv.rectangle(debug_image,
225+
(new_bbox[0], new_bbox[1]),
226+
(new_bbox[0] + new_bbox[2], new_bbox[1] + new_bbox[3]),
227+
color_list[index % len(color_list)],
228+
thickness=2)
229+
else:
230+
# If tracking fails, reset trackers
231+
tracker_list = []
232+
break
233+
234+
# Display processing time and tracker scores for each algorithm
235+
for index, tracker_algorithm in enumerate(tracker_algorithm_list):
236+
if index < len(elapsed_time_list):
237+
elapsed_time_ms = elapsed_time_list[index] * 1000
238+
if index < len(tracker_scores):
239+
score = tracker_scores[index]
240+
if score != '-':
241+
text = f"{tracker_algorithm} : {elapsed_time_ms:.1f}ms Score:{score:.2f}"
242+
else:
243+
text = f"{tracker_algorithm} : {elapsed_time_ms:.1f}ms"
244+
else:
245+
text = f"{tracker_algorithm} : {elapsed_time_ms:.1f}ms"
246+
else:
247+
text = f"{tracker_algorithm} : N/A"
248+
249+
cv.putText(
250+
debug_image,
251+
text,
252+
(10, int(25 * (index + 1))),
253+
cv.FONT_HERSHEY_SIMPLEX,
254+
0.7,
255+
color_list[index % len(color_list)],
256+
2,
257+
cv.LINE_AA
258+
)
259+
260+
cv.imshow(window_name, debug_image)
261+
262+
k = cv.waitKey(1)
263+
if k == 32: # SPACE
264+
# Reinitialize trackers based on new selection
265+
detected_bboxes = detect_objects(image, model)
266+
tracker_list = initialize_tracker_list(window_name, image, tracker_algorithm_list, detected_bboxes)
267+
if k == 27: # ESC
268+
break
269+
270+
cap.release()
271+
cv.destroyAllWindows()
272+
273+
274+
if __name__ == '__main__':
275+
main()

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /