denoising-historical-data/train.py

200 lines
6.0 KiB
Python

import os
import hydra
import logging
logger = logging.getLogger(__name__)
def run(args):
import unet
import tensorflow as tf
import dataset_loader
from tensorflow.keras.optimizers import Adam
import datetime
from tqdm import tqdm
path_experiment = str(args.path_experiment)
if not os.path.exists(path_experiment):
os.makedirs(path_experiment)
path_music_train = args.dset.path_music_train
path_music_validation = args.dset.path_music_validation
path_noise = args.dset.path_noise
fs = args.fs
seg_len_s_train = args.seg_len_s_train
batch_size = args.batch_size
epochs = args.epochs
buffer_size = args.buffer_size # for shuffle
tensorboard_logs = args.tensorboard_logs
def do_stft(noisy, clean=None):
window_fn = tf.signal.hamming_window
win_size = args.stft.win_size
hop_size = args.stft.hop_size
stft_signal_noisy = tf.signal.stft(
noisy, frame_length=win_size, window_fn=window_fn, frame_step=hop_size
)
stft_noisy_stacked = tf.stack(
values=[tf.math.real(stft_signal_noisy), tf.math.imag(stft_signal_noisy)],
axis=-1,
)
if clean is not None:
stft_signal_clean = tf.signal.stft(
clean, frame_length=win_size, window_fn=window_fn, frame_step=hop_size
)
stft_clean_stacked = tf.stack(
values=[
tf.math.real(stft_signal_clean),
tf.math.imag(stft_signal_clean),
],
axis=-1,
)
return stft_noisy_stacked, stft_clean_stacked
else:
return stft_noisy_stacked
# Loading data. The train dataset object is a generator. The validation dataset is loaded in memory.
dataset_train, dataset_val = dataset_loader.load_data(
buffer_size,
path_music_train,
path_music_validation,
path_noise,
fs=fs,
seg_len_s=seg_len_s_train,
)
dataset_train = dataset_train.map(
do_stft, num_parallel_calls=args.num_workers, deterministic=None
)
dataset_val = dataset_val.map(
do_stft, num_parallel_calls=args.num_workers, deterministic=None
)
strategy = tf.distribute.MirroredStrategy()
print("Number of devices: {}".format(strategy.num_replicas_in_sync))
with strategy.scope():
# build the model
unet_model = unet.build_model_denoise(unet_args=args.unet)
current_lr = args.lr
optimizer = Adam(learning_rate=current_lr, beta_1=args.beta1, beta_2=args.beta2)
loss = tf.keras.losses.MeanAbsoluteError()
if args.use_tensorboard:
log_dir = os.path.join(
tensorboard_logs,
os.path.basename(path_experiment)
+ "_"
+ datetime.datetime.now().strftime("%Y%m%d-%H%M%S"),
)
train_summary_writer = tf.summary.create_file_writer(log_dir + "/train")
val_summary_writer = tf.summary.create_file_writer(log_dir + "/validation")
# path where the checkpoints will be saved
checkpoint_filepath = os.path.join(path_experiment, "checkpoint")
dataset_train = dataset_train.batch(batch_size)
dataset_val = dataset_val.batch(batch_size)
# prefetching the dataset for better performance
dataset_train = dataset_train.prefetch(batch_size * 20)
dataset_val = dataset_val.prefetch(batch_size * 20)
dataset_train = strategy.experimental_distribute_dataset(dataset_train)
dataset_val = strategy.experimental_distribute_dataset(dataset_val)
iterator = iter(dataset_train)
from trainer import Trainer
trainer = Trainer(unet_model, optimizer, loss, strategy, path_experiment, args)
for epoch in range(epochs):
total_loss = 0
step_loss = 0
for step in tqdm(
range(args.steps_per_epoch), desc="Training epoch " + str(epoch)
):
step_loss = trainer.distributed_training_step(iterator.get_next())
total_loss += step_loss
with train_summary_writer.as_default():
tf.summary.scalar("batch_loss", step_loss, step=step)
tf.summary.scalar(
"batch_mean_absolute_error", trainer.train_mae.result(), step=step
)
train_loss = total_loss / args.steps_per_epoch
for x in tqdm(dataset_val, desc="Validating epoch " + str(epoch)):
trainer.distributed_test_step(x)
template = "Epoch {}, Loss: {}, train_MAE: {}, val_Loss: {}, val_MAE: {}"
print(
template.format(
epoch + 1,
train_loss,
trainer.train_mae.result(),
trainer.val_loss.result(),
trainer.val_mae.result(),
)
)
with train_summary_writer.as_default():
tf.summary.scalar("epoch_loss", train_loss, step=epoch)
tf.summary.scalar(
"epoch_mean_absolute_error", trainer.train_mae.result(), step=epoch
)
with val_summary_writer.as_default():
tf.summary.scalar("epoch_loss", trainer.val_loss.result(), step=epoch)
tf.summary.scalar(
"epoch_mean_absolute_error", trainer.val_mae.result(), step=epoch
)
trainer.train_mae.reset_states()
trainer.val_loss.reset_states()
trainer.val_mae.reset_states()
if (epoch + 1) % 50 == 0:
if args.variable_lr:
current_lr *= 1e-1
trainer.optimizer.lr = current_lr
try:
unet_model.save_weights(checkpoint_filepath)
except Exception:
pass
def _main(args):
global __file__
__file__ = hydra.utils.to_absolute_path(__file__)
run(args)
@hydra.main(config_path="conf/conf.yaml")
def main(args):
try:
_main(args)
except Exception:
logger.exception("Some error happened")
# Hydra intercepts exit code, fixed in beta but I could not get the beta to work
os._exit(1)
if __name__ == "__main__":
main()