Android TensorFlow Lite 辨識物品

由於最近需要去研究人臉辨識的實作方式,雖然目前網路上提供很多SDK,不過大多數都是要付費使用,畢竟一個Machine Learning的資料庫也需要維護,所以大部分不會有免費使用的SDK。

在網路上搜尋後,找到了一個輕量化的開源資料庫 - TensorFlow Lite

本篇先以辨識圖像的例子來解說,往後再介紹人臉辨識的方式。

參考文章

Step 1: 下載TensorFlow模型

使用MobileNet_v1_1.0_224的模型 下載點

解壓縮之後可以找到 mobilenet_v1_1.0_224.tflite這個檔案

Step 2: 下載模型Label

上述的模型共有1001個分類,不過壓縮檔內沒有包含分類好的標籤,
已分類好的標籤可由此下載 下載點

Step 3: 引用TensorFlow Lite

在build.gradle的dependencies中加入libraries

build.gradle
1
2
3
4
5
dependencies {
...
implementation 'com.github.bumptech.glide:glide:4.3.1'
implementation 'org.tensorflow:tensorflow-lite:0.0.0-nightly'
}

然後在android的項目中加入以下代碼,主要是避免壓縮tensor lite的模型

build.gradle
1
2
3
4
5
6
android {
...
aaptOptions {
noCompress "tflite"
}
}

Step 4: 添加模型

在main目錄下創建assets,將剛剛.tflite和label .txt檔案放到此處。

在畫面中使用兩個按鈕: 引用模型和分析圖片

點擊引用模型的代碼如下

loadModeFile : 把模型文件讀取成ByteBuffer,並交給Interpreter初始化。
load_mode : 取得Interpreter後,使用這個物件來分析圖像,並設置使用的Thread 數量 (tflite.setNumThreads)

MainActivity.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
private final String MODEL_NAME = "mobilenet_v1_1.0_224";

private Interpreter tflite = null;

@Override
public void onClick(View v) {
switch(v.getId()) {
case R.id.load_model:
load_model(MODEL_NAME);
break;
...
}
}

private void load_model(String model) {
try {
tflite = new Interpreter(loadModelFile(model));
Toast.makeText(MainActivity.this, model + " model load success", Toast.LENGTH_SHORT).show();
Log.d(TAG, model + " model load success");
tflite.setNumThreads(4);
load_result = true;
} catch (IOException e) {
Toast.makeText(MainActivity.this, model + " model load fail", Toast.LENGTH_SHORT).show();
Log.d(TAG, model + " model load fail");
load_result = false;
e.printStackTrace();
}
}

private ByteBuffer loadModelFile(String model) throws IOException {
AssetFileDescriptor fileDescriptor = getApplicationContext().getAssets().openFd(model + ".tflite");
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
FileChannel fileChannel = inputStream.getChannel();
long startOffset = fileDescriptor.getStartOffset();
long declaredLength = fileDescriptor.getDeclaredLength();
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}

Step 5: 添加Label

在程式開始時,先讀入Label文件

MainActivity.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
...
readCacheLabelFromLocalFile();
}

private void readCacheLabelFromLocalFile() {
try {
AssetManager assetManager = getApplicationContext().getAssets();
BufferedReader reader = new BufferedReader(new InputStreamReader(assetManager.open(LABEL_NAME + ".txt")));
String readLine = null;
while ((readLine = reader.readLine()) != null) {
resultLabel.add(readLine);
}
reader.close();
} catch (Exception e) {
Log.e("labelCache", "error " + e);
}
}

Step 6: 分析圖片

predict_image :先取得圖片後,對圖片進行壓縮,之後把圖片轉成ByteBuffer格式後,再調用Interpreter.run()進行分析。
get_max_result : 取得分析之後,機率最高的Label

MainActivity.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
private List<String> resultLabel = new ArrayList<>();
private ImageView mImage;
private int[] ddims = {1, 3, 224, 224};
private TextView mResult;

@Override
public void onClick(View v) {
switch(v.getId()) {
...
case R.id.use_photo:

RequestOptions options = new RequestOptions().skipMemoryCache(true).diskCacheStrategy(DiskCacheStrategy.NONE);
File file = new File("/storage/sdcard1/mouse.jpeg");
Uri photoUri = Uri.fromFile(file);
Glide.with(MainActivity.this).load(photoUri).apply(options).into(mImage);
predict_image(file.getAbsolutePath());
break;
default:
break;
}
}

private void predict_image(String image_path) {
// picture to float array
Bitmap bmp = getScaleBitmap(image_path);
ByteBuffer inputData = getScaledMatrix(bmp, ddims);
try {
float[][] labelProbArray = new float[1][1001];
long start = System.currentTimeMillis();
// get predict result
tflite.run(inputData, labelProbArray);
long end = System.currentTimeMillis();
long time = end - start;

float[] results = new float[labelProbArray[0].length];
System.arraycopy(labelProbArray[0], 0, results, 0, labelProbArray[0].length);
// show predict result and time
int r = get_max_result(results);

String show_text = "result:" + r + "\nname:" + resultLabel.get(r) + "\nprobability:" + results[r] + "\ntime:" + time + "ms";
mResult.setText(show_text);
} catch (Exception e) {
e.printStackTrace();
}
}

private Bitmap getScaleBitmap(String filePath) {
BitmapFactory.Options opt = new BitmapFactory.Options();
opt.inJustDecodeBounds = true;
BitmapFactory.decodeFile(filePath, opt);

int bmpWidth = opt.outWidth;
int bmpHeight = opt.outHeight;
int maxSize = 500;
// compress picture with inSampleSize
opt.inSampleSize = 1;
while (true) {
if (bmpWidth / opt.inSampleSize < maxSize || bmpHeight / opt.inSampleSize < maxSize) {
break;
}
opt.inSampleSize *= 2;
}
opt.inJustDecodeBounds = false;

return BitmapFactory.decodeFile(filePath, opt);
}

private ByteBuffer getScaledMatrix(Bitmap bitmap, int[] ddims) {
ByteBuffer imgData = ByteBuffer.allocateDirect(ddims[0] * ddims[1] * ddims[2] * ddims[3] * 4);
imgData.order(ByteOrder.nativeOrder());
// get image pixel
int[] pixels = new int[ddims[2] * ddims[3]];
Bitmap bm = Bitmap.createScaledBitmap(bitmap, ddims[2], ddims[3], false);
bm.getPixels(pixels, 0, bm.getWidth(), 0, 0, ddims[2], ddims[3]);
int pixel = 0;
for (int i = 0; i < ddims[2]; ++i) {
for (int j = 0; j < ddims[3]; ++j) {
final int val = pixels[pixel++];
imgData.putFloat(((((val >> 16) & 0xFF) - 128f) / 128f));
imgData.putFloat(((((val >> 8) & 0xFF) - 128f) / 128f));
imgData.putFloat((((val & 0xFF) - 128f) / 128f));
}
}

if (bm.isRecycled()) {
bm.recycle();
}
return imgData;
}

private int get_max_result(float[] result) {
float probability = result[0];
int r = 0;
for (int i = 0; i < result.length; i++) {
if (probability < result[i]) {
probability = result[i];
r = i;
}
}
return r;
}

執行結果

作者

Nick Lin

發表於

2019-04-10

更新於

2023-01-18

許可協議


評論