Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 5bf8c1c

Browse files
Add Image Classifier
1 parent f5f7ce5 commit 5bf8c1c

File tree

4 files changed

+302
-2
lines changed

4 files changed

+302
-2
lines changed

‎README.md‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
### Find this project useful ? :heart:
1010
* Support it by clicking the :star: button on the upper right of this page. :v:
1111

12-
##[Check out Mindorks awesome open source projects here](https://mindorks.com/open-source-projects)
12+
[Check out Mindorks awesome open source projects here](https://mindorks.com/open-source-projects)
1313

1414
### License
1515
```

‎app/build.gradle‎

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,11 @@ android {
3636
}
3737

3838
dependencies {
39-
compile fileTree(dir: 'libs', include: ['*.jar'])
39+
compile fileTree(include: ['*.jar'], dir: 'libs')
4040
androidTestCompile('com.android.support.test.espresso:espresso-core:2.2.2', {
4141
exclude group: 'com.android.support', module: 'support-annotations'
4242
})
4343
compile 'com.android.support:appcompat-v7:25.2.0'
4444
testCompile 'junit:junit:4.12'
45+
compile files('libs/libandroid_tensorflow_inference_java.jar')
4546
}
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
/*
2+
* Copyright (C) 2017 MINDORKS NEXTGEN PRIVATE LIMITED
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.mindorks.tensorflowexample;
18+
19+
import android.graphics.RectF;
20+
21+
import java.util.List;
22+
23+
/**
24+
* Created by amitshekhar on 16/03/17.
25+
*/
26+
27+
/**
28+
* Generic interface for interacting with different recognition engines.
29+
*/
30+
public interface Classifier {
31+
/**
32+
* An immutable result returned by a Classifier describing what was recognized.
33+
*/
34+
public class Recognition {
35+
/**
36+
* A unique identifier for what has been recognized. Specific to the class, not the instance of
37+
* the object.
38+
*/
39+
private final String id;
40+
41+
/**
42+
* Display name for the recognition.
43+
*/
44+
private final String title;
45+
46+
/**
47+
* A sortable score for how good the recognition is relative to others. Higher should be better.
48+
*/
49+
private final Float confidence;
50+
51+
/**
52+
* Optional location within the source image for the location of the recognized object.
53+
*/
54+
private RectF location;
55+
56+
public Recognition(
57+
final String id, final String title, final Float confidence, final RectF location) {
58+
this.id = id;
59+
this.title = title;
60+
this.confidence = confidence;
61+
this.location = location;
62+
}
63+
64+
public String getId() {
65+
return id;
66+
}
67+
68+
public String getTitle() {
69+
return title;
70+
}
71+
72+
public Float getConfidence() {
73+
return confidence;
74+
}
75+
76+
public RectF getLocation() {
77+
return new RectF(location);
78+
}
79+
80+
public void setLocation(RectF location) {
81+
this.location = location;
82+
}
83+
84+
@Override
85+
public String toString() {
86+
String resultString = "";
87+
if (id != null) {
88+
resultString += "[" + id + "] ";
89+
}
90+
91+
if (title != null) {
92+
resultString += title + " ";
93+
}
94+
95+
if (confidence != null) {
96+
resultString += String.format("(%.1f%%) ", confidence * 100.0f);
97+
}
98+
99+
if (location != null) {
100+
resultString += location + " ";
101+
}
102+
103+
return resultString.trim();
104+
}
105+
}
106+
107+
List<Recognition> recognizeImage(float[] pixels);
108+
109+
void enableStatLogging(final boolean debug);
110+
111+
String getStatString();
112+
113+
void close();
114+
}
115+
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
/*
2+
* Copyright (C) 2017 MINDORKS NEXTGEN PRIVATE LIMITED
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.mindorks.tensorflowexample;
18+
19+
import android.content.res.AssetManager;
20+
import android.os.Trace;
21+
import android.util.Log;
22+
23+
import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
24+
25+
import java.io.BufferedReader;
26+
import java.io.IOException;
27+
import java.io.InputStreamReader;
28+
import java.util.ArrayList;
29+
import java.util.Comparator;
30+
import java.util.List;
31+
import java.util.PriorityQueue;
32+
import java.util.Vector;
33+
34+
/**
35+
* Created by amitshekhar on 16/03/17.
36+
*/
37+
38+
/**
39+
* A classifier specialized to label images using TensorFlow.
40+
*/
41+
public class TensorFlowImageClassifier implements Classifier {
42+
43+
private static final String TAG = "TensorFlowImageClassifier";
44+
45+
// Only return this many results with at least this confidence.
46+
private static final int MAX_RESULTS = 3;
47+
private static final float THRESHOLD = 0.1f;
48+
49+
// Config values.
50+
private String inputName;
51+
private String outputName;
52+
private int inputSize;
53+
54+
// Pre-allocated buffers.
55+
private Vector<String> labels = new Vector<String>();
56+
private float[] outputs;
57+
private String[] outputNames;
58+
59+
private TensorFlowInferenceInterface inferenceInterface;
60+
61+
private TensorFlowImageClassifier() {
62+
}
63+
64+
/**
65+
* Initializes a native TensorFlow session for classifying images.
66+
*
67+
* @param assetManager The asset manager to be used to load assets.
68+
* @param modelFilename The filepath of the model GraphDef protocol buffer.
69+
* @param labelFilename The filepath of label file for classes.
70+
* @param inputSize The input size. A square image of inputSize x inputSize is assumed.
71+
* @param inputName The label of the image input node.
72+
* @param outputName The label of the output node.
73+
* @throws IOException
74+
*/
75+
public static Classifier create(
76+
AssetManager assetManager,
77+
String modelFilename,
78+
String labelFilename,
79+
int inputSize,
80+
String inputName,
81+
String outputName)
82+
throws IOException {
83+
TensorFlowImageClassifier c = new TensorFlowImageClassifier();
84+
c.inputName = inputName;
85+
c.outputName = outputName;
86+
87+
// Read the label names into memory.
88+
// TODO(andrewharp): make this handle non-assets.
89+
String actualFilename = labelFilename.split("file:///android_asset/")[1];
90+
Log.i(TAG, "Reading labels from: " + actualFilename);
91+
BufferedReader br = null;
92+
br = new BufferedReader(new InputStreamReader(assetManager.open(actualFilename)));
93+
String line;
94+
while ((line = br.readLine()) != null) {
95+
c.labels.add(line);
96+
}
97+
br.close();
98+
99+
c.inferenceInterface = new TensorFlowInferenceInterface();
100+
if (c.inferenceInterface.initializeTensorFlow(assetManager, modelFilename) != 0) {
101+
throw new RuntimeException("TF initialization failed");
102+
}
103+
// The shape of the output is [N, NUM_CLASSES], where N is the batch size.
104+
int numClasses =
105+
(int) c.inferenceInterface.graph().operation(outputName).output(0).shape().size(1);
106+
Log.i(TAG, "Read " + c.labels.size() + " labels, output layer size is " + numClasses);
107+
108+
// Ideally, inputSize could have been retrieved from the shape of the input operation. Alas,
109+
// the placeholder node for input in the graphdef typically used does not specify a shape, so it
110+
// must be passed in as a parameter.
111+
c.inputSize = inputSize;
112+
113+
// Pre-allocate buffers.
114+
c.outputNames = new String[]{outputName};
115+
c.outputs = new float[numClasses];
116+
117+
return c;
118+
}
119+
120+
@Override
121+
public List<Recognition> recognizeImage(final float[] pixels) {
122+
// Log this method so that it can be analyzed with systrace.
123+
Trace.beginSection("recognizeImage");
124+
125+
// Copy the input data into TensorFlow.
126+
Trace.beginSection("fillNodeFloat");
127+
inferenceInterface.fillNodeFloat(
128+
inputName, new int[]{inputSize * inputSize}, pixels);
129+
Trace.endSection();
130+
131+
// Run the inference call.
132+
Trace.beginSection("runInference");
133+
inferenceInterface.runInference(outputNames);
134+
Trace.endSection();
135+
136+
// Copy the output Tensor back into the output array.
137+
Trace.beginSection("readNodeFloat");
138+
inferenceInterface.readNodeFloat(outputName, outputs);
139+
Trace.endSection();
140+
141+
// Find the best classifications.
142+
PriorityQueue<Recognition> pq =
143+
new PriorityQueue<Recognition>(
144+
3,
145+
new Comparator<Recognition>() {
146+
@Override
147+
public int compare(Recognition lhs, Recognition rhs) {
148+
// Intentionally reversed to put high confidence at the head of the queue.
149+
return Float.compare(rhs.getConfidence(), lhs.getConfidence());
150+
}
151+
});
152+
for (int i = 0; i < outputs.length; ++i) {
153+
if (outputs[i] > THRESHOLD) {
154+
pq.add(
155+
new Recognition(
156+
"" + i, labels.size() > i ? labels.get(i) : "unknown", outputs[i], null));
157+
}
158+
}
159+
final ArrayList<Recognition> recognitions = new ArrayList<Recognition>();
160+
int recognitionsSize = Math.min(pq.size(), MAX_RESULTS);
161+
for (int i = 0; i < recognitionsSize; ++i) {
162+
recognitions.add(pq.poll());
163+
}
164+
Trace.endSection(); // "recognizeImage"
165+
return recognitions;
166+
}
167+
168+
@Override
169+
public void enableStatLogging(boolean debug) {
170+
inferenceInterface.enableStatLogging(debug);
171+
}
172+
173+
@Override
174+
public String getStatString() {
175+
return inferenceInterface.getStatString();
176+
}
177+
178+
@Override
179+
public void close() {
180+
inferenceInterface.close();
181+
}
182+
}
183+
184+

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /