RISC-V MCU中文社区

TFlite之模型量化

发表于 2023-04-17 10:04:52
0
4981
1

1 简介

1.1 什么是模型量化?

简单来说:模型量化(Model Quantization)就是通过某些方法将浮点模型转为定点模型。

一般通过Tensorflow、PyTorch、Caffe等工具训练出的模型权重等都是float32类型,其模型精度较高但模型尺寸较大,在一些内存受限的场景,需要尽量减小模型尺寸,模型量化就是利用一些方法将浮点模型转为定点模型,这样内存占用可以减少数倍,运算速度也有较大提升,并且精度只有稍许损失,可以把模型量化理解为一种有损压缩,虽然会损失一些精度,但是会显著降低模型尺寸,同时提高运行效率。

1.2 为什么要进行模型量化?

如下表:

参数量 计算速度 内存占用 精度对比
量化前 参数量大 计算量大 内存占用多 精度高
量化后 压缩参数 提升速度 降低内存占用 精度有损失

在嵌入式AI场景,其对内存以及计算速度有较高要求,而可以接受一些精度损失,那么就可以使用模型量化技术来降低模型的复杂性。

2 量化方法

TensorFlow提供两种量化方式:

  1. 量化感知训练(Quantization aware training) ,也叫做训练中量化,基于tf.keras
  2. 训练后量化(Post-training quantization),训练得到模型后,使用TensorFlow Lite转换器量化

其区别如下:

  • 训练后的量化技术迭代快,易于使用,但是模型精度损失较大;

  • 训练中的量化技术相对难于使用,需要重新训练模型,但是模型精度保持较好。

用户可以在使用难易程度、迭代时间、模型精度之间权衡,选择适合的一种量化方式。

3 训练后量化

训练后量化比较简单一些,这篇文档以mnist为例讲述训练后量化方法(训练中量化以后有机会再写文档)。

3.1 训练模型

使用以下脚本,训练获得mnist_train.h5模型

  1. # mnist_train.py
  2. import tensorflow as tf
  3. from tensorflow import keras
  4. print("TensorFlow version {}".format(tf.__version__))
  5. (train_images, train_labels),(test_images, test_labels) = tf.keras.datasets.mnist.load_data()
  6. class_num = 10
  7. # Train the model
  8. model = tf.keras.models.Sequential([
  9. keras.layers.Flatten(input_shape=(train_images.shape[1], train_images.shape[2])), # input_shape=(28,28)
  10. keras.layers.Dense(512, activation=tf.nn.relu),
  11. keras.layers.Dense(64, activation=tf.nn.relu),
  12. keras.layers.Dense(class_num, activation=tf.nn.softmax)
  13. ])
  14. model.compile(optimizer='Adam',
  15. loss='sparse_categorical_crossentropy',
  16. metrics=['accuracy'])
  17. print("model stucture:")
  18. model.summary()
  19. # train
  20. model.fit(train_images, train_labels, epochs=5, batch_size=64)
  21. # evaluate accuracy
  22. loss, acc = model.evaluate(test_images, test_labels)
  23. print("Restored model, accuracy: {:5.2f}% loss: {}".format(100*acc, loss))
  24. # Convert the model to tflite
  25. model.save('mnist_train.h5')
  26. del model # 删除现有模型

日志如下:

  1. TensorFlow version 2.12.0
  2. model stucture:
  3. Model: "sequential"
  4. _________________________________________________________________
  5. Layer (type) Output Shape Param #
  6. =================================================================
  7. flatten (Flatten) (None, 784) 0
  8. dense (Dense) (None, 512) 401920
  9. dense_1 (Dense) (None, 64) 32832
  10. dense_2 (Dense) (None, 10) 650
  11. =================================================================
  12. Total params: 435,402
  13. Trainable params: 435,402
  14. Non-trainable params: 0
  15. _________________________________________________________________
  16. Epoch 1/5
  17. 938/938 [==============================] - 3s 3ms/step - loss: 1.4964 - accuracy: 0.7708
  18. Epoch 2/5
  19. 938/938 [==============================] - 3s 3ms/step - loss: 0.5160 - accuracy: 0.8773
  20. Epoch 3/5
  21. 938/938 [==============================] - 3s 3ms/step - loss: 0.2994 - accuracy: 0.9254
  22. Epoch 4/5
  23. 938/938 [==============================] - 3s 3ms/step - loss: 0.1852 - accuracy: 0.9519
  24. Epoch 5/5
  25. 938/938 [==============================] - 3s 3ms/step - loss: 0.1400 - accuracy: 0.9615
  26. 313/313 [==============================] - 1s 2ms/step - loss: 0.1657 - accuracy: 0.9590
  27. Restored model, accuracy: 95.90% loss: 0.17

最终获得mnist_train.h5模型。

3.2 训练后量化

