@@ -26,6 +26,7 @@ def get_args():
2626 parser .add_argument ('--use_medianflow' , action = 'store_true' )
2727 parser .add_argument ('--use_tld' , action = 'store_true' )
2828 parser .add_argument ('--use_nano' , action = 'store_true' )
29+ parser .add_argument ('--use_vit' , action = 'store_true' )
2930
3031 args = parser .parse_args ()
3132
@@ -63,6 +64,10 @@ def initialize_tracker_list(window_name, image, tracker_algorithm_list):
6364 # params.backbone = "model/nanotrackv3/nanotrack_backbone_sim.onnx"
6465 # params.neckhead = "model/nanotrackv3/nanotrack_head_sim.onnx"
6566 tracker = cv .TrackerNano_create (params )
67+ if tracker_algorithm == 'Vit' :
68+ params = cv .TrackerVit_Params ()
69+ params .net = "model/vit/object_tracking_vittrack_2023sep.onnx"
70+ tracker = cv .TrackerVit_create (params )
6671 if tracker_algorithm == 'CSRT' :
6772 tracker = cv .TrackerCSRT_create ()
6873 if tracker_algorithm == 'KCF' :
@@ -125,6 +130,7 @@ def main():
125130 use_medianflow = args .use_medianflow
126131 use_tld = args .use_tld
127132 use_nano = args .use_nano
133+ use_vit = args .use_vit
128134
129135 # 使用アルゴリズム #########################################################
130136 tracker_algorithm_list = []
@@ -148,7 +154,9 @@ def main():
148154 tracker_algorithm_list .append ('TLD' )
149155 if use_nano :
150156 tracker_algorithm_list .append ('Nano' )
151- 157+ if use_vit :
158+ tracker_algorithm_list .append ('Vit' )
159+ 152160 if len (tracker_algorithm_list ) == 0 :
153161 tracker_algorithm_list .append ('DaSiamRPN' )
154162 print (tracker_algorithm_list )
0 commit comments