📜 ⬆️ ⬇️

Neural networks in Android, Google ML Kit and not only

So, you have developed and trained your neural network to perform some task (for example, the same object recognition through the camera) and want to embed it in your application on android? Then welcome under the cat!

To begin with, it should be understood that the android is currently able to work only with networks of the TensorFlowLite format, which means we need to carry out some manipulations with the original network. Suppose you already have a trained network in a Keras or Tensorflow framework. You must save the grid in pb format.

Let's start with the case when you write to Tensorflow, then everything is a little easier.

saver = tf.train.Saver() tf.train.write_graph(session.graph_def, path_to_folder, "net.pb", False) tf.train.write_graph(session.graph_def, path_to_folder, "net.pbtxt", True) saver.save(session,path_to_folder+"model.ckpt") 

If you are writing to Keras, you need to create a new session object at the beginning of the file where you train the network, save the link to it, and pass it to the set_session function
')
 import keras.backend as K session = K.get_session() K.set_session(session) 

Great, you saved the network, now you need to convert it to tflite format. To do this, we need to run two small scripts, the first one will “freeze” the network, the second one will already translate into the required format. The essence of the “freeze” is that tf does not store the weight of the layers in the saved pb file, but saves them in special checkpoints. For subsequent conversion to tflite, it is necessary that all information about the neural network be in one file.

 freeze_graph --input_binary=false --input_graph=net.pbtxt --output_node_names=result/Softmax --output_graph=frozen_graph.pb --input_checkpoint=model.ckpt 

Note that you need to know the name of the output tensor. In tensorflow, you can set it yourself; if using Keras, set the name in the layer constructor

 model.add(Dense(10,activation="softmax",name="result")) 

In this case, the tensor name usually looks like “result / Softmax”

If not in your case, you can find the name as follows

 [print(n.name) for n in session.graph.as_graph_def().node] 

It remains to run the second script

 toco --graph_def_file=frozen-graph.pb --output_file=model.tflite --output_format=TFLITE --inference_type=FLOAT --input_arrays=input_input --output_arrays=result/Softmax --input_shapes=1,784 

Hooray! Now you have a TensorFlowLite model in your folder, it’s easy to integrate it correctly into your Android application. You can do this with the new-fashioned Firebase ML Kit, but there is another way, about it a little later. Add a dependency to our gradle file

 dependencies { // ... implementation 'com.google.firebase:firebase-ml-model-interpreter:16.2.0' } 

Now you need to decide whether you will keep the model somewhere on your server or supply it with the application.

Consider the first case: a model on the server. First of all, do not forget to add to the manifest

 <uses-permission android:name="android.permission.INTERNET" /> 

  //      ,   /  FirebaseModelDownloadConditions.Builder conditionsBuilder = new FirebaseModelDownloadConditions.Builder().requireWifi(); if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) { conditionsBuilder = conditionsBuilder .requireCharging(); } FirebaseModelDownloadConditions conditions = conditionsBuilder.build(); //   FirebaseCloudModelSource ,   (    ,  //   Firebase) FirebaseCloudModelSource cloudSource = new FirebaseCloudModelSource.Builder("my_cloud_model") .enableModelUpdates(true) .setInitialDownloadConditions(conditions) .setUpdatesDownloadConditions(conditions) .build(); FirebaseModelManager.getInstance().registerCloudModelSource(cloudSource); 

If you are using the model included in the application locally, do not forget to add the following entry to the build.gradle file so that the model file does not compress.

 android { // ... aaptOptions { noCompress "tflite" } } 

After that, by analogy with the model in the cloud, our local neuron needs to be registered.

 FirebaseLocalModelSource localSource = new FirebaseLocalModelSource.Builder("my_local_model") .setAssetFilePath("mymodel.tflite") .build(); FirebaseModelManager.getInstance().registerLocalModelSource(localSource); 

The code above assumes that your model is in the assets folder, if not, instead of

  .setAssetFilePath("mymodel.tflite") 

use

  .seFilePath(filePath) 

Then we create new FirebaseModelOptions and FirebaseModelInterpreter objects.

 FirebaseModelOptions options = new FirebaseModelOptions.Builder() .setCloudModelName("my_cloud_model") .setLocalModelName("my_local_model") .build(); FirebaseModelInterpreter firebaseInterpreter = FirebaseModelInterpreter.getInstance(options); 

You can use both local and server-based models at the same time. In this case, the default cloud will be used if it is available, otherwise local.

Almost everything, it remains to create arrays for input / output data, and run!

 FirebaseModelInputOutputOptions inputOutputOptions = new FirebaseModelInputOutputOptions.Builder() .setInputFormat(0, FirebaseModelDataType.BYTE, new int[]{1, 640, 480, 3}) .setOutputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 784}) .build(); byte[][][][] input = new byte[1][640][480][3]; input = getYourInputData(); FirebaseModelInputs inputs = new FirebaseModelInputs.Builder() .add(input) // add() as many input arrays as your model requires .build(); Task<FirebaseModelOutputs> result = firebaseInterpreter.run(inputs, inputOutputOptions) .addOnSuccessListener( new OnSuccessListener<FirebaseModelOutputs>() { @Override public void onSuccess(FirebaseModelOutputs result) { // ... } }) .addOnFailureListener( new OnFailureListener() { @Override public void onFailure(@NonNull Exception e) { // Task failed with an exception // ... } }); float[][] output = result.<float[][]>getOutput(0); float[] probabilities = output[0]; 

If you do not want to use Firebase for some reason, there is another way, call the tflite interpreter and feed it directly.

Add line to build / gradle

  implementation 'org.tensorflow:tensorflow-lite:+' 

We create the interpreter and arrays

  Interpreter tflite = new Interpreter(loadModelFile(getContext(), "model.tflite")); //     inputs tflite.run(inputs,outputs) 

The code in this case is much smaller, as you see.

That's all you need to use your neural network in android.

Useful links:

Off the docks by ML Kit
Tensorflow lite

Source: https://habr.com/ru/post/422041/


All Articles