From c4cbdd2b8abb3c59ccf852c5a933c9ac03c1584e Mon Sep 17 00:00:00 2001 From: Moliner Eloi Date: Mon, 30 Aug 2021 18:30:51 +0300 Subject: [PATCH] startin again --- dataset_loader.py | 497 ++++++++++++++++++++++++++++++++++++++++++++++ inference.py | 164 +++++++++++++++ test.py | 151 ++++++++++++++ tester.py | 391 ++++++++++++++++++++++++++++++++++++ train.py | 171 ++++++++++++++++ trainer.py | 71 +++++++ unet.py | 486 +++++++++++++++++++++++++++++++++++++++++++++ 7 files changed, 1931 insertions(+) create mode 100644 dataset_loader.py create mode 100644 inference.py create mode 100644 test.py create mode 100644 tester.py create mode 100644 train.py create mode 100644 trainer.py create mode 100644 unet.py diff --git a/dataset_loader.py b/dataset_loader.py new file mode 100644 index 0000000..e10e0b6 --- /dev/null +++ b/dataset_loader.py @@ -0,0 +1,497 @@ +from typing import Tuple, Dict +import ast + +import tensorflow as tf +import random +import os +import numpy as np +from scipy.fft import fft, ifft +import soundfile as sf +import librosa +import math +import pandas as pd +import scipy as sp +import glob +from tqdm import tqdm + +#generator function. It reads the csv file with pandas and loads the largest audio segments from each recording. If extend=False, it will only read the segments with length>length_seg, trim them and yield them with no further processing. Otherwise, if the segment length is inferior, it will extend the length using concatenative synthesis. +def __noise_sample_generator(info_file,fs, length_seq, split): + head=os.path.split(info_file)[0] + load_data=pd.read_csv(info_file) + #split= train, validation, test + load_data_split=load_data.loc[load_data["split"]==split] + load_data_split=load_data_split.reset_index(drop=True) + while True: + r = list(range(len(load_data_split))) + if split!="test": + random.shuffle(r) + for i in r: + segments=ast.literal_eval(load_data_split.loc[i,"segments"]) + if split=="test": + loaded_data, Fs=sf.read(os.path.join(head,load_data_split["recording"].loc[i],load_data_split["largest_segment"].loc[i])) + else: + num=np.random.randint(0,len(segments)) + loaded_data, Fs=sf.read(os.path.join(head,load_data_split["recording"].loc[i],segments[num])) + + if fs!=Fs: + print("wrong fs, resampling...") + data=librosa.resample(loaded_data, Fs, fs) + + yield __extend_sample_by_repeating(loaded_data,fs,length_seq) + +def __extend_sample_by_repeating(data, fs,seq_len): + rpm=78 + target_samp=seq_len + large_data=np.zeros(shape=(target_samp,2)) + + if len(data)>=target_samp: + large_data=data[0:target_samp] + return large_data + + bls=(1000*44100)/1000 #hardcoded + + window=np.stack((np.hanning(bls) ,np.hanning(bls)), axis=1) + window_left=window[0:int(bls/2),:] + window_right=window[int(bls/2)::,:] + bls=int(bls/2) + + rps=rpm/60 + period=1/rps + + period_sam=int(period*fs) + + overhead=len(data)%period_sam + + if(overhead>bls): + complete_periods=(len(data)//period_sam)*period_sam + else: + complete_periods=(len(data)//period_sam -1)*period_sam + + + a=np.multiply(data[0:bls], window_left) + b=np.multiply(data[complete_periods:complete_periods+bls], window_right) + c_1=np.concatenate((data[0:complete_periods,:],b)) + c_2=np.concatenate((a,data[bls:complete_periods,:],b)) + c_3=np.concatenate((a,data[bls::,:])) + + large_data[0:complete_periods+bls,:]=c_1 + + + pointer=complete_periods + not_finished=True + while (not_finished): + if target_samp>pointer+complete_periods+bls: + large_data[pointer:pointer+complete_periods+bls] +=c_2 + pointer+=complete_periods + else: + large_data[pointer::]+=c_3[0:(target_samp-pointer)] + #finish + not_finished=False + + return large_data + + +def generate_real_recordings_data(path_recordings, fs=44100, seg_len_s=15, stereo=False): + + records_info=os.path.join(path_recordings,"audio_files.txt") + num_lines = sum(1 for line in open(records_info)) + f = open(records_info,"r") + #load data record files + print("Loading record files") + records=[] + seg_len=fs*seg_len_s + pointer=int(fs*5) #starting at second 5 by default + for i in tqdm(range(num_lines)): + audio=f.readline() + audio=audio[:-1] + data, fs=sf.read(os.path.join(path_recordings,audio)) + if len(data.shape)>1 and not(stereo): + data=np.mean(data,axis=1) + #elif stereo and len(data.shape)==1: + # data=np.stack((data, data), axis=1) + + #normalize + data=data/np.max(np.abs(data)) + segment=data[pointer:pointer+seg_len] + records.append(segment.astype("float32")) + + return records + +def generate_paired_data_test_formal(path_pianos, path_noises, noise_amount="low_snr",num_samples=-1, fs=44100, seg_len_s=5 , extend=True, stereo=False, prenoise=False): + + print(num_samples) + segments_clean=[] + segments_noisy=[] + seg_len=fs*seg_len_s + noises_info=os.path.join(path_noises,"info.csv") + np.random.seed(42) + if noise_amount=="low_snr": + SNRs=np.random.uniform(2,6,num_samples) + elif noise_amount=="mid_snr": + SNRs=np.random.uniform(6,12,num_samples) + + scales=np.random.uniform(-4,0,num_samples) + #SNRs=[2,6,12] #HARDCODED!!!! + i=0 + print(path_pianos[0]) + print(seg_len) + train_samples=glob.glob(os.path.join(path_pianos[0],"*.wav")) + train_samples=sorted(train_samples) + + if prenoise: + noise_generator=__noise_sample_generator(noises_info,fs, seg_len+fs, extend, "test") #Adds 1s of silence add the begiing, longer noise + else: + noise_generator=__noise_sample_generator(noises_info,fs, seg_len, extend, "test") #this will take care of everything + #load data clean files + for file in tqdm(train_samples): #add [1:5] for testing + data_clean, samplerate = sf.read(file) + if samplerate!=fs: + print("!!!!WRONG SAMPLE RATe!!!") + #Stereo to mono + if len(data_clean.shape)>1 and not(stereo): + data_clean=np.mean(data_clean,axis=1) + #elif stereo and len(data_clean.shape)==1: + # data_clean=np.stack((data_clean, data_clean), axis=1) + #normalize + data_clean=data_clean/np.max(np.abs(data_clean)) + #data_clean_loaded.append(data_clean) + + #framify data clean files + + #framify arguments: seg_len, hop_size + hop_size=int(seg_len)# no overlap + + num_frames=np.floor(len(data_clean)/hop_size - seg_len/hop_size +1) + print(num_frames) + if num_frames==0: + data_clean=np.concatenate((data_clean, np.zeros(shape=(int(2*seg_len-len(data_clean)),))), axis=0) + num_frames=1 + + data_not_finished=True + pointer=0 + while(data_not_finished): + if i>=num_samples: + break + segment=data_clean[pointer:pointer+seg_len] + pointer=pointer+hop_size + if pointer+seg_len>len(data_clean): + data_not_finished=False + segment=segment.astype('float32') + + #SNRs=np.random.uniform(2,20) + snr=SNRs[i] + scale=scales[i] + #load noise signal + data_noise= next(noise_generator) + data_noise=np.mean(data_noise,axis=1) + #normalize + data_noise=data_noise/np.max(np.abs(data_noise)) + new_noise=data_noise #if more processing needed, add here + #load clean data + #configure sizes + power_clean=np.var(segment) + #estimate noise power + if prenoise: + power_noise=np.var(new_noise[fs::]) + else: + power_noise=np.var(new_noise) + + snr = 10.0**(snr/10.0) + + #sum both signals according to snr + if prenoise: + segment=np.concatenate((np.zeros(shape=(fs,)),segment),axis=0) #add one second of silence + summed=segment+np.sqrt(power_clean/(snr*power_noise))*new_noise #not sure if this is correct, maybe revisit later!! + + summed=summed.astype('float32') + #yield tf.convert_to_tensor(summed), tf.convert_to_tensor(segment) + + + summed=10.0**(scale/10.0) *summed + segment=10.0**(scale/10.0) *segment + segments_noisy.append(summed.astype('float32')) + segments_clean.append(segment.astype('float32')) + i=i+1 + + return segments_noisy, segments_clean + +def generate_test_data(path_music, path_noises,num_samples=-1, fs=44100, seg_len_s=5): + + segments_clean=[] + segments_noisy=[] + seg_len=fs*seg_len_s + noises_info=os.path.join(path_noises,"info.csv") + SNRs=[2,6,12] #HARDCODED!!!! + for path in path_music: + print(path) + train_samples=glob.glob(os.path.join(path,"*.wav")) + train_samples=sorted(train_samples) + + noise_generator=__noise_sample_generator(noises_info,fs, seg_len, "test") #this will take care of everything + #load data clean files + jj=0 + for file in tqdm(train_samples): #add [1:5] for testing + data_clean, samplerate = sf.read(file) + if samplerate!=fs: + print("!!!!WRONG SAMPLE RATe!!!") + #Stereo to mono + if len(data_clean.shape)>1: + data_clean=np.mean(data_clean,axis=1) + #normalize + data_clean=data_clean/np.max(np.abs(data_clean)) + #data_clean_loaded.append(data_clean) + + #framify data clean files + + #framify arguments: seg_len, hop_size + hop_size=int(seg_len)# no overlap + + num_frames=np.floor(len(data_clean)/hop_size - seg_len/hop_size +1) + if num_frames==0: + data_clean=np.concatenate((data_clean, np.zeros(shape=(int(2*seg_len-len(data_clean)),))), axis=0) + num_frames=1 + + pointer=0 + segment=data_clean[pointer:pointer+(seg_len-2*fs)] + segment=segment.astype('float32') + segment=np.concatenate(( np.zeros(shape=(2*fs,)), segment), axis=0) #I hope its ok + #segments_clean.append(segment) + + for snr in SNRs: + #load noise signal + data_noise= next(noise_generator) + data_noise=np.mean(data_noise,axis=1) + #normalize + data_noise=data_noise/np.max(np.abs(data_noise)) + new_noise=data_noise #if more processing needed, add here + #load clean data + #configure sizes + #estimate clean signal power + power_clean=np.var(segment) + #estimate noise power + power_noise=np.var(new_noise) + + snr = 10.0**(snr/10.0) + + #sum both signals according to snr + summed=segment+np.sqrt(power_clean/(snr*power_noise))*new_noise #not sure if this is correct, maybe revisit later!! + summed=summed.astype('float32') + #yield tf.convert_to_tensor(summed), tf.convert_to_tensor(segment) + + segments_noisy.append(summed.astype('float32')) + segments_clean.append(segment.astype('float32')) + + return segments_noisy, segments_clean + +def generate_val_data(path_music, path_noises,split,num_samples=-1, fs=44100, seg_len_s=5): + + val_samples=[] + for path in path_music: + val_samples.extend(glob.glob(os.path.join(path,"*.wav"))) + + #load data clean files + print("Loading clean files") + data_clean_loaded=[] + for ff in tqdm(range(0,len(val_samples))): #add [1:5] for testing + data_clean, samplerate = sf.read(val_samples[ff]) + if samplerate!=fs: + print("!!!!WRONG SAMPLE RATe!!!") + #Stereo to mono + if len(data_clean.shape)>1 : + data_clean=np.mean(data_clean,axis=1) + #normalize + data_clean=data_clean/np.max(np.abs(data_clean)) + data_clean_loaded.append(data_clean) + del data_clean + + #framify data clean files + print("Framifying clean files") + seg_len=fs*seg_len_s + segments_clean=[] + for file in tqdm(data_clean_loaded): + + #framify arguments: seg_len, hop_size + hop_size=int(seg_len)# no overlap + + num_frames=np.floor(len(file)/hop_size - seg_len/hop_size +1) + pointer=0 + for i in range(0,int(num_frames)): + segment=file[pointer:pointer+seg_len] + pointer=pointer+hop_size + segment=segment.astype('float32') + segments_clean.append(segment) + + del data_clean_loaded + + SNRs=np.random.uniform(2,20,len(segments_clean)) + scales=np.random.uniform(-6,4,len(segments_clean)) + #noise_shapes=np.random.randint(0,len(noise_samples), len(segments_clean)) + noises_info=os.path.join(path_noises,"info.csv") + + noise_generator=__noise_sample_generator(noises_info,fs, seg_len, split) #this will take care of everything + + + #generate noisy segments + #load noise samples using pandas dataframe. Each split (train, val, test) should have its unique csv info file + + #noise_samples=glob.glob(os.path.join(path_noises,"*.wav")) + segments_noisy=[] + print("Processing noisy segments") + + for i in tqdm(range(0,len(segments_clean))): + #load noise signal + data_noise= next(noise_generator) + #Stereo to mono + data_noise=np.mean(data_noise,axis=1) + #normalize + data_noise=data_noise/np.max(np.abs(data_noise)) + new_noise=data_noise #if more processing needed, add here + #load clean data + data_clean=segments_clean[i] + #configure sizes + + + #estimate clean signal power + power_clean=np.var(data_clean) + #estimate noise power + power_noise=np.var(new_noise) + + snr = 10.0**(SNRs[i]/10.0) + + #sum both signals according to snr + summed=data_clean+np.sqrt(power_clean/(snr*power_noise))*new_noise #not sure if this is correct, maybe revisit later!! + #the rest is normal + + summed=10.0**(scales[i]/10.0) *summed + segments_clean[i]=10.0**(scales[i]/10.0) *segments_clean[i] + + segments_noisy.append(summed.astype('float32')) + + return segments_noisy, segments_clean + + + +def generator_train(path_music, path_noises,split, fs=44100, seg_len_s=5, extend=True, stereo=False): + + train_samples=[] + for path in path_music: + train_samples.extend(glob.glob(os.path.join(path.decode("utf-8") ,"*.wav"))) + + seg_len=fs*seg_len_s + noises_info=os.path.join(path_noises.decode("utf-8"),"info.csv") + noise_generator=__noise_sample_generator(noises_info,fs, seg_len, split.decode("utf-8")) #this will take care of everything + #load data clean files + while True: + random.shuffle(train_samples) + for file in train_samples: + data, samplerate = sf.read(file) + if samplerate!=fs: + print("!!!!WRONG SAMPLE RATe!!!") + data=np.transpose(data) + data=librosa.resample(data, samplerate, 44100) + data=np.transpose(data) + data_clean=data + #Stereo to mono + if len(data.shape)>1 : + data_clean=np.mean(data_clean,axis=1) + + #normalize + data_clean=data_clean/np.max(np.abs(data_clean)) + + #framify data clean files + + #framify arguments: seg_len, hop_size + hop_size=int(seg_len) + + num_frames=np.floor(len(data_clean)/seg_len) + if num_frames==0: + data_clean=np.concatenate((data_clean, np.zeros(shape=(int(2*seg_len-len(data_clean)),))), axis=0) + num_frames=1 + pointer=0 + data_clean=np.roll(data_clean, np.random.randint(0,seg_len)) #if only one frame, roll it for augmentation + elif num_frames>1: + pointer=np.random.randint(0,hop_size) #initial shifting, graeat for augmentation, better than overlap as we get different frames at each "while" iteration + else: + pointer=0 + + data_not_finished=True + while(data_not_finished): + segment=data_clean[pointer:pointer+seg_len] + pointer=pointer+hop_size + if pointer+seg_len>len(data_clean): + data_not_finished=False + segment=segment.astype('float32') + + SNRs=np.random.uniform(2,20) + scale=np.random.uniform(-6,4) + + + #load noise signal + data_noise= next(noise_generator) + data_noise=np.mean(data_noise,axis=1) + #normalize + data_noise=data_noise/np.max(np.abs(data_noise)) + new_noise=data_noise #if more processing needed, add here + #load clean data + #configure sizes + if stereo: + #estimate clean signal power + power_clean=0.5*np.var(segment[:,0])+0.5*np.var(segment[:,1]) + #estimate noise power + power_noise=0.5*np.var(new_noise[:,0])+0.5*np.var(new_noise[:,1]) + else: + #estimate clean signal power + power_clean=np.var(segment) + #estimate noise power + power_noise=np.var(new_noise) + + snr = 10.0**(SNRs/10.0) + + + #sum both signals according to snr + summed=segment+np.sqrt(power_clean/(snr*power_noise))*new_noise #not sure if this is correct, maybe revisit later!! + summed=10.0**(scale/10.0) *summed + segment=10.0**(scale/10.0) *segment + + summed=summed.astype('float32') + yield tf.convert_to_tensor(summed), tf.convert_to_tensor(segment) + +def load_data(buffer_size, path_music_train, path_music_val, path_noises, fs=44100, seg_len_s=5, extend=True, stereo=False) -> Tuple[tf.data.Dataset, tf.data.Dataset]: + print("Generating train dataset") + trainshape=int(fs*seg_len_s) + + dataset_train = tf.data.Dataset.from_generator(generator_train,args=(path_music_train, path_noises,"train", fs, seg_len_s, extend, stereo), output_shapes=(tf.TensorShape((trainshape,)),tf.TensorShape((trainshape,))), output_types=(tf.float32, tf.float32) ) + + + print("Generating validation dataset") + segments_noisy, segments_clean=generate_val_data(path_music_val, path_noises,"validation",fs=fs, seg_len_s=seg_len_s) + + dataset_val=tf.data.Dataset.from_tensor_slices((segments_noisy, segments_clean)) + + return dataset_train.shuffle(buffer_size), dataset_val + +def load_data_test(buffer_size, path_pianos_test, path_noises, **kwargs) -> Tuple[tf.data.Dataset]: + print("Generating test dataset") + segments_noisy, segments_clean=generate_test_data(path_pianos_test, path_noises, extend=True, **kwargs) + dataset_test=tf.data.Dataset.from_tensor_slices((segments_noisy, segments_clean)) + #dataset_test=tf.data.Dataset.from_tensor_slices((segments_noisy[1:3], segments_clean[1:3])) + #train_dataset = train.cache().shuffle(buffer_size).take(info.splits["train"].num_examples) + return dataset_test +def load_data_formal( path_pianos_test, path_noises, **kwargs) -> Tuple[tf.data.Dataset]: + print("Generating test dataset") + segments_noisy, segments_clean=generate_paired_data_test_formal(path_pianos_test, path_noises, extend=True, **kwargs) + print("segments::") + print(len(segments_noisy)) + dataset_test=tf.data.Dataset.from_tensor_slices((segments_noisy, segments_clean)) + #dataset_test=tf.data.Dataset.from_tensor_slices((segments_noisy[1:3], segments_clean[1:3])) + #train_dataset = train.cache().shuffle(buffer_size).take(info.splits["train"].num_examples) + return dataset_test + +def load_real_test_recordings(buffer_size, path_recordings, **kwargs) -> Tuple[tf.data.Dataset]: + print("Generating real test dataset") + + segments_noisy=generate_real_recordings_data(path_recordings, **kwargs) + + dataset_test=tf.data.Dataset.from_tensor_slices(segments_noisy) + #train_dataset = train.cache().shuffle(buffer_size).take(info.splits["train"].num_examples) + return dataset_test diff --git a/inference.py b/inference.py new file mode 100644 index 0000000..25d855c --- /dev/null +++ b/inference.py @@ -0,0 +1,164 @@ +import os +import hydra +import logging + +logger = logging.getLogger(__name__) + +def run(args): + import unet + import tensorflow as tf + import soundfile as sf + import numpy as np + from tqdm import tqdm + import librosa + + path_experiment=str(args.path_experiment) + + unet_model = unet.build_model_denoise(unet_args=args.unet) + + ckpt=os.path.join(os.path.dirname(os.path.abspath(__file__)),path_experiment, 'checkpoint') + unet_model.load_weights(ckpt) + + def do_stft(noisy): + + window_fn = tf.signal.hamming_window + + win_size=args.stft.win_size + hop_size=args.stft.hop_size + + + stft_signal_noisy=tf.signal.stft(noisy,frame_length=win_size, window_fn=window_fn, frame_step=hop_size, pad_end=True) + stft_noisy_stacked=tf.stack( values=[tf.math.real(stft_signal_noisy), tf.math.imag(stft_signal_noisy)], axis=-1) + + return stft_noisy_stacked + + def do_istft(data): + + window_fn = tf.signal.hamming_window + + win_size=args.stft.win_size + hop_size=args.stft.hop_size + + inv_window_fn=tf.signal.inverse_stft_window_fn(hop_size, forward_window_fn=window_fn) + + pred_cpx=data[...,0] + 1j * data[...,1] + pred_time=tf.signal.inverse_stft(pred_cpx, win_size, hop_size, window_fn=inv_window_fn) + return pred_time + + audio=str(args.inference.audio) + data, samplerate = sf.read(audio) + if samplerate!=44100: + print("Resampling") + data=np.transpose(data) + data=librosa.resample(data, samplerate, 44100) + data=np.transpose(data) + + #Stereo to mono + if len(data.shape)>1: + data=np.mean(data,axis=1) + + + segment_size=44101*20 #20s segments + + length_data=len(data) + overlapsize=2048 #samples (46 ms) + window=np.hanning(2*overlapsize) + window_right=window[overlapsize::] + window_left=window[0:overlapsize] + audio_finished=False + pointer=0 + denoised_data=np.zeros(shape=(len(data),)) + residual_noise=np.zeros(shape=(len(data),)) + numchunks=int(np.ceil(length_data/segment_size)) + + for i in tqdm(range(numchunks)): + if pointer+segment_size [B, T, F, N] + DenseNet Block consisting of "num_layers" densely connected convolutional layers + ''' + def __init__(self, num_layers, N, ksize,activation): + ''' + num_layers: number of densely connected conv. layers + N: Number of filters (same in each layer) + ksize: Kernel size (same in each layer) + ''' + super(DenseBlock, self).__init__() + self.activation=activation + + self.paddings_1=get_paddings(ksize) + self.H=[] + self.num_layers=num_layers + + for i in range(num_layers): + self.H.append(layers.Conv2D(filters=N, + kernel_size=ksize, + kernel_initializer=TruncatedNormal(), + strides=1, + padding='VALID', + activation=self.activation)) + + def call(self, x): + + x_=tf.pad(x, self.paddings_1, mode='SYMMETRIC') + x_ = self.H[0](x_) + if self.num_layers>1: + for h in self.H[1:]: + x = tf.concat([x_, x], axis=-1) + x_=tf.pad(x, self.paddings_1, mode='SYMMETRIC') + x_ = h(x_) + + return x_ + + +class FinalBlock(layers.Layer): + ''' + [B, T, F, N] => [B, T, F, 2] + Final block. Basically, a 3x3 conv. layer to map the output features to the output complex spectrogram. + + ''' + def __init__(self): + super(FinalBlock, self).__init__() + ksize=(3,3) + self.paddings_2=get_paddings(ksize) + self.conv2=layers.Conv2D(filters=2, + kernel_size=ksize, + kernel_initializer=TruncatedNormal(), + strides=1, + padding='VALID', + activation=None) + + + def call(self, inputs ): + + x=tf.pad(inputs, self.paddings_2, mode='SYMMETRIC') + pred=self.conv2(x) + + return pred +class SAM(layers.Layer): + ''' + [B, T, F, N] => [B, T, F, N] , [B, T, F, N] + Supervised Attention Module: + The purpose of SAM is to make the network only propagate the most relevant features to the second stage, discarding the less useful ones. + The estimated residual noise signal is generated from the U-Net output features by means of a 3x3 convolutional layer. + The first stage output is then calculated adding the original input spectrogram to the residual noise. + The attention-guided features are computed using the attention masks M, which are directly calculated from the first stage output with a 1x1 convolution and a sigmoid function. + + ''' + def __init__(self, n_feat): + super(SAM, self).__init__() + + ksize=(3,3) + self.paddings_1=get_paddings(ksize) + self.conv1 = layers.Conv2D(filters=n_feat, + kernel_size=ksize, + kernel_initializer=TruncatedNormal(), + strides=1, + padding='VALID', + activation=None) + ksize=(3,3) + self.paddings_2=get_paddings(ksize) + self.conv2=layers.Conv2D(filters=2, + kernel_size=ksize, + kernel_initializer=TruncatedNormal(), + strides=1, + padding='VALID', + activation=None) + + ksize=(3,3) + self.paddings_3=get_paddings(ksize) + self.conv3 = layers.Conv2D(filters=n_feat, + kernel_size=ksize, + kernel_initializer=TruncatedNormal(), + strides=1, + padding='VALID', + activation=None) + self.cropadd=CropAddBlock() + + def call(self, inputs, input_spectrogram): + x1=tf.pad(inputs, self.paddings_1, mode='SYMMETRIC') + x1 = self.conv1(x1) + + x=tf.pad(inputs, self.paddings_2, mode='SYMMETRIC') + x=self.conv2(x) + + #residual prediction + pred = layers.Add()([x, input_spectrogram]) #features to next stage + + x3=tf.pad(pred, self.paddings_3, mode='SYMMETRIC') + M=self.conv3(x3) + + M= tf.keras.activations.sigmoid(M) + x1=layers.Multiply()([x1, M]) + x1 = layers.Add()([x1, inputs]) #features to next stage + + return x1, pred + + +class AddFreqEncoding(layers.Layer): + ''' + [B, T, F, 2] => [B, T, F, 12] + Generates frequency positional embeddings and concatenates them as 10 extra channels + This function is optimized for F=1025 + ''' + def __init__(self, f_dim): + super(AddFreqEncoding, self).__init__() + pi = tf.constant(m.pi) + pi=tf.cast(pi,'float32') + self.f_dim=f_dim #f_dim is fixed + n=tf.cast(tf.range(f_dim)/(f_dim-1),'float32') + coss=tf.math.cos(pi*n) + f_channel = tf.expand_dims(coss, -1) #(1025,1) + self.fembeddings= f_channel + + for k in range(1,10): + coss=tf.math.cos(2**k*pi*n) + f_channel = tf.expand_dims(coss, -1) #(1025,1) + self.fembeddings=tf.concat([self.fembeddings,f_channel],axis=-1) #(1025,10) + + + def call(self, input_tensor): + + batch_size_tensor = tf.shape(input_tensor)[0] # get batch size + time_dim = tf.shape(input_tensor)[1] # get time dimension + + fembeddings_2 = tf.broadcast_to(self.fembeddings, [batch_size_tensor, time_dim, self.f_dim, 10]) + + + return tf.concat([input_tensor,fembeddings_2],axis=-1) #(batch,427,1025,12) + + +def get_paddings(K): + return tf.constant([[0,0],[K[0]//2, K[0]//2 -(1- K[0]%2) ], [ K[1]//2, K[1]//2 -(1- K[1]%2) ],[0,0]]) + +class Decoder(layers.Layer): + ''' + [B, T, F, N] , skip connections => [B, T, F, N] + Decoder side of the U-Net subnetwork. + ''' + def __init__(self, Ns, Ss, unet_args): + super(Decoder, self).__init__() + + self.Ns=Ns + self.Ss=Ss + self.activation=unet_args.activation + self.depth=unet_args.depth + + + ksize=(3,3) + self.paddings_3=get_paddings(ksize) + self.conv2d_3=layers.Conv2D(filters=self.Ns[self.depth], + kernel_size=ksize, + kernel_initializer=TruncatedNormal(), + strides=1, + padding='VALID', + activation=self.activation) + + self.cropadd=CropAddBlock() + + self.dblocks=[] + for i in range(self.depth): + self.dblocks.append(D_Block(layer_idx=i,N=self.Ns[i], S=self.Ss[i], activation=self.activation,num_tfc=unet_args.num_tfc)) + + def call(self,inputs, contracting_layers): + x=inputs + for i in range(self.depth,0,-1): + x=self.dblocks[i-1](x, contracting_layers[i-1]) + return x + +class Encoder(tf.keras.Model): + + ''' + [B, T, F, N] => skip connections , [B, T, F, N_4] + Encoder side of the U-Net subnetwork. + ''' + def __init__(self, Ns, Ss, unet_args): + super(Encoder, self).__init__() + self.Ns=Ns + self.Ss=Ss + self.activation=unet_args.activation + self.depth=unet_args.depth + + self.contracting_layers = {} + + self.eblocks=[] + for i in range(self.depth): + self.eblocks.append(E_Block(layer_idx=i,N0=self.Ns[i],N=self.Ns[i+1],S=self.Ss[i], activation=self.activation , num_tfc=unet_args.num_tfc)) + + self.i_block=I_Block(self.Ns[self.depth],self.activation,unet_args.num_tfc) + + def call(self, inputs): + x=inputs + for i in range(self.depth): + + x, x_contract=self.eblocks[i](x) + + self.contracting_layers[i] = x_contract #if remove 0, correct this + x=self.i_block(x) + + return x, self.contracting_layers + +class MultiStage_denoise(tf.keras.Model): + + def __init__(self, unet_args=None): + super(MultiStage_denoise, self).__init__() + + self.activation=unet_args.activation + self.depth=unet_args.depth + if unet_args.use_fencoding: + self.freq_encoding=AddFreqEncoding(unet_args.f_dim) + self.use_sam=unet_args.use_SAM + self.use_fencoding=unet_args.use_fencoding + self.num_stages=unet_args.num_stages + #Encoder + self.Ns= [32,64,64,128,128,256,512] + self.Ss= [(2,2),(2,2),(2,2),(2,2),(2,2),(2,2)] + + #initial feature extractor + ksize=(7,7) + self.paddings_1=get_paddings(ksize) + self.conv2d_1 = layers.Conv2D(filters=self.Ns[0], + kernel_size=ksize, + kernel_initializer=TruncatedNormal(), + strides=1, + padding='VALID', + activation=self.activation) + + + self.encoder_s1=Encoder(self.Ns, self.Ss, unet_args) + self.decoder_s1=Decoder(self.Ns, self.Ss, unet_args) + + self.cropconcat = CropConcatBlock() + self.cropadd = CropAddBlock() + + self.finalblock=FinalBlock() + + if self.num_stages>1: + self.sam_1=SAM(self.Ns[0]) + + #initial feature extractor + ksize=(7,7) + self.paddings_2=get_paddings(ksize) + self.conv2d_2 = layers.Conv2D(filters=self.Ns[0], + kernel_size=ksize, + kernel_initializer=TruncatedNormal(), + strides=1, + padding='VALID', + activation=self.activation) + + + self.encoder_s2=Encoder(self.Ns, self.Ss, unet_args) + self.decoder_s2=Decoder(self.Ns, self.Ss, unet_args) + + @tf.function() + def call(self, inputs): + + if self.use_fencoding: + x_w_freq=self.freq_encoding(inputs) #None, None, 1025, 12 + else: + x_w_freq=inputs + + #intitial feature extractor + x=tf.pad(x_w_freq, self.paddings_1, mode='SYMMETRIC') + x=self.conv2d_1(x) #None, None, 1025, 32 + + x, contracting_layers_s1= self.encoder_s1(x) + #decoder + feats_s1 =self.decoder_s1(x, contracting_layers_s1) #None, None, 1025, 32 features + + if self.num_stages>1: + #SAM module + Fout, pred_stage_1=self.sam_1(feats_s1,inputs) + + #intitial feature extractor + x=tf.pad(x_w_freq, self.paddings_2, mode='SYMMETRIC') + x=self.conv2d_2(x) + + if self.use_sam: + x = tf.concat([x, Fout], axis=-1) + else: + x = tf.concat([x,feats_s1], axis=-1) + + x, contracting_layers_s2= self.encoder_s2(x) + + feats_s2=self.decoder_s2(x, contracting_layers_s2) #None, None, 1025, 32 features + + #consider implementing a third stage? + + pred_stage_2=self.finalblock(feats_s2) + return pred_stage_2, pred_stage_1 + else: + pred_stage_1=self.finalblock(feats_s1) + return pred_stage_1 + +class I_Block(layers.Layer): + ''' + [B, T, F, N] => [B, T, F, N] + Intermediate block: + Basically, a densenet block with a residual connection + ''' + def __init__(self,N,activation, num_tfc, **kwargs): + super(I_Block, self).__init__(**kwargs) + + ksize=(3,3) + self.tfc=DenseBlock(num_tfc,N,ksize, activation) + + self.conv2d_res= layers.Conv2D(filters=N, + kernel_size=(1,1), + kernel_initializer=TruncatedNormal(), + strides=1, + padding='VALID') + + def call(self,inputs): + x=self.tfc(inputs) + + inputs_proj=self.conv2d_res(inputs) + return layers.Add()([x,inputs_proj]) + + +class E_Block(layers.Layer): + + def __init__(self, layer_idx,N0, N, S,activation, num_tfc, **kwargs): + super(E_Block, self).__init__(**kwargs) + self.layer_idx=layer_idx + self.N0=N0 + self.N=N + self.S=S + self.activation=activation + self.i_block=I_Block(N0,activation,num_tfc) + + ksize=(S[0]+2,S[1]+2) + self.paddings_2=get_paddings(ksize) + self.conv2d_2 = layers.Conv2D(filters=N, + kernel_size=(S[0]+2,S[1]+2), + kernel_initializer=TruncatedNormal(), + strides=S, + padding='VALID', + activation=self.activation) + + + def call(self, inputs, training=None, **kwargs): + x=self.i_block(inputs) + + x_down=tf.pad(x, self.paddings_2, mode='SYMMETRIC') + x_down = self.conv2d_2(x_down) + + return x_down, x + + + def get_config(self): + return dict(layer_idx=self.layer_idx, + N=self.N, + S=self.S, + **super(E_Block, self).get_config() + ) +class D_Block(layers.Layer): + + def __init__(self, layer_idx, N, S,activation, num_tfc, **kwargs): + super(D_Block, self).__init__(**kwargs) + self.layer_idx=layer_idx + self.N=N + self.S=S + self.activation=activation + ksize=(S[0]+2, S[1]+2) + self.paddings_1=get_paddings(ksize) + + self.tconv_1= layers.Conv2DTranspose(filters=N, + kernel_size=(S[0]+2, S[1]+2), + kernel_initializer=TruncatedNormal(), + strides=S, + activation=self.activation, + padding='VALID') + + self.upsampling = layers.UpSampling2D(size=S, interpolation='nearest') + + self.projection = layers.Conv2D(filters=N, + kernel_size=(1,1), + kernel_initializer=TruncatedNormal(), + strides=1, + activation=self.activation, + padding='VALID') + self.cropadd=CropAddBlock() + self.cropconcat=CropConcatBlock() + + self.i_block=I_Block(N,activation,num_tfc) + + def call(self, inputs, bridge, previous_encoder=None, previous_decoder=None,**kwargs): + x = inputs + x=tf.pad(x, self.paddings_1, mode='SYMMETRIC') + x = self.tconv_1(inputs) + + x2= self.upsampling(inputs) + + if x2.shape[-1]!=x.shape[-1]: + x2= self.projection(x2) + + x= self.cropadd(x,x2) + + + x=self.cropconcat(x,bridge) + + x=self.i_block(x) + return x + + def get_config(self): + return dict(layer_idx=self.layer_idx, + N=self.N, + S=self.S, + **super(D_Block, self).get_config() + ) + +class CropAddBlock(layers.Layer): + + def call(self,down_layer, x, **kwargs): + x1_shape = tf.shape(down_layer) + x2_shape = tf.shape(x) + + + height_diff = (x1_shape[1] - x2_shape[1]) // 2 + width_diff = (x1_shape[2] - x2_shape[2]) // 2 + + down_layer_cropped = down_layer[:, + height_diff: (x2_shape[1] + height_diff), + width_diff: (x2_shape[2] + width_diff), + :] + + x = layers.Add()([down_layer_cropped, x]) + return x + +class CropConcatBlock(layers.Layer): + + def call(self, down_layer, x, **kwargs): + x1_shape = tf.shape(down_layer) + x2_shape = tf.shape(x) + + height_diff = (x1_shape[1] - x2_shape[1]) // 2 + width_diff = (x1_shape[2] - x2_shape[2]) // 2 + + down_layer_cropped = down_layer[:, + height_diff: (x2_shape[1] + height_diff), + width_diff: (x2_shape[2] + width_diff), + :] + + x = tf.concat([down_layer_cropped, x], axis=-1) + return x