172 lines
5.7 KiB
Python
172 lines
5.7 KiB
Python
import os
|
|
import hydra
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
def run(args):
|
|
import unet
|
|
import tensorflow as tf
|
|
import tensorflow_addons as tfa
|
|
import dataset_loader
|
|
from tensorflow.keras.optimizers import Adam
|
|
import soundfile as sf
|
|
import datetime
|
|
from tqdm import tqdm
|
|
import numpy as np
|
|
|
|
path_experiment=str(args.path_experiment)
|
|
|
|
if not os.path.exists(path_experiment):
|
|
os.makedirs(path_experiment)
|
|
|
|
path_music_train=args.dset.path_music_train
|
|
path_music_test=args.dset.path_music_test
|
|
path_music_validation=args.dset.path_music_validation
|
|
|
|
path_noise=args.dset.path_noise
|
|
path_recordings=args.dset.path_recordings
|
|
|
|
fs=args.fs
|
|
overlap=args.overlap
|
|
seg_len_s_train=args.seg_len_s_train
|
|
|
|
batch_size=args.batch_size
|
|
epochs=args.epochs
|
|
|
|
num_real_test_segments=args.num_real_test_segments
|
|
buffer_size=args.buffer_size #for shuffle
|
|
|
|
tensorboard_logs=args.tensorboard_logs
|
|
|
|
def do_stft(noisy, clean=None):
|
|
|
|
window_fn = tf.signal.hamming_window
|
|
|
|
win_size=args.stft.win_size
|
|
hop_size=args.stft.hop_size
|
|
|
|
|
|
stft_signal_noisy=tf.signal.stft(noisy,frame_length=win_size, window_fn=window_fn, frame_step=hop_size)
|
|
stft_noisy_stacked=tf.stack( values=[tf.math.real(stft_signal_noisy), tf.math.imag(stft_signal_noisy)], axis=-1)
|
|
|
|
if clean!=None:
|
|
|
|
stft_signal_clean=tf.signal.stft(clean,frame_length=win_size, window_fn=window_fn, frame_step=hop_size)
|
|
stft_clean_stacked=tf.stack( values=[tf.math.real(stft_signal_clean), tf.math.imag(stft_signal_clean)], axis=-1)
|
|
|
|
|
|
return stft_noisy_stacked, stft_clean_stacked
|
|
else:
|
|
|
|
return stft_noisy_stacked
|
|
|
|
#Loading data. The train dataset object is a generator. The validation dataset is loaded in memory.
|
|
|
|
dataset_train, dataset_val=dataset_loader.load_data(buffer_size, path_music_train, path_music_validation, path_noise, fs=fs, seg_len_s=seg_len_s_train)
|
|
|
|
dataset_train=dataset_train.map(do_stft, num_parallel_calls=args.num_workers, deterministic=None)
|
|
dataset_val=dataset_val.map(do_stft, num_parallel_calls=args.num_workers, deterministic=None)
|
|
|
|
strategy = tf.distribute.MirroredStrategy()
|
|
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
|
|
|
|
with strategy.scope():
|
|
#build the model
|
|
unet_model = unet.build_model_denoise(unet_args=args.unet)
|
|
|
|
current_lr=args.lr
|
|
optimizer = Adam(learning_rate=current_lr, beta_1=args.beta1, beta_2=args.beta2)
|
|
|
|
loss=tf.keras.losses.MeanAbsoluteError()
|
|
|
|
if args.use_tensorboard:
|
|
log_dir = os.path.join(tensorboard_logs, os.path.basename(path_experiment)+"_"+datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
|
|
train_summary_writer = tf.summary.create_file_writer(log_dir+"/train")
|
|
val_summary_writer = tf.summary.create_file_writer(log_dir+"/validation")
|
|
|
|
#path where the checkpoints will be saved
|
|
checkpoint_filepath=os.path.join(path_experiment, 'checkpoint')
|
|
|
|
dataset_train=dataset_train.batch(batch_size)
|
|
dataset_val=dataset_val.batch(batch_size)
|
|
|
|
#prefetching the dataset for better performance
|
|
dataset_train=dataset_train.prefetch(batch_size*20)
|
|
dataset_val=dataset_val.prefetch(batch_size*20)
|
|
|
|
dataset_train=strategy.experimental_distribute_dataset(dataset_train)
|
|
dataset_val=strategy.experimental_distribute_dataset(dataset_val)
|
|
|
|
iterator = iter(dataset_train)
|
|
|
|
from trainer import Trainer
|
|
|
|
trainer=Trainer(unet_model,optimizer,loss,strategy, path_experiment, args)
|
|
|
|
for epoch in range(epochs):
|
|
total_loss=0
|
|
step_loss=0
|
|
for step in tqdm(range(args.steps_per_epoch), desc="Training epoch "+str(epoch)):
|
|
step_loss=trainer.distributed_training_step(iterator.get_next())
|
|
total_loss+=step_loss
|
|
with train_summary_writer.as_default():
|
|
tf.summary.scalar('batch_loss', step_loss, step=step)
|
|
tf.summary.scalar('batch_mean_absolute_error', trainer.train_mae.result(), step=step)
|
|
|
|
train_loss=total_loss/args.steps_per_epoch
|
|
|
|
for x in tqdm(dataset_val, desc="Validating epoch "+str(epoch)):
|
|
trainer.distributed_test_step(x)
|
|
|
|
template = ("Epoch {}, Loss: {}, train_MAE: {}, val_Loss: {}, val_MAE: {}")
|
|
print (template.format(epoch+1, train_loss, trainer.train_mae.result(), trainer.val_loss.result(), trainer.val_mae.result()))
|
|
|
|
with train_summary_writer.as_default():
|
|
tf.summary.scalar('epoch_loss', train_loss, step=epoch)
|
|
tf.summary.scalar('epoch_mean_absolute_error', trainer.train_mae.result(), step=epoch)
|
|
with val_summary_writer.as_default():
|
|
tf.summary.scalar('epoch_loss', trainer.val_loss.result(), step=epoch)
|
|
tf.summary.scalar('epoch_mean_absolute_error', trainer.val_mae.result(), step=epoch)
|
|
|
|
trainer.train_mae.reset_states()
|
|
trainer.val_loss.reset_states()
|
|
trainer.val_mae.reset_states()
|
|
|
|
if (epoch+1) % 50 == 0:
|
|
if args.variable_lr:
|
|
current_lr*=1e-1
|
|
trainer.optimizer.lr=current_lr
|
|
try:
|
|
unet_model.save_weights(checkpoint_filpath)
|
|
except:
|
|
pass
|
|
|
|
def _main(args):
|
|
global __file__
|
|
|
|
__file__ = hydra.utils.to_absolute_path(__file__)
|
|
|
|
run(args)
|
|
|
|
|
|
@hydra.main(config_path="conf/conf.yaml")
|
|
def main(args):
|
|
try:
|
|
_main(args)
|
|
except Exception:
|
|
logger.exception("Some error happened")
|
|
# Hydra intercepts exit code, fixed in beta but I could not get the beta to work
|
|
os._exit(1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|