Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 1a1e624

Browse files
committed
new feature
1 parent 6853cd3 commit 1a1e624

File tree

3 files changed

+237
-0
lines changed

3 files changed

+237
-0
lines changed

‎custom_algo.py

Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
import re
4+
import sys
5+
import copy
6+
import time
7+
import argparse
8+
9+
import cv2 as cv
10+
from ultralytics import YOLO # YOLO import
11+
# print(cv.getBuildInformation())
12+
13+
14+
def get_args():
15+
parser = argparse.ArgumentParser()
16+
17+
parser.add_argument("--device", default="sample_movie/bird.mp4")
18+
parser.add_argument("--width", help='cap width', type=int, default=640)
19+
parser.add_argument("--height", help='cap height', type=int, default=360)
20+
21+
# Existing tracker options
22+
parser.add_argument('--use_mil', action='store_true')
23+
parser.add_argument('--use_goturn', action='store_true')
24+
parser.add_argument('--use_dasiamrpn', action='store_true')
25+
parser.add_argument('--use_csrt', action='store_true')
26+
parser.add_argument('--use_kcf', action='store_true')
27+
parser.add_argument('--use_boosting', action='store_true')
28+
parser.add_argument('--use_mosse', action='store_true')
29+
parser.add_argument('--use_medianflow', action='store_true')
30+
parser.add_argument('--use_tld', action='store_true')
31+
parser.add_argument('--use_nano', action='store_true')
32+
parser.add_argument('--use_vit', action='store_true')
33+
34+
# Add argument to enable YOLO detection
35+
parser.add_argument('--use_yolo', action='store_true', help='Use YOLO for object detection')
36+
37+
args = parser.parse_args()
38+
39+
return args
40+
41+
def isint(s):
42+
p = '[-+]?\d+'
43+
return True if re.fullmatch(p, s) else False
44+
45+
def create_tracker_by_name(tracker_algorithm):
46+
tracker = None
47+
if tracker_algorithm == 'MIL':
48+
tracker = cv.TrackerMIL_create()
49+
elif tracker_algorithm == 'GOTURN':
50+
params = cv.TrackerGOTURN_Params()
51+
params.modelTxt = "model/GOTURN/goturn.prototxt"
52+
params.modelBin = "model/GOTURN/goturn.caffemodel"
53+
tracker = cv.TrackerGOTURN_create(params)
54+
elif tracker_algorithm == 'DaSiamRPN':
55+
params = cv.TrackerDaSiamRPN_Params()
56+
params.model = "model/DaSiamRPN/dasiamrpn_model.onnx"
57+
params.kernel_r1 = "model/DaSiamRPN/dasiamrpn_kernel_r1.onnx"
58+
params.kernel_cls1 = "model/DaSiamRPN/dasiamrpn_kernel_cls1.onnx"
59+
tracker = cv.TrackerDaSiamRPN_create(params)
60+
elif tracker_algorithm == 'Nano':
61+
params = cv.TrackerNano_Params()
62+
params.backbone = "model/nanotrackv2/nanotrack_backbone_sim.onnx"
63+
params.neckhead = "model/nanotrackv2/nanotrack_head_sim.onnx"
64+
tracker = cv.TrackerNano_create(params)
65+
elif tracker_algorithm == 'Vit':
66+
params = cv.TrackerVit_Params()
67+
params.net = "model/vit/object_tracking_vittrack_2023sep.onnx"
68+
tracker = cv.TrackerVit_create(params)
69+
elif tracker_algorithm == 'CSRT':
70+
tracker = cv.TrackerCSRT_create()
71+
elif tracker_algorithm == 'KCF':
72+
tracker = cv.TrackerKCF_create()
73+
elif tracker_algorithm == 'Boosting':
74+
tracker = cv.legacy.TrackerBoosting_create()
75+
elif tracker_algorithm == 'MOSSE':
76+
tracker = cv.legacy.TrackerMOSSE_create()
77+
elif tracker_algorithm == 'MedianFlow':
78+
tracker = cv.legacy.TrackerMedianFlow_create()
79+
elif tracker_algorithm == 'TLD':
80+
tracker = cv.legacy.TrackerTLD_create()
81+
return tracker
82+
83+
def initialize_tracker_list(image, tracker_algorithm, bboxes):
84+
tracker_list = []
85+
for i, bbox in enumerate(bboxes):
86+
tracker = create_tracker_by_name(tracker_algorithm)
87+
if tracker is not None:
88+
tracker.init(image, bbox)
89+
tracker_list.append((tracker, tracker_algorithm))
90+
return tracker_list
91+
92+
def main():
93+
color_list = [
94+
[255, 0, 0], # Blue
95+
[0, 255, 0], # Green
96+
[0, 0, 255], # Red
97+
[255, 255, 0], # Cyan
98+
[255, 0, 255], # Magenta
99+
[0, 255, 255], # Yellow
100+
[128, 0, 128], # Purple
101+
[128, 128, 0], # Olive
102+
[0, 128, 128], # Teal
103+
[128, 0, 0], # Maroon
104+
]
105+
106+
# Parse arguments
107+
args = get_args()
108+
109+
cap_device = args.device
110+
cap_width = args.width
111+
cap_height = args.height
112+
113+
# Prepare tracker algorithm
114+
tracker_algorithm = None
115+
if args.use_mil:
116+
tracker_algorithm = 'MIL'
117+
elif args.use_goturn:
118+
tracker_algorithm = 'GOTURN'
119+
elif args.use_dasiamrpn:
120+
tracker_algorithm = 'DaSiamRPN'
121+
elif args.use_csrt:
122+
tracker_algorithm = 'CSRT'
123+
elif args.use_kcf:
124+
tracker_algorithm = 'KCF'
125+
elif args.use_boosting:
126+
tracker_algorithm = 'Boosting'
127+
elif args.use_mosse:
128+
tracker_algorithm = 'MOSSE'
129+
elif args.use_medianflow:
130+
tracker_algorithm = 'MedianFlow'
131+
elif args.use_tld:
132+
tracker_algorithm = 'TLD'
133+
elif args.use_nano:
134+
tracker_algorithm = 'Nano'
135+
elif args.use_vit:
136+
tracker_algorithm = 'Vit'
137+
138+
# If no tracker is specified, default to CSRT
139+
if tracker_algorithm is None:
140+
tracker_algorithm = 'CSRT'
141+
142+
use_yolo = args.use_yolo # New argument for YOLO
143+
144+
print("Tracker:", tracker_algorithm)
145+
print("Use YOLO:", use_yolo)
146+
147+
# Open video capture
148+
if isint(cap_device):
149+
cap_device = int(cap_device)
150+
cap = cv.VideoCapture(cap_device)
151+
cap.set(cv.CAP_PROP_FRAME_WIDTH, cap_width)
152+
cap.set(cv.CAP_PROP_FRAME_HEIGHT, cap_height)
153+
154+
# Initialize YOLO model if use_yolo is True
155+
if use_yolo:
156+
# yolo_model = YOLO('yolov8n.pt')
157+
yolo_model = YOLO("/home/artem-n/PycharmProjects/model/fly_last.pt") # You can choose a different model size
158+
159+
# Initialize trackers
160+
window_name = 'Object Detection and Tracking'
161+
cv.namedWindow(window_name)
162+
163+
ret, image = cap.read()
164+
if not ret:
165+
sys.exit("Can't read first frame")
166+
167+
bboxes = []
168+
if use_yolo:
169+
# Use YOLO to detect objects and initialize trackers
170+
results = yolo_model.predict(image, stream_buffer=False)
171+
for result in results:
172+
boxes = result.boxes
173+
for box in boxes:
174+
# Extract bounding box coordinates
175+
x1, y1, x2, y2 = box.xyxy[0]
176+
bbox = (int(x1), int(y1), int(x2 - x1), int(y2 - y1))
177+
bboxes.append(bbox)
178+
else:
179+
# Manually select ROI
180+
bbox = cv.selectROI(window_name, image)
181+
bboxes.append(bbox)
182+
183+
# Initialize tracker list
184+
tracker_list = initialize_tracker_list(image, tracker_algorithm, bboxes)
185+
186+
while cap.isOpened():
187+
ret, image = cap.read()
188+
if not ret:
189+
break
190+
debug_image = image.copy()
191+
192+
# Update trackers
193+
for index, (tracker, tracker_algorithm) in enumerate(tracker_list):
194+
ok, bbox = tracker.update(image)
195+
if ok:
196+
# Draw bounding box
197+
p1 = (int(bbox[0]), int(bbox[1]))
198+
p2 = (int(bbox[0] + bbox[2]), int(bbox[1] + bbox[3]))
199+
color = color_list[index % len(color_list)]
200+
cv.rectangle(debug_image, p1, p2, color, 2, 1)
201+
# Display tracker type on bounding box
202+
cv.putText(debug_image, f"{tracker_algorithm}", p1, cv.FONT_HERSHEY_SIMPLEX, 0.75, color, 2)
203+
else:
204+
# Tracking failure
205+
cv.putText(debug_image, "Tracking failure detected", (10, 80), cv.FONT_HERSHEY_SIMPLEX, 0.75, (0, 0, 255), 2)
206+
207+
cv.imshow(window_name, debug_image)
208+
209+
k = cv.waitKey(1)
210+
if k == 32: # SPACE
211+
# Re-initialize trackers
212+
ret, image = cap.read()
213+
if not ret:
214+
break
215+
bboxes = []
216+
if use_yolo:
217+
# Re-detect objects using YOLO
218+
results = yolo_model(image)
219+
for result in results:
220+
boxes = result.boxes
221+
for box in boxes:
222+
x1, y1, x2, y2 = box.xyxy[0]
223+
bbox = (int(x1), int(y1), int(x2 - x1), int(y2 - y1))
224+
bboxes.append(bbox)
225+
else:
226+
# Manually select ROI
227+
bbox = cv.selectROI(window_name, image)
228+
bboxes.append(bbox)
229+
tracker_list = initialize_tracker_list(image, tracker_algorithm, bboxes)
230+
elif k == 27: # ESC
231+
break
232+
233+
cap.release()
234+
cv.destroyAllWindows()
235+
236+
if __name__ == '__main__':
237+
main()

‎yolo11m.pt

38.8 MB
Binary file not shown.

‎yolov8n.pt

6.25 MB
Binary file not shown.

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /