1
+ import os
2
+ import sys
3
+ from multiprocessing import Value
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import pyautogui
8
+ import tensorflow as tf
9
+
10
+ cap = cv2 .VideoCapture (0 )
11
+
12
+ sys .path .append (".." )
13
+
14
+ from object_detection .utils import label_map_util
15
+
16
+ from object_detection .utils import visualization_utils as vis_util
17
+
18
+ # # Model preparation
19
+
20
+ # Path to frozen detection graph. This is the actual model that is used for the object detection.
21
+ PATH_TO_CKPT = 'snake/frozen_inference_graph.pb'
22
+
23
+ # List of the strings that is used to add correct label for each box.
24
+ PATH_TO_LABELS = os .path .join ('images/data' , 'object-detection.pbtxt' )
25
+
26
+ NUM_CLASSES = 4
27
+
28
+ # ## Load a (frozen) Tensorflow model into memory.
29
+ detection_graph = tf .Graph ()
30
+ with detection_graph .as_default ():
31
+ od_graph_def = tf .GraphDef ()
32
+ with tf .gfile .GFile (PATH_TO_CKPT , 'rb' ) as fid :
33
+ serialized_graph = fid .read ()
34
+ od_graph_def .ParseFromString (serialized_graph )
35
+ tf .import_graph_def (od_graph_def , name = '' )
36
+
37
+ # ## Loading label map
38
+ label_map = label_map_util .load_labelmap (PATH_TO_LABELS )
39
+ categories = label_map_util .convert_label_map_to_categories (label_map , max_num_classes = NUM_CLASSES ,
40
+ use_display_name = True )
41
+ category_index = label_map_util .create_category_index (categories )
42
+
43
+ with detection_graph .as_default ():
44
+ # from directkeys import PressKey, ReleaseKey, W
45
+
46
+ # enter your monitor's resolution or use a library to fetch this - I had to hard code due to issues with
47
+ # dual monitor setup
48
+ x , y = 288 , 512
49
+
50
+ # init process safe variables for workers
51
+ objectX , objectY = Value ('d' , 0.0 ), Value ('d' , 0.0 )
52
+ objectX_previous = None
53
+ objectY_previous = None
54
+ with tf .Session (graph = detection_graph ) as sess :
55
+ # Definite input and output Tensors for detection_graph
56
+ image_tensor = detection_graph .get_tensor_by_name ('image_tensor:0' )
57
+ # Each box represents a part of the image where a particular object was detected.
58
+ detection_boxes = detection_graph .get_tensor_by_name ('detection_boxes:0' )
59
+ # Each score represent how level of confidence for each of the objects.
60
+ # Score is shown on the result image, together with the class label.
61
+ detection_scores = detection_graph .get_tensor_by_name ('detection_scores:0' )
62
+ detection_classes = detection_graph .get_tensor_by_name ('detection_classes:0' )
63
+ num_detections = detection_graph .get_tensor_by_name ('num_detections:0' )
64
+ while True :
65
+ ret , image_np = cap .read ()
66
+ # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
67
+ image_np_expanded = np .expand_dims (image_np , axis = 0 )
68
+ # Actual detection.
69
+ (boxes , scores , classes , num ) = sess .run (
70
+ [detection_boxes , detection_scores , detection_classes , num_detections ],
71
+ feed_dict = {image_tensor : image_np_expanded })
72
+ # Visualization of the results of a detection.
73
+ vis_util .visualize_boxes_and_labels_on_image_array (
74
+ image_np ,
75
+ np .squeeze (boxes ),
76
+ np .squeeze (classes ).astype (np .int32 ),
77
+ np .squeeze (scores ),
78
+ category_index ,
79
+ use_normalized_coordinates = True ,
80
+ line_thickness = 8 )
81
+ cv2 .imshow ('controls detection' , image_np )
82
+ if cv2 .waitKey (50 ) & amp ; 0xFF == ord ('q' ):
83
+ cv2 .destroyAllWindows ()
84
+ break
85
+
86
+ '''MOVE'''
87
+ # press 'w' if bounding box of finger detected
88
+ objects = np .where (classes [0 ] == 1 )[0 ]
89
+
90
+ # calculate center of box if detection exceeds threshold
91
+ if len (objects ) > 0 and scores [0 ][objects ][0 ] > 0.15 :
92
+ pyautogui .press ('up' )
93
+
94
+ objects = np .where (classes [0 ] == 2 )[0 ]
95
+
96
+ # calculate center of box if detection exceeds threshold
97
+ if len (objects ) > 0 and scores [0 ][objects ][0 ] > 0.15 :
98
+ pyautogui .press ('down' )
99
+
100
+ objects = np .where (classes [0 ] == 3 )[0 ]
101
+
102
+ # calculate center of box if detection exceeds threshold
103
+ if len (objects ) > 0 and scores [0 ][objects ][0 ] > 0.15 :
104
+ pyautogui .press ('left' )
105
+
106
+ objects = np .where (classes [0 ] == 4 )[0 ]
107
+
108
+ # calculate center of box if detection exceeds threshold
109
+ if len (objects ) > 0 and scores [0 ][objects ][0 ] > 0.15 :
110
+ pyautogui .press ('right' )
111
+
112
+ cap .release ()
0 commit comments