Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit d824511

Browse files
added inference script
0 parents commit d824511

File tree

2 files changed

+296
-0
lines changed

2 files changed

+296
-0
lines changed

‎tflite_object_detection_with_video.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# based on https://github.com/tensorflow/examples/blob/master/lite/examples/object_detection/raspberry_pi/detect_picamera.py
2+
from tflite_runtime.interpreter import Interpreter, load_delegate
3+
import argparse
4+
import time
5+
import cv2
6+
import re
7+
from PIL import Image, ImageDraw, ImageFont
8+
import numpy as np
9+
10+
11+
def draw_image(image, results, labels, size):
12+
result_size = len(results)
13+
for idx, obj in enumerate(results):
14+
print(obj)
15+
# Prepare image for drawing
16+
draw = ImageDraw.Draw(image)
17+
18+
# Prepare boundary box
19+
ymin, xmin, ymax, xmax = obj['bounding_box']
20+
xmin = int(xmin * size[0])
21+
xmax = int(xmax * size[0])
22+
ymin = int(ymin * size[1])
23+
ymax = int(ymax * size[1])
24+
25+
# Draw rectangle to desired thickness
26+
for x in range( 0, 4 ):
27+
draw.rectangle((ymin, xmin, ymax, xmax), outline=(255, 255, 0))
28+
29+
# Annotate image with label and confidence score
30+
display_str = labels[obj['class_id']] + ": " + str(round(obj['score']*100, 2)) + "%"
31+
draw.text((box[0], box[1]), display_str, font=ImageFont.truetype("/usr/share/fonts/truetype/piboto/Piboto-Regular.ttf", 20))
32+
33+
displayImage = np.asarray( image )
34+
cv2.imshow('Coral Live Object Detection', displayImage)
35+
36+
37+
def load_labels(path):
38+
"""Loads the labels file. Supports files with or without index numbers."""
39+
with open(path, 'r', encoding='utf-8') as f:
40+
lines = f.readlines()
41+
labels = {}
42+
for row_number, content in enumerate(lines):
43+
pair = re.split(r'[:\s]+', content.strip(), maxsplit=1)
44+
if len(pair) == 2 and pair[0].strip().isdigit():
45+
labels[int(pair[0])] = pair[1].strip()
46+
else:
47+
labels[row_number] = pair[0].strip()
48+
return labels
49+
50+
51+
def set_input_tensor(interpreter, image):
52+
"""Sets the input tensor."""
53+
tensor_index = interpreter.get_input_details()[0]['index']
54+
input_tensor = interpreter.tensor(tensor_index)()[0]
55+
input_tensor[:, :] = image
56+
57+
58+
def get_output_tensor(interpreter, index):
59+
"""Returns the output tensor at the given index."""
60+
output_details = interpreter.get_output_details()[index]
61+
tensor = np.squeeze(interpreter.get_tensor(output_details['index']))
62+
return tensor
63+
64+
65+
def detect_objects(interpreter, image, threshold):
66+
"""Returns a list of detection results, each a dictionary of object info."""
67+
set_input_tensor(interpreter, image)
68+
interpreter.invoke()
69+
70+
# Get all output details
71+
boxes = get_output_tensor(interpreter, 0)
72+
classes = get_output_tensor(interpreter, 1)
73+
scores = get_output_tensor(interpreter, 2)
74+
count = int(get_output_tensor(interpreter, 3))
75+
76+
results = []
77+
for i in range(count):
78+
if scores[i] >= threshold:
79+
result = {
80+
'bounding_box': boxes[i],
81+
'class_id': classes[i],
82+
'score': scores[i]
83+
}
84+
results.append(result)
85+
return results
86+
87+
88+
def make_interpreter(model_file, use_edgetpu):
89+
model_file, *device = model_file.split('@')
90+
if use_edgetpu:
91+
return Interpreter(
92+
model_path=model_file,
93+
experimental_delegates=[
94+
load_delegate('libedgetpu.so.1',
95+
{'device': device[0]} if device else {})
96+
]
97+
)
98+
else:
99+
return Interpreter(model_path=model_file)
100+
101+
102+
def main():
103+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
104+
parser.add_argument('-m', '--model', type=str, required=True, help='File path of .tflite file.')
105+
parser.add_argument('-l', '--labels', type=str, required=True, help='File path of labels file.')
106+
parser.add_argument('-t', '--threshold', type=float, default=0.4, required=False, help='Score threshold for detected objects.')
107+
parser.add_argument('-v', '--video', type=str, required=True, help='Path to video')
108+
parser.add_argument('-e', '--use_edgetpu', action='store_true', default=False, help='Use EdgeTPU')
109+
args = parser.parse_args()
110+
111+
labels = load_labels(args.labels)
112+
interpreter = make_interpreter(args.model, args.use_edgetpu)
113+
interpreter.allocate_tensors()
114+
_, input_height, input_width, _ = interpreter.get_input_details()[0]['shape']
115+
116+
# Initialize video stream
117+
video = cv2.VideoCapture(video)
118+
time.sleep(1)
119+
120+
while(video.isOpened()):
121+
try:
122+
ret, frame = video.read()
123+
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
124+
frame_resized = cv2.resize(frame_rgb, (input_width, input_height))
125+
input_data = np.expand_dims(frame_resized, axis=0)
126+
# Perform inference
127+
results = detect_objects(interpreter, input_data, args.threshold)
128+
129+
draw_image(image, results, labels, image.size)
130+
131+
if( cv2.waitKey( 5 ) & 0xFF == ord( 'q' ) ):
132+
fps.stop()
133+
break
134+
except KeyboardInterrupt:
135+
break
136+
137+
cv2.destroyAllWindows()
138+
video.release()
139+
time.sleep(2)
140+
141+
142+
if __name__ == '__main__':
143+
main()

