import os import hydra import logging logger = logging.getLogger(__name__) def run(args): import unet import tensorflow as tf import dataset_loader from tensorflow.keras.optimizers import Adam import datetime from tqdm import tqdm path_experiment = str(args.path_experiment) if not os.path.exists(path_experiment): os.makedirs(path_experiment) path_music_train = args.dset.path_music_train path_music_validation = args.dset.path_music_validation path_noise = args.dset.path_noise fs = args.fs seg_len_s_train = args.seg_len_s_train batch_size = args.batch_size epochs = args.epochs buffer_size = args.buffer_size # for shuffle tensorboard_logs = args.tensorboard_logs def do_stft(noisy, clean=None): window_fn = tf.signal.hamming_window win_size = args.stft.win_size hop_size = args.stft.hop_size stft_signal_noisy = tf.signal.stft( noisy, frame_length=win_size, window_fn=window_fn, frame_step=hop_size ) stft_noisy_stacked = tf.stack( values=[tf.math.real(stft_signal_noisy), tf.math.imag(stft_signal_noisy)], axis=-1, ) if clean is not None: stft_signal_clean = tf.signal.stft( clean, frame_length=win_size, window_fn=window_fn, frame_step=hop_size ) stft_clean_stacked = tf.stack( values=[ tf.math.real(stft_signal_clean), tf.math.imag(stft_signal_clean), ], axis=-1, ) return stft_noisy_stacked, stft_clean_stacked else: return stft_noisy_stacked # Loading data. The train dataset object is a generator. The validation dataset is loaded in memory. dataset_train, dataset_val = dataset_loader.load_data( buffer_size, path_music_train, path_music_validation, path_noise, fs=fs, seg_len_s=seg_len_s_train, ) dataset_train = dataset_train.map( do_stft, num_parallel_calls=args.num_workers, deterministic=None ) dataset_val = dataset_val.map( do_stft, num_parallel_calls=args.num_workers, deterministic=None ) strategy = tf.distribute.MirroredStrategy() print("Number of devices: {}".format(strategy.num_replicas_in_sync)) with strategy.scope(): # build the model unet_model = unet.build_model_denoise(unet_args=args.unet) current_lr = args.lr optimizer = Adam(learning_rate=current_lr, beta_1=args.beta1, beta_2=args.beta2) loss = tf.keras.losses.MeanAbsoluteError() if args.use_tensorboard: log_dir = os.path.join( tensorboard_logs, os.path.basename(path_experiment) + "_" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"), ) train_summary_writer = tf.summary.create_file_writer(log_dir + "/train") val_summary_writer = tf.summary.create_file_writer(log_dir + "/validation") # path where the checkpoints will be saved checkpoint_filepath = os.path.join(path_experiment, "checkpoint") dataset_train = dataset_train.batch(batch_size) dataset_val = dataset_val.batch(batch_size) # prefetching the dataset for better performance dataset_train = dataset_train.prefetch(batch_size * 20) dataset_val = dataset_val.prefetch(batch_size * 20) dataset_train = strategy.experimental_distribute_dataset(dataset_train) dataset_val = strategy.experimental_distribute_dataset(dataset_val) iterator = iter(dataset_train) from trainer import Trainer trainer = Trainer(unet_model, optimizer, loss, strategy, path_experiment, args) for epoch in range(epochs): total_loss = 0 step_loss = 0 for step in tqdm( range(args.steps_per_epoch), desc="Training epoch " + str(epoch) ): step_loss = trainer.distributed_training_step(iterator.get_next()) total_loss += step_loss with train_summary_writer.as_default(): tf.summary.scalar("batch_loss", step_loss, step=step) tf.summary.scalar( "batch_mean_absolute_error", trainer.train_mae.result(), step=step ) train_loss = total_loss / args.steps_per_epoch for x in tqdm(dataset_val, desc="Validating epoch " + str(epoch)): trainer.distributed_test_step(x) template = "Epoch {}, Loss: {}, train_MAE: {}, val_Loss: {}, val_MAE: {}" print( template.format( epoch + 1, train_loss, trainer.train_mae.result(), trainer.val_loss.result(), trainer.val_mae.result(), ) ) with train_summary_writer.as_default(): tf.summary.scalar("epoch_loss", train_loss, step=epoch) tf.summary.scalar( "epoch_mean_absolute_error", trainer.train_mae.result(), step=epoch ) with val_summary_writer.as_default(): tf.summary.scalar("epoch_loss", trainer.val_loss.result(), step=epoch) tf.summary.scalar( "epoch_mean_absolute_error", trainer.val_mae.result(), step=epoch ) trainer.train_mae.reset_states() trainer.val_loss.reset_states() trainer.val_mae.reset_states() if (epoch + 1) % 50 == 0: if args.variable_lr: current_lr *= 1e-1 trainer.optimizer.lr = current_lr try: unet_model.save_weights(checkpoint_filepath) except Exception: pass def _main(args): global __file__ __file__ = hydra.utils.to_absolute_path(__file__) run(args) @hydra.main(config_path="conf/conf.yaml") def main(args): try: _main(args) except Exception: logger.exception("Some error happened") # Hydra intercepts exit code, fixed in beta but I could not get the beta to work os._exit(1) if __name__ == "__main__": main()