Skip to content

Commit e81ffaa

Browse files
committed
big refactoring, added ImageClassifier class
1 parent 13b9efd commit e81ffaa

File tree

12 files changed

+908
-671
lines changed

12 files changed

+908
-671
lines changed

example-basic/src/example-basic.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class ofApp : public ofBaseApp{
1818
public:
1919

2020
// main interface to everything tensorflow
21-
ofxMSATensorFlow msa_tf;
21+
msa::tf::ofxMSATensorFlow msa_tf;
2222

2323
// input tensors
2424
tensorflow::Tensor a, b;

example-inception3/src/example-inception3.cpp

Lines changed: 91 additions & 176 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
/*
2-
* Image recognition using Google's Inception network
2+
* Image recognition using Google's Inception v3 network
33
* based on https://www.tensorflow.org/versions/master/tutorials/image_recognition/index.html
44
*
5-
*
65
* Uses pre-trained model https://storage.googleapis.com/download.tensorflow.org/models/inception_dec_2015.zip
76
*
87
* openFrameworks code loads and processes pre-trained model (i.e. makes calculations/predictions)
@@ -15,142 +14,38 @@
1514
#include "ofxMSATensorFlow.h"
1615

1716

18-
// input image dimensions dictated by trained model
19-
#define kInputWidth 299
20-
#define kInputHeight 299
21-
#define kInputSize (kInputWidth * kInputHeight)
22-
23-
24-
// we need to normalize the images before feeding into the network
25-
// from each pixel we subtract the mean and divide by variance
26-
// this is also dictated by the trained model
27-
#define kInputMean (128.0f/255.0f)
28-
#define kInputStd (128.0f/255.0f)
29-
30-
// model & labels files to load
31-
#define kModelPath "models/tensorflow_inception_graph.pb"
32-
#define kLabelsPath "models/imagenet_comp_graph_label_strings.txt"
33-
34-
35-
// every node in the network has a name
36-
// when passing in data to the network, or reading data back, we need to refer to the node by name
37-
// i.e. 'pass this data to node A', or 'read data back from node X'
38-
// these node names are specific to the architecture of the model
39-
#define kInputLayer "Mul"
40-
#define kOutputLayer "softmax"
41-
42-
43-
4417
//--------------------------------------------------------------
4518
// ofImage::load() (ie. Freeimage load) doesn't work with TensorFlow! (See README.md)
4619
// so I have to resort to this awful trick of loading raw image data 299x299 RGB
47-
static void loadImageRaw(string path, ofImage &img) {
20+
static void loadImageRaw(string path, ofImage &img, int w, int h) {
4821
ofFile file(path);
49-
img.setFromPixels((unsigned char*)file.readToBuffer().getData(), kInputWidth, kInputHeight, OF_IMAGE_COLOR);
50-
}
51-
52-
53-
54-
//--------------------------------------------------------------
55-
// Takes a file name, and loads a list of labels from it, one per line, and
56-
// returns a vector of the strings. It pads with empty strings so the length
57-
// of the result is a multiple of 16, because our model expects that.
58-
static bool ReadLabelsFile(string file_name, std::vector<string>* result) {
59-
std::ifstream file(file_name);
60-
if (!file) {
61-
ofLogError() <<"ReadLabelsFile: " << file_name << " not found.";
62-
return false;
63-
}
64-
65-
result->clear();
66-
string line;
67-
while (std::getline(file, line)) {
68-
result->push_back(line);
69-
}
70-
const int padding = 16;
71-
while (result->size() % padding) {
72-
result->emplace_back();
73-
}
74-
return true;
22+
img.setFromPixels((unsigned char*)file.readToBuffer().getData(), w, h, OF_IMAGE_COLOR);
7523
}
7624

7725

7826

79-
class ofApp : public ofBaseApp{
27+
class ofApp : public ofBaseApp {
8028
public:
8129

82-
// main interface to everything tensorflow
83-
ofxMSATensorFlow msa_tf;
84-
85-
// Tensor to hold input image which is fed into the network
86-
tensorflow::Tensor image_tensor;
87-
88-
// vector of Tensors to hold data coming back from the network
89-
// (it's a vector of Tensors, because that's how the API works)
90-
vector<tensorflow::Tensor> output_tensors;
30+
// classifies pixels
31+
// check the src of this class (ofxMSATFImageClassifier) to see how to do more generic stuff with ofxMSATensorFlow
32+
msa::tf::ImageClassifier classifier;
9133

9234
// for webcam input
9335
shared_ptr<ofVideoGrabber> video_grabber;
9436

95-
// contains input image to classify
96-
ofImage input_image;
97-
98-
// normalized float version of input image
99-
// keeping texture separate so it's not unnessecarily updated when it isn't needed
100-
ofFloatPixels processed_pix;
101-
ofTexture processed_tex;
102-
103-
// contains all labels
104-
vector<string> labels;
105-
10637
// folder of images to classify
10738
ofDirectory image_dir;
10839

109-
// contains classification information from last classification attempt
110-
vector<int> top_label_indices;
111-
vector<float> top_scores;
112-
113-
//---------------------------------------------------------
114-
// Load pixels into the network, get the results
115-
void classify(ofPixels &pix) {
116-
// convert from unsigned char pix to float pix
117-
processed_pix = pix;
118-
119-
// need to resize image to specific dimensions the model is expecting
120-
processed_pix.resize(kInputWidth, kInputHeight);
121-
122-
// pixelwise normalize image by subtracting the mean and dividing by variance (across entire dataset)
123-
// I could do this without iterating over the pixels, by setting up a TensorFlow Graph, but I can't be bothered, this is less code
124-
float* pix_data = processed_pix.getData();
125-
if(!pix_data) {
126-
ofLogError() << "Could not classify. pixel data is NULL";
127-
return;
128-
}
129-
for(int i=0; i<kInputSize*3; i++) pix_data[i] = (pix_data[i] - kInputMean) / kInputStd;
130-
131-
// make sure opengl texture is updated with new pixel info (needed for correct rendering)
132-
processed_tex.loadData(processed_pix);
133-
134-
// copy data from image into tensorflow's Tensor class
135-
ofxMSATensorFlow::pixelsToTensor(processed_pix, image_tensor);
136-
137-
// feed the data into the network, and request output
138-
// output_tensors don't need to be initialized or allocated. they will be filled once the network runs
139-
if( !msa_tf.run({ {kInputLayer, image_tensor } }, { kOutputLayer }, {}, &output_tensors) ) {
140-
ofLogError() << "Error during running. Check console for details." << endl;
141-
return;
142-
}
143-
144-
// the output from the network above is an array of probabilities for every single label
145-
// i.e. thousands of probabilities, we only want to the top few
146-
ofxMSATensorFlow::getTopScores(output_tensors[0], 6, top_label_indices, top_scores);
147-
}
148-
40+
// top scoring classes
41+
vector<int> top_label_indices; // contains top n label indices for input image
42+
vector<float> top_class_probs; // contains top n probabilities for current input image
14943

15044

15145
//--------------------------------------------------------------
15246
void loadNextImage() {
15347
static int file_index = 0;
48+
ofImage img;
15449

15550
// System load dialog doesn't work with tensorflow :(
15651
//auto o = ofSystemLoadDialog("Select image");
@@ -160,35 +55,48 @@ class ofApp : public ofBaseApp{
16055
//img.load("images/fanboy.jpg");
16156

16257
// resorting to awful raw data file load hack!
163-
loadImageRaw(image_dir.getPath(file_index), input_image);
164-
classify(input_image.getPixels());
58+
loadImageRaw(image_dir.getPath(file_index), img, 299, 299);
59+
60+
classify(img.getPixels());
16561
file_index = (file_index+1) % image_dir.getFiles().size();
16662
}
16763

16864

65+
//--------------------------------------------------------------
66+
void classify(const ofPixels& pix) {
67+
// classify pixels
68+
classifier.classify(pix);
69+
70+
msa::tf::getTopScores(classifier.getOutputTensors()[0], 6, top_label_indices, top_class_probs);
71+
}
72+
16973
//--------------------------------------------------------------
17074
void setup(){
17175
ofLogNotice() << "Initializing... ";
17276
ofBackground(0);
17377
ofSetVerticalSync(true);
174-
ofSetFrameRate(60);
78+
// ofSetFrameRate(60);
79+
80+
// initialize the image classifier, lots of params to setup
81+
// these settings are specific to the model
82+
msa::tf::ImageClassifier::Settings settings;
83+
settings.image_dims = { 299, 299, 3 };
84+
settings.itensor_dims = { 1, 299, 299, 3 };
85+
settings.model_path = "models/tensorflow_inception_graph.pb";
86+
settings.labels_path = "models/imagenet_comp_graph_label_strings.txt";
87+
settings.input_layer_name = "Mul";
88+
settings.output_layer_name = "softmax";
89+
settings.dropout_layer_name = "";
90+
settings.varconst_layer_suffix = "_VARHACK";
91+
settings.norm_mean = 128.0f/255.0f;
92+
settings.norm_stddev = 128.0f/255.0f;
93+
94+
// initialize classifier with these settings
95+
classifier.setup(settings);
17596

17697
// get a list of all images in the 'images' folder
17798
image_dir.listDir("images");
17899

179-
// Initialize tensorflow session, return if error
180-
if( !msa_tf.setup() ) return;
181-
182-
// Load graph (i.e. trained model) add to session, return if error
183-
if( !msa_tf.loadGraph(kModelPath) ) return;
184-
185-
// load text file containing labels (i.e. associating classification index with human readable text)
186-
if( !ReadLabelsFile(ofToDataPath(kLabelsPath), &labels) ) return;
187-
188-
// initialize input tensor dimensions
189-
// (not sure what the best way to do this was as there isn't an 'init' method, just a constructor)
190-
image_tensor = tensorflow::Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({ 1, kInputHeight, kInputWidth, 3 }));
191-
192100
// load first image to classify
193101
loadNextImage();
194102

@@ -197,66 +105,69 @@ class ofApp : public ofBaseApp{
197105

198106

199107
//--------------------------------------------------------------
200-
void update(){
201-
108+
void update() {
202109
// if video_grabber active,
203110
if(video_grabber) {
204111
// grab frame
205112
video_grabber->update();
206113

207114
if(video_grabber->isFrameNew()) {
208-
209-
// update input_image so it's drawn in the right place
210-
input_image.setFromPixels(video_grabber->getPixels());
211-
212115
// send to classification if keypressed
213-
if(ofGetKeyPressed(' ')) classify(input_image.getPixels());
116+
if(ofGetKeyPressed(' '))
117+
classify(video_grabber->getPixels());
214118
}
215119
}
216120
}
217121

122+
218123
//--------------------------------------------------------------
219-
void draw(){
220-
// draw input image if it's available
221-
float x = 0;
222-
if(input_image.isAllocated()) {
223-
input_image.draw(x, 0);
224-
x += input_image.getWidth();
225-
}
124+
void draw() {
125+
if(classifier.isReady()) {
126+
ofSetColor(255);
226127

227-
// draw processed image if it's available
228-
if(processed_tex.isAllocated()) {
229-
processed_tex.draw(x, 0);
230-
x += processed_tex.getWidth();
231-
}
128+
// if video grabber active, draw in bottom left corner
129+
if(video_grabber) video_grabber->draw(0, ofGetHeight() - 240, 320, 240);
232130

233-
x += 20;
234-
float w = ofGetWidth() - 400 - x;
235-
float y = 40;
236-
float bar_height = 35;
237131

238-
// iterate top scores and draw them
239-
for(int i=0; i<top_scores.size(); i++) {
240-
int label_index = top_label_indices[i];
241-
string label = labels[label_index];
242-
float p = top_scores[i]; // the score (i.e. probability, 0...1)
132+
float x = 0;
243133

244-
// draw full bar
245-
ofSetColor(ofLerp(50.0, 255.0, p), ofLerp(100.0, 0.0, p), ofLerp(150.0, 0.0, p));
246-
ofDrawRectangle(x, y, w * p, bar_height);
247-
ofSetColor(40);
134+
// draw input image
135+
classifier.getInputImage().draw(x, 0);
136+
x += classifier.getInputImage().getWidth();
248137

249-
// draw outline
250-
ofNoFill();
251-
ofDrawRectangle(x, y, w, bar_height);
252-
ofFill();
138+
// draw processed image
139+
classifier.getProcessedImage().draw(x, 0);
140+
x += classifier.getProcessedImage().getWidth();
253141

254-
// draw text
255-
ofSetColor(255);
256-
ofDrawBitmapString(label + " (" + ofToString(label_index) + "): " + ofToString(p,4), x + w + 10, y + 20);
257-
y += bar_height + 5;
258-
}
142+
x += 20;
143+
144+
float w = ofGetWidth() - 400 - x;
145+
float y = 40;
146+
float bar_height = 35;
147+
148+
149+
// iterate top scores and draw them
150+
for(int i=0; i<top_class_probs.size(); i++) {
151+
int label_index = top_label_indices[i];
152+
string label = classifier.getLabels()[label_index];
153+
float p = top_class_probs[i]; // the score (i.e. probability, 0...1)
259154

155+
// draw full bar
156+
ofSetColor(ofLerp(50.0, 255.0, p), ofLerp(100.0, 0.0, p), ofLerp(150.0, 0.0, p));
157+
ofDrawRectangle(x, y, w * p, bar_height);
158+
ofSetColor(40);
159+
160+
// draw outline
161+
ofNoFill();
162+
ofDrawRectangle(x, y, w, bar_height);
163+
ofFill();
164+
165+
// draw text
166+
ofSetColor(255);
167+
ofDrawBitmapString(label + " (" + ofToString(label_index) + "): " + ofToString(p,4), x + w + 10, y + 20);
168+
y += bar_height + 5;
169+
}
170+
}
260171

261172
ofSetColor(255);
262173
ofDrawBitmapString(ofToString(ofGetFrameRate()), ofGetWidth() - 100, 30);
@@ -265,7 +176,7 @@ class ofApp : public ofBaseApp{
265176
str_inst << "'l' to load image\n";
266177
str_inst << "or drag an image (must be raw, 299x299) onto the window\n";
267178
str_inst << "'v' to toggle video input";
268-
ofDrawBitmapString(str_inst.str(), 15, input_image.getHeight() + 30);
179+
ofDrawBitmapString(str_inst.str(), 15, classifier.getHeight() + 30);
269180
}
270181

271182

@@ -291,14 +202,18 @@ class ofApp : public ofBaseApp{
291202
void dragEvent(ofDragInfo dragInfo){
292203
if(dragInfo.files.empty()) return;
293204

205+
ofImage img;
206+
294207
string filePath = dragInfo.files[0];
295208
//img.load(filePath); // FreeImage doesn't work :(
296-
loadImageRaw(filePath, input_image);
297-
classify(input_image.getPixels());
209+
loadImageRaw(filePath, img, 299, 299);
210+
classify(img.getPixels());
298211
}
299212

300213
};
301214

215+
216+
302217
//========================================================================
303218
int main( ){
304219
ofSetupOpenGL(1200, 800, OF_WINDOW); // <-------- setup the GL context

0 commit comments

Comments
 (0)