Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 33ec3a5

Browse files
authored
Add files via upload
Added 1. batch test script 2. check class imbalance script
1 parent 09d2742 commit 33ec3a5

File tree

2 files changed

+157
-0
lines changed

2 files changed

+157
-0
lines changed
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import numpy as np
2+
import tensorflow as tf
3+
from PIL import Image
4+
import matplotlib
5+
import os
6+
7+
# Force matplotlib to use 'Agg' backend (non-interactive)
8+
matplotlib.use('Agg')
9+
import matplotlib.pyplot as plt
10+
11+
def load_image(image_path):
12+
"""Load image using PIL"""
13+
try:
14+
return np.array(Image.open(image_path).convert('RGB'))
15+
except Exception as e:
16+
print(f"Image load error: {e}")
17+
return None
18+
19+
def run_inference(image_np, detection_graph):
20+
"""Run object detection"""
21+
with detection_graph.as_default():
22+
with tf.compat.v1.Session(graph=detection_graph) as sess:
23+
ops = detection_graph.get_operations()
24+
all_tensor_names = {output.name for op in ops for output in op.outputs}
25+
tensor_dict = {}
26+
for key in ['num_detections', 'detection_boxes', 'detection_scores', 'detection_classes']:
27+
tensor_name = key + ':0'
28+
if tensor_name in all_tensor_names:
29+
tensor_dict[key] = detection_graph.get_tensor_by_name(tensor_name)
30+
31+
output_dict = sess.run(tensor_dict,
32+
feed_dict={'image_tensor:0': np.expand_dims(image_np, axis=0)})
33+
34+
return (output_dict['detection_boxes'][0],
35+
output_dict['detection_scores'][0],
36+
output_dict['detection_classes'][0].astype(np.int32))
37+
38+
def save_visualization(image_np, boxes, scores, classes, label_map, output_path):
39+
"""Save detection results to file"""
40+
plt.figure(figsize=(12, 8))
41+
plt.imshow(image_np)
42+
ax = plt.gca()
43+
44+
height, width = image_np.shape[:2]
45+
46+
for i in range(min(20, len(scores))):
47+
if scores[i] > 0.5: # Confidence threshold
48+
box = boxes[i]
49+
y1, x1, y2, x2 = box
50+
y1, x1, y2, x2 = int(y1*height), int(x1*width), int(y2*height), int(x2*width)
51+
52+
rect = plt.Rectangle((x1, y1), x2-x1, y2-y1,
53+
fill=False, color='red', linewidth=2)
54+
ax.add_patch(rect)
55+
56+
label = f"{label_map.get(classes[i], str(classes[i]))}: {scores[i]:.2f}"
57+
plt.text(x1, y1-10, label, color='red', fontsize=10,
58+
bbox=dict(facecolor='white', alpha=0.7))
59+
60+
plt.axis('off')
61+
plt.savefig(output_path, bbox_inches='tight', dpi=300)
62+
plt.close()
63+
print(f"Saved results to {output_path}")
64+
65+
def process_single_image(image_path, output_dir, detection_graph, label_map):
66+
"""Process and save results for one image"""
67+
image_np = load_image(image_path)
68+
if image_np is None:
69+
return
70+
71+
# Create output filename
72+
base_name = os.path.basename(image_path)
73+
output_path = os.path.join(output_dir, f"detected_{base_name}")
74+
75+
# Run detection
76+
boxes, scores, classes = run_inference(image_np, detection_graph)
77+
78+
# Save results
79+
save_visualization(image_np, boxes, scores, classes, label_map, output_path)
80+
81+
def main():
82+
# Configuration
83+
MODEL_PATH = 'learn_pet/models/saved_model_640_4963/frozen_inference_graph.pb'
84+
INPUT_DIR = 'learn_pet/pet/images' # Directory containing images to process
85+
OUTPUT_DIR = 'learn_pet/eval' # Where to save results
86+
LABEL_MAP = {1: 'person', 2: 'car'} # Update with your classes
87+
88+
# Create output directory if needed
89+
os.makedirs(OUTPUT_DIR, exist_ok=True)
90+
91+
# Load model once
92+
detection_graph = tf.Graph()
93+
with detection_graph.as_default():
94+
od_graph_def = tf.compat.v1.GraphDef()
95+
with tf.io.gfile.GFile(MODEL_PATH, 'rb') as fid:
96+
od_graph_def.ParseFromString(fid.read())
97+
tf.import_graph_def(od_graph_def, name='')
98+
99+
# Process all images in input directory
100+
supported_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.tiff')
101+
processed_count = 0
102+
103+
for filename in os.listdir(INPUT_DIR):
104+
if filename.lower().endswith(supported_extensions):
105+
image_path = os.path.join(INPUT_DIR, filename)
106+
process_single_image(image_path, OUTPUT_DIR, detection_graph, LABEL_MAP)
107+
processed_count += 1
108+
109+
print(f"\nProcessing complete. Processed {processed_count} images.")
110+
print(f"Results saved to: {os.path.abspath(OUTPUT_DIR)}")
111+
112+
if __name__ == "__main__":
113+
main()
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import tensorflow as tf
2+
tf.compat.v1.enable_eager_execution() # Forces eager mode in TF 1.x
3+
from collections import defaultdict
4+
5+
# Define the feature description (should match how TFRecords were created)
6+
feature_description = {
7+
'image/height': tf.io.FixedLenFeature([], tf.int64),
8+
'image/width': tf.io.FixedLenFeature([], tf.int64),
9+
'image/filename': tf.io.FixedLenFeature([], tf.string),
10+
'image/source_id': tf.io.FixedLenFeature([], tf.string),
11+
'image/encoded': tf.io.FixedLenFeature([], tf.string),
12+
'image/format': tf.io.FixedLenFeature([], tf.string),
13+
'image/object/bbox/xmin': tf.io.VarLenFeature(tf.float32),
14+
'image/object/bbox/ymin': tf.io.VarLenFeature(tf.float32),
15+
'image/object/bbox/xmax': tf.io.VarLenFeature(tf.float32),
16+
'image/object/bbox/ymax': tf.io.VarLenFeature(tf.float32),
17+
'image/object/class/text': tf.io.VarLenFeature(tf.string), # Class names
18+
'image/object/class/label': tf.io.VarLenFeature(tf.int64), # Class IDs
19+
}
20+
21+
def parse_tfrecord(example_proto):
22+
return tf.io.parse_single_example(example_proto, feature_description)
23+
24+
# Count class occurrences
25+
class_counts = defaultdict(int)
26+
27+
# Path to your TFRecord file(s)
28+
tfrecord_paths = ["/tensorflow/models/research/learn_pet/pet/train.record"]
29+
30+
# Read and parse TFRecord
31+
raw_dataset = tf.data.TFRecordDataset(tfrecord_paths)
32+
parsed_dataset = raw_dataset.map(parse_tfrecord)
33+
34+
# Iterate through records and count classes
35+
for record in parsed_dataset:
36+
class_texts = record['image/object/class/text'].values.numpy() # Get class names
37+
for class_text in class_texts:
38+
class_name = class_text.decode('utf-8') # Convert bytes to string
39+
class_counts[class_name] += 1
40+
41+
# Print results
42+
print("Class distribution in TFRecords:")
43+
for class_name, count in class_counts.items():
44+
print(f"{class_name}: {count} objects")

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /