Disables allowing the python argparse library from automatically shortening command line arguments, this prevents issues whereby a new command is added and code that wrongly uses the shortened command of an existing argument which is the same as the new command being added will silently change script behaviour. Signed-off-by: Jamie McCrae <jamie.mccrae@nordicsemi.no>
202 lines
7.4 KiB
Python
202 lines
7.4 KiB
Python
# Lint as: python3
|
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
# pylint: disable=g-bad-import-order
|
|
|
|
"""Build and train neural networks."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import argparse
|
|
import datetime
|
|
import os # pylint: disable=duplicate-code
|
|
from data_load import DataLoader
|
|
|
|
import numpy as np # pylint: disable=duplicate-code
|
|
import tensorflow as tf
|
|
|
|
logdir = "logs/scalars/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
|
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir)
|
|
|
|
|
|
def reshape_function(data, label):
|
|
reshaped_data = tf.reshape(data, [-1, 3, 1])
|
|
return reshaped_data, label
|
|
|
|
|
|
def calculate_model_size(model):
|
|
print(model.summary())
|
|
var_sizes = [
|
|
np.product(list(map(int, v.shape))) * v.dtype.size
|
|
for v in model.trainable_variables
|
|
]
|
|
print("Model size:", sum(var_sizes) / 1024, "KB")
|
|
|
|
|
|
def build_cnn(seq_length):
|
|
"""Builds a convolutional neural network in Keras."""
|
|
model = tf.keras.Sequential([
|
|
tf.keras.layers.Conv2D(
|
|
8, (4, 3),
|
|
padding="same",
|
|
activation="relu",
|
|
input_shape=(seq_length, 3, 1)), # output_shape=(batch, 128, 3, 8)
|
|
tf.keras.layers.MaxPool2D((3, 3)), # (batch, 42, 1, 8)
|
|
tf.keras.layers.Dropout(0.1), # (batch, 42, 1, 8)
|
|
tf.keras.layers.Conv2D(16, (4, 1), padding="same",
|
|
activation="relu"), # (batch, 42, 1, 16)
|
|
tf.keras.layers.MaxPool2D((3, 1), padding="same"), # (batch, 14, 1, 16)
|
|
tf.keras.layers.Dropout(0.1), # (batch, 14, 1, 16)
|
|
tf.keras.layers.Flatten(), # (batch, 224)
|
|
tf.keras.layers.Dense(16, activation="relu"), # (batch, 16)
|
|
tf.keras.layers.Dropout(0.1), # (batch, 16)
|
|
tf.keras.layers.Dense(4, activation="softmax") # (batch, 4)
|
|
])
|
|
model_path = os.path.join("./netmodels", "CNN")
|
|
print("Built CNN.")
|
|
if not os.path.exists(model_path):
|
|
os.makedirs(model_path)
|
|
model.load_weights("./netmodels/CNN/weights.h5")
|
|
return model, model_path
|
|
|
|
|
|
def build_lstm(seq_length):
|
|
"""Builds an LSTM in Keras."""
|
|
model = tf.keras.Sequential([
|
|
tf.keras.layers.Bidirectional(
|
|
tf.keras.layers.LSTM(22),
|
|
input_shape=(seq_length, 3)), # output_shape=(batch, 44)
|
|
tf.keras.layers.Dense(4, activation="sigmoid") # (batch, 4)
|
|
])
|
|
model_path = os.path.join("./netmodels", "LSTM")
|
|
print("Built LSTM.")
|
|
if not os.path.exists(model_path):
|
|
os.makedirs(model_path)
|
|
return model, model_path
|
|
|
|
|
|
def load_data(train_data_path, valid_data_path, test_data_path, seq_length):
|
|
data_loader = DataLoader(
|
|
train_data_path, valid_data_path, test_data_path, seq_length=seq_length)
|
|
data_loader.format()
|
|
return data_loader.train_len, data_loader.train_data, data_loader.valid_len, \
|
|
data_loader.valid_data, data_loader.test_len, data_loader.test_data
|
|
|
|
|
|
def build_net(args, seq_length):
|
|
if args.model == "CNN":
|
|
model, model_path = build_cnn(seq_length)
|
|
elif args.model == "LSTM":
|
|
model, model_path = build_lstm(seq_length)
|
|
else:
|
|
print("Please input correct model name.(CNN LSTM)")
|
|
return model, model_path
|
|
|
|
|
|
def train_net(
|
|
model,
|
|
model_path, # pylint: disable=unused-argument
|
|
train_len, # pylint: disable=unused-argument
|
|
train_data,
|
|
valid_len,
|
|
valid_data,
|
|
test_len,
|
|
test_data,
|
|
kind):
|
|
"""Trains the model."""
|
|
calculate_model_size(model)
|
|
epochs = 50
|
|
batch_size = 64
|
|
model.compile(
|
|
optimizer="adam",
|
|
loss="sparse_categorical_crossentropy",
|
|
metrics=["accuracy"])
|
|
if kind == "CNN":
|
|
train_data = train_data.map(reshape_function)
|
|
test_data = test_data.map(reshape_function)
|
|
valid_data = valid_data.map(reshape_function)
|
|
test_labels = np.zeros(test_len)
|
|
idx = 0
|
|
for data, label in test_data: # pylint: disable=unused-variable
|
|
test_labels[idx] = label.numpy()
|
|
idx += 1
|
|
train_data = train_data.batch(batch_size).repeat()
|
|
valid_data = valid_data.batch(batch_size)
|
|
test_data = test_data.batch(batch_size)
|
|
model.fit(
|
|
train_data,
|
|
epochs=epochs,
|
|
validation_data=valid_data,
|
|
steps_per_epoch=1000,
|
|
validation_steps=int((valid_len - 1) / batch_size + 1),
|
|
callbacks=[tensorboard_callback])
|
|
loss, acc = model.evaluate(test_data)
|
|
pred = np.argmax(model.predict(test_data), axis=1)
|
|
confusion = tf.math.confusion_matrix(
|
|
labels=tf.constant(test_labels),
|
|
predictions=tf.constant(pred),
|
|
num_classes=4)
|
|
print(confusion)
|
|
print("Loss {}, Accuracy {}".format(loss, acc))
|
|
# Convert the model to the TensorFlow Lite format without quantization
|
|
converter = tf.lite.TFLiteConverter.from_keras_model(model)
|
|
tflite_model = converter.convert()
|
|
|
|
# Save the model to disk
|
|
open("model.tflite", "wb").write(tflite_model)
|
|
|
|
# Convert the model to the TensorFlow Lite format with quantization
|
|
converter = tf.lite.TFLiteConverter.from_keras_model(model)
|
|
converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
|
|
tflite_model = converter.convert()
|
|
|
|
# Save the model to disk
|
|
open("model_quantized.tflite", "wb").write(tflite_model)
|
|
|
|
basic_model_size = os.path.getsize("model.tflite")
|
|
print("Basic model is %d bytes" % basic_model_size)
|
|
quantized_model_size = os.path.getsize("model_quantized.tflite")
|
|
print("Quantized model is %d bytes" % quantized_model_size)
|
|
difference = basic_model_size - quantized_model_size
|
|
print("Difference is %d bytes" % difference)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(allow_abbrev=False)
|
|
parser.add_argument("--model", "-m")
|
|
parser.add_argument("--person", "-p")
|
|
args = parser.parse_args()
|
|
|
|
seq_length = 128
|
|
|
|
print("Start to load data...")
|
|
if args.person == "true":
|
|
train_len, train_data, valid_len, valid_data, test_len, test_data = \
|
|
load_data("./person_split/train", "./person_split/valid",
|
|
"./person_split/test", seq_length)
|
|
else:
|
|
train_len, train_data, valid_len, valid_data, test_len, test_data = \
|
|
load_data("./data/train", "./data/valid", "./data/test", seq_length)
|
|
|
|
print("Start to build net...")
|
|
model, model_path = build_net(args, seq_length)
|
|
|
|
print("Start training...")
|
|
train_net(model, model_path, train_len, train_data, valid_len, valid_data,
|
|
test_len, test_data, args.model)
|
|
|
|
print("Training finished!")
|