‎tflite_object_detection_with_webcam.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# based on https://github.com/tensorflow/examples/blob/master/lite/examples/object_detection/raspberry_pi/detect_picamera.py
2+
from imutils.video import VideoStream, FPS
3+
from tflite_runtime.interpreter import Interpreter, load_delegate
4+
import argparse
5+
import time
6+
import cv2
7+
import re
8+
from PIL import Image, ImageDraw, ImageFont
9+
import numpy as np
10+
11+
12+
def draw_image(image, results, labels, size):
13+
result_size = len(results)
14+
for idx, obj in enumerate(results):
15+
print(obj)
16+
# Prepare image for drawing
17+
draw = ImageDraw.Draw(image)
18+
19+
# Prepare boundary box
20+
ymin, xmin, ymax, xmax = obj['bounding_box']
21+
xmin = int(xmin * size[0])
22+
xmax = int(xmax * size[0])
23+
ymin = int(ymin * size[1])
24+
ymax = int(ymax * size[1])
25+
26+
# Draw rectangle to desired thickness
27+
for x in range( 0, 4 ):
28+
draw.rectangle((ymin, xmin, ymax, xmax), outline=(255, 255, 0))
29+
30+
# Annotate image with label and confidence score
31+
display_str = labels[obj['class_id']] + ": " + str(round(obj['score']*100, 2)) + "%"
32+
draw.text((box[0], box[1]), display_str, font=ImageFont.truetype("/usr/share/fonts/truetype/piboto/Piboto-Regular.ttf", 20))
33+
34+
displayImage = np.asarray( image )
35+
cv2.imshow('Coral Live Object Detection', displayImage)
36+
37+
38+
def load_labels(path):
39+
"""Loads the labels file. Supports files with or without index numbers."""
40+
with open(path, 'r', encoding='utf-8') as f:
41+
lines = f.readlines()
42+
labels = {}
43+
for row_number, content in enumerate(lines):
44+
pair = re.split(r'[:\s]+', content.strip(), maxsplit=1)
45+
if len(pair) == 2 and pair[0].strip().isdigit():
46+
labels[int(pair[0])] = pair[1].strip()
47+
else:
48+
labels[row_number] = pair[0].strip()
49+
return labels
50+
51+
52+
def set_input_tensor(interpreter, image):
53+
"""Sets the input tensor."""
54+
tensor_index = interpreter.get_input_details()[0]['index']
55+
input_tensor = interpreter.tensor(tensor_index)()[0]
56+
input_tensor[:, :] = image
57+
58+
59+
def get_output_tensor(interpreter, index):
60+
"""Returns the output tensor at the given index."""
61+
output_details = interpreter.get_output_details()[index]
62+
tensor = np.squeeze(interpreter.get_tensor(output_details['index']))
63+
return tensor
64+
65+
66+
def detect_objects(interpreter, image, threshold):
67+
"""Returns a list of detection results, each a dictionary of object info."""
68+
set_input_tensor(interpreter, image)
69+
interpreter.invoke()
70+
71+
# Get all output details
72+
boxes = get_output_tensor(interpreter, 0)
73+
classes = get_output_tensor(interpreter, 1)
74+
scores = get_output_tensor(interpreter, 2)
75+
count = int(get_output_tensor(interpreter, 3))
76+
77+
results = []
78+
for i in range(count):
79+
if scores[i] >= threshold:
80+
result = {
81+
'bounding_box': boxes[i],
82+
'class_id': classes[i],
83+
'score': scores[i]
84+
}
85+
results.append(result)
86+
return results
87+
88+
89+
def make_interpreter(model_file, use_edgetpu):
90+
model_file, *device = model_file.split('@')
91+
if use_edgetpu:
92+
return Interpreter(
93+
model_path=model_file,
94+
experimental_delegates=[
95+
load_delegate('libedgetpu.so.1',
96+
{'device': device[0]} if device else {})
97+
]
98+
)
99+
else:
100+
return Interpreter(model_path=model_file)
101+
102+
103+
def main():
104+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
105+
parser.add_argument('-m', '--model', type=str, required=True, help='File path of .tflite file.')
106+
parser.add_argument('-l', '--labels', type=str, required=True, help='File path of labels file.')
107+
parser.add_argument('-t', '--threshold', type=float, default=0.4, required=False, help='Score threshold for detected objects.')
108+
parser.add_argument('-p', '--picamera', action='store_true', default=False, help='Use PiCamera for image capture')
109+
parser.add_argument('-e', '--use_edgetpu', action='store_true', default=False, help='Use EdgeTPU')
110+
args = parser.parse_args()
111+
112+
labels = load_labels(args.labels)
113+
interpreter = make_interpreter(args.model, args.use_edgetpu)
114+
interpreter.allocate_tensors()
115+
_, input_height, input_width, _ = interpreter.get_input_details()[0]['shape']
116+
117+
# Initialize video stream
118+
vs = VideoStream(usePiCamera=args.picamera, resolution=(640, 480)).start()
119+
time.sleep(1)
120+
121+
fps = FPS().start()
122+
123+
while True:
124+
try:
125+
# Read frame from video
126+
screenshot = vs.read()
127+
image = Image.fromarray(screenshot)
128+
image_pred = image.resize((input_width ,input_height), Image.ANTIALIAS)
129+
130+
# Perform inference
131+
results = detect_objects(interpreter, image_pred, args.threshold)
132+
133+
draw_image(image, results, labels, image.size)
134+
135+
if( cv2.waitKey( 5 ) & 0xFF == ord( 'q' ) ):
136+
fps.stop()
137+
break
138+
139+
fps.update()
140+
except KeyboardInterrupt:
141+
fps.stop()
142+
break
143+
144+
print("Elapsed time: " + str(fps.elapsed()))
145+
print("Approx FPS: :" + str(fps.fps()))
146+
147+
cv2.destroyAllWindows()
148+
vs.stop()
149+
time.sleep(2)
150+
151+
152+
if __name__ == '__main__':
153+
main()

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /