saver = tf.train.Saver() tf.train.write_graph(session.graph_def, path_to_folder, "net.pb", False) tf.train.write_graph(session.graph_def, path_to_folder, "net.pbtxt", True) saver.save(session,path_to_folder+"model.ckpt")
import keras.backend as K session = K.get_session() K.set_session(session)
freeze_graph --input_binary=false --input_graph=net.pbtxt --output_node_names=result/Softmax --output_graph=frozen_graph.pb --input_checkpoint=model.ckpt
model.add(Dense(10,activation="softmax",name="result"))
[print(n.name) for n in session.graph.as_graph_def().node]
toco --graph_def_file=frozen-graph.pb --output_file=model.tflite --output_format=TFLITE --inference_type=FLOAT --input_arrays=input_input --output_arrays=result/Softmax --input_shapes=1,784
dependencies { // ... implementation 'com.google.firebase:firebase-ml-model-interpreter:16.2.0' }
<uses-permission android:name="android.permission.INTERNET" />
// , / FirebaseModelDownloadConditions.Builder conditionsBuilder = new FirebaseModelDownloadConditions.Builder().requireWifi(); if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) { conditionsBuilder = conditionsBuilder .requireCharging(); } FirebaseModelDownloadConditions conditions = conditionsBuilder.build(); // FirebaseCloudModelSource , ( , // Firebase) FirebaseCloudModelSource cloudSource = new FirebaseCloudModelSource.Builder("my_cloud_model") .enableModelUpdates(true) .setInitialDownloadConditions(conditions) .setUpdatesDownloadConditions(conditions) .build(); FirebaseModelManager.getInstance().registerCloudModelSource(cloudSource);
android { // ... aaptOptions { noCompress "tflite" } }
FirebaseLocalModelSource localSource = new FirebaseLocalModelSource.Builder("my_local_model") .setAssetFilePath("mymodel.tflite") .build(); FirebaseModelManager.getInstance().registerLocalModelSource(localSource);
.setAssetFilePath("mymodel.tflite")
.seFilePath(filePath)
FirebaseModelOptions options = new FirebaseModelOptions.Builder() .setCloudModelName("my_cloud_model") .setLocalModelName("my_local_model") .build(); FirebaseModelInterpreter firebaseInterpreter = FirebaseModelInterpreter.getInstance(options);
FirebaseModelInputOutputOptions inputOutputOptions = new FirebaseModelInputOutputOptions.Builder() .setInputFormat(0, FirebaseModelDataType.BYTE, new int[]{1, 640, 480, 3}) .setOutputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 784}) .build(); byte[][][][] input = new byte[1][640][480][3]; input = getYourInputData(); FirebaseModelInputs inputs = new FirebaseModelInputs.Builder() .add(input) // add() as many input arrays as your model requires .build(); Task<FirebaseModelOutputs> result = firebaseInterpreter.run(inputs, inputOutputOptions) .addOnSuccessListener( new OnSuccessListener<FirebaseModelOutputs>() { @Override public void onSuccess(FirebaseModelOutputs result) { // ... } }) .addOnFailureListener( new OnFailureListener() { @Override public void onFailure(@NonNull Exception e) { // Task failed with an exception // ... } }); float[][] output = result.<float[][]>getOutput(0); float[] probabilities = output[0];
implementation 'org.tensorflow:tensorflow-lite:+'
Interpreter tflite = new Interpreter(loadModelFile(getContext(), "model.tflite")); // inputs tflite.run(inputs,outputs)
Source: https://habr.com/ru/post/422041/
All Articles