startin again

This commit is contained in:
Moliner Eloi 2021-08-30 18:30:51 +03:00
parent fd35cee560
commit c4cbdd2b8a
7 changed files with 1931 additions and 0 deletions

497
dataset_loader.py Normal file
View File

@ -0,0 +1,497 @@
from typing import Tuple, Dict
import ast
import tensorflow as tf
import random
import os
import numpy as np
from scipy.fft import fft, ifft
import soundfile as sf
import librosa
import math
import pandas as pd
import scipy as sp
import glob
from tqdm import tqdm
#generator function. It reads the csv file with pandas and loads the largest audio segments from each recording. If extend=False, it will only read the segments with length>length_seg, trim them and yield them with no further processing. Otherwise, if the segment length is inferior, it will extend the length using concatenative synthesis.
def __noise_sample_generator(info_file,fs, length_seq, split):
head=os.path.split(info_file)[0]
load_data=pd.read_csv(info_file)
#split= train, validation, test
load_data_split=load_data.loc[load_data["split"]==split]
load_data_split=load_data_split.reset_index(drop=True)
while True:
r = list(range(len(load_data_split)))
if split!="test":
random.shuffle(r)
for i in r:
segments=ast.literal_eval(load_data_split.loc[i,"segments"])
if split=="test":
loaded_data, Fs=sf.read(os.path.join(head,load_data_split["recording"].loc[i],load_data_split["largest_segment"].loc[i]))
else:
num=np.random.randint(0,len(segments))
loaded_data, Fs=sf.read(os.path.join(head,load_data_split["recording"].loc[i],segments[num]))
if fs!=Fs:
print("wrong fs, resampling...")
data=librosa.resample(loaded_data, Fs, fs)
yield __extend_sample_by_repeating(loaded_data,fs,length_seq)
def __extend_sample_by_repeating(data, fs,seq_len):
rpm=78
target_samp=seq_len
large_data=np.zeros(shape=(target_samp,2))
if len(data)>=target_samp:
large_data=data[0:target_samp]
return large_data
bls=(1000*44100)/1000 #hardcoded
window=np.stack((np.hanning(bls) ,np.hanning(bls)), axis=1)
window_left=window[0:int(bls/2),:]
window_right=window[int(bls/2)::,:]
bls=int(bls/2)
rps=rpm/60
period=1/rps
period_sam=int(period*fs)
overhead=len(data)%period_sam
if(overhead>bls):
complete_periods=(len(data)//period_sam)*period_sam
else:
complete_periods=(len(data)//period_sam -1)*period_sam
a=np.multiply(data[0:bls], window_left)
b=np.multiply(data[complete_periods:complete_periods+bls], window_right)
c_1=np.concatenate((data[0:complete_periods,:],b))
c_2=np.concatenate((a,data[bls:complete_periods,:],b))
c_3=np.concatenate((a,data[bls::,:]))
large_data[0:complete_periods+bls,:]=c_1
pointer=complete_periods
not_finished=True
while (not_finished):
if target_samp>pointer+complete_periods+bls:
large_data[pointer:pointer+complete_periods+bls] +=c_2
pointer+=complete_periods
else:
large_data[pointer::]+=c_3[0:(target_samp-pointer)]
#finish
not_finished=False
return large_data
def generate_real_recordings_data(path_recordings, fs=44100, seg_len_s=15, stereo=False):
records_info=os.path.join(path_recordings,"audio_files.txt")
num_lines = sum(1 for line in open(records_info))
f = open(records_info,"r")
#load data record files
print("Loading record files")
records=[]
seg_len=fs*seg_len_s
pointer=int(fs*5) #starting at second 5 by default
for i in tqdm(range(num_lines)):
audio=f.readline()
audio=audio[:-1]
data, fs=sf.read(os.path.join(path_recordings,audio))
if len(data.shape)>1 and not(stereo):
data=np.mean(data,axis=1)
#elif stereo and len(data.shape)==1:
# data=np.stack((data, data), axis=1)
#normalize
data=data/np.max(np.abs(data))
segment=data[pointer:pointer+seg_len]
records.append(segment.astype("float32"))
return records
def generate_paired_data_test_formal(path_pianos, path_noises, noise_amount="low_snr",num_samples=-1, fs=44100, seg_len_s=5 , extend=True, stereo=False, prenoise=False):
print(num_samples)
segments_clean=[]
segments_noisy=[]
seg_len=fs*seg_len_s
noises_info=os.path.join(path_noises,"info.csv")
np.random.seed(42)
if noise_amount=="low_snr":
SNRs=np.random.uniform(2,6,num_samples)
elif noise_amount=="mid_snr":
SNRs=np.random.uniform(6,12,num_samples)
scales=np.random.uniform(-4,0,num_samples)
#SNRs=[2,6,12] #HARDCODED!!!!
i=0
print(path_pianos[0])
print(seg_len)
train_samples=glob.glob(os.path.join(path_pianos[0],"*.wav"))
train_samples=sorted(train_samples)
if prenoise:
noise_generator=__noise_sample_generator(noises_info,fs, seg_len+fs, extend, "test") #Adds 1s of silence add the begiing, longer noise
else:
noise_generator=__noise_sample_generator(noises_info,fs, seg_len, extend, "test") #this will take care of everything
#load data clean files
for file in tqdm(train_samples): #add [1:5] for testing
data_clean, samplerate = sf.read(file)
if samplerate!=fs:
print("!!!!WRONG SAMPLE RATe!!!")
#Stereo to mono
if len(data_clean.shape)>1 and not(stereo):
data_clean=np.mean(data_clean,axis=1)
#elif stereo and len(data_clean.shape)==1:
# data_clean=np.stack((data_clean, data_clean), axis=1)
#normalize
data_clean=data_clean/np.max(np.abs(data_clean))
#data_clean_loaded.append(data_clean)
#framify data clean files
#framify arguments: seg_len, hop_size
hop_size=int(seg_len)# no overlap
num_frames=np.floor(len(data_clean)/hop_size - seg_len/hop_size +1)
print(num_frames)
if num_frames==0:
data_clean=np.concatenate((data_clean, np.zeros(shape=(int(2*seg_len-len(data_clean)),))), axis=0)
num_frames=1
data_not_finished=True
pointer=0
while(data_not_finished):
if i>=num_samples:
break
segment=data_clean[pointer:pointer+seg_len]
pointer=pointer+hop_size
if pointer+seg_len>len(data_clean):
data_not_finished=False
segment=segment.astype('float32')
#SNRs=np.random.uniform(2,20)
snr=SNRs[i]
scale=scales[i]
#load noise signal
data_noise= next(noise_generator)
data_noise=np.mean(data_noise,axis=1)
#normalize
data_noise=data_noise/np.max(np.abs(data_noise))
new_noise=data_noise #if more processing needed, add here
#load clean data
#configure sizes
power_clean=np.var(segment)
#estimate noise power
if prenoise:
power_noise=np.var(new_noise[fs::])
else:
power_noise=np.var(new_noise)
snr = 10.0**(snr/10.0)
#sum both signals according to snr
if prenoise:
segment=np.concatenate((np.zeros(shape=(fs,)),segment),axis=0) #add one second of silence
summed=segment+np.sqrt(power_clean/(snr*power_noise))*new_noise #not sure if this is correct, maybe revisit later!!
summed=summed.astype('float32')
#yield tf.convert_to_tensor(summed), tf.convert_to_tensor(segment)
summed=10.0**(scale/10.0) *summed
segment=10.0**(scale/10.0) *segment
segments_noisy.append(summed.astype('float32'))
segments_clean.append(segment.astype('float32'))
i=i+1
return segments_noisy, segments_clean
def generate_test_data(path_music, path_noises,num_samples=-1, fs=44100, seg_len_s=5):
segments_clean=[]
segments_noisy=[]
seg_len=fs*seg_len_s
noises_info=os.path.join(path_noises,"info.csv")
SNRs=[2,6,12] #HARDCODED!!!!
for path in path_music:
print(path)
train_samples=glob.glob(os.path.join(path,"*.wav"))
train_samples=sorted(train_samples)
noise_generator=__noise_sample_generator(noises_info,fs, seg_len, "test") #this will take care of everything
#load data clean files
jj=0
for file in tqdm(train_samples): #add [1:5] for testing
data_clean, samplerate = sf.read(file)
if samplerate!=fs:
print("!!!!WRONG SAMPLE RATe!!!")
#Stereo to mono
if len(data_clean.shape)>1:
data_clean=np.mean(data_clean,axis=1)
#normalize
data_clean=data_clean/np.max(np.abs(data_clean))
#data_clean_loaded.append(data_clean)
#framify data clean files
#framify arguments: seg_len, hop_size
hop_size=int(seg_len)# no overlap
num_frames=np.floor(len(data_clean)/hop_size - seg_len/hop_size +1)
if num_frames==0:
data_clean=np.concatenate((data_clean, np.zeros(shape=(int(2*seg_len-len(data_clean)),))), axis=0)
num_frames=1
pointer=0
segment=data_clean[pointer:pointer+(seg_len-2*fs)]
segment=segment.astype('float32')
segment=np.concatenate(( np.zeros(shape=(2*fs,)), segment), axis=0) #I hope its ok
#segments_clean.append(segment)
for snr in SNRs:
#load noise signal
data_noise= next(noise_generator)
data_noise=np.mean(data_noise,axis=1)
#normalize
data_noise=data_noise/np.max(np.abs(data_noise))
new_noise=data_noise #if more processing needed, add here
#load clean data
#configure sizes
#estimate clean signal power
power_clean=np.var(segment)
#estimate noise power
power_noise=np.var(new_noise)
snr = 10.0**(snr/10.0)
#sum both signals according to snr
summed=segment+np.sqrt(power_clean/(snr*power_noise))*new_noise #not sure if this is correct, maybe revisit later!!
summed=summed.astype('float32')
#yield tf.convert_to_tensor(summed), tf.convert_to_tensor(segment)
segments_noisy.append(summed.astype('float32'))
segments_clean.append(segment.astype('float32'))
return segments_noisy, segments_clean
def generate_val_data(path_music, path_noises,split,num_samples=-1, fs=44100, seg_len_s=5):
val_samples=[]
for path in path_music:
val_samples.extend(glob.glob(os.path.join(path,"*.wav")))
#load data clean files
print("Loading clean files")
data_clean_loaded=[]
for ff in tqdm(range(0,len(val_samples))): #add [1:5] for testing
data_clean, samplerate = sf.read(val_samples[ff])
if samplerate!=fs:
print("!!!!WRONG SAMPLE RATe!!!")
#Stereo to mono
if len(data_clean.shape)>1 :
data_clean=np.mean(data_clean,axis=1)
#normalize
data_clean=data_clean/np.max(np.abs(data_clean))
data_clean_loaded.append(data_clean)
del data_clean
#framify data clean files
print("Framifying clean files")
seg_len=fs*seg_len_s
segments_clean=[]
for file in tqdm(data_clean_loaded):
#framify arguments: seg_len, hop_size
hop_size=int(seg_len)# no overlap
num_frames=np.floor(len(file)/hop_size - seg_len/hop_size +1)
pointer=0
for i in range(0,int(num_frames)):
segment=file[pointer:pointer+seg_len]
pointer=pointer+hop_size
segment=segment.astype('float32')
segments_clean.append(segment)
del data_clean_loaded
SNRs=np.random.uniform(2,20,len(segments_clean))
scales=np.random.uniform(-6,4,len(segments_clean))
#noise_shapes=np.random.randint(0,len(noise_samples), len(segments_clean))
noises_info=os.path.join(path_noises,"info.csv")
noise_generator=__noise_sample_generator(noises_info,fs, seg_len, split) #this will take care of everything
#generate noisy segments
#load noise samples using pandas dataframe. Each split (train, val, test) should have its unique csv info file
#noise_samples=glob.glob(os.path.join(path_noises,"*.wav"))
segments_noisy=[]
print("Processing noisy segments")
for i in tqdm(range(0,len(segments_clean))):
#load noise signal
data_noise= next(noise_generator)
#Stereo to mono
data_noise=np.mean(data_noise,axis=1)
#normalize
data_noise=data_noise/np.max(np.abs(data_noise))
new_noise=data_noise #if more processing needed, add here
#load clean data
data_clean=segments_clean[i]
#configure sizes
#estimate clean signal power
power_clean=np.var(data_clean)
#estimate noise power
power_noise=np.var(new_noise)
snr = 10.0**(SNRs[i]/10.0)
#sum both signals according to snr
summed=data_clean+np.sqrt(power_clean/(snr*power_noise))*new_noise #not sure if this is correct, maybe revisit later!!
#the rest is normal
summed=10.0**(scales[i]/10.0) *summed
segments_clean[i]=10.0**(scales[i]/10.0) *segments_clean[i]
segments_noisy.append(summed.astype('float32'))
return segments_noisy, segments_clean
def generator_train(path_music, path_noises,split, fs=44100, seg_len_s=5, extend=True, stereo=False):
train_samples=[]
for path in path_music:
train_samples.extend(glob.glob(os.path.join(path.decode("utf-8") ,"*.wav")))
seg_len=fs*seg_len_s
noises_info=os.path.join(path_noises.decode("utf-8"),"info.csv")
noise_generator=__noise_sample_generator(noises_info,fs, seg_len, split.decode("utf-8")) #this will take care of everything
#load data clean files
while True:
random.shuffle(train_samples)
for file in train_samples:
data, samplerate = sf.read(file)
if samplerate!=fs:
print("!!!!WRONG SAMPLE RATe!!!")
data=np.transpose(data)
data=librosa.resample(data, samplerate, 44100)
data=np.transpose(data)
data_clean=data
#Stereo to mono
if len(data.shape)>1 :
data_clean=np.mean(data_clean,axis=1)
#normalize
data_clean=data_clean/np.max(np.abs(data_clean))
#framify data clean files
#framify arguments: seg_len, hop_size
hop_size=int(seg_len)
num_frames=np.floor(len(data_clean)/seg_len)
if num_frames==0:
data_clean=np.concatenate((data_clean, np.zeros(shape=(int(2*seg_len-len(data_clean)),))), axis=0)
num_frames=1
pointer=0
data_clean=np.roll(data_clean, np.random.randint(0,seg_len)) #if only one frame, roll it for augmentation
elif num_frames>1:
pointer=np.random.randint(0,hop_size) #initial shifting, graeat for augmentation, better than overlap as we get different frames at each "while" iteration
else:
pointer=0
data_not_finished=True
while(data_not_finished):
segment=data_clean[pointer:pointer+seg_len]
pointer=pointer+hop_size
if pointer+seg_len>len(data_clean):
data_not_finished=False
segment=segment.astype('float32')
SNRs=np.random.uniform(2,20)
scale=np.random.uniform(-6,4)
#load noise signal
data_noise= next(noise_generator)
data_noise=np.mean(data_noise,axis=1)
#normalize
data_noise=data_noise/np.max(np.abs(data_noise))
new_noise=data_noise #if more processing needed, add here
#load clean data
#configure sizes
if stereo:
#estimate clean signal power
power_clean=0.5*np.var(segment[:,0])+0.5*np.var(segment[:,1])
#estimate noise power
power_noise=0.5*np.var(new_noise[:,0])+0.5*np.var(new_noise[:,1])
else:
#estimate clean signal power
power_clean=np.var(segment)
#estimate noise power
power_noise=np.var(new_noise)
snr = 10.0**(SNRs/10.0)
#sum both signals according to snr
summed=segment+np.sqrt(power_clean/(snr*power_noise))*new_noise #not sure if this is correct, maybe revisit later!!
summed=10.0**(scale/10.0) *summed
segment=10.0**(scale/10.0) *segment
summed=summed.astype('float32')
yield tf.convert_to_tensor(summed), tf.convert_to_tensor(segment)
def load_data(buffer_size, path_music_train, path_music_val, path_noises, fs=44100, seg_len_s=5, extend=True, stereo=False) -> Tuple[tf.data.Dataset, tf.data.Dataset]:
print("Generating train dataset")
trainshape=int(fs*seg_len_s)
dataset_train = tf.data.Dataset.from_generator(generator_train,args=(path_music_train, path_noises,"train", fs, seg_len_s, extend, stereo), output_shapes=(tf.TensorShape((trainshape,)),tf.TensorShape((trainshape,))), output_types=(tf.float32, tf.float32) )
print("Generating validation dataset")
segments_noisy, segments_clean=generate_val_data(path_music_val, path_noises,"validation",fs=fs, seg_len_s=seg_len_s)
dataset_val=tf.data.Dataset.from_tensor_slices((segments_noisy, segments_clean))
return dataset_train.shuffle(buffer_size), dataset_val
def load_data_test(buffer_size, path_pianos_test, path_noises, **kwargs) -> Tuple[tf.data.Dataset]:
print("Generating test dataset")
segments_noisy, segments_clean=generate_test_data(path_pianos_test, path_noises, extend=True, **kwargs)
dataset_test=tf.data.Dataset.from_tensor_slices((segments_noisy, segments_clean))
#dataset_test=tf.data.Dataset.from_tensor_slices((segments_noisy[1:3], segments_clean[1:3]))
#train_dataset = train.cache().shuffle(buffer_size).take(info.splits["train"].num_examples)
return dataset_test
def load_data_formal( path_pianos_test, path_noises, **kwargs) -> Tuple[tf.data.Dataset]:
print("Generating test dataset")
segments_noisy, segments_clean=generate_paired_data_test_formal(path_pianos_test, path_noises, extend=True, **kwargs)
print("segments::")
print(len(segments_noisy))
dataset_test=tf.data.Dataset.from_tensor_slices((segments_noisy, segments_clean))
#dataset_test=tf.data.Dataset.from_tensor_slices((segments_noisy[1:3], segments_clean[1:3]))
#train_dataset = train.cache().shuffle(buffer_size).take(info.splits["train"].num_examples)
return dataset_test
def load_real_test_recordings(buffer_size, path_recordings, **kwargs) -> Tuple[tf.data.Dataset]:
print("Generating real test dataset")
segments_noisy=generate_real_recordings_data(path_recordings, **kwargs)
dataset_test=tf.data.Dataset.from_tensor_slices(segments_noisy)
#train_dataset = train.cache().shuffle(buffer_size).take(info.splits["train"].num_examples)
return dataset_test

164
inference.py Normal file
View File

@ -0,0 +1,164 @@
import os
import hydra
import logging
logger = logging.getLogger(__name__)
def run(args):
import unet
import tensorflow as tf
import soundfile as sf
import numpy as np
from tqdm import tqdm
import librosa
path_experiment=str(args.path_experiment)
unet_model = unet.build_model_denoise(unet_args=args.unet)
ckpt=os.path.join(os.path.dirname(os.path.abspath(__file__)),path_experiment, 'checkpoint')
unet_model.load_weights(ckpt)
def do_stft(noisy):
window_fn = tf.signal.hamming_window
win_size=args.stft.win_size
hop_size=args.stft.hop_size
stft_signal_noisy=tf.signal.stft(noisy,frame_length=win_size, window_fn=window_fn, frame_step=hop_size, pad_end=True)
stft_noisy_stacked=tf.stack( values=[tf.math.real(stft_signal_noisy), tf.math.imag(stft_signal_noisy)], axis=-1)
return stft_noisy_stacked
def do_istft(data):
window_fn = tf.signal.hamming_window
win_size=args.stft.win_size
hop_size=args.stft.hop_size
inv_window_fn=tf.signal.inverse_stft_window_fn(hop_size, forward_window_fn=window_fn)
pred_cpx=data[...,0] + 1j * data[...,1]
pred_time=tf.signal.inverse_stft(pred_cpx, win_size, hop_size, window_fn=inv_window_fn)
return pred_time
audio=str(args.inference.audio)
data, samplerate = sf.read(audio)
if samplerate!=44100:
print("Resampling")
data=np.transpose(data)
data=librosa.resample(data, samplerate, 44100)
data=np.transpose(data)
#Stereo to mono
if len(data.shape)>1:
data=np.mean(data,axis=1)
segment_size=44101*20 #20s segments
length_data=len(data)
overlapsize=2048 #samples (46 ms)
window=np.hanning(2*overlapsize)
window_right=window[overlapsize::]
window_left=window[0:overlapsize]
audio_finished=False
pointer=0
denoised_data=np.zeros(shape=(len(data),))
residual_noise=np.zeros(shape=(len(data),))
numchunks=int(np.ceil(length_data/segment_size))
for i in tqdm(range(numchunks)):
if pointer+segment_size<length_data:
segment=data[pointer:pointer+segment_size]
#dostft
segment_TF=do_stft(segment)
segment_TF_ds=tf.data.Dataset.from_tensors(segment_TF)
pred = unet_model.predict(segment_TF_ds.batch(1))
pred=pred[0]
residual=segment_TF-pred[0]
residual=np.array(residual)
pred_time=do_istft(pred[0])
residual_time=do_istft(residual)
residual_time=np.array(residual_time)
if pointer==0:
pred_time=np.concatenate((pred_time[0:int(segment_size-overlapsize)], np.multiply(pred_time[int(segment_size-overlapsize):segment_size],window_right)), axis=0)
residual_time=np.concatenate((residual_time[0:int(segment_size-overlapsize)], np.multiply(residual_time[int(segment_size-overlapsize):segment_size],window_right)), axis=0)
else:
pred_time=np.concatenate((np.multiply(pred_time[0:int(overlapsize)], window_left), pred_time[int(overlapsize):int(segment_size-overlapsize)], np.multiply(pred_time[int(segment_size-overlapsize):int(segment_size)],window_right)), axis=0)
residual_time=np.concatenate((np.multiply(residual_time[0:int(overlapsize)], window_left), residual_time[int(overlapsize):int(segment_size-overlapsize)], np.multiply(residual_time[int(segment_size-overlapsize):int(segment_size)],window_right)), axis=0)
denoised_data[pointer:pointer+segment_size]=denoised_data[pointer:pointer+segment_size]+pred_time
residual_noise[pointer:pointer+segment_size]=residual_noise[pointer:pointer+segment_size]+residual_time
pointer=pointer+segment_size-overlapsize
else:
segment=data[pointer::]
lensegment=len(segment)
segment=np.concatenate((segment, np.zeros(shape=(int(segment_size-len(segment)),))), axis=0)
audio_finished=True
#dostft
segment_TF=do_stft(segment)
segment_TF_ds=tf.data.Dataset.from_tensors(segment_TF)
pred = unet_model.predict(segment_TF_ds.batch(1))
pred=pred[0]
residual=segment_TF-pred[0]
residual=np.array(residual)
pred_time=do_istft(pred[0])
pred_time=np.array(pred_time)
pred_time=pred_time[0:segment_size]
residual_time=do_istft(residual)
residual_time=np.array(residual_time)
residual_time=residual_time[0:segment_size]
if pointer==0:
pred_time=pred_time
residual_time=residual_time
else:
pred_time=np.concatenate((np.multiply(pred_time[0:int(overlapsize)], window_left), pred_time[int(overlapsize):int(segment_size)]),axis=0)
residual_time=np.concatenate((np.multiply(residual_time[0:int(overlapsize)], window_left), residual_time[int(overlapsize):int(segment_size)]),axis=0)
denoised_data[pointer::]=denoised_data[pointer::]+pred_time[0:lensegment]
residual_noise[pointer::]=residual_noise[pointer::]+residual_time[0:lensegment]
basename=os.path.splitext(audio)[0]
wav_noisy_name=basename+"_noisy_input"+".wav"
sf.write(wav_noisy_name, data, 44100)
wav_output_name=basename+"_denoised"+".wav"
sf.write(wav_output_name, denoised_data, 44100)
wav_output_name=basename+"_residual"+".wav"
sf.write(wav_output_name, residual_noise, 44100)
def _main(args):
global __file__
__file__ = hydra.utils.to_absolute_path(__file__)
run(args)
@hydra.main(config_path="conf/conf.yaml")
def main(args):
try:
_main(args)
except Exception:
logger.exception("Some error happened")
# Hydra intercepts exit code, fixed in beta but I could not get the beta to work
os._exit(1)
if __name__ == "__main__":
main()

151
test.py Normal file
View File

@ -0,0 +1,151 @@
import os
import hydra
import logging
'''
Script used for the objective experiments
WARNING: it calls MATLAB to calculate PEAQ and PEMO-Q. The whole process may be very slow
'''
logger = logging.getLogger(__name__)
def run(args):
import unet
import dataset_loader
import tensorflow as tf
import pandas as pd
path_experiment=str(args.path_experiment)
print(path_experiment)
if not os.path.exists(path_experiment):
os.makedirs(path_experiment)
unet_model = unet.build_model_denoise(stereo=stereo,unet_args=args.unet)
ckpt=os.path.join(path_experiment, 'checkpoint')
unet_model.load_weights(ckpt)
path_pianos_test=args.dset.path_piano_test
path_strings_test=args.dset.path_strings_test
path_orchestra_test=args.dset.path_orchestra_test
path_opera_test=args.dset.path_opera_test
path_noise=args.dset.path_noise
fs=args.fs
seg_len_s=20
numsamples=1000//seg_len_s
def do_stft(noisy, clean=None):
if args.stft.window=="hamming":
window_fn = tf.signal.hamming_window
elif args.stft.window=="hann":
window_fn=tf.signal.hann_window
elif args.stft.window=="kaiser_bessel":
window_fn=tf.signal.kaiser_bessel_derived_window
win_size=args.stft.win_size
hop_size=args.stft.hop_size
stft_signal_noisy=tf.signal.stft(noisy,frame_length=win_size, window_fn=window_fn, frame_step=hop_size)
stft_noisy_stacked=tf.stack( values=[tf.math.real(stft_signal_noisy), tf.math.imag(stft_signal_noisy)], axis=-1)
if clean!=None:
stft_signal_clean=tf.signal.stft(clean,frame_length=win_size, window_fn=window_fn, frame_step=hop_size)
stft_clean_stacked=tf.stack( values=[tf.math.real(stft_signal_clean), tf.math.imag(stft_signal_clean)], axis=-1)
return stft_noisy_stacked, stft_clean_stacked
else:
return stft_noisy_stacked
from tester import Tester
testPath=os.path.join(path_experiment,"final_test")
if not os.path.exists(testPath):
os.makedirs(testPath)
tester=Tester(unet_model, testPath, args)
PEAQ_dir="/scratch/work/molinee2/unet_dir/unet_historical_music/PQevalAudio"
PEMOQ_dir="/scratch/work/molinee2/unet_dir/unet_historical_music/PEMOQ"
dataset_test_pianos=dataset_loader.load_data_formal( path_pianos_test, path_noise, noise_amount="mid_snr",num_samples=numsamples, fs=fs, seg_len_s=seg_len_s, stereo=stereo)
dataset_test_pianos=dataset_test_pianos.map(do_stft, num_parallel_calls=args.num_workers, deterministic=None)
tester.init_inference(dataset_test_pianos,numsamples,fs,args.stft, PEAQ_dir, PEMOQ_dir=PEMOQ_dir)
metrics=tester.inference("pianos_midsnr")
dataset_test_strings=dataset_loader.load_data_formal( path_strings_test, path_noise,noise_amount="mid_snr",num_samples=numsamples, fs=fs, seg_len_s=seg_len_s, stereo=stereo)
dataset_test_strings=dataset_test_strings.map(do_stft, num_parallel_calls=args.num_workers, deterministic=None)
tester.init_inference(dataset_test_strings,numsamples,fs,args.stft, PEAQ_dir, PEMOQ_dir=PEMOQ_dir)
metrics=tester.inference("strings_midsnr")
dataset_test_orchestra=dataset_loader.load_data_formal( path_orchestra_test, path_noise, noise_amount="mid_snr", num_samples=numsamples, fs=fs, seg_len_s=seg_len_s, stereo=stereo)
dataset_test_orchestra=dataset_test_orchestra.map(do_stft, num_parallel_calls=args.num_workers, deterministic=None)
tester.init_inference(dataset_test_orchestra,numsamples,fs,args.stft, PEAQ_dir, PEMOQ_dir=PEMOQ_dir)
metrics=tester.inference("orchestra_midsnr")
dataset_test_opera=dataset_loader.load_data_formal( path_opera_test, path_noise, noise_amount="mid_snr",num_samples=numsamples, fs=fs, seg_len_s=seg_len_s, stereo=stereo)
dataset_test_opera=dataset_test_opera.map(do_stft, num_parallel_calls=args.num_workers, deterministic=None)
tester.init_inference(dataset_test_opera,numsamples,fs,args.stft, PEAQ_dir, PEMOQ_dir=PEMOQ_dir)
metrics=tester.inference("opera_midsnr")
dataset_test_strings=dataset_loader.load_data_formal( path_strings_test, path_noise,noise_amount="low_snr",num_samples=numsamples, fs=fs, seg_len_s=seg_len_s, stereo=stereo)
dataset_test_strings=dataset_test_strings.map(do_stft, num_parallel_calls=args.num_workers, deterministic=None)
tester.init_inference(dataset_test_strings,numsamples,fs,args.stft, PEAQ_dir, PEMOQ_dir=PEMOQ_dir)
metrics=tester.inference("strings_lowsnr")
dataset_test_orchestra=dataset_loader.load_data_formal( path_orchestra_test, path_noise,noise_amount="low_snr", num_samples=numsamples, fs=fs, seg_len_s=seg_len_s, stereo=stereo)
dataset_test_orchestra=dataset_test_orchestra.map(do_stft, num_parallel_calls=args.num_workers, deterministic=None)
tester.init_inference(dataset_test_orchestra,numsamples,fs,args.stft, PEAQ_dir, PEMOQ_dir=PEMOQ_dir)
metrics=tester.inference("orchestra_lowsnr")
dataset_test_opera=dataset_loader.load_data_formal( path_opera_test, path_noise, noise_amount="low_snr",num_samples=numsamples, fs=fs, seg_len_s=seg_len_s, stereo=stereo)
dataset_test_opera=dataset_test_opera.map(do_stft, num_parallel_calls=args.num_workers, deterministic=None)
tester.init_inference(dataset_test_opera,numsamples,fs,args.stft, PEAQ_dir, PEMOQ_dir=PEMOQ_dir)
metrics=tester.inference("opera_lowsnr")
dataset_test_pianos=dataset_loader.load_data_formal( path_pianos_test, path_noise, noise_amount="low_snr",num_samples=numsamples, fs=fs, seg_len_s=seg_len_s, stereo=stereo)
dataset_test_pianos=dataset_test_pianos.map(do_stft, num_parallel_calls=args.num_workers, deterministic=None)
tester.init_inference(dataset_test_pianos,numsamples,fs,args.stft, PEAQ_dir, PEMOQ_dir=PEMOQ_dir)
metrics=tester.inference("pianos_lowsnr")
names=["strings_midsnr","strings_lowsnr","opera_midsnr","opera_lowsnr","pianos_midsnr","pianos_lowsnr","orchestra_midsnr","orchestra_lowsnr"]
for n in names:
a=pd.read_csv(os.path.join(testPath,n,"metrics.csv"))
meanPEAQ=a["PEAQ(ODG)_diff"].sum()/50
meanPEMOQ=a["PEMOQ(ODG)_diff"].sum()/50
meanSDR=a["SDR_diff"].sum()/50
print(n,": PEAQ ",str(meanPEAQ), "PEMOQ ", str(meanPEMOQ), "SDR ", str(meanSDR))
def _main(args):
global __file__
__file__ = hydra.utils.to_absolute_path(__file__)
run(args)
@hydra.main(config_path="conf/conf.yaml")
def main(args):
try:
_main(args)
except Exception:
logger.exception("Some error happened")
# Hydra intercepts exit code, fixed in beta but I could not get the beta to work
os._exit(1)
if __name__ == "__main__":
main()

391
tester.py Normal file
View File

@ -0,0 +1,391 @@
import os
import numpy as np
import cv2
import librosa
import imageio
import tensorflow as tf
import soundfile as sf
import subprocess
from tqdm import tqdm
from vggish.vgg_distance import process_wav
import pandas as pd
from scipy.io import loadmat
class Tester():
def __init__(self, model, path_experiment, args):
if model !=None:
self.model=model
print(self.model.summary())
self.args=args
self.path_experiment=path_experiment
def init_inference(self, dataset_test=None,num_test_segments=0 , fs=44100, stft_args=None, PEAQ_dir=None, alg_dir=None, PEMOQ_dir=None):
self.num_test_segments=num_test_segments
self.dataset_test=dataset_test
if self.dataset_test!=None:
self.dataset_test=self.dataset_test.take(self.num_test_segments)
self.fs=fs
self.stft_args=stft_args
self.win_size=stft_args.win_size
self.hop_size=stft_args.hop_size
self.window=stft_args.window
self.PEAQ_dir=PEAQ_dir
self.PEMOQ_dir=PEMOQ_dir
self.alg_dir=alg_dir
def generate_inverse_window(self, stft_args):
if stft_args.window=="hamming":
return tf.signal.inverse_stft_window_fn(stft_args.hop_size, forward_window_fn=tf.signal.hamming_window)
elif stft_args.window=="hann":
return tf.signal.inverse_stft_window_fn(stft_args.hop_size, forward_window_fn=tf.signal.hann_window)
elif stft_args.window=="kaiser_bessel":
return tf.signal.inverse_stft_window_fn(stft_args.hop_size, forward_window_fn=tf.signal.kaiser_bessel_derived_window)
def do_istft(self,data):
window_fn = self.generate_inverse_window(self.stft_args)
win_size=self.win_size
hop_size=self.hop_size
pred_cpx=data[...,0] + 1j * data[...,1]
pred_time=tf.signal.inverse_stft(pred_cpx, win_size, hop_size, window_fn=window_fn)
return pred_time
def generate_images(self,cpx,name):
spectro=np.clip((np.flipud(np.transpose(10*np.log10(np.sqrt(np.power(cpx[...,0],2)+np.power(cpx[...,1],2)))))+30)/50,0,1)
spectrorgb=np.zeros(shape=(spectro.shape[0],spectro.shape[1],3))
spectrorgb[...,0]=np.clip((np.flipud(np.transpose(10*np.log10(np.abs(cpx[...,0])+0.001)))+30)/50,0,1)
spectrorgb[...,1]=np.clip((np.flipud(np.transpose(10*np.log10(np.abs(cpx[...,1])+0.001)))+30)/50,0,1)
cmap=cv2.COLORMAP_JET
spectro = np.array((1-spectro)* 255, dtype = np.uint8)
spectro = cv2.applyColorMap(spectro, cmap)
imageio.imwrite(os.path.join(self.test_results_filepath, name+".png"),spectro)
spectrorgb = np.array(spectrorgb* 255, dtype = np.uint8)
imageio.imwrite(os.path.join(self.test_results_filepath, name+"_ir.png"),spectrorgb)
def generate_image_diff(self,clean , pred,name):
difference=np.sqrt((clean[...,0]-pred[...,0])**2+(clean[...,1]-pred[...,1])**2)
dif=np.clip(np.flipud(np.transpose(difference)),0,1)
cmap=cv2.COLORMAP_JET
dif = np.array((1-dif)* 255, dtype = np.uint8)
dif = cv2.applyColorMap(dif, cmap)
imageio.imwrite(os.path.join(self.test_results_filepath, name+"_diff.png"),dif)
def inference_inner_classical(self, folder_name, method):
nums=[]
PEAQ_odg_noisy=[]
PEAQ_odg_output=[]
PEAQ_odg_diff=[]
PEMOQ_odg_noisy=[]
PEMOQ_odg_output=[]
PEMOQ_odg_diff=[]
SDR_noisy=[]
SDR_output=[]
SDR_diff=[]
VGGish_noisy=[]
VGGish_output=[]
VGGish_diff=[]
self.test_results_filepath = os.path.join(self.path_experiment,folder_name)
if not os.path.exists(self.test_results_filepath):
os.makedirs(self.test_results_filepath)
num=0
for element in tqdm(self.dataset_test.take(self.num_test_segments)):
test_element=tf.data.Dataset.from_tensors(element)
noisy_time=element[0].numpy()
#noisy_time=self.do_istft(noisy)
name_noisy=str(num)+'_noisy'
clean_time=element[1].numpy()
#clean_time=self.do_istft(clean)
name_clean=str(num)+'_clean'
print("inferencing")
nums.append(num)
print("generating wavs")
#noisy_time=noisy_time.numpy().astype(np.float32)
noisy_time=noisy_time.astype(np.float32)
wav_noisy_name_pre=os.path.join(self.test_results_filepath, name_noisy+"pre.wav")
sf.write(wav_noisy_name_pre, noisy_time, 44100)
#pred = self.model.predict(test_element.batch(1))
name_pred=str(num)+'_output'
wav_output_name_proc=os.path.join(self.test_results_filepath, name_pred+"proc.wav")
self.process_in_matlab(wav_noisy_name_pre, wav_output_name_proc, method)
noisy_time=noisy_time[44100::] #remove pre noise
#clean_time=clean_time.numpy().astype(np.float32)
clean_time=clean_time.astype(np.float32)
clean_time=clean_time[44100::] #remove pre noise
#change that !!!!
#pred_time=self.do_istft(pred[0])
#pred_time=pred_time.numpy().astype(np.float32)
#pred_time=librosa.resample(np.transpose(pred_time),self.fs, 48000)
#sf.write(wav_output_name, pred_time, 48000)
#LOAD THE AUDIO!!!
pred_time, sr=sf.read(wav_output_name_proc)
assert sr==44100
pred_time=pred_time[44100::] #remove prenoise
#I am computing here the SDR at 48k, whle I was doing it before at 44.1k. I hope this won't cause any problem in the results. Consider resampling???
SDR_t_noisy=10*np.log10(np.mean(np.square(clean_time))/np.mean(np.square(noisy_time-clean_time)))
SDR_noisy.append(SDR_t_noisy)
SDR_t_output=10*np.log10(np.mean(np.square(clean_time))/np.mean(np.square(pred_time-clean_time)))
SDR_output.append(SDR_t_output)
SDR_diff.append(SDR_t_output-SDR_t_noisy)
noisy_time=librosa.resample(np.transpose(noisy_time),self.fs, 48000) #P.Kabal PEAQ code is hardcoded at Fs=48000, so we have to resample
wav_noisy_name=os.path.join(self.test_results_filepath, name_noisy+".wav")
sf.write(wav_noisy_name, noisy_time, 48000) #overwrite without prenoise
clean_time=librosa.resample(np.transpose(clean_time),self.fs, 48000) #without prenoise please!!!
wav_clean_name=os.path.join(self.test_results_filepath, name_clean+".wav")
sf.write(wav_clean_name, clean_time, 48000)
pred_time=librosa.resample(np.transpose(pred_time),self.fs, 48000) #without prenoise please!!!
wav_output_name=os.path.join(self.test_results_filepath, name_pred+".wav")
sf.write(wav_output_name, pred_time, 48000)
#save pred at 48k
#print("calculating PEMOQ")
#odg_noisy,odg_output =self.calculate_PEMOQ(wav_clean_name,wav_noisy_name,wav_output_name)
#PEMOQ_odg_noisy.append(odg_noisy)
#PEMOQ_odg_output.append(odg_output)
#PEMOQ_odg_diff.append(odg_output-odg_noisy)
#print("calculating PEAQ")
#odg_noisy,odg_output =self.calculate_PEAQ(wav_clean_name,wav_noisy_name,wav_output_name)
#PEAQ_odg_noisy.append(odg_noisy)
#PEAQ_odg_output.append(odg_output)
#PEAQ_odg_diff.append(odg_output-odg_noisy)
print("calculating VGGish")
VGGish_clean_embeddings=process_wav(wav_clean_name)
VGGish_noisy_embeddings=process_wav(wav_noisy_name)
VGGish_output_embeddings=process_wav(wav_output_name)
dist_noisy = np.linalg.norm(VGGish_noisy_embeddings-VGGish_clean_embeddings)
dist_output = np.linalg.norm(VGGish_output_embeddings-VGGish_clean_embeddings)
VGGish_noisy.append(dist_noisy)
VGGish_output.append(dist_output)
VGGish_diff.append(-(dist_output-dist_noisy))
os.remove(wav_clean_name)
os.remove(wav_noisy_name)
os.remove(wav_noisy_name_pre)
os.remove(wav_output_name)
os.remove(wav_output_name_proc)
num=num+1
frame = { 'num':nums,'PEAQ(ODG)_noisy': PEAQ_odg_noisy, 'PEAQ(ODG)_output': PEAQ_odg_output, 'PEAQ(ODG)_diff': PEAQ_odg_diff, 'PEMOQ(ODG)_noisy': PEMOQ_odg_noisy, 'PEMOQ(ODG)_output': PEMOQ_odg_output, 'PEMOQ(ODG)_diff': PEMOQ_odg_diff,'SDR_noisy': SDR_noisy, 'SDR_output': SDR_output, 'SDR_diff': SDR_diff, 'VGGish_noisy': VGGish_noisy, 'VGGish_output': VGGish_output,'VGGish_diff': VGGish_diff }
metrics=pd.DataFrame(frame)
metrics.to_csv(os.path.join(self.test_results_filepath,"metrics.csv"),index=False)
metrics=metrics.set_index('num')
return metrics
def inference_inner(self, folder_name):
nums=[]
PEAQ_odg_noisy=[]
PEAQ_odg_output=[]
PEAQ_odg_diff=[]
PEMOQ_odg_noisy=[]
PEMOQ_odg_output=[]
PEMOQ_odg_diff=[]
SDR_noisy=[]
SDR_output=[]
SDR_diff=[]
VGGish_noisy=[]
VGGish_output=[]
VGGish_diff=[]
self.test_results_filepath = os.path.join(self.path_experiment,folder_name)
if not os.path.exists(self.test_results_filepath):
os.makedirs(self.test_results_filepath)
num=0
for element in tqdm(self.dataset_test.take(self.num_test_segments)):
test_element=tf.data.Dataset.from_tensors(element)
noisy=element[0].numpy()
noisy_time=self.do_istft(noisy)
name_noisy=str(num)+'_noisy'
clean=element[1].numpy()
clean_time=self.do_istft(clean)
name_clean=str(num)+'_clean'
print("inferencing")
pred = self.model.predict(test_element.batch(1))
if self.args.unet.num_stages==2:
pred=pred[0]
pred_time=self.do_istft(pred[0])
name_pred=str(num)+'_output'
nums.append(num)
pred_time=pred_time.numpy().astype(np.float32)
clean_time=clean_time.numpy().astype(np.float32)
SDR_t_noisy=10*np.log10(np.mean(np.square(clean_time))/np.mean(np.square(noisy_time-clean_time)))
SDR_t_output=10*np.log10(np.mean(np.square(clean_time))/np.mean(np.square(pred_time-clean_time)))
SDR_noisy.append(SDR_t_noisy)
SDR_output.append(SDR_t_output)
SDR_diff.append(SDR_t_output-SDR_t_noisy)
print("generating wavs")
noisy_time=librosa.resample(np.transpose(noisy_time),self.fs, 48000) #P.Kabal PEAQ code is hardcoded at Fs=48000, so we have to resample
clean_time=librosa.resample(np.transpose(clean_time),self.fs, 48000)
pred_time=librosa.resample(np.transpose(pred_time),self.fs, 48000)
wav_noisy_name=os.path.join(self.test_results_filepath, name_noisy+".wav")
sf.write(wav_noisy_name, noisy_time, 48000)
wav_clean_name=os.path.join(self.test_results_filepath, name_clean+".wav")
sf.write(wav_clean_name, clean_time, 48000)
wav_output_name=os.path.join(self.test_results_filepath, name_pred+".wav")
sf.write(wav_output_name, pred_time, 48000)
print("calculating PEMOQ")
odg_noisy,odg_output =self.calculate_PEMOQ(wav_clean_name,wav_noisy_name,wav_output_name)
PEMOQ_odg_noisy.append(odg_noisy)
PEMOQ_odg_output.append(odg_output)
PEMOQ_odg_diff.append(odg_output-odg_noisy)
print("calculating PEAQ")
odg_noisy,odg_output =self.calculate_PEAQ(wav_clean_name,wav_noisy_name,wav_output_name)
PEAQ_odg_noisy.append(odg_noisy)
PEAQ_odg_output.append(odg_output)
PEAQ_odg_diff.append(odg_output-odg_noisy)
print("calculating VGGish")
VGGish_clean_embeddings=process_wav(wav_clean_name)
VGGish_noisy_embeddings=process_wav(wav_noisy_name)
VGGish_output_embeddings=process_wav(wav_output_name)
dist_noisy = np.linalg.norm(VGGish_noisy_embeddings-VGGish_clean_embeddings)
dist_output = np.linalg.norm(VGGish_output_embeddings-VGGish_clean_embeddings)
VGGish_noisy.append(dist_noisy)
VGGish_output.append(dist_output)
VGGish_diff.append(-(dist_output-dist_noisy))
os.remove(wav_clean_name)
os.remove(wav_noisy_name)
os.remove(wav_output_name)
num=num+1
frame = { 'num':nums,'PEAQ(ODG)_noisy': PEAQ_odg_noisy, 'PEAQ(ODG)_output': PEAQ_odg_output, 'PEAQ(ODG)_diff': PEAQ_odg_diff, 'PEMOQ(ODG)_noisy': PEMOQ_odg_noisy, 'PEMOQ(ODG)_output': PEMOQ_odg_output, 'PEMOQ(ODG)_diff': PEMOQ_odg_diff,'SDR_noisy': SDR_noisy, 'SDR_output': SDR_output, 'SDR_diff': SDR_diff, 'VGGish_noisy': VGGish_noisy, 'VGGish_output': VGGish_output,'VGGish_diff': VGGish_diff }
metrics=pd.DataFrame(frame)
metrics.to_csv(os.path.join(self.test_results_filepath,"metrics.csv"),index=False)
metrics=metrics.set_index('num')
return metrics
def inference_real(self, folder_name):
self.test_results_filepath = os.path.join(self.path_experiment,folder_name)
if not os.path.exists(self.test_results_filepath):
os.makedirs(self.test_results_filepath)
num=0
for element in tqdm(self.dataset_real.take(self.num_real_test_segments)):
test_element=tf.data.Dataset.from_tensors(element)
noisy=element.numpy()
noisy_time=self.do_istft(noisy)
name_noisy="recording_"+str(num)+'_noisy.wav'
pred = self.model.predict(test_element.batch(1))
if self.args.unet.num_stages==2:
pred=pred[0]
pred_time=self.do_istft(pred[0])
name_pred="recording_"+str(num)+'_output.wav'
sf.write(os.path.join(self.test_results_filepath, name_noisy), noisy_time, self.fs)
sf.write(os.path.join(self.test_results_filepath, name_pred), pred_time, self.fs)
self.generate_images(noisy,name_noisy)
self.generate_images(pred[0],name_pred)
num=num+1
def process_in_matlab(self,wav_noisy_name,wav_output_name,mode): #Opening and closing matlab to calculate PEAQ, rudimentary way to do it but easier. Make sure to have matlab installed
addpath=self.alg_dir
#odgmatfile_noisy=os.path.join(self.test_results_filepath, "odg_noisy.mat")
#odgmatfile_pred=os.path.join(self.test_results_filepath, "odg_pred.mat")
#bashCommand = "matlab -nodesktop -r 'addpath(\"PQevalAudio\", \"PQevalAudio/CB\",\"PQevalAudio/Misc\",\"PQevalAudio/MOV\", \"PQevalAudio/Patt\"), [odg, MOV]=PQevalAudio(\"0_clean_48.wav\",\"0_noise_48.wav\"), save(\"odg_noisy.mat\",\"odg\"), save(\"mov.mat\",\"MOV\") , exit'"
bashCommand = "matlab -nodesktop -r 'addpath(genpath(\""+addpath+"\")), declick_and_denoise(\""+wav_noisy_name+"\",\""+wav_output_name+"\",\""+mode+"\") , exit'"
print(bashCommand)
p1 = subprocess.Popen(bashCommand, stdout=subprocess.PIPE, shell=True)
(output, err) = p1.communicate()
print(output)
p1.wait()
def calculate_PEMOQ(self,wav_clean_name,wav_noisy_name,wav_output_name): #Opening and closing matlab to calculate PEAQ, rudimentary way to do it but easier. Make sure to have matlab installed
addpath=self.PEMOQ_dir
odgmatfile_noisy=os.path.join(self.test_results_filepath, "odg_pemo_noisy.mat")
odgmatfile_pred=os.path.join(self.test_results_filepath, "odg_pemo_pred.mat")
#bashCommand = "matlab -nodesktop -r 'addpath(\"PQevalAudio\", \"PQevalAudio/CB\",\"PQevalAudio/Misc\",\"PQevalAudio/MOV\", \"PQevalAudio/Patt\"), [odg, MOV]=PQevalAudio(\"0_clean_48.wav\",\"0_noise_48.wav\"), save(\"odg_noisy.mat\",\"odg\"), save(\"mov.mat\",\"MOV\") , exit'"
bashCommand = "matlab -nodesktop -r 'addpath(genpath(\""+addpath+"\")), [ ODG]=PEMOQ(\""+wav_clean_name+"\",\""+wav_noisy_name+"\"), save(\""+odgmatfile_noisy+"\",\"ODG\"), exit'"
print(bashCommand)
p1 = subprocess.Popen(bashCommand, stdout=subprocess.PIPE, shell=True)
(output, err) = p1.communicate()
print(output)
bashCommand = "matlab -nodesktop -r 'addpath(genpath(\""+addpath+"\")), [ ODG]=PEMOQ(\""+wav_clean_name+"\",\""+wav_output_name+"\"), save(\""+odgmatfile_pred+"\",\"ODG\"), exit'"
p2 = subprocess.Popen(bashCommand, stdout=subprocess.PIPE, shell=True)
(output, err) = p2.communicate()
print(output)
p1.wait()
p2.wait()
#I save the odg results in a .mat file, which I load here. Not the most optimal method, sorry :/
annots_noise = loadmat(odgmatfile_noisy)
annots_pred = loadmat(odgmatfile_pred)
#Consider loading also the movs!!
return annots_noise["ODG"][0][0], annots_pred["ODG"][0][0]
def calculate_PEAQ(self,wav_clean_name,wav_noisy_name,wav_output_name): #Opening and closing matlab to calculate PEAQ, rudimentary way to do it but easier. Make sure to have matlab installed
addpath=self.PEAQ_dir
odgmatfile_noisy=os.path.join(self.test_results_filepath, "odg_noisy.mat")
odgmatfile_pred=os.path.join(self.test_results_filepath, "odg_pred.mat")
#bashCommand = "matlab -nodesktop -r 'addpath(\"PQevalAudio\", \"PQevalAudio/CB\",\"PQevalAudio/Misc\",\"PQevalAudio/MOV\", \"PQevalAudio/Patt\"), [odg, MOV]=PQevalAudio(\"0_clean_48.wav\",\"0_noise_48.wav\"), save(\"odg_noisy.mat\",\"odg\"), save(\"mov.mat\",\"MOV\") , exit'"
bashCommand = "matlab -nodesktop -r 'addpath(genpath(\""+addpath+"\")), [odg, MOV]=PQevalAudio(\""+wav_clean_name+"\",\""+wav_noisy_name+"\"), save(\""+odgmatfile_noisy+"\",\"odg\"), save(\"mov.mat\",\"MOV\") , exit'"
p1 = subprocess.Popen(bashCommand, stdout=subprocess.PIPE, shell=True)
(output, err) = p1.communicate()
print(output)
bashCommand = "matlab -nodesktop -r 'addpath(genpath(\""+addpath+"\")), [odg, MOV]=PQevalAudio(\""+wav_clean_name+"\",\""+wav_output_name+"\"), save(\""+odgmatfile_pred+"\",\"odg\"), save(\"mov.mat\",\"MOV\") , exit'"
p2 = subprocess.Popen(bashCommand, stdout=subprocess.PIPE, shell=True)
(output, err) = p2.communicate()
print(output)
p1.wait()
p2.wait()
#I save the odg results in a .mat file, which I load here. Not the most optimal method, sorry :/
annots_noise = loadmat(odgmatfile_noisy)
annots_pred = loadmat(odgmatfile_pred)
#Consider loading also the movs!!
return annots_noise["odg"][0][0], annots_pred["odg"][0][0]
def inference(self, name, method=None):
print("Inferencing :",name)
if self.dataset_test!=None:
if method=="EM":
return self.inference_inner_classical(name, "EM")
elif method=="wiener":
return self.inference_inner_classical(name, "wiener")
elif method=="wiener_declick":
return self.inference_inner_classical(name, "wiener_declick")
elif method=="EM_declick":
return self.inference_inner_classical(name, "EM_declick")
else:
return self.inference_inner(name)

171
train.py Normal file
View File

@ -0,0 +1,171 @@
import os
import hydra
import logging
logger = logging.getLogger(__name__)
def run(args):
import unet
import tensorflow as tf
import tensorflow_addons as tfa
import dataset_loader
from tensorflow.keras.optimizers import Adam
import soundfile as sf
import datetime
from tqdm import tqdm
import numpy as np
path_experiment=str(args.path_experiment)
if not os.path.exists(path_experiment):
os.makedirs(path_experiment)
path_music_train=args.dset.path_music_train
path_music_test=args.dset.path_music_test
path_music_validation=args.dset.path_music_validation
path_noise=args.dset.path_noise
path_recordings=args.dset.path_recordings
fs=args.fs
overlap=args.overlap
seg_len_s_train=args.seg_len_s_train
batch_size=args.batch_size
epochs=args.epochs
num_real_test_segments=args.num_real_test_segments
buffer_size=args.buffer_size #for shuffle
tensorboard_logs=args.tensorboard_logs
def do_stft(noisy, clean=None):
window_fn = tf.signal.hamming_window
win_size=args.stft.win_size
hop_size=args.stft.hop_size
stft_signal_noisy=tf.signal.stft(noisy,frame_length=win_size, window_fn=window_fn, frame_step=hop_size)
stft_noisy_stacked=tf.stack( values=[tf.math.real(stft_signal_noisy), tf.math.imag(stft_signal_noisy)], axis=-1)
if clean!=None:
stft_signal_clean=tf.signal.stft(clean,frame_length=win_size, window_fn=window_fn, frame_step=hop_size)
stft_clean_stacked=tf.stack( values=[tf.math.real(stft_signal_clean), tf.math.imag(stft_signal_clean)], axis=-1)
return stft_noisy_stacked, stft_clean_stacked
else:
return stft_noisy_stacked
#Loading data. The train dataset object is a generator. The validation dataset is loaded in memory.
dataset_train, dataset_val=dataset_loader.load_data(buffer_size, path_music_train, path_music_validation, path_noise, fs=fs, seg_len_s=seg_len_s_train)
dataset_train=dataset_train.map(do_stft, num_parallel_calls=args.num_workers, deterministic=None)
dataset_val=dataset_val.map(do_stft, num_parallel_calls=args.num_workers, deterministic=None)
strategy = tf.distribute.MirroredStrategy()
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
with strategy.scope():
#build the model
unet_model = unet.build_model_denoise(unet_args=args.unet)
current_lr=args.lr
optimizer = Adam(learning_rate=current_lr, beta_1=args.beta1, beta_2=args.beta2)
loss=tf.keras.losses.MeanAbsoluteError()
if args.use_tensorboard:
log_dir = os.path.join(tensorboard_logs, os.path.basename(path_experiment)+"_"+datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
train_summary_writer = tf.summary.create_file_writer(log_dir+"/train")
val_summary_writer = tf.summary.create_file_writer(log_dir+"/validation")
#path where the checkpoints will be saved
checkpoint_filepath=os.path.join(path_experiment, 'checkpoint')
dataset_train=dataset_train.batch(batch_size)
dataset_val=dataset_val.batch(batch_size)
#prefetching the dataset for better performance
dataset_train=dataset_train.prefetch(batch_size*20)
dataset_val=dataset_val.prefetch(batch_size*20)
dataset_train=strategy.experimental_distribute_dataset(dataset_train)
dataset_val=strategy.experimental_distribute_dataset(dataset_val)
iterator = iter(dataset_train)
from trainer import Trainer
trainer=Trainer(unet_model,optimizer,loss,strategy, path_experiment, args)
for epoch in range(epochs):
total_loss=0
step_loss=0
for step in tqdm(range(args.steps_per_epoch), desc="Training epoch "+str(epoch)):
step_loss=trainer.distributed_training_step(iterator.get_next())
total_loss+=step_loss
with train_summary_writer.as_default():
tf.summary.scalar('batch_loss', step_loss, step=step)
tf.summary.scalar('batch_mean_absolute_error', trainer.train_mae.result(), step=step)
train_loss=total_loss/args.steps_per_epoch
for x in tqdm(dataset_val, desc="Validating epoch "+str(epoch)):
trainer.distributed_test_step(x)
template = ("Epoch {}, Loss: {}, train_MAE: {}, val_Loss: {}, val_MAE: {}")
print (template.format(epoch+1, train_loss, trainer.train_mae.result(), trainer.val_loss.result(), trainer.val_mae.result()))
with train_summary_writer.as_default():
tf.summary.scalar('epoch_loss', train_loss, step=epoch)
tf.summary.scalar('epoch_mean_absolute_error', trainer.train_mae.result(), step=epoch)
with val_summary_writer.as_default():
tf.summary.scalar('epoch_loss', trainer.val_loss.result(), step=epoch)
tf.summary.scalar('epoch_mean_absolute_error', trainer.val_mae.result(), step=epoch)
trainer.train_mae.reset_states()
trainer.val_loss.reset_states()
trainer.val_mae.reset_states()
if (epoch+1) % 50 == 0:
if args.variable_lr:
current_lr*=1e-1
trainer.optimizer.lr=current_lr
try:
unet_model.save_weights(checkpoint_filpath)
except:
pass
def _main(args):
global __file__
__file__ = hydra.utils.to_absolute_path(__file__)
run(args)
@hydra.main(config_path="conf/conf.yaml")
def main(args):
try:
_main(args)
except Exception:
logger.exception("Some error happened")
# Hydra intercepts exit code, fixed in beta but I could not get the beta to work
os._exit(1)
if __name__ == "__main__":
main()

71
trainer.py Normal file
View File

@ -0,0 +1,71 @@
import os
import numpy as np
import cv2
import librosa
import imageio
import tensorflow as tf
import soundfile as sf
import subprocess
from tqdm import tqdm
import pandas as pd
from scipy.io import loadmat
class Trainer():
def __init__(self, model, optimizer,loss, strategy, path_experiment, args):
self.model=model
print(self.model.summary())
self.strategy=strategy
self.optimizer=optimizer
self.path_experiment=path_experiment
self.args=args
#self.metrics=[]
with self.strategy.scope():
#loss_fn=tf.keras.losses.mean_absolute_error
loss.reduction=tf.keras.losses.Reduction.NONE
self.loss_object=loss
self.train_mae_s1=tf.keras.metrics.MeanAbsoluteError(name="train_mae_s1")
self.train_mae=tf.keras.metrics.MeanAbsoluteError(name="train_mae_s2")
self.val_mae=tf.keras.metrics.MeanAbsoluteError(name="validation_mae")
self.val_loss = tf.keras.metrics.Mean(name='test_loss')
def train_step(self,inputs):
noisy, clean= inputs
with tf.GradientTape() as tape:
logits_2,logits_1 = self.model(noisy, training=True) # Logits for this minibatch
loss_value = tf.reduce_mean(self.loss_object(clean, logits_2) + tf.reduce_mean(self.loss_object(clean, logits_1)))
grads = tape.gradient(loss_value, self.model.trainable_weights)
self.optimizer.apply_gradients(zip(grads, self.model.trainable_weights))
self.train_mae.update_state(clean, logits_2)
self.train_mae_s1.update_state(clean, logits_1)
return loss_value
def test_step(self,inputs):
noisy,clean = inputs
predictions_s2, predictions_s1 = self.model(noisy, training=False)
t_loss = self.loss_object(clean, predictions_s2)+self.loss_object(clean, predictions_s1)
self.val_mae.update_state(clean,predictions_s2)
self.val_loss.update_state(t_loss)
@tf.function()
def distributed_training_step(self,inputs):
per_replica_losses=self.strategy.run(self.train_step, args=(inputs,))
reduced_losses=self.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None)
return reduced_losses
@tf.function
def distributed_test_step(self,inputs):
return self.strategy.run(self.test_step, args=(inputs,))

486
unet.py Normal file
View File

@ -0,0 +1,486 @@
import tensorflow as tf
from tensorflow.keras import Model, Input
from tensorflow.keras import layers
from tensorflow.keras.initializers import TruncatedNormal
import math as m
def build_model_denoise(unet_args=None):
inputs=Input(shape=(None, None,2))
outputs_stage_2,outputs_stage_1=MultiStage_denoise(unet_args=unet_args)(inputs)
#Encapsulating MultiStage_denoise in a keras.Model object
model= tf.keras.Model(inputs=inputs,outputs=[outputs_stage_2, outputs_stage_1])
return model
class DenseBlock(layers.Layer):
'''
[B, T, F, N] => [B, T, F, N]
DenseNet Block consisting of "num_layers" densely connected convolutional layers
'''
def __init__(self, num_layers, N, ksize,activation):
'''
num_layers: number of densely connected conv. layers
N: Number of filters (same in each layer)
ksize: Kernel size (same in each layer)
'''
super(DenseBlock, self).__init__()
self.activation=activation
self.paddings_1=get_paddings(ksize)
self.H=[]
self.num_layers=num_layers
for i in range(num_layers):
self.H.append(layers.Conv2D(filters=N,
kernel_size=ksize,
kernel_initializer=TruncatedNormal(),
strides=1,
padding='VALID',
activation=self.activation))
def call(self, x):
x_=tf.pad(x, self.paddings_1, mode='SYMMETRIC')
x_ = self.H[0](x_)
if self.num_layers>1:
for h in self.H[1:]:
x = tf.concat([x_, x], axis=-1)
x_=tf.pad(x, self.paddings_1, mode='SYMMETRIC')
x_ = h(x_)
return x_
class FinalBlock(layers.Layer):
'''
[B, T, F, N] => [B, T, F, 2]
Final block. Basically, a 3x3 conv. layer to map the output features to the output complex spectrogram.
'''
def __init__(self):
super(FinalBlock, self).__init__()
ksize=(3,3)
self.paddings_2=get_paddings(ksize)
self.conv2=layers.Conv2D(filters=2,
kernel_size=ksize,
kernel_initializer=TruncatedNormal(),
strides=1,
padding='VALID',
activation=None)
def call(self, inputs ):
x=tf.pad(inputs, self.paddings_2, mode='SYMMETRIC')
pred=self.conv2(x)
return pred
class SAM(layers.Layer):
'''
[B, T, F, N] => [B, T, F, N] , [B, T, F, N]
Supervised Attention Module:
The purpose of SAM is to make the network only propagate the most relevant features to the second stage, discarding the less useful ones.
The estimated residual noise signal is generated from the U-Net output features by means of a 3x3 convolutional layer.
The first stage output is then calculated adding the original input spectrogram to the residual noise.
The attention-guided features are computed using the attention masks M, which are directly calculated from the first stage output with a 1x1 convolution and a sigmoid function.
'''
def __init__(self, n_feat):
super(SAM, self).__init__()
ksize=(3,3)
self.paddings_1=get_paddings(ksize)
self.conv1 = layers.Conv2D(filters=n_feat,
kernel_size=ksize,
kernel_initializer=TruncatedNormal(),
strides=1,
padding='VALID',
activation=None)
ksize=(3,3)
self.paddings_2=get_paddings(ksize)
self.conv2=layers.Conv2D(filters=2,
kernel_size=ksize,
kernel_initializer=TruncatedNormal(),
strides=1,
padding='VALID',
activation=None)
ksize=(3,3)
self.paddings_3=get_paddings(ksize)
self.conv3 = layers.Conv2D(filters=n_feat,
kernel_size=ksize,
kernel_initializer=TruncatedNormal(),
strides=1,
padding='VALID',
activation=None)
self.cropadd=CropAddBlock()
def call(self, inputs, input_spectrogram):
x1=tf.pad(inputs, self.paddings_1, mode='SYMMETRIC')
x1 = self.conv1(x1)
x=tf.pad(inputs, self.paddings_2, mode='SYMMETRIC')
x=self.conv2(x)
#residual prediction
pred = layers.Add()([x, input_spectrogram]) #features to next stage
x3=tf.pad(pred, self.paddings_3, mode='SYMMETRIC')
M=self.conv3(x3)
M= tf.keras.activations.sigmoid(M)
x1=layers.Multiply()([x1, M])
x1 = layers.Add()([x1, inputs]) #features to next stage
return x1, pred
class AddFreqEncoding(layers.Layer):
'''
[B, T, F, 2] => [B, T, F, 12]
Generates frequency positional embeddings and concatenates them as 10 extra channels
This function is optimized for F=1025
'''
def __init__(self, f_dim):
super(AddFreqEncoding, self).__init__()
pi = tf.constant(m.pi)
pi=tf.cast(pi,'float32')
self.f_dim=f_dim #f_dim is fixed
n=tf.cast(tf.range(f_dim)/(f_dim-1),'float32')
coss=tf.math.cos(pi*n)
f_channel = tf.expand_dims(coss, -1) #(1025,1)
self.fembeddings= f_channel
for k in range(1,10):
coss=tf.math.cos(2**k*pi*n)
f_channel = tf.expand_dims(coss, -1) #(1025,1)
self.fembeddings=tf.concat([self.fembeddings,f_channel],axis=-1) #(1025,10)
def call(self, input_tensor):
batch_size_tensor = tf.shape(input_tensor)[0] # get batch size
time_dim = tf.shape(input_tensor)[1] # get time dimension
fembeddings_2 = tf.broadcast_to(self.fembeddings, [batch_size_tensor, time_dim, self.f_dim, 10])
return tf.concat([input_tensor,fembeddings_2],axis=-1) #(batch,427,1025,12)
def get_paddings(K):
return tf.constant([[0,0],[K[0]//2, K[0]//2 -(1- K[0]%2) ], [ K[1]//2, K[1]//2 -(1- K[1]%2) ],[0,0]])
class Decoder(layers.Layer):
'''
[B, T, F, N] , skip connections => [B, T, F, N]
Decoder side of the U-Net subnetwork.
'''
def __init__(self, Ns, Ss, unet_args):
super(Decoder, self).__init__()
self.Ns=Ns
self.Ss=Ss
self.activation=unet_args.activation
self.depth=unet_args.depth
ksize=(3,3)
self.paddings_3=get_paddings(ksize)
self.conv2d_3=layers.Conv2D(filters=self.Ns[self.depth],
kernel_size=ksize,
kernel_initializer=TruncatedNormal(),
strides=1,
padding='VALID',
activation=self.activation)
self.cropadd=CropAddBlock()
self.dblocks=[]
for i in range(self.depth):
self.dblocks.append(D_Block(layer_idx=i,N=self.Ns[i], S=self.Ss[i], activation=self.activation,num_tfc=unet_args.num_tfc))
def call(self,inputs, contracting_layers):
x=inputs
for i in range(self.depth,0,-1):
x=self.dblocks[i-1](x, contracting_layers[i-1])
return x
class Encoder(tf.keras.Model):
'''
[B, T, F, N] => skip connections , [B, T, F, N_4]
Encoder side of the U-Net subnetwork.
'''
def __init__(self, Ns, Ss, unet_args):
super(Encoder, self).__init__()
self.Ns=Ns
self.Ss=Ss
self.activation=unet_args.activation
self.depth=unet_args.depth
self.contracting_layers = {}
self.eblocks=[]
for i in range(self.depth):
self.eblocks.append(E_Block(layer_idx=i,N0=self.Ns[i],N=self.Ns[i+1],S=self.Ss[i], activation=self.activation , num_tfc=unet_args.num_tfc))
self.i_block=I_Block(self.Ns[self.depth],self.activation,unet_args.num_tfc)
def call(self, inputs):
x=inputs
for i in range(self.depth):
x, x_contract=self.eblocks[i](x)
self.contracting_layers[i] = x_contract #if remove 0, correct this
x=self.i_block(x)
return x, self.contracting_layers
class MultiStage_denoise(tf.keras.Model):
def __init__(self, unet_args=None):
super(MultiStage_denoise, self).__init__()
self.activation=unet_args.activation
self.depth=unet_args.depth
if unet_args.use_fencoding:
self.freq_encoding=AddFreqEncoding(unet_args.f_dim)
self.use_sam=unet_args.use_SAM
self.use_fencoding=unet_args.use_fencoding
self.num_stages=unet_args.num_stages
#Encoder
self.Ns= [32,64,64,128,128,256,512]
self.Ss= [(2,2),(2,2),(2,2),(2,2),(2,2),(2,2)]
#initial feature extractor
ksize=(7,7)
self.paddings_1=get_paddings(ksize)
self.conv2d_1 = layers.Conv2D(filters=self.Ns[0],
kernel_size=ksize,
kernel_initializer=TruncatedNormal(),
strides=1,
padding='VALID',
activation=self.activation)
self.encoder_s1=Encoder(self.Ns, self.Ss, unet_args)
self.decoder_s1=Decoder(self.Ns, self.Ss, unet_args)
self.cropconcat = CropConcatBlock()
self.cropadd = CropAddBlock()
self.finalblock=FinalBlock()
if self.num_stages>1:
self.sam_1=SAM(self.Ns[0])
#initial feature extractor
ksize=(7,7)
self.paddings_2=get_paddings(ksize)
self.conv2d_2 = layers.Conv2D(filters=self.Ns[0],
kernel_size=ksize,
kernel_initializer=TruncatedNormal(),
strides=1,
padding='VALID',
activation=self.activation)
self.encoder_s2=Encoder(self.Ns, self.Ss, unet_args)
self.decoder_s2=Decoder(self.Ns, self.Ss, unet_args)
@tf.function()
def call(self, inputs):
if self.use_fencoding:
x_w_freq=self.freq_encoding(inputs) #None, None, 1025, 12
else:
x_w_freq=inputs
#intitial feature extractor
x=tf.pad(x_w_freq, self.paddings_1, mode='SYMMETRIC')
x=self.conv2d_1(x) #None, None, 1025, 32
x, contracting_layers_s1= self.encoder_s1(x)
#decoder
feats_s1 =self.decoder_s1(x, contracting_layers_s1) #None, None, 1025, 32 features
if self.num_stages>1:
#SAM module
Fout, pred_stage_1=self.sam_1(feats_s1,inputs)
#intitial feature extractor
x=tf.pad(x_w_freq, self.paddings_2, mode='SYMMETRIC')
x=self.conv2d_2(x)
if self.use_sam:
x = tf.concat([x, Fout], axis=-1)
else:
x = tf.concat([x,feats_s1], axis=-1)
x, contracting_layers_s2= self.encoder_s2(x)
feats_s2=self.decoder_s2(x, contracting_layers_s2) #None, None, 1025, 32 features
#consider implementing a third stage?
pred_stage_2=self.finalblock(feats_s2)
return pred_stage_2, pred_stage_1
else:
pred_stage_1=self.finalblock(feats_s1)
return pred_stage_1
class I_Block(layers.Layer):
'''
[B, T, F, N] => [B, T, F, N]
Intermediate block:
Basically, a densenet block with a residual connection
'''
def __init__(self,N,activation, num_tfc, **kwargs):
super(I_Block, self).__init__(**kwargs)
ksize=(3,3)
self.tfc=DenseBlock(num_tfc,N,ksize, activation)
self.conv2d_res= layers.Conv2D(filters=N,
kernel_size=(1,1),
kernel_initializer=TruncatedNormal(),
strides=1,
padding='VALID')
def call(self,inputs):
x=self.tfc(inputs)
inputs_proj=self.conv2d_res(inputs)
return layers.Add()([x,inputs_proj])
class E_Block(layers.Layer):
def __init__(self, layer_idx,N0, N, S,activation, num_tfc, **kwargs):
super(E_Block, self).__init__(**kwargs)
self.layer_idx=layer_idx
self.N0=N0
self.N=N
self.S=S
self.activation=activation
self.i_block=I_Block(N0,activation,num_tfc)
ksize=(S[0]+2,S[1]+2)
self.paddings_2=get_paddings(ksize)
self.conv2d_2 = layers.Conv2D(filters=N,
kernel_size=(S[0]+2,S[1]+2),
kernel_initializer=TruncatedNormal(),
strides=S,
padding='VALID',
activation=self.activation)
def call(self, inputs, training=None, **kwargs):
x=self.i_block(inputs)
x_down=tf.pad(x, self.paddings_2, mode='SYMMETRIC')
x_down = self.conv2d_2(x_down)
return x_down, x
def get_config(self):
return dict(layer_idx=self.layer_idx,
N=self.N,
S=self.S,
**super(E_Block, self).get_config()
)
class D_Block(layers.Layer):
def __init__(self, layer_idx, N, S,activation, num_tfc, **kwargs):
super(D_Block, self).__init__(**kwargs)
self.layer_idx=layer_idx
self.N=N
self.S=S
self.activation=activation
ksize=(S[0]+2, S[1]+2)
self.paddings_1=get_paddings(ksize)
self.tconv_1= layers.Conv2DTranspose(filters=N,
kernel_size=(S[0]+2, S[1]+2),
kernel_initializer=TruncatedNormal(),
strides=S,
activation=self.activation,
padding='VALID')
self.upsampling = layers.UpSampling2D(size=S, interpolation='nearest')
self.projection = layers.Conv2D(filters=N,
kernel_size=(1,1),
kernel_initializer=TruncatedNormal(),
strides=1,
activation=self.activation,
padding='VALID')
self.cropadd=CropAddBlock()
self.cropconcat=CropConcatBlock()
self.i_block=I_Block(N,activation,num_tfc)
def call(self, inputs, bridge, previous_encoder=None, previous_decoder=None,**kwargs):
x = inputs
x=tf.pad(x, self.paddings_1, mode='SYMMETRIC')
x = self.tconv_1(inputs)
x2= self.upsampling(inputs)
if x2.shape[-1]!=x.shape[-1]:
x2= self.projection(x2)
x= self.cropadd(x,x2)
x=self.cropconcat(x,bridge)
x=self.i_block(x)
return x
def get_config(self):
return dict(layer_idx=self.layer_idx,
N=self.N,
S=self.S,
**super(D_Block, self).get_config()
)
class CropAddBlock(layers.Layer):
def call(self,down_layer, x, **kwargs):
x1_shape = tf.shape(down_layer)
x2_shape = tf.shape(x)
height_diff = (x1_shape[1] - x2_shape[1]) // 2
width_diff = (x1_shape[2] - x2_shape[2]) // 2
down_layer_cropped = down_layer[:,
height_diff: (x2_shape[1] + height_diff),
width_diff: (x2_shape[2] + width_diff),
:]
x = layers.Add()([down_layer_cropped, x])
return x
class CropConcatBlock(layers.Layer):
def call(self, down_layer, x, **kwargs):
x1_shape = tf.shape(down_layer)
x2_shape = tf.shape(x)
height_diff = (x1_shape[1] - x2_shape[1]) // 2
width_diff = (x1_shape[2] - x2_shape[2]) // 2
down_layer_cropped = down_layer[:,
height_diff: (x2_shape[1] + height_diff),
width_diff: (x2_shape[2] + width_diff),
:]
x = tf.concat([down_layer_cropped, x], axis=-1)
return x