startin again
This commit is contained in:
parent
fd35cee560
commit
c4cbdd2b8a
497
dataset_loader.py
Normal file
497
dataset_loader.py
Normal file
@ -0,0 +1,497 @@
|
||||
from typing import Tuple, Dict
|
||||
import ast
|
||||
|
||||
import tensorflow as tf
|
||||
import random
|
||||
import os
|
||||
import numpy as np
|
||||
from scipy.fft import fft, ifft
|
||||
import soundfile as sf
|
||||
import librosa
|
||||
import math
|
||||
import pandas as pd
|
||||
import scipy as sp
|
||||
import glob
|
||||
from tqdm import tqdm
|
||||
|
||||
#generator function. It reads the csv file with pandas and loads the largest audio segments from each recording. If extend=False, it will only read the segments with length>length_seg, trim them and yield them with no further processing. Otherwise, if the segment length is inferior, it will extend the length using concatenative synthesis.
|
||||
def __noise_sample_generator(info_file,fs, length_seq, split):
|
||||
head=os.path.split(info_file)[0]
|
||||
load_data=pd.read_csv(info_file)
|
||||
#split= train, validation, test
|
||||
load_data_split=load_data.loc[load_data["split"]==split]
|
||||
load_data_split=load_data_split.reset_index(drop=True)
|
||||
while True:
|
||||
r = list(range(len(load_data_split)))
|
||||
if split!="test":
|
||||
random.shuffle(r)
|
||||
for i in r:
|
||||
segments=ast.literal_eval(load_data_split.loc[i,"segments"])
|
||||
if split=="test":
|
||||
loaded_data, Fs=sf.read(os.path.join(head,load_data_split["recording"].loc[i],load_data_split["largest_segment"].loc[i]))
|
||||
else:
|
||||
num=np.random.randint(0,len(segments))
|
||||
loaded_data, Fs=sf.read(os.path.join(head,load_data_split["recording"].loc[i],segments[num]))
|
||||
|
||||
if fs!=Fs:
|
||||
print("wrong fs, resampling...")
|
||||
data=librosa.resample(loaded_data, Fs, fs)
|
||||
|
||||
yield __extend_sample_by_repeating(loaded_data,fs,length_seq)
|
||||
|
||||
def __extend_sample_by_repeating(data, fs,seq_len):
|
||||
rpm=78
|
||||
target_samp=seq_len
|
||||
large_data=np.zeros(shape=(target_samp,2))
|
||||
|
||||
if len(data)>=target_samp:
|
||||
large_data=data[0:target_samp]
|
||||
return large_data
|
||||
|
||||
bls=(1000*44100)/1000 #hardcoded
|
||||
|
||||
window=np.stack((np.hanning(bls) ,np.hanning(bls)), axis=1)
|
||||
window_left=window[0:int(bls/2),:]
|
||||
window_right=window[int(bls/2)::,:]
|
||||
bls=int(bls/2)
|
||||
|
||||
rps=rpm/60
|
||||
period=1/rps
|
||||
|
||||
period_sam=int(period*fs)
|
||||
|
||||
overhead=len(data)%period_sam
|
||||
|
||||
if(overhead>bls):
|
||||
complete_periods=(len(data)//period_sam)*period_sam
|
||||
else:
|
||||
complete_periods=(len(data)//period_sam -1)*period_sam
|
||||
|
||||
|
||||
a=np.multiply(data[0:bls], window_left)
|
||||
b=np.multiply(data[complete_periods:complete_periods+bls], window_right)
|
||||
c_1=np.concatenate((data[0:complete_periods,:],b))
|
||||
c_2=np.concatenate((a,data[bls:complete_periods,:],b))
|
||||
c_3=np.concatenate((a,data[bls::,:]))
|
||||
|
||||
large_data[0:complete_periods+bls,:]=c_1
|
||||
|
||||
|
||||
pointer=complete_periods
|
||||
not_finished=True
|
||||
while (not_finished):
|
||||
if target_samp>pointer+complete_periods+bls:
|
||||
large_data[pointer:pointer+complete_periods+bls] +=c_2
|
||||
pointer+=complete_periods
|
||||
else:
|
||||
large_data[pointer::]+=c_3[0:(target_samp-pointer)]
|
||||
#finish
|
||||
not_finished=False
|
||||
|
||||
return large_data
|
||||
|
||||
|
||||
def generate_real_recordings_data(path_recordings, fs=44100, seg_len_s=15, stereo=False):
|
||||
|
||||
records_info=os.path.join(path_recordings,"audio_files.txt")
|
||||
num_lines = sum(1 for line in open(records_info))
|
||||
f = open(records_info,"r")
|
||||
#load data record files
|
||||
print("Loading record files")
|
||||
records=[]
|
||||
seg_len=fs*seg_len_s
|
||||
pointer=int(fs*5) #starting at second 5 by default
|
||||
for i in tqdm(range(num_lines)):
|
||||
audio=f.readline()
|
||||
audio=audio[:-1]
|
||||
data, fs=sf.read(os.path.join(path_recordings,audio))
|
||||
if len(data.shape)>1 and not(stereo):
|
||||
data=np.mean(data,axis=1)
|
||||
#elif stereo and len(data.shape)==1:
|
||||
# data=np.stack((data, data), axis=1)
|
||||
|
||||
#normalize
|
||||
data=data/np.max(np.abs(data))
|
||||
segment=data[pointer:pointer+seg_len]
|
||||
records.append(segment.astype("float32"))
|
||||
|
||||
return records
|
||||
|
||||
def generate_paired_data_test_formal(path_pianos, path_noises, noise_amount="low_snr",num_samples=-1, fs=44100, seg_len_s=5 , extend=True, stereo=False, prenoise=False):
|
||||
|
||||
print(num_samples)
|
||||
segments_clean=[]
|
||||
segments_noisy=[]
|
||||
seg_len=fs*seg_len_s
|
||||
noises_info=os.path.join(path_noises,"info.csv")
|
||||
np.random.seed(42)
|
||||
if noise_amount=="low_snr":
|
||||
SNRs=np.random.uniform(2,6,num_samples)
|
||||
elif noise_amount=="mid_snr":
|
||||
SNRs=np.random.uniform(6,12,num_samples)
|
||||
|
||||
scales=np.random.uniform(-4,0,num_samples)
|
||||
#SNRs=[2,6,12] #HARDCODED!!!!
|
||||
i=0
|
||||
print(path_pianos[0])
|
||||
print(seg_len)
|
||||
train_samples=glob.glob(os.path.join(path_pianos[0],"*.wav"))
|
||||
train_samples=sorted(train_samples)
|
||||
|
||||
if prenoise:
|
||||
noise_generator=__noise_sample_generator(noises_info,fs, seg_len+fs, extend, "test") #Adds 1s of silence add the begiing, longer noise
|
||||
else:
|
||||
noise_generator=__noise_sample_generator(noises_info,fs, seg_len, extend, "test") #this will take care of everything
|
||||
#load data clean files
|
||||
for file in tqdm(train_samples): #add [1:5] for testing
|
||||
data_clean, samplerate = sf.read(file)
|
||||
if samplerate!=fs:
|
||||
print("!!!!WRONG SAMPLE RATe!!!")
|
||||
#Stereo to mono
|
||||
if len(data_clean.shape)>1 and not(stereo):
|
||||
data_clean=np.mean(data_clean,axis=1)
|
||||
#elif stereo and len(data_clean.shape)==1:
|
||||
# data_clean=np.stack((data_clean, data_clean), axis=1)
|
||||
#normalize
|
||||
data_clean=data_clean/np.max(np.abs(data_clean))
|
||||
#data_clean_loaded.append(data_clean)
|
||||
|
||||
#framify data clean files
|
||||
|
||||
#framify arguments: seg_len, hop_size
|
||||
hop_size=int(seg_len)# no overlap
|
||||
|
||||
num_frames=np.floor(len(data_clean)/hop_size - seg_len/hop_size +1)
|
||||
print(num_frames)
|
||||
if num_frames==0:
|
||||
data_clean=np.concatenate((data_clean, np.zeros(shape=(int(2*seg_len-len(data_clean)),))), axis=0)
|
||||
num_frames=1
|
||||
|
||||
data_not_finished=True
|
||||
pointer=0
|
||||
while(data_not_finished):
|
||||
if i>=num_samples:
|
||||
break
|
||||
segment=data_clean[pointer:pointer+seg_len]
|
||||
pointer=pointer+hop_size
|
||||
if pointer+seg_len>len(data_clean):
|
||||
data_not_finished=False
|
||||
segment=segment.astype('float32')
|
||||
|
||||
#SNRs=np.random.uniform(2,20)
|
||||
snr=SNRs[i]
|
||||
scale=scales[i]
|
||||
#load noise signal
|
||||
data_noise= next(noise_generator)
|
||||
data_noise=np.mean(data_noise,axis=1)
|
||||
#normalize
|
||||
data_noise=data_noise/np.max(np.abs(data_noise))
|
||||
new_noise=data_noise #if more processing needed, add here
|
||||
#load clean data
|
||||
#configure sizes
|
||||
power_clean=np.var(segment)
|
||||
#estimate noise power
|
||||
if prenoise:
|
||||
power_noise=np.var(new_noise[fs::])
|
||||
else:
|
||||
power_noise=np.var(new_noise)
|
||||
|
||||
snr = 10.0**(snr/10.0)
|
||||
|
||||
#sum both signals according to snr
|
||||
if prenoise:
|
||||
segment=np.concatenate((np.zeros(shape=(fs,)),segment),axis=0) #add one second of silence
|
||||
summed=segment+np.sqrt(power_clean/(snr*power_noise))*new_noise #not sure if this is correct, maybe revisit later!!
|
||||
|
||||
summed=summed.astype('float32')
|
||||
#yield tf.convert_to_tensor(summed), tf.convert_to_tensor(segment)
|
||||
|
||||
|
||||
summed=10.0**(scale/10.0) *summed
|
||||
segment=10.0**(scale/10.0) *segment
|
||||
segments_noisy.append(summed.astype('float32'))
|
||||
segments_clean.append(segment.astype('float32'))
|
||||
i=i+1
|
||||
|
||||
return segments_noisy, segments_clean
|
||||
|
||||
def generate_test_data(path_music, path_noises,num_samples=-1, fs=44100, seg_len_s=5):
|
||||
|
||||
segments_clean=[]
|
||||
segments_noisy=[]
|
||||
seg_len=fs*seg_len_s
|
||||
noises_info=os.path.join(path_noises,"info.csv")
|
||||
SNRs=[2,6,12] #HARDCODED!!!!
|
||||
for path in path_music:
|
||||
print(path)
|
||||
train_samples=glob.glob(os.path.join(path,"*.wav"))
|
||||
train_samples=sorted(train_samples)
|
||||
|
||||
noise_generator=__noise_sample_generator(noises_info,fs, seg_len, "test") #this will take care of everything
|
||||
#load data clean files
|
||||
jj=0
|
||||
for file in tqdm(train_samples): #add [1:5] for testing
|
||||
data_clean, samplerate = sf.read(file)
|
||||
if samplerate!=fs:
|
||||
print("!!!!WRONG SAMPLE RATe!!!")
|
||||
#Stereo to mono
|
||||
if len(data_clean.shape)>1:
|
||||
data_clean=np.mean(data_clean,axis=1)
|
||||
#normalize
|
||||
data_clean=data_clean/np.max(np.abs(data_clean))
|
||||
#data_clean_loaded.append(data_clean)
|
||||
|
||||
#framify data clean files
|
||||
|
||||
#framify arguments: seg_len, hop_size
|
||||
hop_size=int(seg_len)# no overlap
|
||||
|
||||
num_frames=np.floor(len(data_clean)/hop_size - seg_len/hop_size +1)
|
||||
if num_frames==0:
|
||||
data_clean=np.concatenate((data_clean, np.zeros(shape=(int(2*seg_len-len(data_clean)),))), axis=0)
|
||||
num_frames=1
|
||||
|
||||
pointer=0
|
||||
segment=data_clean[pointer:pointer+(seg_len-2*fs)]
|
||||
segment=segment.astype('float32')
|
||||
segment=np.concatenate(( np.zeros(shape=(2*fs,)), segment), axis=0) #I hope its ok
|
||||
#segments_clean.append(segment)
|
||||
|
||||
for snr in SNRs:
|
||||
#load noise signal
|
||||
data_noise= next(noise_generator)
|
||||
data_noise=np.mean(data_noise,axis=1)
|
||||
#normalize
|
||||
data_noise=data_noise/np.max(np.abs(data_noise))
|
||||
new_noise=data_noise #if more processing needed, add here
|
||||
#load clean data
|
||||
#configure sizes
|
||||
#estimate clean signal power
|
||||
power_clean=np.var(segment)
|
||||
#estimate noise power
|
||||
power_noise=np.var(new_noise)
|
||||
|
||||
snr = 10.0**(snr/10.0)
|
||||
|
||||
#sum both signals according to snr
|
||||
summed=segment+np.sqrt(power_clean/(snr*power_noise))*new_noise #not sure if this is correct, maybe revisit later!!
|
||||
summed=summed.astype('float32')
|
||||
#yield tf.convert_to_tensor(summed), tf.convert_to_tensor(segment)
|
||||
|
||||
segments_noisy.append(summed.astype('float32'))
|
||||
segments_clean.append(segment.astype('float32'))
|
||||
|
||||
return segments_noisy, segments_clean
|
||||
|
||||
def generate_val_data(path_music, path_noises,split,num_samples=-1, fs=44100, seg_len_s=5):
|
||||
|
||||
val_samples=[]
|
||||
for path in path_music:
|
||||
val_samples.extend(glob.glob(os.path.join(path,"*.wav")))
|
||||
|
||||
#load data clean files
|
||||
print("Loading clean files")
|
||||
data_clean_loaded=[]
|
||||
for ff in tqdm(range(0,len(val_samples))): #add [1:5] for testing
|
||||
data_clean, samplerate = sf.read(val_samples[ff])
|
||||
if samplerate!=fs:
|
||||
print("!!!!WRONG SAMPLE RATe!!!")
|
||||
#Stereo to mono
|
||||
if len(data_clean.shape)>1 :
|
||||
data_clean=np.mean(data_clean,axis=1)
|
||||
#normalize
|
||||
data_clean=data_clean/np.max(np.abs(data_clean))
|
||||
data_clean_loaded.append(data_clean)
|
||||
del data_clean
|
||||
|
||||
#framify data clean files
|
||||
print("Framifying clean files")
|
||||
seg_len=fs*seg_len_s
|
||||
segments_clean=[]
|
||||
for file in tqdm(data_clean_loaded):
|
||||
|
||||
#framify arguments: seg_len, hop_size
|
||||
hop_size=int(seg_len)# no overlap
|
||||
|
||||
num_frames=np.floor(len(file)/hop_size - seg_len/hop_size +1)
|
||||
pointer=0
|
||||
for i in range(0,int(num_frames)):
|
||||
segment=file[pointer:pointer+seg_len]
|
||||
pointer=pointer+hop_size
|
||||
segment=segment.astype('float32')
|
||||
segments_clean.append(segment)
|
||||
|
||||
del data_clean_loaded
|
||||
|
||||
SNRs=np.random.uniform(2,20,len(segments_clean))
|
||||
scales=np.random.uniform(-6,4,len(segments_clean))
|
||||
#noise_shapes=np.random.randint(0,len(noise_samples), len(segments_clean))
|
||||
noises_info=os.path.join(path_noises,"info.csv")
|
||||
|
||||
noise_generator=__noise_sample_generator(noises_info,fs, seg_len, split) #this will take care of everything
|
||||
|
||||
|
||||
#generate noisy segments
|
||||
#load noise samples using pandas dataframe. Each split (train, val, test) should have its unique csv info file
|
||||
|
||||
#noise_samples=glob.glob(os.path.join(path_noises,"*.wav"))
|
||||
segments_noisy=[]
|
||||
print("Processing noisy segments")
|
||||
|
||||
for i in tqdm(range(0,len(segments_clean))):
|
||||
#load noise signal
|
||||
data_noise= next(noise_generator)
|
||||
#Stereo to mono
|
||||
data_noise=np.mean(data_noise,axis=1)
|
||||
#normalize
|
||||
data_noise=data_noise/np.max(np.abs(data_noise))
|
||||
new_noise=data_noise #if more processing needed, add here
|
||||
#load clean data
|
||||
data_clean=segments_clean[i]
|
||||
#configure sizes
|
||||
|
||||
|
||||
#estimate clean signal power
|
||||
power_clean=np.var(data_clean)
|
||||
#estimate noise power
|
||||
power_noise=np.var(new_noise)
|
||||
|
||||
snr = 10.0**(SNRs[i]/10.0)
|
||||
|
||||
#sum both signals according to snr
|
||||
summed=data_clean+np.sqrt(power_clean/(snr*power_noise))*new_noise #not sure if this is correct, maybe revisit later!!
|
||||
#the rest is normal
|
||||
|
||||
summed=10.0**(scales[i]/10.0) *summed
|
||||
segments_clean[i]=10.0**(scales[i]/10.0) *segments_clean[i]
|
||||
|
||||
segments_noisy.append(summed.astype('float32'))
|
||||
|
||||
return segments_noisy, segments_clean
|
||||
|
||||
|
||||
|
||||
def generator_train(path_music, path_noises,split, fs=44100, seg_len_s=5, extend=True, stereo=False):
|
||||
|
||||
train_samples=[]
|
||||
for path in path_music:
|
||||
train_samples.extend(glob.glob(os.path.join(path.decode("utf-8") ,"*.wav")))
|
||||
|
||||
seg_len=fs*seg_len_s
|
||||
noises_info=os.path.join(path_noises.decode("utf-8"),"info.csv")
|
||||
noise_generator=__noise_sample_generator(noises_info,fs, seg_len, split.decode("utf-8")) #this will take care of everything
|
||||
#load data clean files
|
||||
while True:
|
||||
random.shuffle(train_samples)
|
||||
for file in train_samples:
|
||||
data, samplerate = sf.read(file)
|
||||
if samplerate!=fs:
|
||||
print("!!!!WRONG SAMPLE RATe!!!")
|
||||
data=np.transpose(data)
|
||||
data=librosa.resample(data, samplerate, 44100)
|
||||
data=np.transpose(data)
|
||||
data_clean=data
|
||||
#Stereo to mono
|
||||
if len(data.shape)>1 :
|
||||
data_clean=np.mean(data_clean,axis=1)
|
||||
|
||||
#normalize
|
||||
data_clean=data_clean/np.max(np.abs(data_clean))
|
||||
|
||||
#framify data clean files
|
||||
|
||||
#framify arguments: seg_len, hop_size
|
||||
hop_size=int(seg_len)
|
||||
|
||||
num_frames=np.floor(len(data_clean)/seg_len)
|
||||
if num_frames==0:
|
||||
data_clean=np.concatenate((data_clean, np.zeros(shape=(int(2*seg_len-len(data_clean)),))), axis=0)
|
||||
num_frames=1
|
||||
pointer=0
|
||||
data_clean=np.roll(data_clean, np.random.randint(0,seg_len)) #if only one frame, roll it for augmentation
|
||||
elif num_frames>1:
|
||||
pointer=np.random.randint(0,hop_size) #initial shifting, graeat for augmentation, better than overlap as we get different frames at each "while" iteration
|
||||
else:
|
||||
pointer=0
|
||||
|
||||
data_not_finished=True
|
||||
while(data_not_finished):
|
||||
segment=data_clean[pointer:pointer+seg_len]
|
||||
pointer=pointer+hop_size
|
||||
if pointer+seg_len>len(data_clean):
|
||||
data_not_finished=False
|
||||
segment=segment.astype('float32')
|
||||
|
||||
SNRs=np.random.uniform(2,20)
|
||||
scale=np.random.uniform(-6,4)
|
||||
|
||||
|
||||
#load noise signal
|
||||
data_noise= next(noise_generator)
|
||||
data_noise=np.mean(data_noise,axis=1)
|
||||
#normalize
|
||||
data_noise=data_noise/np.max(np.abs(data_noise))
|
||||
new_noise=data_noise #if more processing needed, add here
|
||||
#load clean data
|
||||
#configure sizes
|
||||
if stereo:
|
||||
#estimate clean signal power
|
||||
power_clean=0.5*np.var(segment[:,0])+0.5*np.var(segment[:,1])
|
||||
#estimate noise power
|
||||
power_noise=0.5*np.var(new_noise[:,0])+0.5*np.var(new_noise[:,1])
|
||||
else:
|
||||
#estimate clean signal power
|
||||
power_clean=np.var(segment)
|
||||
#estimate noise power
|
||||
power_noise=np.var(new_noise)
|
||||
|
||||
snr = 10.0**(SNRs/10.0)
|
||||
|
||||
|
||||
#sum both signals according to snr
|
||||
summed=segment+np.sqrt(power_clean/(snr*power_noise))*new_noise #not sure if this is correct, maybe revisit later!!
|
||||
summed=10.0**(scale/10.0) *summed
|
||||
segment=10.0**(scale/10.0) *segment
|
||||
|
||||
summed=summed.astype('float32')
|
||||
yield tf.convert_to_tensor(summed), tf.convert_to_tensor(segment)
|
||||
|
||||
def load_data(buffer_size, path_music_train, path_music_val, path_noises, fs=44100, seg_len_s=5, extend=True, stereo=False) -> Tuple[tf.data.Dataset, tf.data.Dataset]:
|
||||
print("Generating train dataset")
|
||||
trainshape=int(fs*seg_len_s)
|
||||
|
||||
dataset_train = tf.data.Dataset.from_generator(generator_train,args=(path_music_train, path_noises,"train", fs, seg_len_s, extend, stereo), output_shapes=(tf.TensorShape((trainshape,)),tf.TensorShape((trainshape,))), output_types=(tf.float32, tf.float32) )
|
||||
|
||||
|
||||
print("Generating validation dataset")
|
||||
segments_noisy, segments_clean=generate_val_data(path_music_val, path_noises,"validation",fs=fs, seg_len_s=seg_len_s)
|
||||
|
||||
dataset_val=tf.data.Dataset.from_tensor_slices((segments_noisy, segments_clean))
|
||||
|
||||
return dataset_train.shuffle(buffer_size), dataset_val
|
||||
|
||||
def load_data_test(buffer_size, path_pianos_test, path_noises, **kwargs) -> Tuple[tf.data.Dataset]:
|
||||
print("Generating test dataset")
|
||||
segments_noisy, segments_clean=generate_test_data(path_pianos_test, path_noises, extend=True, **kwargs)
|
||||
dataset_test=tf.data.Dataset.from_tensor_slices((segments_noisy, segments_clean))
|
||||
#dataset_test=tf.data.Dataset.from_tensor_slices((segments_noisy[1:3], segments_clean[1:3]))
|
||||
#train_dataset = train.cache().shuffle(buffer_size).take(info.splits["train"].num_examples)
|
||||
return dataset_test
|
||||
def load_data_formal( path_pianos_test, path_noises, **kwargs) -> Tuple[tf.data.Dataset]:
|
||||
print("Generating test dataset")
|
||||
segments_noisy, segments_clean=generate_paired_data_test_formal(path_pianos_test, path_noises, extend=True, **kwargs)
|
||||
print("segments::")
|
||||
print(len(segments_noisy))
|
||||
dataset_test=tf.data.Dataset.from_tensor_slices((segments_noisy, segments_clean))
|
||||
#dataset_test=tf.data.Dataset.from_tensor_slices((segments_noisy[1:3], segments_clean[1:3]))
|
||||
#train_dataset = train.cache().shuffle(buffer_size).take(info.splits["train"].num_examples)
|
||||
return dataset_test
|
||||
|
||||
def load_real_test_recordings(buffer_size, path_recordings, **kwargs) -> Tuple[tf.data.Dataset]:
|
||||
print("Generating real test dataset")
|
||||
|
||||
segments_noisy=generate_real_recordings_data(path_recordings, **kwargs)
|
||||
|
||||
dataset_test=tf.data.Dataset.from_tensor_slices(segments_noisy)
|
||||
#train_dataset = train.cache().shuffle(buffer_size).take(info.splits["train"].num_examples)
|
||||
return dataset_test
|
||||
164
inference.py
Normal file
164
inference.py
Normal file
@ -0,0 +1,164 @@
|
||||
import os
|
||||
import hydra
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def run(args):
|
||||
import unet
|
||||
import tensorflow as tf
|
||||
import soundfile as sf
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
import librosa
|
||||
|
||||
path_experiment=str(args.path_experiment)
|
||||
|
||||
unet_model = unet.build_model_denoise(unet_args=args.unet)
|
||||
|
||||
ckpt=os.path.join(os.path.dirname(os.path.abspath(__file__)),path_experiment, 'checkpoint')
|
||||
unet_model.load_weights(ckpt)
|
||||
|
||||
def do_stft(noisy):
|
||||
|
||||
window_fn = tf.signal.hamming_window
|
||||
|
||||
win_size=args.stft.win_size
|
||||
hop_size=args.stft.hop_size
|
||||
|
||||
|
||||
stft_signal_noisy=tf.signal.stft(noisy,frame_length=win_size, window_fn=window_fn, frame_step=hop_size, pad_end=True)
|
||||
stft_noisy_stacked=tf.stack( values=[tf.math.real(stft_signal_noisy), tf.math.imag(stft_signal_noisy)], axis=-1)
|
||||
|
||||
return stft_noisy_stacked
|
||||
|
||||
def do_istft(data):
|
||||
|
||||
window_fn = tf.signal.hamming_window
|
||||
|
||||
win_size=args.stft.win_size
|
||||
hop_size=args.stft.hop_size
|
||||
|
||||
inv_window_fn=tf.signal.inverse_stft_window_fn(hop_size, forward_window_fn=window_fn)
|
||||
|
||||
pred_cpx=data[...,0] + 1j * data[...,1]
|
||||
pred_time=tf.signal.inverse_stft(pred_cpx, win_size, hop_size, window_fn=inv_window_fn)
|
||||
return pred_time
|
||||
|
||||
audio=str(args.inference.audio)
|
||||
data, samplerate = sf.read(audio)
|
||||
if samplerate!=44100:
|
||||
print("Resampling")
|
||||
data=np.transpose(data)
|
||||
data=librosa.resample(data, samplerate, 44100)
|
||||
data=np.transpose(data)
|
||||
|
||||
#Stereo to mono
|
||||
if len(data.shape)>1:
|
||||
data=np.mean(data,axis=1)
|
||||
|
||||
|
||||
segment_size=44101*20 #20s segments
|
||||
|
||||
length_data=len(data)
|
||||
overlapsize=2048 #samples (46 ms)
|
||||
window=np.hanning(2*overlapsize)
|
||||
window_right=window[overlapsize::]
|
||||
window_left=window[0:overlapsize]
|
||||
audio_finished=False
|
||||
pointer=0
|
||||
denoised_data=np.zeros(shape=(len(data),))
|
||||
residual_noise=np.zeros(shape=(len(data),))
|
||||
numchunks=int(np.ceil(length_data/segment_size))
|
||||
|
||||
for i in tqdm(range(numchunks)):
|
||||
if pointer+segment_size<length_data:
|
||||
segment=data[pointer:pointer+segment_size]
|
||||
#dostft
|
||||
segment_TF=do_stft(segment)
|
||||
segment_TF_ds=tf.data.Dataset.from_tensors(segment_TF)
|
||||
pred = unet_model.predict(segment_TF_ds.batch(1))
|
||||
pred=pred[0]
|
||||
residual=segment_TF-pred[0]
|
||||
residual=np.array(residual)
|
||||
pred_time=do_istft(pred[0])
|
||||
residual_time=do_istft(residual)
|
||||
residual_time=np.array(residual_time)
|
||||
|
||||
if pointer==0:
|
||||
pred_time=np.concatenate((pred_time[0:int(segment_size-overlapsize)], np.multiply(pred_time[int(segment_size-overlapsize):segment_size],window_right)), axis=0)
|
||||
residual_time=np.concatenate((residual_time[0:int(segment_size-overlapsize)], np.multiply(residual_time[int(segment_size-overlapsize):segment_size],window_right)), axis=0)
|
||||
else:
|
||||
pred_time=np.concatenate((np.multiply(pred_time[0:int(overlapsize)], window_left), pred_time[int(overlapsize):int(segment_size-overlapsize)], np.multiply(pred_time[int(segment_size-overlapsize):int(segment_size)],window_right)), axis=0)
|
||||
residual_time=np.concatenate((np.multiply(residual_time[0:int(overlapsize)], window_left), residual_time[int(overlapsize):int(segment_size-overlapsize)], np.multiply(residual_time[int(segment_size-overlapsize):int(segment_size)],window_right)), axis=0)
|
||||
|
||||
denoised_data[pointer:pointer+segment_size]=denoised_data[pointer:pointer+segment_size]+pred_time
|
||||
residual_noise[pointer:pointer+segment_size]=residual_noise[pointer:pointer+segment_size]+residual_time
|
||||
|
||||
pointer=pointer+segment_size-overlapsize
|
||||
else:
|
||||
segment=data[pointer::]
|
||||
lensegment=len(segment)
|
||||
segment=np.concatenate((segment, np.zeros(shape=(int(segment_size-len(segment)),))), axis=0)
|
||||
audio_finished=True
|
||||
#dostft
|
||||
segment_TF=do_stft(segment)
|
||||
|
||||
segment_TF_ds=tf.data.Dataset.from_tensors(segment_TF)
|
||||
|
||||
pred = unet_model.predict(segment_TF_ds.batch(1))
|
||||
pred=pred[0]
|
||||
residual=segment_TF-pred[0]
|
||||
residual=np.array(residual)
|
||||
pred_time=do_istft(pred[0])
|
||||
pred_time=np.array(pred_time)
|
||||
pred_time=pred_time[0:segment_size]
|
||||
residual_time=do_istft(residual)
|
||||
residual_time=np.array(residual_time)
|
||||
residual_time=residual_time[0:segment_size]
|
||||
if pointer==0:
|
||||
pred_time=pred_time
|
||||
residual_time=residual_time
|
||||
else:
|
||||
pred_time=np.concatenate((np.multiply(pred_time[0:int(overlapsize)], window_left), pred_time[int(overlapsize):int(segment_size)]),axis=0)
|
||||
residual_time=np.concatenate((np.multiply(residual_time[0:int(overlapsize)], window_left), residual_time[int(overlapsize):int(segment_size)]),axis=0)
|
||||
|
||||
denoised_data[pointer::]=denoised_data[pointer::]+pred_time[0:lensegment]
|
||||
residual_noise[pointer::]=residual_noise[pointer::]+residual_time[0:lensegment]
|
||||
|
||||
basename=os.path.splitext(audio)[0]
|
||||
wav_noisy_name=basename+"_noisy_input"+".wav"
|
||||
sf.write(wav_noisy_name, data, 44100)
|
||||
wav_output_name=basename+"_denoised"+".wav"
|
||||
sf.write(wav_output_name, denoised_data, 44100)
|
||||
wav_output_name=basename+"_residual"+".wav"
|
||||
sf.write(wav_output_name, residual_noise, 44100)
|
||||
|
||||
|
||||
def _main(args):
|
||||
global __file__
|
||||
|
||||
__file__ = hydra.utils.to_absolute_path(__file__)
|
||||
|
||||
run(args)
|
||||
|
||||
|
||||
@hydra.main(config_path="conf/conf.yaml")
|
||||
def main(args):
|
||||
try:
|
||||
_main(args)
|
||||
except Exception:
|
||||
logger.exception("Some error happened")
|
||||
# Hydra intercepts exit code, fixed in beta but I could not get the beta to work
|
||||
os._exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
151
test.py
Normal file
151
test.py
Normal file
@ -0,0 +1,151 @@
|
||||
import os
|
||||
import hydra
|
||||
import logging
|
||||
'''
|
||||
Script used for the objective experiments
|
||||
WARNING: it calls MATLAB to calculate PEAQ and PEMO-Q. The whole process may be very slow
|
||||
'''
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def run(args):
|
||||
import unet
|
||||
import dataset_loader
|
||||
import tensorflow as tf
|
||||
import pandas as pd
|
||||
|
||||
path_experiment=str(args.path_experiment)
|
||||
|
||||
print(path_experiment)
|
||||
if not os.path.exists(path_experiment):
|
||||
os.makedirs(path_experiment)
|
||||
|
||||
unet_model = unet.build_model_denoise(stereo=stereo,unet_args=args.unet)
|
||||
|
||||
ckpt=os.path.join(path_experiment, 'checkpoint')
|
||||
unet_model.load_weights(ckpt)
|
||||
|
||||
|
||||
path_pianos_test=args.dset.path_piano_test
|
||||
path_strings_test=args.dset.path_strings_test
|
||||
path_orchestra_test=args.dset.path_orchestra_test
|
||||
path_opera_test=args.dset.path_opera_test
|
||||
path_noise=args.dset.path_noise
|
||||
fs=args.fs
|
||||
seg_len_s=20
|
||||
numsamples=1000//seg_len_s
|
||||
|
||||
def do_stft(noisy, clean=None):
|
||||
|
||||
if args.stft.window=="hamming":
|
||||
window_fn = tf.signal.hamming_window
|
||||
elif args.stft.window=="hann":
|
||||
window_fn=tf.signal.hann_window
|
||||
elif args.stft.window=="kaiser_bessel":
|
||||
window_fn=tf.signal.kaiser_bessel_derived_window
|
||||
|
||||
win_size=args.stft.win_size
|
||||
hop_size=args.stft.hop_size
|
||||
|
||||
|
||||
stft_signal_noisy=tf.signal.stft(noisy,frame_length=win_size, window_fn=window_fn, frame_step=hop_size)
|
||||
stft_noisy_stacked=tf.stack( values=[tf.math.real(stft_signal_noisy), tf.math.imag(stft_signal_noisy)], axis=-1)
|
||||
|
||||
if clean!=None:
|
||||
|
||||
stft_signal_clean=tf.signal.stft(clean,frame_length=win_size, window_fn=window_fn, frame_step=hop_size)
|
||||
stft_clean_stacked=tf.stack( values=[tf.math.real(stft_signal_clean), tf.math.imag(stft_signal_clean)], axis=-1)
|
||||
|
||||
|
||||
return stft_noisy_stacked, stft_clean_stacked
|
||||
else:
|
||||
|
||||
return stft_noisy_stacked
|
||||
|
||||
from tester import Tester
|
||||
|
||||
testPath=os.path.join(path_experiment,"final_test")
|
||||
if not os.path.exists(testPath):
|
||||
os.makedirs(testPath)
|
||||
|
||||
tester=Tester(unet_model, testPath, args)
|
||||
|
||||
PEAQ_dir="/scratch/work/molinee2/unet_dir/unet_historical_music/PQevalAudio"
|
||||
PEMOQ_dir="/scratch/work/molinee2/unet_dir/unet_historical_music/PEMOQ"
|
||||
|
||||
dataset_test_pianos=dataset_loader.load_data_formal( path_pianos_test, path_noise, noise_amount="mid_snr",num_samples=numsamples, fs=fs, seg_len_s=seg_len_s, stereo=stereo)
|
||||
dataset_test_pianos=dataset_test_pianos.map(do_stft, num_parallel_calls=args.num_workers, deterministic=None)
|
||||
tester.init_inference(dataset_test_pianos,numsamples,fs,args.stft, PEAQ_dir, PEMOQ_dir=PEMOQ_dir)
|
||||
metrics=tester.inference("pianos_midsnr")
|
||||
|
||||
dataset_test_strings=dataset_loader.load_data_formal( path_strings_test, path_noise,noise_amount="mid_snr",num_samples=numsamples, fs=fs, seg_len_s=seg_len_s, stereo=stereo)
|
||||
dataset_test_strings=dataset_test_strings.map(do_stft, num_parallel_calls=args.num_workers, deterministic=None)
|
||||
tester.init_inference(dataset_test_strings,numsamples,fs,args.stft, PEAQ_dir, PEMOQ_dir=PEMOQ_dir)
|
||||
metrics=tester.inference("strings_midsnr")
|
||||
|
||||
dataset_test_orchestra=dataset_loader.load_data_formal( path_orchestra_test, path_noise, noise_amount="mid_snr", num_samples=numsamples, fs=fs, seg_len_s=seg_len_s, stereo=stereo)
|
||||
dataset_test_orchestra=dataset_test_orchestra.map(do_stft, num_parallel_calls=args.num_workers, deterministic=None)
|
||||
tester.init_inference(dataset_test_orchestra,numsamples,fs,args.stft, PEAQ_dir, PEMOQ_dir=PEMOQ_dir)
|
||||
metrics=tester.inference("orchestra_midsnr")
|
||||
|
||||
dataset_test_opera=dataset_loader.load_data_formal( path_opera_test, path_noise, noise_amount="mid_snr",num_samples=numsamples, fs=fs, seg_len_s=seg_len_s, stereo=stereo)
|
||||
dataset_test_opera=dataset_test_opera.map(do_stft, num_parallel_calls=args.num_workers, deterministic=None)
|
||||
tester.init_inference(dataset_test_opera,numsamples,fs,args.stft, PEAQ_dir, PEMOQ_dir=PEMOQ_dir)
|
||||
metrics=tester.inference("opera_midsnr")
|
||||
|
||||
dataset_test_strings=dataset_loader.load_data_formal( path_strings_test, path_noise,noise_amount="low_snr",num_samples=numsamples, fs=fs, seg_len_s=seg_len_s, stereo=stereo)
|
||||
dataset_test_strings=dataset_test_strings.map(do_stft, num_parallel_calls=args.num_workers, deterministic=None)
|
||||
tester.init_inference(dataset_test_strings,numsamples,fs,args.stft, PEAQ_dir, PEMOQ_dir=PEMOQ_dir)
|
||||
metrics=tester.inference("strings_lowsnr")
|
||||
|
||||
dataset_test_orchestra=dataset_loader.load_data_formal( path_orchestra_test, path_noise,noise_amount="low_snr", num_samples=numsamples, fs=fs, seg_len_s=seg_len_s, stereo=stereo)
|
||||
dataset_test_orchestra=dataset_test_orchestra.map(do_stft, num_parallel_calls=args.num_workers, deterministic=None)
|
||||
tester.init_inference(dataset_test_orchestra,numsamples,fs,args.stft, PEAQ_dir, PEMOQ_dir=PEMOQ_dir)
|
||||
metrics=tester.inference("orchestra_lowsnr")
|
||||
|
||||
dataset_test_opera=dataset_loader.load_data_formal( path_opera_test, path_noise, noise_amount="low_snr",num_samples=numsamples, fs=fs, seg_len_s=seg_len_s, stereo=stereo)
|
||||
dataset_test_opera=dataset_test_opera.map(do_stft, num_parallel_calls=args.num_workers, deterministic=None)
|
||||
tester.init_inference(dataset_test_opera,numsamples,fs,args.stft, PEAQ_dir, PEMOQ_dir=PEMOQ_dir)
|
||||
metrics=tester.inference("opera_lowsnr")
|
||||
|
||||
|
||||
dataset_test_pianos=dataset_loader.load_data_formal( path_pianos_test, path_noise, noise_amount="low_snr",num_samples=numsamples, fs=fs, seg_len_s=seg_len_s, stereo=stereo)
|
||||
dataset_test_pianos=dataset_test_pianos.map(do_stft, num_parallel_calls=args.num_workers, deterministic=None)
|
||||
tester.init_inference(dataset_test_pianos,numsamples,fs,args.stft, PEAQ_dir, PEMOQ_dir=PEMOQ_dir)
|
||||
metrics=tester.inference("pianos_lowsnr")
|
||||
|
||||
|
||||
names=["strings_midsnr","strings_lowsnr","opera_midsnr","opera_lowsnr","pianos_midsnr","pianos_lowsnr","orchestra_midsnr","orchestra_lowsnr"]
|
||||
for n in names:
|
||||
a=pd.read_csv(os.path.join(testPath,n,"metrics.csv"))
|
||||
meanPEAQ=a["PEAQ(ODG)_diff"].sum()/50
|
||||
meanPEMOQ=a["PEMOQ(ODG)_diff"].sum()/50
|
||||
meanSDR=a["SDR_diff"].sum()/50
|
||||
print(n,": PEAQ ",str(meanPEAQ), "PEMOQ ", str(meanPEMOQ), "SDR ", str(meanSDR))
|
||||
|
||||
def _main(args):
|
||||
global __file__
|
||||
|
||||
__file__ = hydra.utils.to_absolute_path(__file__)
|
||||
|
||||
run(args)
|
||||
|
||||
|
||||
@hydra.main(config_path="conf/conf.yaml")
|
||||
def main(args):
|
||||
try:
|
||||
_main(args)
|
||||
except Exception:
|
||||
logger.exception("Some error happened")
|
||||
# Hydra intercepts exit code, fixed in beta but I could not get the beta to work
|
||||
os._exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
391
tester.py
Normal file
391
tester.py
Normal file
@ -0,0 +1,391 @@
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import cv2
|
||||
import librosa
|
||||
import imageio
|
||||
import tensorflow as tf
|
||||
import soundfile as sf
|
||||
import subprocess
|
||||
from tqdm import tqdm
|
||||
from vggish.vgg_distance import process_wav
|
||||
import pandas as pd
|
||||
from scipy.io import loadmat
|
||||
|
||||
class Tester():
|
||||
def __init__(self, model, path_experiment, args):
|
||||
if model !=None:
|
||||
self.model=model
|
||||
print(self.model.summary())
|
||||
self.args=args
|
||||
self.path_experiment=path_experiment
|
||||
|
||||
def init_inference(self, dataset_test=None,num_test_segments=0 , fs=44100, stft_args=None, PEAQ_dir=None, alg_dir=None, PEMOQ_dir=None):
|
||||
|
||||
self.num_test_segments=num_test_segments
|
||||
self.dataset_test=dataset_test
|
||||
|
||||
if self.dataset_test!=None:
|
||||
self.dataset_test=self.dataset_test.take(self.num_test_segments)
|
||||
|
||||
self.fs=fs
|
||||
self.stft_args=stft_args
|
||||
self.win_size=stft_args.win_size
|
||||
self.hop_size=stft_args.hop_size
|
||||
self.window=stft_args.window
|
||||
self.PEAQ_dir=PEAQ_dir
|
||||
self.PEMOQ_dir=PEMOQ_dir
|
||||
self.alg_dir=alg_dir
|
||||
|
||||
|
||||
|
||||
|
||||
def generate_inverse_window(self, stft_args):
|
||||
if stft_args.window=="hamming":
|
||||
return tf.signal.inverse_stft_window_fn(stft_args.hop_size, forward_window_fn=tf.signal.hamming_window)
|
||||
elif stft_args.window=="hann":
|
||||
return tf.signal.inverse_stft_window_fn(stft_args.hop_size, forward_window_fn=tf.signal.hann_window)
|
||||
elif stft_args.window=="kaiser_bessel":
|
||||
return tf.signal.inverse_stft_window_fn(stft_args.hop_size, forward_window_fn=tf.signal.kaiser_bessel_derived_window)
|
||||
def do_istft(self,data):
|
||||
|
||||
window_fn = self.generate_inverse_window(self.stft_args)
|
||||
win_size=self.win_size
|
||||
hop_size=self.hop_size
|
||||
pred_cpx=data[...,0] + 1j * data[...,1]
|
||||
pred_time=tf.signal.inverse_stft(pred_cpx, win_size, hop_size, window_fn=window_fn)
|
||||
return pred_time
|
||||
|
||||
def generate_images(self,cpx,name):
|
||||
spectro=np.clip((np.flipud(np.transpose(10*np.log10(np.sqrt(np.power(cpx[...,0],2)+np.power(cpx[...,1],2)))))+30)/50,0,1)
|
||||
spectrorgb=np.zeros(shape=(spectro.shape[0],spectro.shape[1],3))
|
||||
spectrorgb[...,0]=np.clip((np.flipud(np.transpose(10*np.log10(np.abs(cpx[...,0])+0.001)))+30)/50,0,1)
|
||||
spectrorgb[...,1]=np.clip((np.flipud(np.transpose(10*np.log10(np.abs(cpx[...,1])+0.001)))+30)/50,0,1)
|
||||
cmap=cv2.COLORMAP_JET
|
||||
spectro = np.array((1-spectro)* 255, dtype = np.uint8)
|
||||
spectro = cv2.applyColorMap(spectro, cmap)
|
||||
imageio.imwrite(os.path.join(self.test_results_filepath, name+".png"),spectro)
|
||||
spectrorgb = np.array(spectrorgb* 255, dtype = np.uint8)
|
||||
imageio.imwrite(os.path.join(self.test_results_filepath, name+"_ir.png"),spectrorgb)
|
||||
|
||||
def generate_image_diff(self,clean , pred,name):
|
||||
difference=np.sqrt((clean[...,0]-pred[...,0])**2+(clean[...,1]-pred[...,1])**2)
|
||||
dif=np.clip(np.flipud(np.transpose(difference)),0,1)
|
||||
cmap=cv2.COLORMAP_JET
|
||||
dif = np.array((1-dif)* 255, dtype = np.uint8)
|
||||
dif = cv2.applyColorMap(dif, cmap)
|
||||
imageio.imwrite(os.path.join(self.test_results_filepath, name+"_diff.png"),dif)
|
||||
|
||||
def inference_inner_classical(self, folder_name, method):
|
||||
nums=[]
|
||||
|
||||
PEAQ_odg_noisy=[]
|
||||
PEAQ_odg_output=[]
|
||||
PEAQ_odg_diff=[]
|
||||
|
||||
PEMOQ_odg_noisy=[]
|
||||
PEMOQ_odg_output=[]
|
||||
PEMOQ_odg_diff=[]
|
||||
|
||||
SDR_noisy=[]
|
||||
SDR_output=[]
|
||||
SDR_diff=[]
|
||||
|
||||
VGGish_noisy=[]
|
||||
VGGish_output=[]
|
||||
VGGish_diff=[]
|
||||
|
||||
self.test_results_filepath = os.path.join(self.path_experiment,folder_name)
|
||||
if not os.path.exists(self.test_results_filepath):
|
||||
os.makedirs(self.test_results_filepath)
|
||||
num=0
|
||||
for element in tqdm(self.dataset_test.take(self.num_test_segments)):
|
||||
test_element=tf.data.Dataset.from_tensors(element)
|
||||
noisy_time=element[0].numpy()
|
||||
#noisy_time=self.do_istft(noisy)
|
||||
name_noisy=str(num)+'_noisy'
|
||||
clean_time=element[1].numpy()
|
||||
#clean_time=self.do_istft(clean)
|
||||
name_clean=str(num)+'_clean'
|
||||
print("inferencing")
|
||||
|
||||
|
||||
nums.append(num)
|
||||
|
||||
print("generating wavs")
|
||||
#noisy_time=noisy_time.numpy().astype(np.float32)
|
||||
noisy_time=noisy_time.astype(np.float32)
|
||||
wav_noisy_name_pre=os.path.join(self.test_results_filepath, name_noisy+"pre.wav")
|
||||
sf.write(wav_noisy_name_pre, noisy_time, 44100)
|
||||
|
||||
#pred = self.model.predict(test_element.batch(1))
|
||||
name_pred=str(num)+'_output'
|
||||
wav_output_name_proc=os.path.join(self.test_results_filepath, name_pred+"proc.wav")
|
||||
self.process_in_matlab(wav_noisy_name_pre, wav_output_name_proc, method)
|
||||
|
||||
noisy_time=noisy_time[44100::] #remove pre noise
|
||||
|
||||
#clean_time=clean_time.numpy().astype(np.float32)
|
||||
clean_time=clean_time.astype(np.float32)
|
||||
clean_time=clean_time[44100::] #remove pre noise
|
||||
|
||||
#change that !!!!
|
||||
#pred_time=self.do_istft(pred[0])
|
||||
#pred_time=pred_time.numpy().astype(np.float32)
|
||||
#pred_time=librosa.resample(np.transpose(pred_time),self.fs, 48000)
|
||||
#sf.write(wav_output_name, pred_time, 48000)
|
||||
#LOAD THE AUDIO!!!
|
||||
pred_time, sr=sf.read(wav_output_name_proc)
|
||||
assert sr==44100
|
||||
pred_time=pred_time[44100::] #remove prenoise
|
||||
|
||||
#I am computing here the SDR at 48k, whle I was doing it before at 44.1k. I hope this won't cause any problem in the results. Consider resampling???
|
||||
SDR_t_noisy=10*np.log10(np.mean(np.square(clean_time))/np.mean(np.square(noisy_time-clean_time)))
|
||||
SDR_noisy.append(SDR_t_noisy)
|
||||
SDR_t_output=10*np.log10(np.mean(np.square(clean_time))/np.mean(np.square(pred_time-clean_time)))
|
||||
SDR_output.append(SDR_t_output)
|
||||
SDR_diff.append(SDR_t_output-SDR_t_noisy)
|
||||
|
||||
noisy_time=librosa.resample(np.transpose(noisy_time),self.fs, 48000) #P.Kabal PEAQ code is hardcoded at Fs=48000, so we have to resample
|
||||
wav_noisy_name=os.path.join(self.test_results_filepath, name_noisy+".wav")
|
||||
sf.write(wav_noisy_name, noisy_time, 48000) #overwrite without prenoise
|
||||
|
||||
clean_time=librosa.resample(np.transpose(clean_time),self.fs, 48000) #without prenoise please!!!
|
||||
wav_clean_name=os.path.join(self.test_results_filepath, name_clean+".wav")
|
||||
sf.write(wav_clean_name, clean_time, 48000)
|
||||
|
||||
pred_time=librosa.resample(np.transpose(pred_time),self.fs, 48000) #without prenoise please!!!
|
||||
wav_output_name=os.path.join(self.test_results_filepath, name_pred+".wav")
|
||||
sf.write(wav_output_name, pred_time, 48000)
|
||||
|
||||
#save pred at 48k
|
||||
#print("calculating PEMOQ")
|
||||
#odg_noisy,odg_output =self.calculate_PEMOQ(wav_clean_name,wav_noisy_name,wav_output_name)
|
||||
#PEMOQ_odg_noisy.append(odg_noisy)
|
||||
#PEMOQ_odg_output.append(odg_output)
|
||||
#PEMOQ_odg_diff.append(odg_output-odg_noisy)
|
||||
|
||||
#print("calculating PEAQ")
|
||||
#odg_noisy,odg_output =self.calculate_PEAQ(wav_clean_name,wav_noisy_name,wav_output_name)
|
||||
#PEAQ_odg_noisy.append(odg_noisy)
|
||||
#PEAQ_odg_output.append(odg_output)
|
||||
#PEAQ_odg_diff.append(odg_output-odg_noisy)
|
||||
|
||||
print("calculating VGGish")
|
||||
VGGish_clean_embeddings=process_wav(wav_clean_name)
|
||||
VGGish_noisy_embeddings=process_wav(wav_noisy_name)
|
||||
VGGish_output_embeddings=process_wav(wav_output_name)
|
||||
dist_noisy = np.linalg.norm(VGGish_noisy_embeddings-VGGish_clean_embeddings)
|
||||
dist_output = np.linalg.norm(VGGish_output_embeddings-VGGish_clean_embeddings)
|
||||
VGGish_noisy.append(dist_noisy)
|
||||
VGGish_output.append(dist_output)
|
||||
VGGish_diff.append(-(dist_output-dist_noisy))
|
||||
os.remove(wav_clean_name)
|
||||
os.remove(wav_noisy_name)
|
||||
os.remove(wav_noisy_name_pre)
|
||||
os.remove(wav_output_name)
|
||||
os.remove(wav_output_name_proc)
|
||||
|
||||
num=num+1
|
||||
|
||||
frame = { 'num':nums,'PEAQ(ODG)_noisy': PEAQ_odg_noisy, 'PEAQ(ODG)_output': PEAQ_odg_output, 'PEAQ(ODG)_diff': PEAQ_odg_diff, 'PEMOQ(ODG)_noisy': PEMOQ_odg_noisy, 'PEMOQ(ODG)_output': PEMOQ_odg_output, 'PEMOQ(ODG)_diff': PEMOQ_odg_diff,'SDR_noisy': SDR_noisy, 'SDR_output': SDR_output, 'SDR_diff': SDR_diff, 'VGGish_noisy': VGGish_noisy, 'VGGish_output': VGGish_output,'VGGish_diff': VGGish_diff }
|
||||
|
||||
metrics=pd.DataFrame(frame)
|
||||
metrics.to_csv(os.path.join(self.test_results_filepath,"metrics.csv"),index=False)
|
||||
metrics=metrics.set_index('num')
|
||||
|
||||
return metrics
|
||||
def inference_inner(self, folder_name):
|
||||
nums=[]
|
||||
|
||||
PEAQ_odg_noisy=[]
|
||||
PEAQ_odg_output=[]
|
||||
PEAQ_odg_diff=[]
|
||||
|
||||
PEMOQ_odg_noisy=[]
|
||||
PEMOQ_odg_output=[]
|
||||
PEMOQ_odg_diff=[]
|
||||
|
||||
SDR_noisy=[]
|
||||
SDR_output=[]
|
||||
SDR_diff=[]
|
||||
|
||||
VGGish_noisy=[]
|
||||
VGGish_output=[]
|
||||
VGGish_diff=[]
|
||||
|
||||
self.test_results_filepath = os.path.join(self.path_experiment,folder_name)
|
||||
if not os.path.exists(self.test_results_filepath):
|
||||
os.makedirs(self.test_results_filepath)
|
||||
num=0
|
||||
for element in tqdm(self.dataset_test.take(self.num_test_segments)):
|
||||
test_element=tf.data.Dataset.from_tensors(element)
|
||||
noisy=element[0].numpy()
|
||||
noisy_time=self.do_istft(noisy)
|
||||
name_noisy=str(num)+'_noisy'
|
||||
clean=element[1].numpy()
|
||||
clean_time=self.do_istft(clean)
|
||||
name_clean=str(num)+'_clean'
|
||||
print("inferencing")
|
||||
pred = self.model.predict(test_element.batch(1))
|
||||
if self.args.unet.num_stages==2:
|
||||
pred=pred[0]
|
||||
pred_time=self.do_istft(pred[0])
|
||||
name_pred=str(num)+'_output'
|
||||
|
||||
nums.append(num)
|
||||
pred_time=pred_time.numpy().astype(np.float32)
|
||||
clean_time=clean_time.numpy().astype(np.float32)
|
||||
SDR_t_noisy=10*np.log10(np.mean(np.square(clean_time))/np.mean(np.square(noisy_time-clean_time)))
|
||||
SDR_t_output=10*np.log10(np.mean(np.square(clean_time))/np.mean(np.square(pred_time-clean_time)))
|
||||
SDR_noisy.append(SDR_t_noisy)
|
||||
SDR_output.append(SDR_t_output)
|
||||
SDR_diff.append(SDR_t_output-SDR_t_noisy)
|
||||
|
||||
print("generating wavs")
|
||||
noisy_time=librosa.resample(np.transpose(noisy_time),self.fs, 48000) #P.Kabal PEAQ code is hardcoded at Fs=48000, so we have to resample
|
||||
clean_time=librosa.resample(np.transpose(clean_time),self.fs, 48000)
|
||||
pred_time=librosa.resample(np.transpose(pred_time),self.fs, 48000)
|
||||
|
||||
wav_noisy_name=os.path.join(self.test_results_filepath, name_noisy+".wav")
|
||||
sf.write(wav_noisy_name, noisy_time, 48000)
|
||||
wav_clean_name=os.path.join(self.test_results_filepath, name_clean+".wav")
|
||||
sf.write(wav_clean_name, clean_time, 48000)
|
||||
wav_output_name=os.path.join(self.test_results_filepath, name_pred+".wav")
|
||||
sf.write(wav_output_name, pred_time, 48000)
|
||||
|
||||
print("calculating PEMOQ")
|
||||
odg_noisy,odg_output =self.calculate_PEMOQ(wav_clean_name,wav_noisy_name,wav_output_name)
|
||||
PEMOQ_odg_noisy.append(odg_noisy)
|
||||
PEMOQ_odg_output.append(odg_output)
|
||||
PEMOQ_odg_diff.append(odg_output-odg_noisy)
|
||||
print("calculating PEAQ")
|
||||
odg_noisy,odg_output =self.calculate_PEAQ(wav_clean_name,wav_noisy_name,wav_output_name)
|
||||
PEAQ_odg_noisy.append(odg_noisy)
|
||||
PEAQ_odg_output.append(odg_output)
|
||||
PEAQ_odg_diff.append(odg_output-odg_noisy)
|
||||
|
||||
print("calculating VGGish")
|
||||
VGGish_clean_embeddings=process_wav(wav_clean_name)
|
||||
VGGish_noisy_embeddings=process_wav(wav_noisy_name)
|
||||
VGGish_output_embeddings=process_wav(wav_output_name)
|
||||
dist_noisy = np.linalg.norm(VGGish_noisy_embeddings-VGGish_clean_embeddings)
|
||||
dist_output = np.linalg.norm(VGGish_output_embeddings-VGGish_clean_embeddings)
|
||||
VGGish_noisy.append(dist_noisy)
|
||||
VGGish_output.append(dist_output)
|
||||
VGGish_diff.append(-(dist_output-dist_noisy))
|
||||
os.remove(wav_clean_name)
|
||||
os.remove(wav_noisy_name)
|
||||
os.remove(wav_output_name)
|
||||
|
||||
num=num+1
|
||||
|
||||
frame = { 'num':nums,'PEAQ(ODG)_noisy': PEAQ_odg_noisy, 'PEAQ(ODG)_output': PEAQ_odg_output, 'PEAQ(ODG)_diff': PEAQ_odg_diff, 'PEMOQ(ODG)_noisy': PEMOQ_odg_noisy, 'PEMOQ(ODG)_output': PEMOQ_odg_output, 'PEMOQ(ODG)_diff': PEMOQ_odg_diff,'SDR_noisy': SDR_noisy, 'SDR_output': SDR_output, 'SDR_diff': SDR_diff, 'VGGish_noisy': VGGish_noisy, 'VGGish_output': VGGish_output,'VGGish_diff': VGGish_diff }
|
||||
|
||||
metrics=pd.DataFrame(frame)
|
||||
metrics.to_csv(os.path.join(self.test_results_filepath,"metrics.csv"),index=False)
|
||||
metrics=metrics.set_index('num')
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def inference_real(self, folder_name):
|
||||
self.test_results_filepath = os.path.join(self.path_experiment,folder_name)
|
||||
if not os.path.exists(self.test_results_filepath):
|
||||
os.makedirs(self.test_results_filepath)
|
||||
num=0
|
||||
for element in tqdm(self.dataset_real.take(self.num_real_test_segments)):
|
||||
test_element=tf.data.Dataset.from_tensors(element)
|
||||
noisy=element.numpy()
|
||||
noisy_time=self.do_istft(noisy)
|
||||
name_noisy="recording_"+str(num)+'_noisy.wav'
|
||||
pred = self.model.predict(test_element.batch(1))
|
||||
if self.args.unet.num_stages==2:
|
||||
pred=pred[0]
|
||||
pred_time=self.do_istft(pred[0])
|
||||
name_pred="recording_"+str(num)+'_output.wav'
|
||||
sf.write(os.path.join(self.test_results_filepath, name_noisy), noisy_time, self.fs)
|
||||
sf.write(os.path.join(self.test_results_filepath, name_pred), pred_time, self.fs)
|
||||
self.generate_images(noisy,name_noisy)
|
||||
self.generate_images(pred[0],name_pred)
|
||||
num=num+1
|
||||
|
||||
|
||||
def process_in_matlab(self,wav_noisy_name,wav_output_name,mode): #Opening and closing matlab to calculate PEAQ, rudimentary way to do it but easier. Make sure to have matlab installed
|
||||
addpath=self.alg_dir
|
||||
#odgmatfile_noisy=os.path.join(self.test_results_filepath, "odg_noisy.mat")
|
||||
#odgmatfile_pred=os.path.join(self.test_results_filepath, "odg_pred.mat")
|
||||
#bashCommand = "matlab -nodesktop -r 'addpath(\"PQevalAudio\", \"PQevalAudio/CB\",\"PQevalAudio/Misc\",\"PQevalAudio/MOV\", \"PQevalAudio/Patt\"), [odg, MOV]=PQevalAudio(\"0_clean_48.wav\",\"0_noise_48.wav\"), save(\"odg_noisy.mat\",\"odg\"), save(\"mov.mat\",\"MOV\") , exit'"
|
||||
bashCommand = "matlab -nodesktop -r 'addpath(genpath(\""+addpath+"\")), declick_and_denoise(\""+wav_noisy_name+"\",\""+wav_output_name+"\",\""+mode+"\") , exit'"
|
||||
print(bashCommand)
|
||||
p1 = subprocess.Popen(bashCommand, stdout=subprocess.PIPE, shell=True)
|
||||
(output, err) = p1.communicate()
|
||||
|
||||
print(output)
|
||||
|
||||
p1.wait()
|
||||
|
||||
def calculate_PEMOQ(self,wav_clean_name,wav_noisy_name,wav_output_name): #Opening and closing matlab to calculate PEAQ, rudimentary way to do it but easier. Make sure to have matlab installed
|
||||
addpath=self.PEMOQ_dir
|
||||
odgmatfile_noisy=os.path.join(self.test_results_filepath, "odg_pemo_noisy.mat")
|
||||
odgmatfile_pred=os.path.join(self.test_results_filepath, "odg_pemo_pred.mat")
|
||||
#bashCommand = "matlab -nodesktop -r 'addpath(\"PQevalAudio\", \"PQevalAudio/CB\",\"PQevalAudio/Misc\",\"PQevalAudio/MOV\", \"PQevalAudio/Patt\"), [odg, MOV]=PQevalAudio(\"0_clean_48.wav\",\"0_noise_48.wav\"), save(\"odg_noisy.mat\",\"odg\"), save(\"mov.mat\",\"MOV\") , exit'"
|
||||
bashCommand = "matlab -nodesktop -r 'addpath(genpath(\""+addpath+"\")), [ ODG]=PEMOQ(\""+wav_clean_name+"\",\""+wav_noisy_name+"\"), save(\""+odgmatfile_noisy+"\",\"ODG\"), exit'"
|
||||
print(bashCommand)
|
||||
|
||||
p1 = subprocess.Popen(bashCommand, stdout=subprocess.PIPE, shell=True)
|
||||
(output, err) = p1.communicate()
|
||||
|
||||
print(output)
|
||||
|
||||
bashCommand = "matlab -nodesktop -r 'addpath(genpath(\""+addpath+"\")), [ ODG]=PEMOQ(\""+wav_clean_name+"\",\""+wav_output_name+"\"), save(\""+odgmatfile_pred+"\",\"ODG\"), exit'"
|
||||
|
||||
p2 = subprocess.Popen(bashCommand, stdout=subprocess.PIPE, shell=True)
|
||||
(output, err) = p2.communicate()
|
||||
|
||||
print(output)
|
||||
p1.wait()
|
||||
p2.wait()
|
||||
#I save the odg results in a .mat file, which I load here. Not the most optimal method, sorry :/
|
||||
annots_noise = loadmat(odgmatfile_noisy)
|
||||
annots_pred = loadmat(odgmatfile_pred)
|
||||
#Consider loading also the movs!!
|
||||
return annots_noise["ODG"][0][0], annots_pred["ODG"][0][0]
|
||||
|
||||
def calculate_PEAQ(self,wav_clean_name,wav_noisy_name,wav_output_name): #Opening and closing matlab to calculate PEAQ, rudimentary way to do it but easier. Make sure to have matlab installed
|
||||
addpath=self.PEAQ_dir
|
||||
odgmatfile_noisy=os.path.join(self.test_results_filepath, "odg_noisy.mat")
|
||||
odgmatfile_pred=os.path.join(self.test_results_filepath, "odg_pred.mat")
|
||||
#bashCommand = "matlab -nodesktop -r 'addpath(\"PQevalAudio\", \"PQevalAudio/CB\",\"PQevalAudio/Misc\",\"PQevalAudio/MOV\", \"PQevalAudio/Patt\"), [odg, MOV]=PQevalAudio(\"0_clean_48.wav\",\"0_noise_48.wav\"), save(\"odg_noisy.mat\",\"odg\"), save(\"mov.mat\",\"MOV\") , exit'"
|
||||
bashCommand = "matlab -nodesktop -r 'addpath(genpath(\""+addpath+"\")), [odg, MOV]=PQevalAudio(\""+wav_clean_name+"\",\""+wav_noisy_name+"\"), save(\""+odgmatfile_noisy+"\",\"odg\"), save(\"mov.mat\",\"MOV\") , exit'"
|
||||
p1 = subprocess.Popen(bashCommand, stdout=subprocess.PIPE, shell=True)
|
||||
(output, err) = p1.communicate()
|
||||
|
||||
print(output)
|
||||
|
||||
bashCommand = "matlab -nodesktop -r 'addpath(genpath(\""+addpath+"\")), [odg, MOV]=PQevalAudio(\""+wav_clean_name+"\",\""+wav_output_name+"\"), save(\""+odgmatfile_pred+"\",\"odg\"), save(\"mov.mat\",\"MOV\") , exit'"
|
||||
p2 = subprocess.Popen(bashCommand, stdout=subprocess.PIPE, shell=True)
|
||||
(output, err) = p2.communicate()
|
||||
|
||||
print(output)
|
||||
p1.wait()
|
||||
p2.wait()
|
||||
#I save the odg results in a .mat file, which I load here. Not the most optimal method, sorry :/
|
||||
annots_noise = loadmat(odgmatfile_noisy)
|
||||
annots_pred = loadmat(odgmatfile_pred)
|
||||
#Consider loading also the movs!!
|
||||
return annots_noise["odg"][0][0], annots_pred["odg"][0][0]
|
||||
|
||||
def inference(self, name, method=None):
|
||||
print("Inferencing :",name)
|
||||
if self.dataset_test!=None:
|
||||
if method=="EM":
|
||||
return self.inference_inner_classical(name, "EM")
|
||||
elif method=="wiener":
|
||||
return self.inference_inner_classical(name, "wiener")
|
||||
elif method=="wiener_declick":
|
||||
return self.inference_inner_classical(name, "wiener_declick")
|
||||
elif method=="EM_declick":
|
||||
return self.inference_inner_classical(name, "EM_declick")
|
||||
else:
|
||||
return self.inference_inner(name)
|
||||
171
train.py
Normal file
171
train.py
Normal file
@ -0,0 +1,171 @@
|
||||
import os
|
||||
import hydra
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def run(args):
|
||||
import unet
|
||||
import tensorflow as tf
|
||||
import tensorflow_addons as tfa
|
||||
import dataset_loader
|
||||
from tensorflow.keras.optimizers import Adam
|
||||
import soundfile as sf
|
||||
import datetime
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
|
||||
path_experiment=str(args.path_experiment)
|
||||
|
||||
if not os.path.exists(path_experiment):
|
||||
os.makedirs(path_experiment)
|
||||
|
||||
path_music_train=args.dset.path_music_train
|
||||
path_music_test=args.dset.path_music_test
|
||||
path_music_validation=args.dset.path_music_validation
|
||||
|
||||
path_noise=args.dset.path_noise
|
||||
path_recordings=args.dset.path_recordings
|
||||
|
||||
fs=args.fs
|
||||
overlap=args.overlap
|
||||
seg_len_s_train=args.seg_len_s_train
|
||||
|
||||
batch_size=args.batch_size
|
||||
epochs=args.epochs
|
||||
|
||||
num_real_test_segments=args.num_real_test_segments
|
||||
buffer_size=args.buffer_size #for shuffle
|
||||
|
||||
tensorboard_logs=args.tensorboard_logs
|
||||
|
||||
def do_stft(noisy, clean=None):
|
||||
|
||||
window_fn = tf.signal.hamming_window
|
||||
|
||||
win_size=args.stft.win_size
|
||||
hop_size=args.stft.hop_size
|
||||
|
||||
|
||||
stft_signal_noisy=tf.signal.stft(noisy,frame_length=win_size, window_fn=window_fn, frame_step=hop_size)
|
||||
stft_noisy_stacked=tf.stack( values=[tf.math.real(stft_signal_noisy), tf.math.imag(stft_signal_noisy)], axis=-1)
|
||||
|
||||
if clean!=None:
|
||||
|
||||
stft_signal_clean=tf.signal.stft(clean,frame_length=win_size, window_fn=window_fn, frame_step=hop_size)
|
||||
stft_clean_stacked=tf.stack( values=[tf.math.real(stft_signal_clean), tf.math.imag(stft_signal_clean)], axis=-1)
|
||||
|
||||
|
||||
return stft_noisy_stacked, stft_clean_stacked
|
||||
else:
|
||||
|
||||
return stft_noisy_stacked
|
||||
|
||||
#Loading data. The train dataset object is a generator. The validation dataset is loaded in memory.
|
||||
|
||||
dataset_train, dataset_val=dataset_loader.load_data(buffer_size, path_music_train, path_music_validation, path_noise, fs=fs, seg_len_s=seg_len_s_train)
|
||||
|
||||
dataset_train=dataset_train.map(do_stft, num_parallel_calls=args.num_workers, deterministic=None)
|
||||
dataset_val=dataset_val.map(do_stft, num_parallel_calls=args.num_workers, deterministic=None)
|
||||
|
||||
strategy = tf.distribute.MirroredStrategy()
|
||||
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
|
||||
|
||||
with strategy.scope():
|
||||
#build the model
|
||||
unet_model = unet.build_model_denoise(unet_args=args.unet)
|
||||
|
||||
current_lr=args.lr
|
||||
optimizer = Adam(learning_rate=current_lr, beta_1=args.beta1, beta_2=args.beta2)
|
||||
|
||||
loss=tf.keras.losses.MeanAbsoluteError()
|
||||
|
||||
if args.use_tensorboard:
|
||||
log_dir = os.path.join(tensorboard_logs, os.path.basename(path_experiment)+"_"+datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
|
||||
train_summary_writer = tf.summary.create_file_writer(log_dir+"/train")
|
||||
val_summary_writer = tf.summary.create_file_writer(log_dir+"/validation")
|
||||
|
||||
#path where the checkpoints will be saved
|
||||
checkpoint_filepath=os.path.join(path_experiment, 'checkpoint')
|
||||
|
||||
dataset_train=dataset_train.batch(batch_size)
|
||||
dataset_val=dataset_val.batch(batch_size)
|
||||
|
||||
#prefetching the dataset for better performance
|
||||
dataset_train=dataset_train.prefetch(batch_size*20)
|
||||
dataset_val=dataset_val.prefetch(batch_size*20)
|
||||
|
||||
dataset_train=strategy.experimental_distribute_dataset(dataset_train)
|
||||
dataset_val=strategy.experimental_distribute_dataset(dataset_val)
|
||||
|
||||
iterator = iter(dataset_train)
|
||||
|
||||
from trainer import Trainer
|
||||
|
||||
trainer=Trainer(unet_model,optimizer,loss,strategy, path_experiment, args)
|
||||
|
||||
for epoch in range(epochs):
|
||||
total_loss=0
|
||||
step_loss=0
|
||||
for step in tqdm(range(args.steps_per_epoch), desc="Training epoch "+str(epoch)):
|
||||
step_loss=trainer.distributed_training_step(iterator.get_next())
|
||||
total_loss+=step_loss
|
||||
with train_summary_writer.as_default():
|
||||
tf.summary.scalar('batch_loss', step_loss, step=step)
|
||||
tf.summary.scalar('batch_mean_absolute_error', trainer.train_mae.result(), step=step)
|
||||
|
||||
train_loss=total_loss/args.steps_per_epoch
|
||||
|
||||
for x in tqdm(dataset_val, desc="Validating epoch "+str(epoch)):
|
||||
trainer.distributed_test_step(x)
|
||||
|
||||
template = ("Epoch {}, Loss: {}, train_MAE: {}, val_Loss: {}, val_MAE: {}")
|
||||
print (template.format(epoch+1, train_loss, trainer.train_mae.result(), trainer.val_loss.result(), trainer.val_mae.result()))
|
||||
|
||||
with train_summary_writer.as_default():
|
||||
tf.summary.scalar('epoch_loss', train_loss, step=epoch)
|
||||
tf.summary.scalar('epoch_mean_absolute_error', trainer.train_mae.result(), step=epoch)
|
||||
with val_summary_writer.as_default():
|
||||
tf.summary.scalar('epoch_loss', trainer.val_loss.result(), step=epoch)
|
||||
tf.summary.scalar('epoch_mean_absolute_error', trainer.val_mae.result(), step=epoch)
|
||||
|
||||
trainer.train_mae.reset_states()
|
||||
trainer.val_loss.reset_states()
|
||||
trainer.val_mae.reset_states()
|
||||
|
||||
if (epoch+1) % 50 == 0:
|
||||
if args.variable_lr:
|
||||
current_lr*=1e-1
|
||||
trainer.optimizer.lr=current_lr
|
||||
try:
|
||||
unet_model.save_weights(checkpoint_filpath)
|
||||
except:
|
||||
pass
|
||||
|
||||
def _main(args):
|
||||
global __file__
|
||||
|
||||
__file__ = hydra.utils.to_absolute_path(__file__)
|
||||
|
||||
run(args)
|
||||
|
||||
|
||||
@hydra.main(config_path="conf/conf.yaml")
|
||||
def main(args):
|
||||
try:
|
||||
_main(args)
|
||||
except Exception:
|
||||
logger.exception("Some error happened")
|
||||
# Hydra intercepts exit code, fixed in beta but I could not get the beta to work
|
||||
os._exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
71
trainer.py
Normal file
71
trainer.py
Normal file
@ -0,0 +1,71 @@
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import cv2
|
||||
import librosa
|
||||
import imageio
|
||||
import tensorflow as tf
|
||||
import soundfile as sf
|
||||
import subprocess
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
from scipy.io import loadmat
|
||||
|
||||
class Trainer():
|
||||
def __init__(self, model, optimizer,loss, strategy, path_experiment, args):
|
||||
self.model=model
|
||||
print(self.model.summary())
|
||||
self.strategy=strategy
|
||||
self.optimizer=optimizer
|
||||
self.path_experiment=path_experiment
|
||||
self.args=args
|
||||
#self.metrics=[]
|
||||
|
||||
with self.strategy.scope():
|
||||
#loss_fn=tf.keras.losses.mean_absolute_error
|
||||
loss.reduction=tf.keras.losses.Reduction.NONE
|
||||
self.loss_object=loss
|
||||
self.train_mae_s1=tf.keras.metrics.MeanAbsoluteError(name="train_mae_s1")
|
||||
self.train_mae=tf.keras.metrics.MeanAbsoluteError(name="train_mae_s2")
|
||||
self.val_mae=tf.keras.metrics.MeanAbsoluteError(name="validation_mae")
|
||||
self.val_loss = tf.keras.metrics.Mean(name='test_loss')
|
||||
|
||||
|
||||
def train_step(self,inputs):
|
||||
noisy, clean= inputs
|
||||
|
||||
with tf.GradientTape() as tape:
|
||||
|
||||
logits_2,logits_1 = self.model(noisy, training=True) # Logits for this minibatch
|
||||
|
||||
loss_value = tf.reduce_mean(self.loss_object(clean, logits_2) + tf.reduce_mean(self.loss_object(clean, logits_1)))
|
||||
|
||||
grads = tape.gradient(loss_value, self.model.trainable_weights)
|
||||
self.optimizer.apply_gradients(zip(grads, self.model.trainable_weights))
|
||||
self.train_mae.update_state(clean, logits_2)
|
||||
self.train_mae_s1.update_state(clean, logits_1)
|
||||
return loss_value
|
||||
|
||||
def test_step(self,inputs):
|
||||
|
||||
noisy,clean = inputs
|
||||
|
||||
predictions_s2, predictions_s1 = self.model(noisy, training=False)
|
||||
t_loss = self.loss_object(clean, predictions_s2)+self.loss_object(clean, predictions_s1)
|
||||
|
||||
self.val_mae.update_state(clean,predictions_s2)
|
||||
self.val_loss.update_state(t_loss)
|
||||
|
||||
@tf.function()
|
||||
def distributed_training_step(self,inputs):
|
||||
per_replica_losses=self.strategy.run(self.train_step, args=(inputs,))
|
||||
reduced_losses=self.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None)
|
||||
return reduced_losses
|
||||
|
||||
@tf.function
|
||||
def distributed_test_step(self,inputs):
|
||||
return self.strategy.run(self.test_step, args=(inputs,))
|
||||
|
||||
|
||||
|
||||
|
||||
486
unet.py
Normal file
486
unet.py
Normal file
@ -0,0 +1,486 @@
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras import Model, Input
|
||||
from tensorflow.keras import layers
|
||||
from tensorflow.keras.initializers import TruncatedNormal
|
||||
import math as m
|
||||
|
||||
def build_model_denoise(unet_args=None):
|
||||
|
||||
inputs=Input(shape=(None, None,2))
|
||||
|
||||
outputs_stage_2,outputs_stage_1=MultiStage_denoise(unet_args=unet_args)(inputs)
|
||||
|
||||
#Encapsulating MultiStage_denoise in a keras.Model object
|
||||
model= tf.keras.Model(inputs=inputs,outputs=[outputs_stage_2, outputs_stage_1])
|
||||
|
||||
return model
|
||||
class DenseBlock(layers.Layer):
|
||||
'''
|
||||
[B, T, F, N] => [B, T, F, N]
|
||||
DenseNet Block consisting of "num_layers" densely connected convolutional layers
|
||||
'''
|
||||
def __init__(self, num_layers, N, ksize,activation):
|
||||
'''
|
||||
num_layers: number of densely connected conv. layers
|
||||
N: Number of filters (same in each layer)
|
||||
ksize: Kernel size (same in each layer)
|
||||
'''
|
||||
super(DenseBlock, self).__init__()
|
||||
self.activation=activation
|
||||
|
||||
self.paddings_1=get_paddings(ksize)
|
||||
self.H=[]
|
||||
self.num_layers=num_layers
|
||||
|
||||
for i in range(num_layers):
|
||||
self.H.append(layers.Conv2D(filters=N,
|
||||
kernel_size=ksize,
|
||||
kernel_initializer=TruncatedNormal(),
|
||||
strides=1,
|
||||
padding='VALID',
|
||||
activation=self.activation))
|
||||
|
||||
def call(self, x):
|
||||
|
||||
x_=tf.pad(x, self.paddings_1, mode='SYMMETRIC')
|
||||
x_ = self.H[0](x_)
|
||||
if self.num_layers>1:
|
||||
for h in self.H[1:]:
|
||||
x = tf.concat([x_, x], axis=-1)
|
||||
x_=tf.pad(x, self.paddings_1, mode='SYMMETRIC')
|
||||
x_ = h(x_)
|
||||
|
||||
return x_
|
||||
|
||||
|
||||
class FinalBlock(layers.Layer):
|
||||
'''
|
||||
[B, T, F, N] => [B, T, F, 2]
|
||||
Final block. Basically, a 3x3 conv. layer to map the output features to the output complex spectrogram.
|
||||
|
||||
'''
|
||||
def __init__(self):
|
||||
super(FinalBlock, self).__init__()
|
||||
ksize=(3,3)
|
||||
self.paddings_2=get_paddings(ksize)
|
||||
self.conv2=layers.Conv2D(filters=2,
|
||||
kernel_size=ksize,
|
||||
kernel_initializer=TruncatedNormal(),
|
||||
strides=1,
|
||||
padding='VALID',
|
||||
activation=None)
|
||||
|
||||
|
||||
def call(self, inputs ):
|
||||
|
||||
x=tf.pad(inputs, self.paddings_2, mode='SYMMETRIC')
|
||||
pred=self.conv2(x)
|
||||
|
||||
return pred
|
||||
class SAM(layers.Layer):
|
||||
'''
|
||||
[B, T, F, N] => [B, T, F, N] , [B, T, F, N]
|
||||
Supervised Attention Module:
|
||||
The purpose of SAM is to make the network only propagate the most relevant features to the second stage, discarding the less useful ones.
|
||||
The estimated residual noise signal is generated from the U-Net output features by means of a 3x3 convolutional layer.
|
||||
The first stage output is then calculated adding the original input spectrogram to the residual noise.
|
||||
The attention-guided features are computed using the attention masks M, which are directly calculated from the first stage output with a 1x1 convolution and a sigmoid function.
|
||||
|
||||
'''
|
||||
def __init__(self, n_feat):
|
||||
super(SAM, self).__init__()
|
||||
|
||||
ksize=(3,3)
|
||||
self.paddings_1=get_paddings(ksize)
|
||||
self.conv1 = layers.Conv2D(filters=n_feat,
|
||||
kernel_size=ksize,
|
||||
kernel_initializer=TruncatedNormal(),
|
||||
strides=1,
|
||||
padding='VALID',
|
||||
activation=None)
|
||||
ksize=(3,3)
|
||||
self.paddings_2=get_paddings(ksize)
|
||||
self.conv2=layers.Conv2D(filters=2,
|
||||
kernel_size=ksize,
|
||||
kernel_initializer=TruncatedNormal(),
|
||||
strides=1,
|
||||
padding='VALID',
|
||||
activation=None)
|
||||
|
||||
ksize=(3,3)
|
||||
self.paddings_3=get_paddings(ksize)
|
||||
self.conv3 = layers.Conv2D(filters=n_feat,
|
||||
kernel_size=ksize,
|
||||
kernel_initializer=TruncatedNormal(),
|
||||
strides=1,
|
||||
padding='VALID',
|
||||
activation=None)
|
||||
self.cropadd=CropAddBlock()
|
||||
|
||||
def call(self, inputs, input_spectrogram):
|
||||
x1=tf.pad(inputs, self.paddings_1, mode='SYMMETRIC')
|
||||
x1 = self.conv1(x1)
|
||||
|
||||
x=tf.pad(inputs, self.paddings_2, mode='SYMMETRIC')
|
||||
x=self.conv2(x)
|
||||
|
||||
#residual prediction
|
||||
pred = layers.Add()([x, input_spectrogram]) #features to next stage
|
||||
|
||||
x3=tf.pad(pred, self.paddings_3, mode='SYMMETRIC')
|
||||
M=self.conv3(x3)
|
||||
|
||||
M= tf.keras.activations.sigmoid(M)
|
||||
x1=layers.Multiply()([x1, M])
|
||||
x1 = layers.Add()([x1, inputs]) #features to next stage
|
||||
|
||||
return x1, pred
|
||||
|
||||
|
||||
class AddFreqEncoding(layers.Layer):
|
||||
'''
|
||||
[B, T, F, 2] => [B, T, F, 12]
|
||||
Generates frequency positional embeddings and concatenates them as 10 extra channels
|
||||
This function is optimized for F=1025
|
||||
'''
|
||||
def __init__(self, f_dim):
|
||||
super(AddFreqEncoding, self).__init__()
|
||||
pi = tf.constant(m.pi)
|
||||
pi=tf.cast(pi,'float32')
|
||||
self.f_dim=f_dim #f_dim is fixed
|
||||
n=tf.cast(tf.range(f_dim)/(f_dim-1),'float32')
|
||||
coss=tf.math.cos(pi*n)
|
||||
f_channel = tf.expand_dims(coss, -1) #(1025,1)
|
||||
self.fembeddings= f_channel
|
||||
|
||||
for k in range(1,10):
|
||||
coss=tf.math.cos(2**k*pi*n)
|
||||
f_channel = tf.expand_dims(coss, -1) #(1025,1)
|
||||
self.fembeddings=tf.concat([self.fembeddings,f_channel],axis=-1) #(1025,10)
|
||||
|
||||
|
||||
def call(self, input_tensor):
|
||||
|
||||
batch_size_tensor = tf.shape(input_tensor)[0] # get batch size
|
||||
time_dim = tf.shape(input_tensor)[1] # get time dimension
|
||||
|
||||
fembeddings_2 = tf.broadcast_to(self.fembeddings, [batch_size_tensor, time_dim, self.f_dim, 10])
|
||||
|
||||
|
||||
return tf.concat([input_tensor,fembeddings_2],axis=-1) #(batch,427,1025,12)
|
||||
|
||||
|
||||
def get_paddings(K):
|
||||
return tf.constant([[0,0],[K[0]//2, K[0]//2 -(1- K[0]%2) ], [ K[1]//2, K[1]//2 -(1- K[1]%2) ],[0,0]])
|
||||
|
||||
class Decoder(layers.Layer):
|
||||
'''
|
||||
[B, T, F, N] , skip connections => [B, T, F, N]
|
||||
Decoder side of the U-Net subnetwork.
|
||||
'''
|
||||
def __init__(self, Ns, Ss, unet_args):
|
||||
super(Decoder, self).__init__()
|
||||
|
||||
self.Ns=Ns
|
||||
self.Ss=Ss
|
||||
self.activation=unet_args.activation
|
||||
self.depth=unet_args.depth
|
||||
|
||||
|
||||
ksize=(3,3)
|
||||
self.paddings_3=get_paddings(ksize)
|
||||
self.conv2d_3=layers.Conv2D(filters=self.Ns[self.depth],
|
||||
kernel_size=ksize,
|
||||
kernel_initializer=TruncatedNormal(),
|
||||
strides=1,
|
||||
padding='VALID',
|
||||
activation=self.activation)
|
||||
|
||||
self.cropadd=CropAddBlock()
|
||||
|
||||
self.dblocks=[]
|
||||
for i in range(self.depth):
|
||||
self.dblocks.append(D_Block(layer_idx=i,N=self.Ns[i], S=self.Ss[i], activation=self.activation,num_tfc=unet_args.num_tfc))
|
||||
|
||||
def call(self,inputs, contracting_layers):
|
||||
x=inputs
|
||||
for i in range(self.depth,0,-1):
|
||||
x=self.dblocks[i-1](x, contracting_layers[i-1])
|
||||
return x
|
||||
|
||||
class Encoder(tf.keras.Model):
|
||||
|
||||
'''
|
||||
[B, T, F, N] => skip connections , [B, T, F, N_4]
|
||||
Encoder side of the U-Net subnetwork.
|
||||
'''
|
||||
def __init__(self, Ns, Ss, unet_args):
|
||||
super(Encoder, self).__init__()
|
||||
self.Ns=Ns
|
||||
self.Ss=Ss
|
||||
self.activation=unet_args.activation
|
||||
self.depth=unet_args.depth
|
||||
|
||||
self.contracting_layers = {}
|
||||
|
||||
self.eblocks=[]
|
||||
for i in range(self.depth):
|
||||
self.eblocks.append(E_Block(layer_idx=i,N0=self.Ns[i],N=self.Ns[i+1],S=self.Ss[i], activation=self.activation , num_tfc=unet_args.num_tfc))
|
||||
|
||||
self.i_block=I_Block(self.Ns[self.depth],self.activation,unet_args.num_tfc)
|
||||
|
||||
def call(self, inputs):
|
||||
x=inputs
|
||||
for i in range(self.depth):
|
||||
|
||||
x, x_contract=self.eblocks[i](x)
|
||||
|
||||
self.contracting_layers[i] = x_contract #if remove 0, correct this
|
||||
x=self.i_block(x)
|
||||
|
||||
return x, self.contracting_layers
|
||||
|
||||
class MultiStage_denoise(tf.keras.Model):
|
||||
|
||||
def __init__(self, unet_args=None):
|
||||
super(MultiStage_denoise, self).__init__()
|
||||
|
||||
self.activation=unet_args.activation
|
||||
self.depth=unet_args.depth
|
||||
if unet_args.use_fencoding:
|
||||
self.freq_encoding=AddFreqEncoding(unet_args.f_dim)
|
||||
self.use_sam=unet_args.use_SAM
|
||||
self.use_fencoding=unet_args.use_fencoding
|
||||
self.num_stages=unet_args.num_stages
|
||||
#Encoder
|
||||
self.Ns= [32,64,64,128,128,256,512]
|
||||
self.Ss= [(2,2),(2,2),(2,2),(2,2),(2,2),(2,2)]
|
||||
|
||||
#initial feature extractor
|
||||
ksize=(7,7)
|
||||
self.paddings_1=get_paddings(ksize)
|
||||
self.conv2d_1 = layers.Conv2D(filters=self.Ns[0],
|
||||
kernel_size=ksize,
|
||||
kernel_initializer=TruncatedNormal(),
|
||||
strides=1,
|
||||
padding='VALID',
|
||||
activation=self.activation)
|
||||
|
||||
|
||||
self.encoder_s1=Encoder(self.Ns, self.Ss, unet_args)
|
||||
self.decoder_s1=Decoder(self.Ns, self.Ss, unet_args)
|
||||
|
||||
self.cropconcat = CropConcatBlock()
|
||||
self.cropadd = CropAddBlock()
|
||||
|
||||
self.finalblock=FinalBlock()
|
||||
|
||||
if self.num_stages>1:
|
||||
self.sam_1=SAM(self.Ns[0])
|
||||
|
||||
#initial feature extractor
|
||||
ksize=(7,7)
|
||||
self.paddings_2=get_paddings(ksize)
|
||||
self.conv2d_2 = layers.Conv2D(filters=self.Ns[0],
|
||||
kernel_size=ksize,
|
||||
kernel_initializer=TruncatedNormal(),
|
||||
strides=1,
|
||||
padding='VALID',
|
||||
activation=self.activation)
|
||||
|
||||
|
||||
self.encoder_s2=Encoder(self.Ns, self.Ss, unet_args)
|
||||
self.decoder_s2=Decoder(self.Ns, self.Ss, unet_args)
|
||||
|
||||
@tf.function()
|
||||
def call(self, inputs):
|
||||
|
||||
if self.use_fencoding:
|
||||
x_w_freq=self.freq_encoding(inputs) #None, None, 1025, 12
|
||||
else:
|
||||
x_w_freq=inputs
|
||||
|
||||
#intitial feature extractor
|
||||
x=tf.pad(x_w_freq, self.paddings_1, mode='SYMMETRIC')
|
||||
x=self.conv2d_1(x) #None, None, 1025, 32
|
||||
|
||||
x, contracting_layers_s1= self.encoder_s1(x)
|
||||
#decoder
|
||||
feats_s1 =self.decoder_s1(x, contracting_layers_s1) #None, None, 1025, 32 features
|
||||
|
||||
if self.num_stages>1:
|
||||
#SAM module
|
||||
Fout, pred_stage_1=self.sam_1(feats_s1,inputs)
|
||||
|
||||
#intitial feature extractor
|
||||
x=tf.pad(x_w_freq, self.paddings_2, mode='SYMMETRIC')
|
||||
x=self.conv2d_2(x)
|
||||
|
||||
if self.use_sam:
|
||||
x = tf.concat([x, Fout], axis=-1)
|
||||
else:
|
||||
x = tf.concat([x,feats_s1], axis=-1)
|
||||
|
||||
x, contracting_layers_s2= self.encoder_s2(x)
|
||||
|
||||
feats_s2=self.decoder_s2(x, contracting_layers_s2) #None, None, 1025, 32 features
|
||||
|
||||
#consider implementing a third stage?
|
||||
|
||||
pred_stage_2=self.finalblock(feats_s2)
|
||||
return pred_stage_2, pred_stage_1
|
||||
else:
|
||||
pred_stage_1=self.finalblock(feats_s1)
|
||||
return pred_stage_1
|
||||
|
||||
class I_Block(layers.Layer):
|
||||
'''
|
||||
[B, T, F, N] => [B, T, F, N]
|
||||
Intermediate block:
|
||||
Basically, a densenet block with a residual connection
|
||||
'''
|
||||
def __init__(self,N,activation, num_tfc, **kwargs):
|
||||
super(I_Block, self).__init__(**kwargs)
|
||||
|
||||
ksize=(3,3)
|
||||
self.tfc=DenseBlock(num_tfc,N,ksize, activation)
|
||||
|
||||
self.conv2d_res= layers.Conv2D(filters=N,
|
||||
kernel_size=(1,1),
|
||||
kernel_initializer=TruncatedNormal(),
|
||||
strides=1,
|
||||
padding='VALID')
|
||||
|
||||
def call(self,inputs):
|
||||
x=self.tfc(inputs)
|
||||
|
||||
inputs_proj=self.conv2d_res(inputs)
|
||||
return layers.Add()([x,inputs_proj])
|
||||
|
||||
|
||||
class E_Block(layers.Layer):
|
||||
|
||||
def __init__(self, layer_idx,N0, N, S,activation, num_tfc, **kwargs):
|
||||
super(E_Block, self).__init__(**kwargs)
|
||||
self.layer_idx=layer_idx
|
||||
self.N0=N0
|
||||
self.N=N
|
||||
self.S=S
|
||||
self.activation=activation
|
||||
self.i_block=I_Block(N0,activation,num_tfc)
|
||||
|
||||
ksize=(S[0]+2,S[1]+2)
|
||||
self.paddings_2=get_paddings(ksize)
|
||||
self.conv2d_2 = layers.Conv2D(filters=N,
|
||||
kernel_size=(S[0]+2,S[1]+2),
|
||||
kernel_initializer=TruncatedNormal(),
|
||||
strides=S,
|
||||
padding='VALID',
|
||||
activation=self.activation)
|
||||
|
||||
|
||||
def call(self, inputs, training=None, **kwargs):
|
||||
x=self.i_block(inputs)
|
||||
|
||||
x_down=tf.pad(x, self.paddings_2, mode='SYMMETRIC')
|
||||
x_down = self.conv2d_2(x_down)
|
||||
|
||||
return x_down, x
|
||||
|
||||
|
||||
def get_config(self):
|
||||
return dict(layer_idx=self.layer_idx,
|
||||
N=self.N,
|
||||
S=self.S,
|
||||
**super(E_Block, self).get_config()
|
||||
)
|
||||
class D_Block(layers.Layer):
|
||||
|
||||
def __init__(self, layer_idx, N, S,activation, num_tfc, **kwargs):
|
||||
super(D_Block, self).__init__(**kwargs)
|
||||
self.layer_idx=layer_idx
|
||||
self.N=N
|
||||
self.S=S
|
||||
self.activation=activation
|
||||
ksize=(S[0]+2, S[1]+2)
|
||||
self.paddings_1=get_paddings(ksize)
|
||||
|
||||
self.tconv_1= layers.Conv2DTranspose(filters=N,
|
||||
kernel_size=(S[0]+2, S[1]+2),
|
||||
kernel_initializer=TruncatedNormal(),
|
||||
strides=S,
|
||||
activation=self.activation,
|
||||
padding='VALID')
|
||||
|
||||
self.upsampling = layers.UpSampling2D(size=S, interpolation='nearest')
|
||||
|
||||
self.projection = layers.Conv2D(filters=N,
|
||||
kernel_size=(1,1),
|
||||
kernel_initializer=TruncatedNormal(),
|
||||
strides=1,
|
||||
activation=self.activation,
|
||||
padding='VALID')
|
||||
self.cropadd=CropAddBlock()
|
||||
self.cropconcat=CropConcatBlock()
|
||||
|
||||
self.i_block=I_Block(N,activation,num_tfc)
|
||||
|
||||
def call(self, inputs, bridge, previous_encoder=None, previous_decoder=None,**kwargs):
|
||||
x = inputs
|
||||
x=tf.pad(x, self.paddings_1, mode='SYMMETRIC')
|
||||
x = self.tconv_1(inputs)
|
||||
|
||||
x2= self.upsampling(inputs)
|
||||
|
||||
if x2.shape[-1]!=x.shape[-1]:
|
||||
x2= self.projection(x2)
|
||||
|
||||
x= self.cropadd(x,x2)
|
||||
|
||||
|
||||
x=self.cropconcat(x,bridge)
|
||||
|
||||
x=self.i_block(x)
|
||||
return x
|
||||
|
||||
def get_config(self):
|
||||
return dict(layer_idx=self.layer_idx,
|
||||
N=self.N,
|
||||
S=self.S,
|
||||
**super(D_Block, self).get_config()
|
||||
)
|
||||
|
||||
class CropAddBlock(layers.Layer):
|
||||
|
||||
def call(self,down_layer, x, **kwargs):
|
||||
x1_shape = tf.shape(down_layer)
|
||||
x2_shape = tf.shape(x)
|
||||
|
||||
|
||||
height_diff = (x1_shape[1] - x2_shape[1]) // 2
|
||||
width_diff = (x1_shape[2] - x2_shape[2]) // 2
|
||||
|
||||
down_layer_cropped = down_layer[:,
|
||||
height_diff: (x2_shape[1] + height_diff),
|
||||
width_diff: (x2_shape[2] + width_diff),
|
||||
:]
|
||||
|
||||
x = layers.Add()([down_layer_cropped, x])
|
||||
return x
|
||||
|
||||
class CropConcatBlock(layers.Layer):
|
||||
|
||||
def call(self, down_layer, x, **kwargs):
|
||||
x1_shape = tf.shape(down_layer)
|
||||
x2_shape = tf.shape(x)
|
||||
|
||||
height_diff = (x1_shape[1] - x2_shape[1]) // 2
|
||||
width_diff = (x1_shape[2] - x2_shape[2]) // 2
|
||||
|
||||
down_layer_cropped = down_layer[:,
|
||||
height_diff: (x2_shape[1] + height_diff),
|
||||
width_diff: (x2_shape[2] + width_diff),
|
||||
:]
|
||||
|
||||
x = tf.concat([down_layer_cropped, x], axis=-1)
|
||||
return x
|
||||
Loading…
Reference in New Issue
Block a user