From 98702405729670a2ef501bda5945447646938451 Mon Sep 17 00:00:00 2001 From: festinuz Date: Thu, 8 May 2025 12:56:32 +0300 Subject: [PATCH] fix formatting & linters --- .gitignore | 3 + dataset_loader.py | 800 +++++++++++++++++++++++++--------------------- inference.py | 260 +++++++++------ train.py | 203 +++++++----- trainer.py | 83 +++-- unet.py | 698 ++++++++++++++++++++++------------------ 6 files changed, 1143 insertions(+), 904 deletions(-) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..da90d1a --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +experiments +outputs +__pycache__ \ No newline at end of file diff --git a/dataset_loader.py b/dataset_loader.py index 1fae225..582af6c 100644 --- a/dataset_loader.py +++ b/dataset_loader.py @@ -4,485 +4,561 @@ import tensorflow as tf import random import os import numpy as np -from scipy.fft import fft, ifft import soundfile as sf -import math import pandas as pd -import scipy as sp import glob from tqdm import tqdm -#generator function. It reads the csv file with pandas and loads the largest audio segments from each recording. If extend=False, it will only read the segments with length>length_seg, trim them and yield them with no further processing. Otherwise, if the segment length is inferior, it will extend the length using concatenative synthesis. -def __noise_sample_generator(info_file,fs, length_seq, split): - head=os.path.split(info_file)[0] - load_data=pd.read_csv(info_file) - #split= train, validation, test - load_data_split=load_data.loc[load_data["split"]==split] - load_data_split=load_data_split.reset_index(drop=True) + +# generator function. It reads the csv file with pandas and loads the largest audio segments from each recording. If extend=False, it will only read the segments with length>length_seg, trim them and yield them with no further processing. Otherwise, if the segment length is inferior, it will extend the length using concatenative synthesis. +def __noise_sample_generator(info_file, fs, length_seq, split): + head = os.path.split(info_file)[0] + load_data = pd.read_csv(info_file) + # split= train, validation, test + load_data_split = load_data.loc[load_data["split"] == split] + load_data_split = load_data_split.reset_index(drop=True) while True: r = list(range(len(load_data_split))) - if split!="test": + if split != "test": random.shuffle(r) for i in r: - segments=ast.literal_eval(load_data_split.loc[i,"segments"]) - if split=="test": - loaded_data, Fs=sf.read(os.path.join(head,load_data_split["recording"].loc[i],load_data_split["largest_segment"].loc[i])) + segments = ast.literal_eval(load_data_split.loc[i, "segments"]) + if split == "test": + loaded_data, Fs = sf.read( + os.path.join( + head, + load_data_split["recording"].loc[i], + load_data_split["largest_segment"].loc[i], + ) + ) else: - num=np.random.randint(0,len(segments)) - loaded_data, Fs=sf.read(os.path.join(head,load_data_split["recording"].loc[i],segments[num])) - assert(fs==Fs, "wrong sampling rate") + num = np.random.randint(0, len(segments)) + loaded_data, Fs = sf.read( + os.path.join( + head, load_data_split["recording"].loc[i], segments[num] + ) + ) + assert fs == Fs, "wrong sampling rate" - yield __extend_sample_by_repeating(loaded_data,fs,length_seq) + yield __extend_sample_by_repeating(loaded_data, fs, length_seq) -def __extend_sample_by_repeating(data, fs,seq_len): - rpm=78 - target_samp=seq_len - large_data=np.zeros(shape=(target_samp,2)) - - if len(data)>=target_samp: - large_data=data[0:target_samp] + +def __extend_sample_by_repeating(data, fs, seq_len): + rpm = 78 + target_samp = seq_len + large_data = np.zeros(shape=(target_samp, 2)) + + if len(data) >= target_samp: + large_data = data[0:target_samp] return large_data - - bls=(1000*44100)/1000 #hardcoded - - window=np.stack((np.hanning(bls) ,np.hanning(bls)), axis=1) - window_left=window[0:int(bls/2),:] - window_right=window[int(bls/2)::,:] - bls=int(bls/2) - - rps=rpm/60 - period=1/rps - - period_sam=int(period*fs) - - overhead=len(data)%period_sam - - if(overhead>bls): - complete_periods=(len(data)//period_sam)*period_sam + + bls = (1000 * 44100) / 1000 # hardcoded + + window = np.stack((np.hanning(bls), np.hanning(bls)), axis=1) + window_left = window[0 : int(bls / 2), :] + window_right = window[int(bls / 2) : :, :] + bls = int(bls / 2) + + rps = rpm / 60 + period = 1 / rps + + period_sam = int(period * fs) + + overhead = len(data) % period_sam + + if overhead > bls: + complete_periods = (len(data) // period_sam) * period_sam else: - complete_periods=(len(data)//period_sam -1)*period_sam - - - a=np.multiply(data[0:bls], window_left) - b=np.multiply(data[complete_periods:complete_periods+bls], window_right) - c_1=np.concatenate((data[0:complete_periods,:],b)) - c_2=np.concatenate((a,data[bls:complete_periods,:],b)) - c_3=np.concatenate((a,data[bls::,:])) - - large_data[0:complete_periods+bls,:]=c_1 - - - pointer=complete_periods - not_finished=True - while (not_finished): - if target_samp>pointer+complete_periods+bls: - large_data[pointer:pointer+complete_periods+bls] +=c_2 - pointer+=complete_periods - else: - large_data[pointer::]+=c_3[0:(target_samp-pointer)] - #finish - not_finished=False + complete_periods = (len(data) // period_sam - 1) * period_sam + + a = np.multiply(data[0:bls], window_left) + b = np.multiply(data[complete_periods : complete_periods + bls], window_right) + c_1 = np.concatenate((data[0:complete_periods, :], b)) + c_2 = np.concatenate((a, data[bls:complete_periods, :], b)) + c_3 = np.concatenate((a, data[bls::, :])) + + large_data[0 : complete_periods + bls, :] = c_1 + + pointer = complete_periods + not_finished = True + while not_finished: + if target_samp > pointer + complete_periods + bls: + large_data[pointer : pointer + complete_periods + bls] += c_2 + pointer += complete_periods + else: + large_data[pointer::] += c_3[0 : (target_samp - pointer)] + # finish + not_finished = False return large_data - -def generate_real_recordings_data(path_recordings, fs=44100, seg_len_s=15, stereo=False): - records_info=os.path.join(path_recordings,"audio_files.txt") +def generate_real_recordings_data( + path_recordings, fs=44100, seg_len_s=15, stereo=False +): + records_info = os.path.join(path_recordings, "audio_files.txt") num_lines = sum(1 for line in open(records_info)) - f = open(records_info,"r") - #load data record files + f = open(records_info, "r") + # load data record files print("Loading record files") - records=[] - seg_len=fs*seg_len_s - pointer=int(fs*5) #starting at second 5 by default + records = [] + seg_len = fs * seg_len_s + pointer = int(fs * 5) # starting at second 5 by default for i in tqdm(range(num_lines)): - audio=f.readline() - audio=audio[:-1] - data, fs=sf.read(os.path.join(path_recordings,audio)) - if len(data.shape)>1 and not(stereo): - data=np.mean(data,axis=1) - #elif stereo and len(data.shape)==1: + audio = f.readline() + audio = audio[:-1] + data, fs = sf.read(os.path.join(path_recordings, audio)) + if len(data.shape) > 1 and not (stereo): + data = np.mean(data, axis=1) + # elif stereo and len(data.shape)==1: # data=np.stack((data, data), axis=1) - #normalize - data=data/np.max(np.abs(data)) - segment=data[pointer:pointer+seg_len] + # normalize + data = data / np.max(np.abs(data)) + segment = data[pointer : pointer + seg_len] records.append(segment.astype("float32")) return records -def generate_paired_data_test_formal(path_pianos, path_noises, noise_amount="low_snr",num_samples=-1, fs=44100, seg_len_s=5 , extend=True, stereo=False, prenoise=False): +def generate_paired_data_test_formal( + path_pianos, + path_noises, + noise_amount="low_snr", + num_samples=-1, + fs=44100, + seg_len_s=5, + extend=True, + stereo=False, + prenoise=False, +): print(num_samples) - segments_clean=[] - segments_noisy=[] - seg_len=fs*seg_len_s - noises_info=os.path.join(path_noises,"info.csv") + segments_clean = [] + segments_noisy = [] + seg_len = fs * seg_len_s + noises_info = os.path.join(path_noises, "info.csv") np.random.seed(42) - if noise_amount=="low_snr": - SNRs=np.random.uniform(2,6,num_samples) - elif noise_amount=="mid_snr": - SNRs=np.random.uniform(6,12,num_samples) + if noise_amount == "low_snr": + SNRs = np.random.uniform(2, 6, num_samples) + elif noise_amount == "mid_snr": + SNRs = np.random.uniform(6, 12, num_samples) - scales=np.random.uniform(-4,0,num_samples) - #SNRs=[2,6,12] #HARDCODED!!!! - i=0 + scales = np.random.uniform(-4, 0, num_samples) + # SNRs=[2,6,12] #HARDCODED!!!! + i = 0 print(path_pianos[0]) print(seg_len) - train_samples=glob.glob(os.path.join(path_pianos[0],"*.wav")) - train_samples=sorted(train_samples) + train_samples = glob.glob(os.path.join(path_pianos[0], "*.wav")) + train_samples = sorted(train_samples) if prenoise: - noise_generator=__noise_sample_generator(noises_info,fs, seg_len+fs, extend, "test") #Adds 1s of silence add the begiing, longer noise + noise_generator = __noise_sample_generator( + noises_info, fs, seg_len + fs, extend, "test" + ) # Adds 1s of silence add the begiing, longer noise else: - noise_generator=__noise_sample_generator(noises_info,fs, seg_len, extend, "test") #this will take care of everything - #load data clean files - for file in tqdm(train_samples): #add [1:5] for testing + noise_generator = __noise_sample_generator( + noises_info, fs, seg_len, extend, "test" + ) # this will take care of everything + # load data clean files + for file in tqdm(train_samples): # add [1:5] for testing data_clean, samplerate = sf.read(file) - if samplerate!=fs: + if samplerate != fs: print("!!!!WRONG SAMPLE RATe!!!") - #Stereo to mono - if len(data_clean.shape)>1 and not(stereo): - data_clean=np.mean(data_clean,axis=1) - #elif stereo and len(data_clean.shape)==1: + # Stereo to mono + if len(data_clean.shape) > 1 and not (stereo): + data_clean = np.mean(data_clean, axis=1) + # elif stereo and len(data_clean.shape)==1: # data_clean=np.stack((data_clean, data_clean), axis=1) - #normalize - data_clean=data_clean/np.max(np.abs(data_clean)) - #data_clean_loaded.append(data_clean) - - #framify data clean files - - #framify arguments: seg_len, hop_size - hop_size=int(seg_len)# no overlap - - num_frames=np.floor(len(data_clean)/hop_size - seg_len/hop_size +1) + # normalize + data_clean = data_clean / np.max(np.abs(data_clean)) + # data_clean_loaded.append(data_clean) + + # framify data clean files + + # framify arguments: seg_len, hop_size + hop_size = int(seg_len) # no overlap + + num_frames = np.floor(len(data_clean) / hop_size - seg_len / hop_size + 1) print(num_frames) - if num_frames==0: - data_clean=np.concatenate((data_clean, np.zeros(shape=(int(2*seg_len-len(data_clean)),))), axis=0) - num_frames=1 + if num_frames == 0: + data_clean = np.concatenate( + (data_clean, np.zeros(shape=(int(2 * seg_len - len(data_clean)),))), + axis=0, + ) + num_frames = 1 - data_not_finished=True - pointer=0 - while(data_not_finished): - if i>=num_samples: + data_not_finished = True + pointer = 0 + while data_not_finished: + if i >= num_samples: break - segment=data_clean[pointer:pointer+seg_len] - pointer=pointer+hop_size - if pointer+seg_len>len(data_clean): - data_not_finished=False - segment=segment.astype('float32') - - #SNRs=np.random.uniform(2,20) - snr=SNRs[i] - scale=scales[i] - #load noise signal - data_noise= next(noise_generator) - data_noise=np.mean(data_noise,axis=1) - #normalize - data_noise=data_noise/np.max(np.abs(data_noise)) - new_noise=data_noise #if more processing needed, add here - #load clean data - #configure sizes - power_clean=np.var(segment) - #estimate noise power + segment = data_clean[pointer : pointer + seg_len] + pointer = pointer + hop_size + if pointer + seg_len > len(data_clean): + data_not_finished = False + segment = segment.astype("float32") + + # SNRs=np.random.uniform(2,20) + snr = SNRs[i] + scale = scales[i] + # load noise signal + data_noise = next(noise_generator) + data_noise = np.mean(data_noise, axis=1) + # normalize + data_noise = data_noise / np.max(np.abs(data_noise)) + new_noise = data_noise # if more processing needed, add here + # load clean data + # configure sizes + power_clean = np.var(segment) + # estimate noise power if prenoise: - power_noise=np.var(new_noise[fs::]) + power_noise = np.var(new_noise[fs::]) else: - power_noise=np.var(new_noise) + power_noise = np.var(new_noise) - snr = 10.0**(snr/10.0) + snr = 10.0 ** (snr / 10.0) - #sum both signals according to snr + # sum both signals according to snr if prenoise: - segment=np.concatenate((np.zeros(shape=(fs,)),segment),axis=0) #add one second of silence - summed=segment+np.sqrt(power_clean/(snr*power_noise))*new_noise #not sure if this is correct, maybe revisit later!! + segment = np.concatenate( + (np.zeros(shape=(fs,)), segment), axis=0 + ) # add one second of silence + summed = ( + segment + np.sqrt(power_clean / (snr * power_noise)) * new_noise + ) # not sure if this is correct, maybe revisit later!! - summed=summed.astype('float32') - #yield tf.convert_to_tensor(summed), tf.convert_to_tensor(segment) - - - summed=10.0**(scale/10.0) *summed - segment=10.0**(scale/10.0) *segment - segments_noisy.append(summed.astype('float32')) - segments_clean.append(segment.astype('float32')) - i=i+1 + summed = summed.astype("float32") + # yield tf.convert_to_tensor(summed), tf.convert_to_tensor(segment) + + summed = 10.0 ** (scale / 10.0) * summed + segment = 10.0 ** (scale / 10.0) * segment + segments_noisy.append(summed.astype("float32")) + segments_clean.append(segment.astype("float32")) + i = i + 1 return segments_noisy, segments_clean -def generate_test_data(path_music, path_noises,num_samples=-1, fs=44100, seg_len_s=5): - segments_clean=[] - segments_noisy=[] - seg_len=fs*seg_len_s - noises_info=os.path.join(path_noises,"info.csv") - SNRs=[2,6,12] #HARDCODED!!!! +def generate_test_data(path_music, path_noises, num_samples=-1, fs=44100, seg_len_s=5): + segments_clean = [] + segments_noisy = [] + seg_len = fs * seg_len_s + noises_info = os.path.join(path_noises, "info.csv") + SNRs = [2, 6, 12] # HARDCODED!!!! for path in path_music: print(path) - train_samples=glob.glob(os.path.join(path,"*.wav")) - train_samples=sorted(train_samples) + train_samples = glob.glob(os.path.join(path, "*.wav")) + train_samples = sorted(train_samples) - noise_generator=__noise_sample_generator(noises_info,fs, seg_len, "test") #this will take care of everything - #load data clean files - jj=0 - for file in tqdm(train_samples): #add [1:5] for testing + noise_generator = __noise_sample_generator( + noises_info, fs, seg_len, "test" + ) # this will take care of everything + # load data clean files + for file in tqdm(train_samples): # add [1:5] for testing data_clean, samplerate = sf.read(file) - if samplerate!=fs: + if samplerate != fs: print("!!!!WRONG SAMPLE RATe!!!") - #Stereo to mono - if len(data_clean.shape)>1: - data_clean=np.mean(data_clean,axis=1) - #normalize - data_clean=data_clean/np.max(np.abs(data_clean)) - #data_clean_loaded.append(data_clean) - - #framify data clean files - - #framify arguments: seg_len, hop_size - hop_size=int(seg_len)# no overlap - - num_frames=np.floor(len(data_clean)/hop_size - seg_len/hop_size +1) - if num_frames==0: - data_clean=np.concatenate((data_clean, np.zeros(shape=(int(2*seg_len-len(data_clean)),))), axis=0) - num_frames=1 + # Stereo to mono + if len(data_clean.shape) > 1: + data_clean = np.mean(data_clean, axis=1) + # normalize + data_clean = data_clean / np.max(np.abs(data_clean)) + # data_clean_loaded.append(data_clean) + + # framify data clean files + + # framify arguments: seg_len, hop_size + hop_size = int(seg_len) # no overlap + + num_frames = np.floor(len(data_clean) / hop_size - seg_len / hop_size + 1) + if num_frames == 0: + data_clean = np.concatenate( + (data_clean, np.zeros(shape=(int(2 * seg_len - len(data_clean)),))), + axis=0, + ) + num_frames = 1 + + pointer = 0 + segment = data_clean[pointer : pointer + (seg_len - 2 * fs)] + segment = segment.astype("float32") + segment = np.concatenate( + (np.zeros(shape=(2 * fs,)), segment), axis=0 + ) # I hope its ok + # segments_clean.append(segment) - pointer=0 - segment=data_clean[pointer:pointer+(seg_len-2*fs)] - segment=segment.astype('float32') - segment=np.concatenate(( np.zeros(shape=(2*fs,)), segment), axis=0) #I hope its ok - #segments_clean.append(segment) - for snr in SNRs: - #load noise signal - data_noise= next(noise_generator) - data_noise=np.mean(data_noise,axis=1) - #normalize - data_noise=data_noise/np.max(np.abs(data_noise)) - new_noise=data_noise #if more processing needed, add here - #load clean data - #configure sizes - #estimate clean signal power - power_clean=np.var(segment) - #estimate noise power - power_noise=np.var(new_noise) + # load noise signal + data_noise = next(noise_generator) + data_noise = np.mean(data_noise, axis=1) + # normalize + data_noise = data_noise / np.max(np.abs(data_noise)) + new_noise = data_noise # if more processing needed, add here + # load clean data + # configure sizes + # estimate clean signal power + power_clean = np.var(segment) + # estimate noise power + power_noise = np.var(new_noise) - snr = 10.0**(snr/10.0) + snr = 10.0 ** (snr / 10.0) - #sum both signals according to snr - summed=segment+np.sqrt(power_clean/(snr*power_noise))*new_noise #not sure if this is correct, maybe revisit later!! - summed=summed.astype('float32') - #yield tf.convert_to_tensor(summed), tf.convert_to_tensor(segment) - - segments_noisy.append(summed.astype('float32')) - segments_clean.append(segment.astype('float32')) + # sum both signals according to snr + summed = ( + segment + np.sqrt(power_clean / (snr * power_noise)) * new_noise + ) # not sure if this is correct, maybe revisit later!! + summed = summed.astype("float32") + # yield tf.convert_to_tensor(summed), tf.convert_to_tensor(segment) + + segments_noisy.append(summed.astype("float32")) + segments_clean.append(segment.astype("float32")) return segments_noisy, segments_clean -def generate_val_data(path_music, path_noises,split,num_samples=-1, fs=44100, seg_len_s=5): - val_samples=[] +def generate_val_data( + path_music, path_noises, split, num_samples=-1, fs=44100, seg_len_s=5 +): + val_samples = [] for path in path_music: - val_samples.extend(glob.glob(os.path.join(path,"*.wav"))) + val_samples.extend(glob.glob(os.path.join(path, "*.wav"))) - #load data clean files + # load data clean files print("Loading clean files") - data_clean_loaded=[] - for ff in tqdm(range(0,len(val_samples))): #add [1:5] for testing + data_clean_loaded = [] + for ff in tqdm(range(0, len(val_samples))): # add [1:5] for testing data_clean, samplerate = sf.read(val_samples[ff]) - if samplerate!=fs: + if samplerate != fs: print("!!!!WRONG SAMPLE RATe!!!") - #Stereo to mono - if len(data_clean.shape)>1 : - data_clean=np.mean(data_clean,axis=1) - #normalize - data_clean=data_clean/np.max(np.abs(data_clean)) + # Stereo to mono + if len(data_clean.shape) > 1: + data_clean = np.mean(data_clean, axis=1) + # normalize + data_clean = data_clean / np.max(np.abs(data_clean)) data_clean_loaded.append(data_clean) del data_clean - #framify data clean files + # framify data clean files print("Framifying clean files") - seg_len=fs*seg_len_s - segments_clean=[] + seg_len = fs * seg_len_s + segments_clean = [] for file in tqdm(data_clean_loaded): + # framify arguments: seg_len, hop_size + hop_size = int(seg_len) # no overlap - #framify arguments: seg_len, hop_size - hop_size=int(seg_len)# no overlap - - num_frames=np.floor(len(file)/hop_size - seg_len/hop_size +1) - pointer=0 - for i in range(0,int(num_frames)): - segment=file[pointer:pointer+seg_len] - pointer=pointer+hop_size - segment=segment.astype('float32') + num_frames = np.floor(len(file) / hop_size - seg_len / hop_size + 1) + pointer = 0 + for i in range(0, int(num_frames)): + segment = file[pointer : pointer + seg_len] + pointer = pointer + hop_size + segment = segment.astype("float32") segments_clean.append(segment) del data_clean_loaded - - SNRs=np.random.uniform(2,20,len(segments_clean)) - scales=np.random.uniform(-6,4,len(segments_clean)) - #noise_shapes=np.random.randint(0,len(noise_samples), len(segments_clean)) - noises_info=os.path.join(path_noises,"info.csv") - noise_generator=__noise_sample_generator(noises_info,fs, seg_len, split) #this will take care of everything - + SNRs = np.random.uniform(2, 20, len(segments_clean)) + scales = np.random.uniform(-6, 4, len(segments_clean)) + # noise_shapes=np.random.randint(0,len(noise_samples), len(segments_clean)) + noises_info = os.path.join(path_noises, "info.csv") - #generate noisy segments - #load noise samples using pandas dataframe. Each split (train, val, test) should have its unique csv info file + noise_generator = __noise_sample_generator( + noises_info, fs, seg_len, split + ) # this will take care of everything - #noise_samples=glob.glob(os.path.join(path_noises,"*.wav")) - segments_noisy=[] + # generate noisy segments + # load noise samples using pandas dataframe. Each split (train, val, test) should have its unique csv info file + + # noise_samples=glob.glob(os.path.join(path_noises,"*.wav")) + segments_noisy = [] print("Processing noisy segments") - for i in tqdm(range(0,len(segments_clean))): - #load noise signal - data_noise= next(noise_generator) - #Stereo to mono - data_noise=np.mean(data_noise,axis=1) - #normalize - data_noise=data_noise/np.max(np.abs(data_noise)) - new_noise=data_noise #if more processing needed, add here - #load clean data - data_clean=segments_clean[i] - #configure sizes - - - #estimate clean signal power - power_clean=np.var(data_clean) - #estimate noise power - power_noise=np.var(new_noise) + for i in tqdm(range(0, len(segments_clean))): + # load noise signal + data_noise = next(noise_generator) + # Stereo to mono + data_noise = np.mean(data_noise, axis=1) + # normalize + data_noise = data_noise / np.max(np.abs(data_noise)) + new_noise = data_noise # if more processing needed, add here + # load clean data + data_clean = segments_clean[i] + # configure sizes - snr = 10.0**(SNRs[i]/10.0) + # estimate clean signal power + power_clean = np.var(data_clean) + # estimate noise power + power_noise = np.var(new_noise) - #sum both signals according to snr - summed=data_clean+np.sqrt(power_clean/(snr*power_noise))*new_noise #not sure if this is correct, maybe revisit later!! - #the rest is normal - - summed=10.0**(scales[i]/10.0) *summed - segments_clean[i]=10.0**(scales[i]/10.0) *segments_clean[i] + snr = 10.0 ** (SNRs[i] / 10.0) + + # sum both signals according to snr + summed = ( + data_clean + np.sqrt(power_clean / (snr * power_noise)) * new_noise + ) # not sure if this is correct, maybe revisit later!! + # the rest is normal + + summed = 10.0 ** (scales[i] / 10.0) * summed + segments_clean[i] = 10.0 ** (scales[i] / 10.0) * segments_clean[i] + + segments_noisy.append(summed.astype("float32")) - segments_noisy.append(summed.astype('float32')) - return segments_noisy, segments_clean - -def generator_train(path_music, path_noises,split, fs=44100, seg_len_s=5, extend=True, stereo=False): - - train_samples=[] +def generator_train( + path_music, path_noises, split, fs=44100, seg_len_s=5, extend=True, stereo=False +): + train_samples = [] for path in path_music: - train_samples.extend(glob.glob(os.path.join(path.decode("utf-8") ,"*.wav"))) + train_samples.extend(glob.glob(os.path.join(path.decode("utf-8"), "*.wav"))) - seg_len=fs*seg_len_s - noises_info=os.path.join(path_noises.decode("utf-8"),"info.csv") - noise_generator=__noise_sample_generator(noises_info,fs, seg_len, split.decode("utf-8")) #this will take care of everything - #load data clean files + seg_len = fs * seg_len_s + noises_info = os.path.join(path_noises.decode("utf-8"), "info.csv") + noise_generator = __noise_sample_generator( + noises_info, fs, seg_len, split.decode("utf-8") + ) # this will take care of everything + # load data clean files while True: random.shuffle(train_samples) - for file in train_samples: + for file in train_samples: data, samplerate = sf.read(file) - assert(samplerate==fs, "wrong sampling rate") - data_clean=data - #Stereo to mono - if len(data.shape)>1 : - data_clean=np.mean(data_clean,axis=1) + assert samplerate == fs, "wrong sampling rate" + data_clean = data + # Stereo to mono + if len(data.shape) > 1: + data_clean = np.mean(data_clean, axis=1) - #normalize - data_clean=data_clean/np.max(np.abs(data_clean)) - - #framify data clean files - - #framify arguments: seg_len, hop_size - hop_size=int(seg_len) - - num_frames=np.floor(len(data_clean)/seg_len) - if num_frames==0: - data_clean=np.concatenate((data_clean, np.zeros(shape=(int(2*seg_len-len(data_clean)),))), axis=0) - num_frames=1 - pointer=0 - data_clean=np.roll(data_clean, np.random.randint(0,seg_len)) #if only one frame, roll it for augmentation - elif num_frames>1: - pointer=np.random.randint(0,hop_size) #initial shifting, graeat for augmentation, better than overlap as we get different frames at each "while" iteration + # normalize + data_clean = data_clean / np.max(np.abs(data_clean)) + + # framify data clean files + + # framify arguments: seg_len, hop_size + hop_size = int(seg_len) + + num_frames = np.floor(len(data_clean) / seg_len) + if num_frames == 0: + data_clean = np.concatenate( + (data_clean, np.zeros(shape=(int(2 * seg_len - len(data_clean)),))), + axis=0, + ) + num_frames = 1 + pointer = 0 + data_clean = np.roll( + data_clean, np.random.randint(0, seg_len) + ) # if only one frame, roll it for augmentation + elif num_frames > 1: + pointer = np.random.randint( + 0, hop_size + ) # initial shifting, graeat for augmentation, better than overlap as we get different frames at each "while" iteration else: - pointer=0 + pointer = 0 - data_not_finished=True - while(data_not_finished): - segment=data_clean[pointer:pointer+seg_len] - pointer=pointer+hop_size - if pointer+seg_len>len(data_clean): - data_not_finished=False - segment=segment.astype('float32') - - SNRs=np.random.uniform(2,20) - scale=np.random.uniform(-6,4) - - - #load noise signal - data_noise= next(noise_generator) - data_noise=np.mean(data_noise,axis=1) - #normalize - data_noise=data_noise/np.max(np.abs(data_noise)) - new_noise=data_noise #if more processing needed, add here - #load clean data - #configure sizes + data_not_finished = True + while data_not_finished: + segment = data_clean[pointer : pointer + seg_len] + pointer = pointer + hop_size + if pointer + seg_len > len(data_clean): + data_not_finished = False + segment = segment.astype("float32") + + SNRs = np.random.uniform(2, 20) + scale = np.random.uniform(-6, 4) + + # load noise signal + data_noise = next(noise_generator) + data_noise = np.mean(data_noise, axis=1) + # normalize + data_noise = data_noise / np.max(np.abs(data_noise)) + new_noise = data_noise # if more processing needed, add here + # load clean data + # configure sizes if stereo: - #estimate clean signal power - power_clean=0.5*np.var(segment[:,0])+0.5*np.var(segment[:,1]) - #estimate noise power - power_noise=0.5*np.var(new_noise[:,0])+0.5*np.var(new_noise[:,1]) + # estimate clean signal power + power_clean = 0.5 * np.var(segment[:, 0]) + 0.5 * np.var( + segment[:, 1] + ) + # estimate noise power + power_noise = 0.5 * np.var(new_noise[:, 0]) + 0.5 * np.var( + new_noise[:, 1] + ) else: - #estimate clean signal power - power_clean=np.var(segment) - #estimate noise power - power_noise=np.var(new_noise) + # estimate clean signal power + power_clean = np.var(segment) + # estimate noise power + power_noise = np.var(new_noise) - snr = 10.0**(SNRs/10.0) + snr = 10.0 ** (SNRs / 10.0) - - #sum both signals according to snr - summed=segment+np.sqrt(power_clean/(snr*power_noise))*new_noise #not sure if this is correct, maybe revisit later!! - summed=10.0**(scale/10.0) *summed - segment=10.0**(scale/10.0) *segment - - summed=summed.astype('float32') + # sum both signals according to snr + summed = ( + segment + np.sqrt(power_clean / (snr * power_noise)) * new_noise + ) # not sure if this is correct, maybe revisit later!! + summed = 10.0 ** (scale / 10.0) * summed + segment = 10.0 ** (scale / 10.0) * segment + + summed = summed.astype("float32") yield tf.convert_to_tensor(summed), tf.convert_to_tensor(segment) - -def load_data(buffer_size, path_music_train, path_music_val, path_noises, fs=44100, seg_len_s=5, extend=True, stereo=False) : + + +def load_data( + buffer_size, + path_music_train, + path_music_val, + path_noises, + fs=44100, + seg_len_s=5, + extend=True, + stereo=False, +): print("Generating train dataset") - trainshape=int(fs*seg_len_s) - - dataset_train = tf.data.Dataset.from_generator(generator_train,args=(path_music_train, path_noises,"train", fs, seg_len_s, extend, stereo), output_shapes=(tf.TensorShape((trainshape,)),tf.TensorShape((trainshape,))), output_types=(tf.float32, tf.float32) ) + trainshape = int(fs * seg_len_s) + dataset_train = tf.data.Dataset.from_generator( + generator_train, + args=(path_music_train, path_noises, "train", fs, seg_len_s, extend, stereo), + output_shapes=(tf.TensorShape((trainshape,)), tf.TensorShape((trainshape,))), + output_types=(tf.float32, tf.float32), + ) print("Generating validation dataset") - segments_noisy, segments_clean=generate_val_data(path_music_val, path_noises,"validation",fs=fs, seg_len_s=seg_len_s) - - dataset_val=tf.data.Dataset.from_tensor_slices((segments_noisy, segments_clean)) + segments_noisy, segments_clean = generate_val_data( + path_music_val, path_noises, "validation", fs=fs, seg_len_s=seg_len_s + ) + + dataset_val = tf.data.Dataset.from_tensor_slices((segments_noisy, segments_clean)) return dataset_train.shuffle(buffer_size), dataset_val -def load_data_test(buffer_size, path_pianos_test, path_noises, **kwargs): + +def load_data_test(buffer_size, path_pianos_test, path_noises, **kwargs): print("Generating test dataset") - segments_noisy, segments_clean=generate_test_data(path_pianos_test, path_noises, extend=True, **kwargs) - dataset_test=tf.data.Dataset.from_tensor_slices((segments_noisy, segments_clean)) - #dataset_test=tf.data.Dataset.from_tensor_slices((segments_noisy[1:3], segments_clean[1:3])) - #train_dataset = train.cache().shuffle(buffer_size).take(info.splits["train"].num_examples) + segments_noisy, segments_clean = generate_test_data( + path_pianos_test, path_noises, extend=True, **kwargs + ) + dataset_test = tf.data.Dataset.from_tensor_slices((segments_noisy, segments_clean)) + # dataset_test=tf.data.Dataset.from_tensor_slices((segments_noisy[1:3], segments_clean[1:3])) + # train_dataset = train.cache().shuffle(buffer_size).take(info.splits["train"].num_examples) return dataset_test -def load_data_formal( path_pianos_test, path_noises, **kwargs) : + + +def load_data_formal(path_pianos_test, path_noises, **kwargs): print("Generating test dataset") - segments_noisy, segments_clean=generate_paired_data_test_formal(path_pianos_test, path_noises, extend=True, **kwargs) + segments_noisy, segments_clean = generate_paired_data_test_formal( + path_pianos_test, path_noises, extend=True, **kwargs + ) print("segments::") print(len(segments_noisy)) - dataset_test=tf.data.Dataset.from_tensor_slices((segments_noisy, segments_clean)) - #dataset_test=tf.data.Dataset.from_tensor_slices((segments_noisy[1:3], segments_clean[1:3])) - #train_dataset = train.cache().shuffle(buffer_size).take(info.splits["train"].num_examples) + dataset_test = tf.data.Dataset.from_tensor_slices((segments_noisy, segments_clean)) + # dataset_test=tf.data.Dataset.from_tensor_slices((segments_noisy[1:3], segments_clean[1:3])) + # train_dataset = train.cache().shuffle(buffer_size).take(info.splits["train"].num_examples) return dataset_test -def load_real_test_recordings(buffer_size, path_recordings, **kwargs): + +def load_real_test_recordings(buffer_size, path_recordings, **kwargs): print("Generating real test dataset") - - segments_noisy=generate_real_recordings_data(path_recordings, **kwargs) - dataset_test=tf.data.Dataset.from_tensor_slices(segments_noisy) - #train_dataset = train.cache().shuffle(buffer_size).take(info.splits["train"].num_examples) + segments_noisy = generate_real_recordings_data(path_recordings, **kwargs) + + dataset_test = tf.data.Dataset.from_tensor_slices(segments_noisy) + # train_dataset = train.cache().shuffle(buffer_size).take(info.splits["train"].num_examples) return dataset_test diff --git a/inference.py b/inference.py index 547dc2e..ff42f14 100644 --- a/inference.py +++ b/inference.py @@ -4,6 +4,7 @@ import logging logger = logging.getLogger(__name__) + def run(args): import unet import tensorflow as tf @@ -11,130 +12,206 @@ def run(args): import numpy as np from tqdm import tqdm import scipy.signal - - path_experiment=str(args.path_experiment) + + path_experiment = str(args.path_experiment) unet_model = unet.build_model_denoise(unet_args=args.unet) - ckpt=os.path.join(os.path.dirname(os.path.abspath(__file__)),path_experiment, 'checkpoint') + ckpt = os.path.join( + os.path.dirname(os.path.abspath(__file__)), path_experiment, "checkpoint" + ) unet_model.load_weights(ckpt) def do_stft(noisy): - window_fn = tf.signal.hamming_window - win_size=args.stft.win_size - hop_size=args.stft.hop_size + win_size = args.stft.win_size + hop_size = args.stft.hop_size + + stft_signal_noisy = tf.signal.stft( + noisy, + frame_length=win_size, + window_fn=window_fn, + frame_step=hop_size, + pad_end=True, + ) + stft_noisy_stacked = tf.stack( + values=[tf.math.real(stft_signal_noisy), tf.math.imag(stft_signal_noisy)], + axis=-1, + ) - - stft_signal_noisy=tf.signal.stft(noisy,frame_length=win_size, window_fn=window_fn, frame_step=hop_size, pad_end=True) - stft_noisy_stacked=tf.stack( values=[tf.math.real(stft_signal_noisy), tf.math.imag(stft_signal_noisy)], axis=-1) - return stft_noisy_stacked def do_istft(data): - window_fn = tf.signal.hamming_window - win_size=args.stft.win_size - hop_size=args.stft.hop_size + win_size = args.stft.win_size + hop_size = args.stft.hop_size - inv_window_fn=tf.signal.inverse_stft_window_fn(hop_size, forward_window_fn=window_fn) + inv_window_fn = tf.signal.inverse_stft_window_fn( + hop_size, forward_window_fn=window_fn + ) - pred_cpx=data[...,0] + 1j * data[...,1] - pred_time=tf.signal.inverse_stft(pred_cpx, win_size, hop_size, window_fn=inv_window_fn) + pred_cpx = data[..., 0] + 1j * data[..., 1] + pred_time = tf.signal.inverse_stft( + pred_cpx, win_size, hop_size, window_fn=inv_window_fn + ) return pred_time - audio=str(args.inference.audio) + audio = str(args.inference.audio) data, samplerate = sf.read(audio) print(data.dtype) - #Stereo to mono - if len(data.shape)>1: - data=np.mean(data,axis=1) - - if samplerate!=44100: + # Stereo to mono + if len(data.shape) > 1: + data = np.mean(data, axis=1) + + if samplerate != 44100: print("Resampling") - - data=scipy.signal.resample(data, int((44100 / samplerate )*len(data))+1) - - - - segment_size=44100*5 #20s segments - length_data=len(data) - overlapsize=2048 #samples (46 ms) - window=np.hanning(2*overlapsize) - window_right=window[overlapsize::] - window_left=window[0:overlapsize] - audio_finished=False - pointer=0 - denoised_data=np.zeros(shape=(len(data),)) - residual_noise=np.zeros(shape=(len(data),)) - numchunks=int(np.ceil(length_data/segment_size)) - + data = scipy.signal.resample(data, int((44100 / samplerate) * len(data)) + 1) + + segment_size = 44100 * 5 # 20s segments + + length_data = len(data) + overlapsize = 2048 # samples (46 ms) + window = np.hanning(2 * overlapsize) + window_right = window[overlapsize::] + window_left = window[0:overlapsize] + pointer = 0 + denoised_data = np.zeros(shape=(len(data),)) + residual_noise = np.zeros(shape=(len(data),)) + numchunks = int(np.ceil(length_data / segment_size)) + for i in tqdm(range(numchunks)): - if pointer+segment_size [B, T, F, N] - DenseNet Block consisting of "num_layers" densely connected convolutional layers - ''' - def __init__(self, num_layers, N, ksize,activation): - ''' - num_layers: number of densely connected conv. layers - N: Number of filters (same in each layer) - ksize: Kernel size (same in each layer) - ''' - super(DenseBlock, self).__init__() - self.activation=activation - self.paddings_1=get_paddings(ksize) - self.H=[] - self.num_layers=num_layers + +class DenseBlock(layers.Layer): + """ + [B, T, F, N] => [B, T, F, N] + DenseNet Block consisting of "num_layers" densely connected convolutional layers + """ + + def __init__(self, num_layers, N, ksize, activation): + """ + num_layers: number of densely connected conv. layers + N: Number of filters (same in each layer) + ksize: Kernel size (same in each layer) + """ + super(DenseBlock, self).__init__() + self.activation = activation + + self.paddings_1 = get_paddings(ksize) + self.H = [] + self.num_layers = num_layers for i in range(num_layers): - self.H.append(layers.Conv2D(filters=N, - kernel_size=ksize, - kernel_initializer=TruncatedNormal(), - strides=1, - padding='VALID', - activation=self.activation)) + self.H.append( + layers.Conv2D( + filters=N, + kernel_size=ksize, + kernel_initializer=TruncatedNormal(), + strides=1, + padding="VALID", + activation=self.activation, + ) + ) def call(self, x): - - x_=tf.pad(x, self.paddings_1, mode='SYMMETRIC') + x_ = tf.pad(x, self.paddings_1, mode="SYMMETRIC") x_ = self.H[0](x_) - if self.num_layers>1: + if self.num_layers > 1: for h in self.H[1:]: x = tf.concat([x_, x], axis=-1) - x_=tf.pad(x, self.paddings_1, mode='SYMMETRIC') - x_ = h(x_) + x_ = tf.pad(x, self.paddings_1, mode="SYMMETRIC") + x_ = h(x_) return x_ class FinalBlock(layers.Layer): - ''' - [B, T, F, N] => [B, T, F, 2] + """ + [B, T, F, N] => [B, T, F, 2] Final block. Basically, a 3x3 conv. layer to map the output features to the output complex spectrogram. - ''' + """ + def __init__(self): super(FinalBlock, self).__init__() - ksize=(3,3) - self.paddings_2=get_paddings(ksize) - self.conv2=layers.Conv2D(filters=2, - kernel_size=ksize, - kernel_initializer=TruncatedNormal(), - strides=1, - padding='VALID', - activation=None) + ksize = (3, 3) + self.paddings_2 = get_paddings(ksize) + self.conv2 = layers.Conv2D( + filters=2, + kernel_size=ksize, + kernel_initializer=TruncatedNormal(), + strides=1, + padding="VALID", + activation=None, + ) - - def call(self, inputs ): - - x=tf.pad(inputs, self.paddings_2, mode='SYMMETRIC') - pred=self.conv2(x) + def call(self, inputs): + x = tf.pad(inputs, self.paddings_2, mode="SYMMETRIC") + pred = self.conv2(x) return pred + + class SAM(layers.Layer): - ''' + """ [B, T, F, N] => [B, T, F, N] , [B, T, F, N] Supervised Attention Module: The purpose of SAM is to make the network only propagate the most relevant features to the second stage, discarding the less useful ones. - The estimated residual noise signal is generated from the U-Net output features by means of a 3x3 convolutional layer. - The first stage output is then calculated adding the original input spectrogram to the residual noise. - The attention-guided features are computed using the attention masks M, which are directly calculated from the first stage output with a 1x1 convolution and a sigmoid function. + The estimated residual noise signal is generated from the U-Net output features by means of a 3x3 convolutional layer. + The first stage output is then calculated adding the original input spectrogram to the residual noise. + The attention-guided features are computed using the attention masks M, which are directly calculated from the first stage output with a 1x1 convolution and a sigmoid function. + + """ - ''' def __init__(self, n_feat): super(SAM, self).__init__() - ksize=(3,3) - self.paddings_1=get_paddings(ksize) - self.conv1 = layers.Conv2D(filters=n_feat, - kernel_size=ksize, - kernel_initializer=TruncatedNormal(), - strides=1, - padding='VALID', - activation=None) - ksize=(3,3) - self.paddings_2=get_paddings(ksize) - self.conv2=layers.Conv2D(filters=2, - kernel_size=ksize, - kernel_initializer=TruncatedNormal(), - strides=1, - padding='VALID', - activation=None) + ksize = (3, 3) + self.paddings_1 = get_paddings(ksize) + self.conv1 = layers.Conv2D( + filters=n_feat, + kernel_size=ksize, + kernel_initializer=TruncatedNormal(), + strides=1, + padding="VALID", + activation=None, + ) + ksize = (3, 3) + self.paddings_2 = get_paddings(ksize) + self.conv2 = layers.Conv2D( + filters=2, + kernel_size=ksize, + kernel_initializer=TruncatedNormal(), + strides=1, + padding="VALID", + activation=None, + ) - ksize=(3,3) - self.paddings_3=get_paddings(ksize) - self.conv3 = layers.Conv2D(filters=n_feat, - kernel_size=ksize, - kernel_initializer=TruncatedNormal(), - strides=1, - padding='VALID', - activation=None) - self.cropadd=CropAddBlock() + ksize = (3, 3) + self.paddings_3 = get_paddings(ksize) + self.conv3 = layers.Conv2D( + filters=n_feat, + kernel_size=ksize, + kernel_initializer=TruncatedNormal(), + strides=1, + padding="VALID", + activation=None, + ) + self.cropadd = CropAddBlock() def call(self, inputs, input_spectrogram): - x1=tf.pad(inputs, self.paddings_1, mode='SYMMETRIC') + x1 = tf.pad(inputs, self.paddings_1, mode="SYMMETRIC") x1 = self.conv1(x1) - x=tf.pad(inputs, self.paddings_2, mode='SYMMETRIC') - x=self.conv2(x) + x = tf.pad(inputs, self.paddings_2, mode="SYMMETRIC") + x = self.conv2(x) - #residual prediction - pred = layers.Add()([x, input_spectrogram]) #features to next stage + # residual prediction + pred = layers.Add()([x, input_spectrogram]) # features to next stage - x3=tf.pad(pred, self.paddings_3, mode='SYMMETRIC') - M=self.conv3(x3) + x3 = tf.pad(pred, self.paddings_3, mode="SYMMETRIC") + M = self.conv3(x3) - M= tf.keras.activations.sigmoid(M) - x1=layers.Multiply()([x1, M]) - x1 = layers.Add()([x1, inputs]) #features to next stage + M = tf.keras.activations.sigmoid(M) + x1 = layers.Multiply()([x1, M]) + x1 = layers.Add()([x1, inputs]) # features to next stage return x1, pred class AddFreqEncoding(layers.Layer): - ''' - [B, T, F, 2] => [B, T, F, 12] + """ + [B, T, F, 2] => [B, T, F, 12] Generates frequency positional embeddings and concatenates them as 10 extra channels This function is optimized for F=1025 - ''' + """ + def __init__(self, f_dim): super(AddFreqEncoding, self).__init__() pi = tf.constant(m.pi) - pi=tf.cast(pi,'float32') - self.f_dim=f_dim #f_dim is fixed - n=tf.cast(tf.range(f_dim)/(f_dim-1),'float32') - coss=tf.math.cos(pi*n) - f_channel = tf.expand_dims(coss, -1) #(1025,1) - self.fembeddings= f_channel - - for k in range(1,10): - coss=tf.math.cos(2**k*pi*n) - f_channel = tf.expand_dims(coss, -1) #(1025,1) - self.fembeddings=tf.concat([self.fembeddings,f_channel],axis=-1) #(1025,10) - + pi = tf.cast(pi, "float32") + self.f_dim = f_dim # f_dim is fixed + n = tf.cast(tf.range(f_dim) / (f_dim - 1), "float32") + coss = tf.math.cos(pi * n) + f_channel = tf.expand_dims(coss, -1) # (1025,1) + self.fembeddings = f_channel + + for k in range(1, 10): + coss = tf.math.cos(2**k * pi * n) + f_channel = tf.expand_dims(coss, -1) # (1025,1) + self.fembeddings = tf.concat( + [self.fembeddings, f_channel], axis=-1 + ) # (1025,10) def call(self, input_tensor): - batch_size_tensor = tf.shape(input_tensor)[0] # get batch size time_dim = tf.shape(input_tensor)[1] # get time dimension - fembeddings_2 = tf.broadcast_to(self.fembeddings, [batch_size_tensor, time_dim, self.f_dim, 10]) - - - return tf.concat([input_tensor,fembeddings_2],axis=-1) #(batch,427,1025,12) + fembeddings_2 = tf.broadcast_to( + self.fembeddings, [batch_size_tensor, time_dim, self.f_dim, 10] + ) + + return tf.concat([input_tensor, fembeddings_2], axis=-1) # (batch,427,1025,12) def get_paddings(K): - return tf.constant([[0,0],[K[0]//2, K[0]//2 -(1- K[0]%2) ], [ K[1]//2, K[1]//2 -(1- K[1]%2) ],[0,0]]) + return tf.constant( + [ + [0, 0], + [K[0] // 2, K[0] // 2 - (1 - K[0] % 2)], + [K[1] // 2, K[1] // 2 - (1 - K[1] % 2)], + [0, 0], + ] + ) + class Decoder(layers.Layer): - ''' - [B, T, F, N] , skip connections => [B, T, F, N] + """ + [B, T, F, N] , skip connections => [B, T, F, N] Decoder side of the U-Net subnetwork. - ''' + """ + def __init__(self, Ns, Ss, unet_args): super(Decoder, self).__init__() - self.Ns=Ns - self.Ss=Ss - self.activation=unet_args.activation - self.depth=unet_args.depth + self.Ns = Ns + self.Ss = Ss + self.activation = unet_args.activation + self.depth = unet_args.depth + ksize = (3, 3) + self.paddings_3 = get_paddings(ksize) + self.conv2d_3 = layers.Conv2D( + filters=self.Ns[self.depth], + kernel_size=ksize, + kernel_initializer=TruncatedNormal(), + strides=1, + padding="VALID", + activation=self.activation, + ) - ksize=(3,3) - self.paddings_3=get_paddings(ksize) - self.conv2d_3=layers.Conv2D(filters=self.Ns[self.depth], - kernel_size=ksize, - kernel_initializer=TruncatedNormal(), - strides=1, - padding='VALID', - activation=self.activation) + self.cropadd = CropAddBlock() - self.cropadd=CropAddBlock() - - self.dblocks=[] + self.dblocks = [] for i in range(self.depth): - self.dblocks.append(D_Block(layer_idx=i,N=self.Ns[i], S=self.Ss[i], activation=self.activation,num_tfc=unet_args.num_tfc)) + self.dblocks.append( + D_Block( + layer_idx=i, + N=self.Ns[i], + S=self.Ss[i], + activation=self.activation, + num_tfc=unet_args.num_tfc, + ) + ) + + def call(self, inputs, contracting_layers): + x = inputs + for i in range(self.depth, 0, -1): + x = self.dblocks[i - 1](x, contracting_layers[i - 1]) + return x - def call(self,inputs, contracting_layers): - x=inputs - for i in range(self.depth,0,-1): - x=self.dblocks[i-1](x, contracting_layers[i-1]) - return x class Encoder(tf.keras.Model): - - ''' - [B, T, F, N] => skip connections , [B, T, F, N_4] + """ + [B, T, F, N] => skip connections , [B, T, F, N_4] Encoder side of the U-Net subnetwork. - ''' + """ + def __init__(self, Ns, Ss, unet_args): super(Encoder, self).__init__() - self.Ns=Ns - self.Ss=Ss - self.activation=unet_args.activation - self.depth=unet_args.depth + self.Ns = Ns + self.Ss = Ss + self.activation = unet_args.activation + self.depth = unet_args.depth self.contracting_layers = {} - self.eblocks=[] + self.eblocks = [] for i in range(self.depth): - self.eblocks.append(E_Block(layer_idx=i,N0=self.Ns[i],N=self.Ns[i+1],S=self.Ss[i], activation=self.activation , num_tfc=unet_args.num_tfc)) + self.eblocks.append( + E_Block( + layer_idx=i, + N0=self.Ns[i], + N=self.Ns[i + 1], + S=self.Ss[i], + activation=self.activation, + num_tfc=unet_args.num_tfc, + ) + ) - self.i_block=I_Block(self.Ns[self.depth],self.activation,unet_args.num_tfc) + self.i_block = I_Block(self.Ns[self.depth], self.activation, unet_args.num_tfc) def call(self, inputs): - x=inputs + x = inputs for i in range(self.depth): + x, x_contract = self.eblocks[i](x) - x, x_contract=self.eblocks[i](x) - - self.contracting_layers[i] = x_contract #if remove 0, correct this - x=self.i_block(x) + self.contracting_layers[i] = x_contract # if remove 0, correct this + x = self.i_block(x) return x, self.contracting_layers -class MultiStage_denoise(tf.keras.Model): - def __init__(self, unet_args=None): +class MultiStage_denoise(tf.keras.Model): + def __init__(self, unet_args=None): super(MultiStage_denoise, self).__init__() - self.activation=unet_args.activation - self.depth=unet_args.depth + self.activation = unet_args.activation + self.depth = unet_args.depth if unet_args.use_fencoding: - self.freq_encoding=AddFreqEncoding(unet_args.f_dim) - self.use_sam=unet_args.use_SAM - self.use_fencoding=unet_args.use_fencoding - self.num_stages=unet_args.num_stages - #Encoder - self.Ns= [32,64,64,128,128,256,512] - self.Ss= [(2,2),(2,2),(2,2),(2,2),(2,2),(2,2)] - - #initial feature extractor - ksize=(7,7) - self.paddings_1=get_paddings(ksize) - self.conv2d_1 = layers.Conv2D(filters=self.Ns[0], - kernel_size=ksize, - kernel_initializer=TruncatedNormal(), - strides=1, - padding='VALID', - activation=self.activation) + self.freq_encoding = AddFreqEncoding(unet_args.f_dim) + self.use_sam = unet_args.use_SAM + self.use_fencoding = unet_args.use_fencoding + self.num_stages = unet_args.num_stages + # Encoder + self.Ns = [32, 64, 64, 128, 128, 256, 512] + self.Ss = [(2, 2), (2, 2), (2, 2), (2, 2), (2, 2), (2, 2)] + # initial feature extractor + ksize = (7, 7) + self.paddings_1 = get_paddings(ksize) + self.conv2d_1 = layers.Conv2D( + filters=self.Ns[0], + kernel_size=ksize, + kernel_initializer=TruncatedNormal(), + strides=1, + padding="VALID", + activation=self.activation, + ) - self.encoder_s1=Encoder(self.Ns, self.Ss, unet_args) - self.decoder_s1=Decoder(self.Ns, self.Ss, unet_args) + self.encoder_s1 = Encoder(self.Ns, self.Ss, unet_args) + self.decoder_s1 = Decoder(self.Ns, self.Ss, unet_args) self.cropconcat = CropConcatBlock() self.cropadd = CropAddBlock() - self.finalblock=FinalBlock() + self.finalblock = FinalBlock() - if self.num_stages>1: - self.sam_1=SAM(self.Ns[0]) + if self.num_stages > 1: + self.sam_1 = SAM(self.Ns[0]) - #initial feature extractor - ksize=(7,7) - self.paddings_2=get_paddings(ksize) - self.conv2d_2 = layers.Conv2D(filters=self.Ns[0], - kernel_size=ksize, - kernel_initializer=TruncatedNormal(), - strides=1, - padding='VALID', - activation=self.activation) - + # initial feature extractor + ksize = (7, 7) + self.paddings_2 = get_paddings(ksize) + self.conv2d_2 = layers.Conv2D( + filters=self.Ns[0], + kernel_size=ksize, + kernel_initializer=TruncatedNormal(), + strides=1, + padding="VALID", + activation=self.activation, + ) - self.encoder_s2=Encoder(self.Ns, self.Ss, unet_args) - self.decoder_s2=Decoder(self.Ns, self.Ss, unet_args) + self.encoder_s2 = Encoder(self.Ns, self.Ss, unet_args) + self.decoder_s2 = Decoder(self.Ns, self.Ss, unet_args) @tf.function() def call(self, inputs): - if self.use_fencoding: - x_w_freq=self.freq_encoding(inputs) #None, None, 1025, 12 + x_w_freq = self.freq_encoding(inputs) # None, None, 1025, 12 else: - x_w_freq=inputs + x_w_freq = inputs - #intitial feature extractor - x=tf.pad(x_w_freq, self.paddings_1, mode='SYMMETRIC') - x=self.conv2d_1(x) #None, None, 1025, 32 + # intitial feature extractor + x = tf.pad(x_w_freq, self.paddings_1, mode="SYMMETRIC") + x = self.conv2d_1(x) # None, None, 1025, 32 - x, contracting_layers_s1= self.encoder_s1(x) - #decoder - feats_s1 =self.decoder_s1(x, contracting_layers_s1) #None, None, 1025, 32 features + x, contracting_layers_s1 = self.encoder_s1(x) + # decoder + feats_s1 = self.decoder_s1( + x, contracting_layers_s1 + ) # None, None, 1025, 32 features + + if self.num_stages > 1: + # SAM module + Fout, pred_stage_1 = self.sam_1(feats_s1, inputs) + + # intitial feature extractor + x = tf.pad(x_w_freq, self.paddings_2, mode="SYMMETRIC") + x = self.conv2d_2(x) - if self.num_stages>1: - #SAM module - Fout, pred_stage_1=self.sam_1(feats_s1,inputs) - - #intitial feature extractor - x=tf.pad(x_w_freq, self.paddings_2, mode='SYMMETRIC') - x=self.conv2d_2(x) - if self.use_sam: x = tf.concat([x, Fout], axis=-1) else: - x = tf.concat([x,feats_s1], axis=-1) - - x, contracting_layers_s2= self.encoder_s2(x) + x = tf.concat([x, feats_s1], axis=-1) - feats_s2=self.decoder_s2(x, contracting_layers_s2) #None, None, 1025, 32 features - - #consider implementing a third stage? + x, contracting_layers_s2 = self.encoder_s2(x) - pred_stage_2=self.finalblock(feats_s2) + feats_s2 = self.decoder_s2( + x, contracting_layers_s2 + ) # None, None, 1025, 32 features + + # consider implementing a third stage? + + pred_stage_2 = self.finalblock(feats_s2) return pred_stage_2, pred_stage_1 - else: - pred_stage_1=self.finalblock(feats_s1) + else: + pred_stage_1 = self.finalblock(feats_s1) return pred_stage_1 - + + class I_Block(layers.Layer): - ''' - [B, T, F, N] => [B, T, F, N] + """ + [B, T, F, N] => [B, T, F, N] Intermediate block: Basically, a densenet block with a residual connection - ''' - def __init__(self,N,activation, num_tfc, **kwargs): + """ + + def __init__(self, N, activation, num_tfc, **kwargs): super(I_Block, self).__init__(**kwargs) - ksize=(3,3) - self.tfc=DenseBlock(num_tfc,N,ksize, activation) + ksize = (3, 3) + self.tfc = DenseBlock(num_tfc, N, ksize, activation) - self.conv2d_res= layers.Conv2D(filters=N, - kernel_size=(1,1), - kernel_initializer=TruncatedNormal(), - strides=1, - padding='VALID') + self.conv2d_res = layers.Conv2D( + filters=N, + kernel_size=(1, 1), + kernel_initializer=TruncatedNormal(), + strides=1, + padding="VALID", + ) - def call(self,inputs): - x=self.tfc(inputs) + def call(self, inputs): + x = self.tfc(inputs) - inputs_proj=self.conv2d_res(inputs) - return layers.Add()([x,inputs_proj]) + inputs_proj = self.conv2d_res(inputs) + return layers.Add()([x, inputs_proj]) class E_Block(layers.Layer): - - def __init__(self, layer_idx,N0, N, S,activation, num_tfc, **kwargs): + def __init__(self, layer_idx, N0, N, S, activation, num_tfc, **kwargs): super(E_Block, self).__init__(**kwargs) - self.layer_idx=layer_idx - self.N0=N0 - self.N=N - self.S=S - self.activation=activation - self.i_block=I_Block(N0,activation,num_tfc) - - ksize=(S[0]+2,S[1]+2) - self.paddings_2=get_paddings(ksize) - self.conv2d_2 = layers.Conv2D(filters=N, - kernel_size=(S[0]+2,S[1]+2), - kernel_initializer=TruncatedNormal(), - strides=S, - padding='VALID', - activation=self.activation) + self.layer_idx = layer_idx + self.N0 = N0 + self.N = N + self.S = S + self.activation = activation + self.i_block = I_Block(N0, activation, num_tfc) + ksize = (S[0] + 2, S[1] + 2) + self.paddings_2 = get_paddings(ksize) + self.conv2d_2 = layers.Conv2D( + filters=N, + kernel_size=(S[0] + 2, S[1] + 2), + kernel_initializer=TruncatedNormal(), + strides=S, + padding="VALID", + activation=self.activation, + ) def call(self, inputs, training=None, **kwargs): - x=self.i_block(inputs) - - x_down=tf.pad(x, self.paddings_2, mode='SYMMETRIC') + x = self.i_block(inputs) + + x_down = tf.pad(x, self.paddings_2, mode="SYMMETRIC") x_down = self.conv2d_2(x_down) return x_down, x - def get_config(self): - return dict(layer_idx=self.layer_idx, - N=self.N, - S=self.S, - **super(E_Block, self).get_config() - ) + return dict( + layer_idx=self.layer_idx, + N=self.N, + S=self.S, + **super(E_Block, self).get_config(), + ) + + class D_Block(layers.Layer): - - def __init__(self, layer_idx, N, S,activation, num_tfc, **kwargs): + def __init__(self, layer_idx, N, S, activation, num_tfc, **kwargs): super(D_Block, self).__init__(**kwargs) - self.layer_idx=layer_idx - self.N=N - self.S=S - self.activation=activation - ksize=(S[0]+2, S[1]+2) - self.paddings_1=get_paddings(ksize) + self.layer_idx = layer_idx + self.N = N + self.S = S + self.activation = activation + ksize = (S[0] + 2, S[1] + 2) + self.paddings_1 = get_paddings(ksize) - self.tconv_1= layers.Conv2DTranspose(filters=N, - kernel_size=(S[0]+2, S[1]+2), - kernel_initializer=TruncatedNormal(), - strides=S, - activation=self.activation, - padding='VALID') + self.tconv_1 = layers.Conv2DTranspose( + filters=N, + kernel_size=(S[0] + 2, S[1] + 2), + kernel_initializer=TruncatedNormal(), + strides=S, + activation=self.activation, + padding="VALID", + ) - self.upsampling = layers.UpSampling2D(size=S, interpolation='nearest') + self.upsampling = layers.UpSampling2D(size=S, interpolation="nearest") - self.projection = layers.Conv2D(filters=N, - kernel_size=(1,1), - kernel_initializer=TruncatedNormal(), - strides=1, - activation=self.activation, - padding='VALID') - self.cropadd=CropAddBlock() - self.cropconcat=CropConcatBlock() + self.projection = layers.Conv2D( + filters=N, + kernel_size=(1, 1), + kernel_initializer=TruncatedNormal(), + strides=1, + activation=self.activation, + padding="VALID", + ) + self.cropadd = CropAddBlock() + self.cropconcat = CropConcatBlock() - self.i_block=I_Block(N,activation,num_tfc) + self.i_block = I_Block(N, activation, num_tfc) - def call(self, inputs, bridge, previous_encoder=None, previous_decoder=None,**kwargs): + def call( + self, inputs, bridge, previous_encoder=None, previous_decoder=None, **kwargs + ): x = inputs - x=tf.pad(x, self.paddings_1, mode='SYMMETRIC') + x = tf.pad(x, self.paddings_1, mode="SYMMETRIC") x = self.tconv_1(inputs) - x2= self.upsampling(inputs) + x2 = self.upsampling(inputs) - if x2.shape[-1]!=x.shape[-1]: - x2= self.projection(x2) + if x2.shape[-1] != x.shape[-1]: + x2 = self.projection(x2) - x= self.cropadd(x,x2) + x = self.cropadd(x, x2) + x = self.cropconcat(x, bridge) - x=self.cropconcat(x,bridge) - - x=self.i_block(x) + x = self.i_block(x) return x def get_config(self): - return dict(layer_idx=self.layer_idx, - N=self.N, - S=self.S, - **super(D_Block, self).get_config() - ) + return dict( + layer_idx=self.layer_idx, + N=self.N, + S=self.S, + **super(D_Block, self).get_config(), + ) + class CropAddBlock(layers.Layer): - - def call(self,down_layer, x, **kwargs): - x1_shape = tf.shape(down_layer) - x2_shape = tf.shape(x) - - - height_diff = (x1_shape[1] - x2_shape[1]) // 2 - width_diff = (x1_shape[2] - x2_shape[2]) // 2 - - down_layer_cropped = down_layer[:, - height_diff: (x2_shape[1] + height_diff), - width_diff: (x2_shape[2] + width_diff), - :] - - x = layers.Add()([down_layer_cropped, x]) - return x - -class CropConcatBlock(layers.Layer): - def call(self, down_layer, x, **kwargs): x1_shape = tf.shape(down_layer) x2_shape = tf.shape(x) @@ -477,10 +520,31 @@ class CropConcatBlock(layers.Layer): height_diff = (x1_shape[1] - x2_shape[1]) // 2 width_diff = (x1_shape[2] - x2_shape[2]) // 2 - down_layer_cropped = down_layer[:, - height_diff: (x2_shape[1] + height_diff), - width_diff: (x2_shape[2] + width_diff), - :] + down_layer_cropped = down_layer[ + :, + height_diff : (x2_shape[1] + height_diff), + width_diff : (x2_shape[2] + width_diff), + :, + ] + + x = layers.Add()([down_layer_cropped, x]) + return x + + +class CropConcatBlock(layers.Layer): + def call(self, down_layer, x, **kwargs): + x1_shape = tf.shape(down_layer) + x2_shape = tf.shape(x) + + height_diff = (x1_shape[1] - x2_shape[1]) // 2 + width_diff = (x1_shape[2] - x2_shape[2]) // 2 + + down_layer_cropped = down_layer[ + :, + height_diff : (x2_shape[1] + height_diff), + width_diff : (x2_shape[2] + width_diff), + :, + ] x = tf.concat([down_layer_cropped, x], axis=-1) return x