1
+ # based on https://github.com/tensorflow/examples/blob/master/lite/examples/object_detection/raspberry_pi/detect_picamera.py
2
+ from imutils .video import VideoStream , FPS
3
+ from tflite_runtime .interpreter import Interpreter , load_delegate
4
+ import argparse
5
+ import time
6
+ import cv2
7
+ import re
8
+ from PIL import Image , ImageDraw , ImageFont
9
+ import numpy as np
10
+
11
+
12
+ def draw_image (image , results , labels , size ):
13
+ result_size = len (results )
14
+ for idx , obj in enumerate (results ):
15
+ print (obj )
16
+ # Prepare image for drawing
17
+ draw = ImageDraw .Draw (image )
18
+
19
+ # Prepare boundary box
20
+ ymin , xmin , ymax , xmax = obj ['bounding_box' ]
21
+ xmin = int (xmin * size [0 ])
22
+ xmax = int (xmax * size [0 ])
23
+ ymin = int (ymin * size [1 ])
24
+ ymax = int (ymax * size [1 ])
25
+
26
+ # Draw rectangle to desired thickness
27
+ for x in range ( 0 , 4 ):
28
+ draw .rectangle ((ymin , xmin , ymax , xmax ), outline = (255 , 255 , 0 ))
29
+
30
+ # Annotate image with label and confidence score
31
+ display_str = labels [obj ['class_id' ]] + ": " + str (round (obj ['score' ]* 100 , 2 )) + "%"
32
+ draw .text ((box [0 ], box [1 ]), display_str , font = ImageFont .truetype ("/usr/share/fonts/truetype/piboto/Piboto-Regular.ttf" , 20 ))
33
+
34
+ displayImage = np .asarray ( image )
35
+ cv2 .imshow ('Coral Live Object Detection' , displayImage )
36
+
37
+
38
+ def load_labels (path ):
39
+ """Loads the labels file. Supports files with or without index numbers."""
40
+ with open (path , 'r' , encoding = 'utf-8' ) as f :
41
+ lines = f .readlines ()
42
+ labels = {}
43
+ for row_number , content in enumerate (lines ):
44
+ pair = re .split (r'[:\s]+' , content .strip (), maxsplit = 1 )
45
+ if len (pair ) == 2 and pair [0 ].strip ().isdigit ():
46
+ labels [int (pair [0 ])] = pair [1 ].strip ()
47
+ else :
48
+ labels [row_number ] = pair [0 ].strip ()
49
+ return labels
50
+
51
+
52
+ def set_input_tensor (interpreter , image ):
53
+ """Sets the input tensor."""
54
+ tensor_index = interpreter .get_input_details ()[0 ]['index' ]
55
+ input_tensor = interpreter .tensor (tensor_index )()[0 ]
56
+ input_tensor [:, :] = image
57
+
58
+
59
+ def get_output_tensor (interpreter , index ):
60
+ """Returns the output tensor at the given index."""
61
+ output_details = interpreter .get_output_details ()[index ]
62
+ tensor = np .squeeze (interpreter .get_tensor (output_details ['index' ]))
63
+ return tensor
64
+
65
+
66
+ def detect_objects (interpreter , image , threshold ):
67
+ """Returns a list of detection results, each a dictionary of object info."""
68
+ set_input_tensor (interpreter , image )
69
+ interpreter .invoke ()
70
+
71
+ # Get all output details
72
+ boxes = get_output_tensor (interpreter , 0 )
73
+ classes = get_output_tensor (interpreter , 1 )
74
+ scores = get_output_tensor (interpreter , 2 )
75
+ count = int (get_output_tensor (interpreter , 3 ))
76
+
77
+ results = []
78
+ for i in range (count ):
79
+ if scores [i ] >= threshold :
80
+ result = {
81
+ 'bounding_box' : boxes [i ],
82
+ 'class_id' : classes [i ],
83
+ 'score' : scores [i ]
84
+ }
85
+ results .append (result )
86
+ return results
87
+
88
+
89
+ def make_interpreter (model_file , use_edgetpu ):
90
+ model_file , * device = model_file .split ('@' )
91
+ if use_edgetpu :
92
+ return Interpreter (
93
+ model_path = model_file ,
94
+ experimental_delegates = [
95
+ load_delegate ('libedgetpu.so.1' ,
96
+ {'device' : device [0 ]} if device else {})
97
+ ]
98
+ )
99
+ else :
100
+ return Interpreter (model_path = model_file )
101
+
102
+
103
+ def main ():
104
+ parser = argparse .ArgumentParser (formatter_class = argparse .ArgumentDefaultsHelpFormatter )
105
+ parser .add_argument ('-m' , '--model' , type = str , required = True , help = 'File path of .tflite file.' )
106
+ parser .add_argument ('-l' , '--labels' , type = str , required = True , help = 'File path of labels file.' )
107
+ parser .add_argument ('-t' , '--threshold' , type = float , default = 0.4 , required = False , help = 'Score threshold for detected objects.' )
108
+ parser .add_argument ('-p' , '--picamera' , action = 'store_true' , default = False , help = 'Use PiCamera for image capture' )
109
+ parser .add_argument ('-e' , '--use_edgetpu' , action = 'store_true' , default = False , help = 'Use EdgeTPU' )
110
+ args = parser .parse_args ()
111
+
112
+ labels = load_labels (args .labels )
113
+ interpreter = make_interpreter (args .model , args .use_edgetpu )
114
+ interpreter .allocate_tensors ()
115
+ _ , input_height , input_width , _ = interpreter .get_input_details ()[0 ]['shape' ]
116
+
117
+ # Initialize video stream
118
+ vs = VideoStream (usePiCamera = args .picamera , resolution = (640 , 480 )).start ()
119
+ time .sleep (1 )
120
+
121
+ fps = FPS ().start ()
122
+
123
+ while True :
124
+ try :
125
+ # Read frame from video
126
+ screenshot = vs .read ()
127
+ image = Image .fromarray (screenshot )
128
+ image_pred = image .resize ((input_width ,input_height ), Image .ANTIALIAS )
129
+
130
+ # Perform inference
131
+ results = detect_objects (interpreter , image_pred , args .threshold )
132
+
133
+ draw_image (image , results , labels , image .size )
134
+
135
+ if ( cv2 .waitKey ( 5 ) & 0xFF == ord ( 'q' ) ):
136
+ fps .stop ()
137
+ break
138
+
139
+ fps .update ()
140
+ except KeyboardInterrupt :
141
+ fps .stop ()
142
+ break
143
+
144
+ print ("Elapsed time: " + str (fps .elapsed ()))
145
+ print ("Approx FPS: :" + str (fps .fps ()))
146
+
147
+ cv2 .destroyAllWindows ()
148
+ vs .stop ()
149
+ time .sleep (2 )
150
+
151
+
152
+ if __name__ == '__main__' :
153
+ main ()
0 commit comments