denoising-historical-data/test.py
2021-08-30 18:30:51 +03:00

152 lines
6.6 KiB
Python

import os
import hydra
import logging
'''
Script used for the objective experiments
WARNING: it calls MATLAB to calculate PEAQ and PEMO-Q. The whole process may be very slow
'''
logger = logging.getLogger(__name__)
def run(args):
import unet
import dataset_loader
import tensorflow as tf
import pandas as pd
path_experiment=str(args.path_experiment)
print(path_experiment)
if not os.path.exists(path_experiment):
os.makedirs(path_experiment)
unet_model = unet.build_model_denoise(stereo=stereo,unet_args=args.unet)
ckpt=os.path.join(path_experiment, 'checkpoint')
unet_model.load_weights(ckpt)
path_pianos_test=args.dset.path_piano_test
path_strings_test=args.dset.path_strings_test
path_orchestra_test=args.dset.path_orchestra_test
path_opera_test=args.dset.path_opera_test
path_noise=args.dset.path_noise
fs=args.fs
seg_len_s=20
numsamples=1000//seg_len_s
def do_stft(noisy, clean=None):
if args.stft.window=="hamming":
window_fn = tf.signal.hamming_window
elif args.stft.window=="hann":
window_fn=tf.signal.hann_window
elif args.stft.window=="kaiser_bessel":
window_fn=tf.signal.kaiser_bessel_derived_window
win_size=args.stft.win_size
hop_size=args.stft.hop_size
stft_signal_noisy=tf.signal.stft(noisy,frame_length=win_size, window_fn=window_fn, frame_step=hop_size)
stft_noisy_stacked=tf.stack( values=[tf.math.real(stft_signal_noisy), tf.math.imag(stft_signal_noisy)], axis=-1)
if clean!=None:
stft_signal_clean=tf.signal.stft(clean,frame_length=win_size, window_fn=window_fn, frame_step=hop_size)
stft_clean_stacked=tf.stack( values=[tf.math.real(stft_signal_clean), tf.math.imag(stft_signal_clean)], axis=-1)
return stft_noisy_stacked, stft_clean_stacked
else:
return stft_noisy_stacked
from tester import Tester
testPath=os.path.join(path_experiment,"final_test")
if not os.path.exists(testPath):
os.makedirs(testPath)
tester=Tester(unet_model, testPath, args)
PEAQ_dir="/scratch/work/molinee2/unet_dir/unet_historical_music/PQevalAudio"
PEMOQ_dir="/scratch/work/molinee2/unet_dir/unet_historical_music/PEMOQ"
dataset_test_pianos=dataset_loader.load_data_formal( path_pianos_test, path_noise, noise_amount="mid_snr",num_samples=numsamples, fs=fs, seg_len_s=seg_len_s, stereo=stereo)
dataset_test_pianos=dataset_test_pianos.map(do_stft, num_parallel_calls=args.num_workers, deterministic=None)
tester.init_inference(dataset_test_pianos,numsamples,fs,args.stft, PEAQ_dir, PEMOQ_dir=PEMOQ_dir)
metrics=tester.inference("pianos_midsnr")
dataset_test_strings=dataset_loader.load_data_formal( path_strings_test, path_noise,noise_amount="mid_snr",num_samples=numsamples, fs=fs, seg_len_s=seg_len_s, stereo=stereo)
dataset_test_strings=dataset_test_strings.map(do_stft, num_parallel_calls=args.num_workers, deterministic=None)
tester.init_inference(dataset_test_strings,numsamples,fs,args.stft, PEAQ_dir, PEMOQ_dir=PEMOQ_dir)
metrics=tester.inference("strings_midsnr")
dataset_test_orchestra=dataset_loader.load_data_formal( path_orchestra_test, path_noise, noise_amount="mid_snr", num_samples=numsamples, fs=fs, seg_len_s=seg_len_s, stereo=stereo)
dataset_test_orchestra=dataset_test_orchestra.map(do_stft, num_parallel_calls=args.num_workers, deterministic=None)
tester.init_inference(dataset_test_orchestra,numsamples,fs,args.stft, PEAQ_dir, PEMOQ_dir=PEMOQ_dir)
metrics=tester.inference("orchestra_midsnr")
dataset_test_opera=dataset_loader.load_data_formal( path_opera_test, path_noise, noise_amount="mid_snr",num_samples=numsamples, fs=fs, seg_len_s=seg_len_s, stereo=stereo)
dataset_test_opera=dataset_test_opera.map(do_stft, num_parallel_calls=args.num_workers, deterministic=None)
tester.init_inference(dataset_test_opera,numsamples,fs,args.stft, PEAQ_dir, PEMOQ_dir=PEMOQ_dir)
metrics=tester.inference("opera_midsnr")
dataset_test_strings=dataset_loader.load_data_formal( path_strings_test, path_noise,noise_amount="low_snr",num_samples=numsamples, fs=fs, seg_len_s=seg_len_s, stereo=stereo)
dataset_test_strings=dataset_test_strings.map(do_stft, num_parallel_calls=args.num_workers, deterministic=None)
tester.init_inference(dataset_test_strings,numsamples,fs,args.stft, PEAQ_dir, PEMOQ_dir=PEMOQ_dir)
metrics=tester.inference("strings_lowsnr")
dataset_test_orchestra=dataset_loader.load_data_formal( path_orchestra_test, path_noise,noise_amount="low_snr", num_samples=numsamples, fs=fs, seg_len_s=seg_len_s, stereo=stereo)
dataset_test_orchestra=dataset_test_orchestra.map(do_stft, num_parallel_calls=args.num_workers, deterministic=None)
tester.init_inference(dataset_test_orchestra,numsamples,fs,args.stft, PEAQ_dir, PEMOQ_dir=PEMOQ_dir)
metrics=tester.inference("orchestra_lowsnr")
dataset_test_opera=dataset_loader.load_data_formal( path_opera_test, path_noise, noise_amount="low_snr",num_samples=numsamples, fs=fs, seg_len_s=seg_len_s, stereo=stereo)
dataset_test_opera=dataset_test_opera.map(do_stft, num_parallel_calls=args.num_workers, deterministic=None)
tester.init_inference(dataset_test_opera,numsamples,fs,args.stft, PEAQ_dir, PEMOQ_dir=PEMOQ_dir)
metrics=tester.inference("opera_lowsnr")
dataset_test_pianos=dataset_loader.load_data_formal( path_pianos_test, path_noise, noise_amount="low_snr",num_samples=numsamples, fs=fs, seg_len_s=seg_len_s, stereo=stereo)
dataset_test_pianos=dataset_test_pianos.map(do_stft, num_parallel_calls=args.num_workers, deterministic=None)
tester.init_inference(dataset_test_pianos,numsamples,fs,args.stft, PEAQ_dir, PEMOQ_dir=PEMOQ_dir)
metrics=tester.inference("pianos_lowsnr")
names=["strings_midsnr","strings_lowsnr","opera_midsnr","opera_lowsnr","pianos_midsnr","pianos_lowsnr","orchestra_midsnr","orchestra_lowsnr"]
for n in names:
a=pd.read_csv(os.path.join(testPath,n,"metrics.csv"))
meanPEAQ=a["PEAQ(ODG)_diff"].sum()/50
meanPEMOQ=a["PEMOQ(ODG)_diff"].sum()/50
meanSDR=a["SDR_diff"].sum()/50
print(n,": PEAQ ",str(meanPEAQ), "PEMOQ ", str(meanPEMOQ), "SDR ", str(meanSDR))
def _main(args):
global __file__
__file__ = hydra.utils.to_absolute_path(__file__)
run(args)
@hydra.main(config_path="conf/conf.yaml")
def main(args):
try:
_main(args)
except Exception:
logger.exception("Some error happened")
# Hydra intercepts exit code, fixed in beta but I could not get the beta to work
os._exit(1)
if __name__ == "__main__":
main()