1. 混合量化:

  1. # mnist_quant_hybrid.py
  2. import tensorflow as tf
  3. from tensorflow import keras
  4. # load h5模型,并评估其精度
  5. (train_images, train_labels),(test_images, test_labels) = tf.keras.datasets.mnist.load_data()
  6. model = keras.models.load_model('mnist_train.h5') # 创建 HDF5 文件 'mnist_train.h5'
  7. # evaluate accuracy
  8. loss, acc = model.evaluate(test_images, test_labels)
  9. print("Restored model, accuracy: {:5.2f}% loss: {:5.2f}".format(100*acc, loss))
  10. tflite_mnist_model = 'mnist_quant_hybrid.tflite'
  11. converter = tf.lite.TFLiteConverter.from_keras_model(model)
  12. converter.optimizations = [tf.lite.Optimize.DEFAULT]
  13. tflite_model = converter.convert()
  14. flatbuffer_size = open(tflite_mnist_model, "wb").write(tflite_model)
  15. print('hybrid: The size of the converted flatbuffer is: %d bytes' % flatbuffer_size)
  16. # 评估量化后模型的准确度
  17. #在PC python中测试tf lite 模型的准确率
  18. def evaluate(interpreter_path):
  19. #加载模型并分配张量
  20. interpreter = tf.lite.Interpreter(model_path=interpreter_path)
  21. interpreter.allocate_tensors()
  22. #获得输入输出张量.
  23. input_details = interpreter.get_input_details()
  24. output_details = interpreter.get_output_details()
  25. import numpy as np
  26. index = input_details[0]['index']
  27. shape = input_details[0]['shape']
  28. acc_count = 0
  29. image_count = test_images.shape[0]
  30. for i in range(image_count):
  31. interpreter.set_tensor(index, test_images[i].reshape(shape).astype("float32"))
  32. interpreter.invoke()
  33. output_data = interpreter.get_tensor(output_details[0]['index'])
  34. label = np.argmax(output_data)
  35. if label == test_labels[i]:
  36. acc_count += 1
  37. print("test_images accuracy is {:.2%}".format(acc_count/(image_count)))
  38. evaluate(tflite_mnist_model)

2. 全整形量化:

  1. # mnist_quant_full_integer.py
  2. import tensorflow as tf
  3. from tensorflow import keras
  4. import numpy as np
  5. print("TensorFlow version {}".format(tf.__version__))
  6. # load h5模型,并评估其精度
  7. (train_images, train_labels),(test_images, test_labels) = tf.keras.datasets.mnist.load_data()
  8. model = keras.models.load_model('mnist_train.h5') # 创建 HDF5 文件 'mnist_train.h5'
  9. # evaluate accuracy
  10. loss, acc = model.evaluate(test_images, test_labels)
  11. print("Restored model, accuracy: {:5.2f}% loss: {:5.2f}".format(100*acc, loss))
  12. tflite_mnist_model = 'mnist_quant_full_integer.tflite'
  13. converter = tf.lite.TFLiteConverter.from_keras_model(model)
  14. converter.optimizations = [tf.lite.Optimize.DEFAULT]
  15. def representative_data_gen():
  16. for image in train_images[0:100,:,:]:
  17. yield[image.reshape(-1,train_images.shape[1],train_images.shape[2]).astype("float32")]
  18. # 设置representative_dataset
  19. converter.representative_dataset = representative_data_gen
  20. # 设置ops量化类型
  21. converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
  22. # 设置模型输入输出数据格式
  23. converter.inference_input_type = tf.uint8 # or tf.uint8
  24. converter.inference_output_type = tf.uint8 # or tf.uint8
  25. tflite_model = converter.convert()
  26. flatbuffer_size = open(tflite_mnist_model, "wb").write(tflite_model)
  27. print('full_integer: The size of the converted flatbuffer is: %d bytes' % flatbuffer_size)
  28. # 评估量化后模型的准确度
  29. #在PC python中测试tf lite 模型的准确率
  30. def evaluate(interpreter_path):
  31. #加载模型并分配张量
  32. interpreter = tf.lite.Interpreter(model_path=interpreter_path)
  33. interpreter.allocate_tensors()
  34. #获得输入输出张量.
  35. input_details = interpreter.get_input_details()
  36. output_details = interpreter.get_output_details()
  37. import numpy as np
  38. index = input_details[0]['index']
  39. shape = input_details[0]['shape']
  40. acc_count = 0
  41. image_count = test_images.shape[0]
  42. for i in range(image_count):
  43. interpreter.set_tensor(index, test_images[i].reshape(shape).astype("uint8"))
  44. interpreter.invoke()
  45. output_data = interpreter.get_tensor(output_details[0]['index'])
  46. label = np.argmax(output_data)
  47. if label == test_labels[i]:
  48. acc_count += 1
  49. print("test_images accuracy is {:.2%}".format(acc_count/(image_count)))
  50. evaluate(tflite_mnist_model)

这样得到:

  1. h5: 5258320 bytes accuracy: 95.80%
  2. hybrid: The size of the converted flatbuffer is: 441576 bytes accuracy is 95.78%
  3. full_integer: The size of the converted flatbuffer is: 440312 bytes accuracy is 95.69%

参考:

  1. tensorflow模型量化篇(1)量化方法及动态范围量化
  2. 模型量化了解一下?
  3. 模型优化与量化
  4. https://tensorflow.google.cn/model_optimization
  5. AI视觉组仙人一步之模型量化
  6. TFLite: Support grouped convolutions · Issue #40044 · tensorflow/tensorflow · GitHub
喜欢1
用户评论
sureZ-ok

sureZ-ok 实名认证

懒的都不写签名

积分
问答
粉丝
关注
专栏作者
  • RV-STAR 开发板
  • RISC-V处理器设计系列课程
  • 培养RISC-V大学土壤 共建RISC-V教育生态
RV-STAR 开发板