package com.whyc.util;
|
|
import java.io.IOException;
|
import java.io.PrintStream;
|
import java.nio.charset.Charset;
|
import java.nio.file.Files;
|
import java.nio.file.Path;
|
import java.nio.file.Paths;
|
import java.util.Arrays;
|
import java.util.List;
|
|
import org.tensorflow.*;
|
import org.tensorflow.types.UInt8;
|
|
/** Sample use of the TensorFlow Java API to label images using a pre-trained model. */
|
public class LabelImage {
|
private static void printUsage(PrintStream s) {
|
final String url =
|
"https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip";
|
s.println(
|
"Java program that uses a pre-trained Inception model (http://arxiv.org/abs/1512.00567)");
|
s.println("to label JPEG images.");
|
s.println("TensorFlow version: " + TensorFlow.version());
|
s.println();
|
s.println("Usage: label_image <model dir> <image file>");
|
s.println();
|
s.println("Where:");
|
s.println("<model dir> is a directory containing the unzipped contents of the inception model");
|
s.println(" (from " + url + ")");
|
s.println("<image file> is the path to a JPEG image file");
|
}
|
|
public static void main(String[] args) {
|
test2(args);
|
|
}
|
|
private static void test2(String[] args){
|
String exportDir = "F:\\tensorflow\\retrain\\saved_model\\";
|
String imageFile = "F:\\tensorflow\\retrain\\test_images\\b1.jpg";
|
String labelDir = "F:\\tensorflow\\retrain\\";
|
SavedModelBundle model = SavedModelBundle.load(exportDir, "serve");
|
|
List<String> labels =
|
readAllLinesOrExit(Paths.get(labelDir, "output_labels.txt"));
|
//readAllLinesOrExit(Paths.get(modelDir, "imagenet_comp_graph_label_strings.txt"));
|
byte[] imageBytes = readAllBytesOrExit(Paths.get(imageFile));
|
|
//对输入的图片进行处理
|
Tensor<Float> image = constructAndExecuteGraphToNormalizeImage2(imageBytes);
|
float[] labelProbabilities = executeInceptionGraph2(model, image);
|
/*int bestLabelIdx = maxIndex(labelProbabilities);
|
System.out.println(
|
String.format("BEST MATCH: %s (%.2f%% likely)",
|
labels.get(bestLabelIdx),
|
labelProbabilities[bestLabelIdx] * 100f));*/
|
for (int i = 0; i < labelProbabilities.length; i++) {
|
System.out.println(
|
String.format("BEST MATCH: %s (%.2f%% likely)",
|
labels.get(i),
|
labelProbabilities[i] * 100f));
|
}
|
}
|
|
/**
|
* 原案例
|
* @param args
|
*/
|
private static void test(String[] args){
|
if (args.length != 2) {
|
printUsage(System.err);
|
System.exit(1);
|
}
|
String modelDir = args[0];
|
String imageFile = args[1];
|
|
//byte[] graphDef = readAllBytesOrExit(Paths.get(modelDir, "tensorflow_inception_graph.pb"));
|
//byte[] graphDef = readAllBytesOrExit(Paths.get(modelDir, "output_graph.pb"));
|
//byte[] graphDef = SavedModelBundle.load("F:\\tensorflow\\retrain\\saved_model3", "serve").metaGraphDef();
|
SavedModelBundle model = SavedModelBundle.load("F:\\tensorflow\\retrain\\saved_model3", "serve");
|
List<String> labels =
|
readAllLinesOrExit(Paths.get(modelDir, "output_labels.txt"));
|
//readAllLinesOrExit(Paths.get(modelDir, "imagenet_comp_graph_label_strings.txt"));
|
byte[] imageBytes = readAllBytesOrExit(Paths.get(imageFile));
|
|
try (Tensor<Float> image = constructAndExecuteGraphToNormalizeImage(imageBytes)) {
|
//try (Tensor<Float> image = constructAndExecuteGraphToNormalizeImage(model,imageBytes)) {
|
float[] labelProbabilities = executeInceptionGraph(model.metaGraphDef(),image);
|
int bestLabelIdx = maxIndex(labelProbabilities);
|
System.out.println(
|
String.format("BEST MATCH: %s (%.2f%% likely)",
|
labels.get(bestLabelIdx),
|
labelProbabilities[bestLabelIdx] * 100f));
|
}
|
}
|
|
private static Tensor<Float> constructAndExecuteGraphToNormalizeImage(byte[] imageBytes) {//把图片转换成inception需要模式
|
try (Graph g = new Graph()) {//创建一个空的构造方法
|
GraphBuilder b = new GraphBuilder(g);
|
// Some constants specific to the pre-trained model at:
|
// https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip
|
//
|
// - The model was trained with images scaled to 224x224 pixels.
|
// - The colors, represented as R, G, B in 1-byte each were converted to
|
// float using (value - Mean)/Scale.
|
final int H = 224;
|
final int W = 224;
|
final float mean = 117f;
|
final float scale = 1f;
|
|
// Since the graph is being constructed once per execution here, we can use a constant for the
|
// input image. If the graph were to be re-used for multiple input images, a placeholder would
|
// have been more appropriate.
|
final Output<String> input = b.constant("DecodeJpeg/contents", imageBytes);//DecodeJpeg/contents:0
|
final Output<Float> output =
|
b.div(
|
b.sub(
|
b.resizeBilinear(
|
b.expandDims(
|
b.cast(b.decodeJpeg(input, 3), Float.class),
|
b.constant("make_batch", 0)),
|
b.constant("size", new int[] {H, W})),
|
b.constant("mean", mean)),
|
b.constant("scale", scale));
|
try (Session s = new Session(g)) {
|
// Generally, there may be multiple output tensors, all of them must be closed to prevent resource leaks.
|
return s.runner().fetch(output.op().name()).run().get(0).expect(Float.class);
|
}
|
}
|
}
|
|
private static Tensor<Float> constructAndExecuteGraphToNormalizeImage2(byte[] imageByte) {
|
Graph g = new Graph();
|
GraphBuilder b = new GraphBuilder(g);
|
// Some constants specific to the pre-trained model at:
|
// https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip
|
//
|
// - The model was trained with images scaled to 224x224 pixels.
|
// - The colors, represented as R, G, B in 1-byte each were converted to
|
// float using (value - Mean)/Scale.
|
//final int H = 224;
|
final int H = 299;
|
final int W = 299;
|
//final int W = 224;
|
|
//需要实际调试
|
final float mean = 0.5f;
|
//final float mean = 127.5f;
|
//final float scale = 1f;
|
//final float scale = 127.5f;
|
|
final float scale = 1250f;
|
|
// Since the graph is being constructed once per execution here, we can use a constant for the
|
// input image. If the graph were to be re-used for multiple input images, a placeholder would
|
// have been more appropriate.
|
//final Output<String> input = b.constant("input",imageByte);
|
final Output<String> input = b.constant("DecodeJpeg/contents",imageByte);
|
final Output<Float> output =
|
b.div(
|
b.sub(
|
b.resizeBilinear(
|
b.expandDims(
|
b.cast(b.decodeJpeg(input, 3), Float.class),
|
b.constant("make_batch", 0)),
|
b.constant("size", new int[] {H, W})),
|
b.constant("mean", mean)),
|
b.constant("scale", scale));
|
// Generally, there may be multiple output tensors, all of them must be closed to prevent resource leaks.
|
//return model.session().runner().fetch(output).run().get(0).expect(Float.class);
|
Session s = new Session(g);
|
return s.runner().fetch(output.op().name()).run().get(0).expect(Float.class);
|
}
|
|
private static float[] executeInceptionGraph(byte[] graphDef, Tensor<Float> image) {
|
try (Graph g = new Graph()) {
|
g.importGraphDef(graphDef);
|
try (Session s = new Session(g);
|
// Generally, there may be multiple output tensors, all of them must be closed to prevent resource leaks.
|
Tensor<Float> result =
|
s.runner().feed("input", image).fetch("output").run().get(0).expect(Float.class)) {
|
//s.runner().feed("Placeholder:0", image).fetch("final_result:0").run().get(0).expect(Float.class)) {
|
final long[] rshape = result.shape();
|
if (result.numDimensions() != 2 || rshape[0] != 1) {
|
throw new RuntimeException(
|
String.format(
|
"Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s",
|
Arrays.toString(rshape)));
|
}
|
int nlabels = (int) rshape[1];
|
return result.copyTo(new float[1][nlabels])[0];
|
}
|
}
|
|
}
|
|
private static float[] executeInceptionGraph2(SavedModelBundle model, Tensor<Float> image) {
|
|
Session s = model.session();
|
// Generally, there may be multiple output tensors, all of them must be closed to prevent resource leaks.
|
Tensor<Float> result =
|
//s.runner().feed("input", image).fetch("output").run().get(0).expect(Float.class)) {
|
s.runner().feed("Placeholder:0", image).fetch("final_result:0").run().get(0).expect(Float.class); {
|
final long[] rshape = result.shape();
|
if (result.numDimensions() != 2 || rshape[0] != 1) {
|
throw new RuntimeException(
|
String.format(
|
"Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s",
|
Arrays.toString(rshape)));
|
}
|
int nlabels = (int) rshape[1];
|
return result.copyTo(new float[1][nlabels])[0];
|
|
}
|
|
}
|
|
private static int maxIndex(float[] probabilities) {
|
int best = 0;
|
for (int i = 1; i < probabilities.length; ++i) {
|
if (probabilities[i] > probabilities[best]) {
|
best = i;
|
}
|
}
|
return best;
|
}
|
|
private static byte[] readAllBytesOrExit(Path path) {
|
try {
|
return Files.readAllBytes(path);
|
} catch (IOException e) {
|
System.err.println("Failed to read [" + path + "]: " + e.getMessage());
|
System.exit(1);
|
}
|
return null;
|
}
|
|
private static List<String> readAllLinesOrExit(Path path) {
|
try {
|
return Files.readAllLines(path, Charset.forName("UTF-8"));
|
} catch (IOException e) {
|
System.err.println("Failed to read [" + path + "]: " + e.getMessage());
|
System.exit(0);
|
}
|
return null;
|
}
|
|
// In the fullness of time, equivalents of the methods of this class should be auto-generated from
|
// the OpDefs linked into libtensorflow_jni.so. That would match what is done in other languages
|
// like Python, C++ and Go.
|
static class GraphBuilder {
|
GraphBuilder(Graph g) {
|
this.g = g;
|
}
|
|
Output<Float> div(Output<Float> x, Output<Float> y) {
|
return binaryOp("Div", x, y);
|
}
|
|
Output<Float> holder(Output<Float> x, Output<Float> y) {
|
return binaryOp("Placeholder:0", x, y);
|
}
|
|
<T> Output<T> sub(Output<T> x, Output<T> y) {
|
return binaryOp("Sub", x, y);
|
}
|
|
<T> Output<Float> resizeBilinear(Output<T> images, Output<Integer> size) {
|
return binaryOp3("ResizeBilinear", images, size);
|
}
|
|
<T> Output<T> expandDims(Output<T> input, Output<Integer> dim) {
|
return binaryOp3("ExpandDims", input, dim);
|
}
|
|
<T, U> Output<U> cast(Output<T> value, Class<U> type) {
|
DataType dtype = DataType.fromClass(type);
|
return g.opBuilder("Cast", "Cast")
|
.addInput(value)
|
.setAttr("DstT", dtype)
|
.build()
|
.<U>output(0);
|
}
|
|
Output<UInt8> decodeJpeg(Output<String> contents, long channels) {
|
return g.opBuilder("DecodeJpeg", "DecodeJpeg")
|
.addInput(contents)
|
.setAttr("channels", channels)
|
.build()
|
.<UInt8>output(0);
|
}
|
|
<T> Output<T> constant(String name, Object value, Class<T> type) {
|
try (Tensor<T> t = Tensor.<T>create(value, type)) {
|
return g.opBuilder("Const", name)
|
.setAttr("dtype", DataType.fromClass(type))
|
.setAttr("value", t)
|
.build()
|
.<T>output(0);
|
}
|
}
|
Output<String> constant(String name, byte[] value) {
|
return this.constant(name, value, String.class);
|
}
|
|
Output<Integer> constant(String name, int value) {
|
return this.constant(name, value, Integer.class);
|
}
|
|
Output<Integer> constant(String name, int[] value) {
|
return this.constant(name, value, Integer.class);
|
}
|
|
Output<Float> constant(String name, float value) {
|
return this.constant(name, value, Float.class);
|
}
|
|
private <T> Output<T> binaryOp(String type, Output<T> in1, Output<T> in2) {
|
return g.opBuilder(type, type).addInput(in1).addInput(in2).build().<T>output(0);
|
}
|
|
private <T, U, V> Output<T> binaryOp3(String type, Output<U> in1, Output<V> in2) {
|
return g.opBuilder(type, type).addInput(in1).addInput(in2).build().<T>output(0);
|
}
|
private Graph g;
|
}
|
}
|