|  | 
| 42 | 42 |  "id": "hRTa3Ee15WsJ" | 
| 43 | 43 |  }, | 
| 44 | 44 |  "source": [ | 
| 45 |  | - "# Retrain a classification model for Edge TPU (with TF 2.0)" | 
|  | 45 | + "# Retrain a classification model for Edge TPU using post-training quantization (with TF2)" | 
| 46 | 46 |  ] | 
| 47 | 47 |  }, | 
| 48 | 48 |  { | 
|  | 
| 52 | 52 |  "id": "TaX0smDP7xQY" | 
| 53 | 53 |  }, | 
| 54 | 54 |  "source": [ | 
| 55 |  | - "In this tutorial, we'll use TensorFlow 2.x to create an image classification model, train it with a flowers dataset, and convert it to TensorFlow Lite using post-training quantization. Finally, we compile it for compatibility with the Edge TPU (available in [Coral devices](https://coral.ai/products/)).\n", | 
|  | 55 | + "In this tutorial, we'll use TensorFlow 2.3 to create an image classification model, train it with a flowers dataset, and convert it to TensorFlow Lite using post-training quantization. Finally, we compile it for compatibility with the Edge TPU (available in [Coral devices](https://coral.ai/products/)).\n", | 
| 56 | 56 |  "\n", | 
| 57 | 57 |  "The model is based on a pre-trained version of MobileNet V2. We'll start by retraining only the classification layers, reusing MobileNet's pre-trained feature extractor layers. Then we'll fine-tune the model by updating weights in some of the feature extractor layers. This type of transfer learning is much faster than training the entire model from scratch.\n", | 
| 58 | 58 |  "\n", | 
| 59 | 59 |  "Once it's trained, we'll use post-training quantization to convert all parameters to int8 format, which reduces the model size and increases inferencing speed. This format is also required for compatibility on the Edge TPU.\n", | 
| 60 | 60 |  "\n", | 
| 61 | 61 |  "For more information about how to create a model compatible with the Edge TPU, see the [documentation at coral.ai](https://coral.ai/docs/edgetpu/models-intro/).\n", | 
| 62 | 62 |  "\n", | 
| 63 |  | - "**Note:** This tutorial requires TensorFlow 2.0+. If you're using TF 1.x, see [the 1.x version of this tutorial](https://colab.research.google.com/github/google-coral/tutorials/blob/master/retrain_classification_ptq_tf1.ipynb)." | 
|  | 63 | + "**Note:** This tutorial requires TensorFlow 2.3+ and depends on an early release version of the `TFliteConverter` for full quantization, which currently does not work for all types of models. In particular, this tutorial expects a Keras-built model and this conversion strategy currently doesn't work with models imported from a frozen graph. (If you're using TF 1.x, see [the 1.x version of this tutorial](https://colab.research.google.com/github/google-coral/tutorials/blob/master/retrain_classification_ptq_tf1.ipynb).)" | 
| 64 | 64 |  ] | 
| 65 | 65 |  }, | 
| 66 | 66 |  { | 
|  | 
| 95 | 95 |  "## Import the required libraries" | 
| 96 | 96 |  ] | 
| 97 | 97 |  }, | 
|  | 98 | + { | 
|  | 99 | + "cell_type": "markdown", | 
|  | 100 | + "metadata": { | 
|  | 101 | + "colab_type": "text", | 
|  | 102 | + "id": "02MxhCyFmpzn" | 
|  | 103 | + }, | 
|  | 104 | + "source": [ | 
|  | 105 | + "**Note:** Until TensorFlow 2.3 is released as stable, we need to install the nightly build in order to use the latest `TFLiteConverter` that supports quantization for input and output tensors:" | 
|  | 106 | + ] | 
|  | 107 | + }, | 
|  | 108 | + { | 
|  | 109 | + "cell_type": "code", | 
|  | 110 | + "execution_count": 0, | 
|  | 111 | + "metadata": { | 
|  | 112 | + "colab": {}, | 
|  | 113 | + "colab_type": "code", | 
|  | 114 | + "id": "L-YbcBDDmaxO" | 
|  | 115 | + }, | 
|  | 116 | + "outputs": [], | 
|  | 117 | + "source": [ | 
|  | 118 | + "! pip uninstall -y tensorflow\n", | 
|  | 119 | + "! pip install tf-nightly" | 
|  | 120 | + ] | 
|  | 121 | + }, | 
| 98 | 122 |  { | 
| 99 | 123 |  "cell_type": "code", | 
| 100 | 124 |  "execution_count": 0, | 
|  | 
| 106 | 130 |  "outputs": [], | 
| 107 | 131 |  "source": [ | 
| 108 | 132 |  "import tensorflow as tf\n", | 
| 109 |  | - "assert tf.__version__.startswith('2')\n", | 
|  | 133 | + "assert tf.__version__.startswith('2.3')\n", | 
| 110 | 134 |  "\n", | 
| 111 | 135 |  "import os\n", | 
| 112 | 136 |  "import numpy as np\n", | 
|  | 
| 813 | 837 |  "converter.target_spec.supported_types = [tf.int8]\n", | 
| 814 | 838 |  "# This ensures that if any ops can't be quantized, the converter throws an error\n", | 
| 815 | 839 |  "converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]\n", | 
|  | 840 | + "# These set the input and output tensors to uint8 (added in r2.3)\n", | 
|  | 841 | + "converter.inference_input_type = tf.uint8\n", | 
|  | 842 | + "converter.inference_output_type = tf.uint8\n", | 
| 816 | 843 |  "# And this sets the representative dataset so we can quantize the activations\n", | 
| 817 | 844 |  "converter.representative_dataset = representative_data_gen\n", | 
| 818 | 845 |  "tflite_model = converter.convert()\n", | 
|  | 
| 821 | 848 |  " f.write(tflite_model)" | 
| 822 | 849 |  ] | 
| 823 | 850 |  }, | 
| 824 |  | - { | 
| 825 |  | - "cell_type": "markdown", | 
| 826 |  | - "metadata": { | 
| 827 |  | - "colab_type": "text", | 
| 828 |  | - "id": "O7XqlzVfxCtD" | 
| 829 |  | - }, | 
| 830 |  | - "source": [ | 
| 831 |  | - "**Note:** An alternative technique to quantize the model is to use [quantization-aware training](https://www.tensorflow.org/model_optimization/guide/quantization/training). This typically results in better accuracy because the training accounts for the decreased parameter precision. But it requires modification to the model graph before initial training, which isn't always possible if you don't have a robust training dataset." | 
| 832 |  | - ] | 
| 833 |  | - }, | 
| 834 | 851 |  { | 
| 835 | 852 |  "cell_type": "markdown", | 
| 836 | 853 |  "metadata": { | 
|  | 
| 901 | 918 |  " input_details = interpreter.get_input_details()[0]\n", | 
| 902 | 919 |  " tensor_index = input_details['index']\n", | 
| 903 | 920 |  " input_tensor = interpreter.tensor(tensor_index)()[0]\n", | 
| 904 |  | - " input_tensor[:, :] = input\n", | 
| 905 |  | - " # NOTE: This model uses float inputs, but if inputs were uint8,\n", | 
| 906 |  | - " # we would quantize the input like this:\n", | 
| 907 |  | - " # scale, zero_point = input_details['quantization']\n", | 
| 908 |  | - " # input_tensor[:, :] = np.uint8(input / scale + zero_point)\n", | 
|  | 921 | + " # Inputs for the TFLite model must be uint8, so we quantize our input data.\n", | 
|  | 922 | + " # NOTE: This step is necessary only because we're receiving input data from\n", | 
|  | 923 | + " # ImageDataGenerator, which rescaled all image data to float [0,1]. When using\n", | 
|  | 924 | + " # bitmap inputs, they're already uint8 [0,255] so this can be replaced with:\n", | 
|  | 925 | + " # input_tensor[:, :] = input\n", | 
|  | 926 | + " scale, zero_point = input_details['quantization']\n", | 
|  | 927 | + " input_tensor[:, :] = np.uint8(input / scale + zero_point)\n", | 
| 909 | 928 |  "\n", | 
| 910 | 929 |  "def classify_image(interpreter, input):\n", | 
| 911 | 930 |  " set_input_tensor(interpreter, input)\n", | 
| 912 | 931 |  " interpreter.invoke()\n", | 
| 913 | 932 |  " output_details = interpreter.get_output_details()[0]\n", | 
| 914 | 933 |  " output = interpreter.get_tensor(output_details['index'])\n", | 
| 915 |  | - " # NOTE: This model uses float outputs, but if outputs were uint8,\n", | 
| 916 |  | - " # we would dequantize the results like this:\n", | 
| 917 |  | - " # scale, zero_point = output_details['quantization']\n", | 
| 918 |  | - " # output = scale * (output - zero_point)\n", | 
|  | 934 | + " # Outputs from the TFLite model are uint8, so we dequantize the results:\n", | 
|  | 935 | + " scale, zero_point = output_details['quantization']\n", | 
|  | 936 | + " output = scale * (output - zero_point)\n", | 
| 919 | 937 |  " top_1 = np.argmax(output)\n", | 
| 920 | 938 |  " return top_1\n", | 
| 921 | 939 |  "\n", | 
|  | 
| 1104 | 1122 |  "\n", | 
| 1105 | 1123 |  "Check out more examples for running inference at [coral.ai/examples](https://coral.ai/examples/#code-examples/)." | 
| 1106 | 1124 |  ] | 
| 1107 |  | - }, | 
| 1108 |  | - { | 
| 1109 |  | - "cell_type": "markdown", | 
| 1110 |  | - "metadata": { | 
| 1111 |  | - "colab_type": "text", | 
| 1112 |  | - "id": "AsMDBkU43qen" | 
| 1113 |  | - }, | 
| 1114 |  | - "source": [ | 
| 1115 |  | - "### Notice about float inputs/outputs\n", | 
| 1116 |  | - "\n", | 
| 1117 |  | - "Currently, the [`TFLiteConverter`](https://www.tensorflow.org/api_docs/python/tf/lite/TFLiteConverter) v2 always leaves the input and output tensors in float format, but because all internal parameters are int8, the converter adds a quantize op at the beginning of the graph and a dequantize op at the end.\n", | 
| 1118 |  | - "\n", | 
| 1119 |  | - "These quant/dequant ops cannot be compiled for the Edge TPU, so the Edge TPU Compiler leaves them to run on the CPU—the rest of the model runs on the Edge TPU.\n", | 
| 1120 |  | - "\n", | 
| 1121 |  | - "You can see from the compiler log file that just two ops did not map to the Edge TPU (the `DEQUANTIZE` and the `QUANTIZE` ops):\n" | 
| 1122 |  | - ] | 
| 1123 |  | - }, | 
| 1124 |  | - { | 
| 1125 |  | - "cell_type": "code", | 
| 1126 |  | - "execution_count": 0, | 
| 1127 |  | - "metadata": { | 
| 1128 |  | - "colab": {}, | 
| 1129 |  | - "colab_type": "code", | 
| 1130 |  | - "id": "SCKb9EeS354f" | 
| 1131 |  | - }, | 
| 1132 |  | - "outputs": [], | 
| 1133 |  | - "source": [ | 
| 1134 |  | - "! cat mobilenet_v2_1.0_224_quant_edgetpu.log" | 
| 1135 |  | - ] | 
| 1136 |  | - }, | 
| 1137 |  | - { | 
| 1138 |  | - "cell_type": "markdown", | 
| 1139 |  | - "metadata": { | 
| 1140 |  | - "colab_type": "text", | 
| 1141 |  | - "id": "d0Bspd0q74aI" | 
| 1142 |  | - }, | 
| 1143 |  | - "source": [ | 
| 1144 |  | - "These quant/dequant ops add negligible latency, but ideally these should be removed so you can feed the model quantized (not float) inputs directly. If you use the `TFLiteConverter` in TF 1.15, you can remove them using the `inference_input_type` and `inference_output_type` options during conversion. The TensorFlow team is working to bring these options to TF 2.x and you can follow the progress on [this GitHub issue](https://github.com/tensorflow/tensorflow/issues/38285).\n", | 
| 1145 |  | - "\n", | 
| 1146 |  | - "You can also read more about [Coral's model compatibility with float inputs](https://coral.ai/docs/edgetpu/models-intro/#float-input-and-output-tensors)." | 
| 1147 |  | - ] | 
| 1148 | 1125 |  } | 
| 1149 | 1126 |  ], | 
| 1150 | 1127 |  "metadata": { | 
|  | 
0 commit comments