zephyr/samples/modules/tflite-micro/magic_wand/train/train.py
Jamie McCrae ec7044437e treewide: Disable automatic argparse argument shortening
Disables allowing the python argparse library from automatically
shortening command line arguments, this prevents issues whereby
a new command is added and code that wrongly uses the shortened
command of an existing argument which is the same as the new
command being added will silently change script behaviour.

Signed-off-by: Jamie McCrae <jamie.mccrae@nordicsemi.no>
2023-01-26 20:12:36 +09:00

202 lines
7.4 KiB
Python

# Lint as: python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=g-bad-import-order
"""Build and train neural networks."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import datetime
import os # pylint: disable=duplicate-code
from data_load import DataLoader
import numpy as np # pylint: disable=duplicate-code
import tensorflow as tf
logdir = "logs/scalars/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir)
def reshape_function(data, label):
reshaped_data = tf.reshape(data, [-1, 3, 1])
return reshaped_data, label
def calculate_model_size(model):
print(model.summary())
var_sizes = [
np.product(list(map(int, v.shape))) * v.dtype.size
for v in model.trainable_variables
]
print("Model size:", sum(var_sizes) / 1024, "KB")
def build_cnn(seq_length):
"""Builds a convolutional neural network in Keras."""
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(
8, (4, 3),
padding="same",
activation="relu",
input_shape=(seq_length, 3, 1)), # output_shape=(batch, 128, 3, 8)
tf.keras.layers.MaxPool2D((3, 3)), # (batch, 42, 1, 8)
tf.keras.layers.Dropout(0.1), # (batch, 42, 1, 8)
tf.keras.layers.Conv2D(16, (4, 1), padding="same",
activation="relu"), # (batch, 42, 1, 16)
tf.keras.layers.MaxPool2D((3, 1), padding="same"), # (batch, 14, 1, 16)
tf.keras.layers.Dropout(0.1), # (batch, 14, 1, 16)
tf.keras.layers.Flatten(), # (batch, 224)
tf.keras.layers.Dense(16, activation="relu"), # (batch, 16)
tf.keras.layers.Dropout(0.1), # (batch, 16)
tf.keras.layers.Dense(4, activation="softmax") # (batch, 4)
])
model_path = os.path.join("./netmodels", "CNN")
print("Built CNN.")
if not os.path.exists(model_path):
os.makedirs(model_path)
model.load_weights("./netmodels/CNN/weights.h5")
return model, model_path
def build_lstm(seq_length):
"""Builds an LSTM in Keras."""
model = tf.keras.Sequential([
tf.keras.layers.Bidirectional(
tf.keras.layers.LSTM(22),
input_shape=(seq_length, 3)), # output_shape=(batch, 44)
tf.keras.layers.Dense(4, activation="sigmoid") # (batch, 4)
])
model_path = os.path.join("./netmodels", "LSTM")
print("Built LSTM.")
if not os.path.exists(model_path):
os.makedirs(model_path)
return model, model_path
def load_data(train_data_path, valid_data_path, test_data_path, seq_length):
data_loader = DataLoader(
train_data_path, valid_data_path, test_data_path, seq_length=seq_length)
data_loader.format()
return data_loader.train_len, data_loader.train_data, data_loader.valid_len, \
data_loader.valid_data, data_loader.test_len, data_loader.test_data
def build_net(args, seq_length):
if args.model == "CNN":
model, model_path = build_cnn(seq_length)
elif args.model == "LSTM":
model, model_path = build_lstm(seq_length)
else:
print("Please input correct model name.(CNN LSTM)")
return model, model_path
def train_net(
model,
model_path, # pylint: disable=unused-argument
train_len, # pylint: disable=unused-argument
train_data,
valid_len,
valid_data,
test_len,
test_data,
kind):
"""Trains the model."""
calculate_model_size(model)
epochs = 50
batch_size = 64
model.compile(
optimizer="adam",
loss="sparse_categorical_crossentropy",
metrics=["accuracy"])
if kind == "CNN":
train_data = train_data.map(reshape_function)
test_data = test_data.map(reshape_function)
valid_data = valid_data.map(reshape_function)
test_labels = np.zeros(test_len)
idx = 0
for data, label in test_data: # pylint: disable=unused-variable
test_labels[idx] = label.numpy()
idx += 1
train_data = train_data.batch(batch_size).repeat()
valid_data = valid_data.batch(batch_size)
test_data = test_data.batch(batch_size)
model.fit(
train_data,
epochs=epochs,
validation_data=valid_data,
steps_per_epoch=1000,
validation_steps=int((valid_len - 1) / batch_size + 1),
callbacks=[tensorboard_callback])
loss, acc = model.evaluate(test_data)
pred = np.argmax(model.predict(test_data), axis=1)
confusion = tf.math.confusion_matrix(
labels=tf.constant(test_labels),
predictions=tf.constant(pred),
num_classes=4)
print(confusion)
print("Loss {}, Accuracy {}".format(loss, acc))
# Convert the model to the TensorFlow Lite format without quantization
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
# Save the model to disk
open("model.tflite", "wb").write(tflite_model)
# Convert the model to the TensorFlow Lite format with quantization
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
tflite_model = converter.convert()
# Save the model to disk
open("model_quantized.tflite", "wb").write(tflite_model)
basic_model_size = os.path.getsize("model.tflite")
print("Basic model is %d bytes" % basic_model_size)
quantized_model_size = os.path.getsize("model_quantized.tflite")
print("Quantized model is %d bytes" % quantized_model_size)
difference = basic_model_size - quantized_model_size
print("Difference is %d bytes" % difference)
if __name__ == "__main__":
parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument("--model", "-m")
parser.add_argument("--person", "-p")
args = parser.parse_args()
seq_length = 128
print("Start to load data...")
if args.person == "true":
train_len, train_data, valid_len, valid_data, test_len, test_data = \
load_data("./person_split/train", "./person_split/valid",
"./person_split/test", seq_length)
else:
train_len, train_data, valid_len, valid_data, test_len, test_data = \
load_data("./data/train", "./data/valid", "./data/test", seq_length)
print("Start to build net...")
model, model_path = build_net(args, seq_length)
print("Start training...")
train_net(model, model_path, train_len, train_data, valid_len, valid_data,
test_len, test_data, args.model)
print("Training finished!")