fix formatting & linters
This commit is contained in:
parent
114fce7c84
commit
9870240572
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
experiments
|
||||
outputs
|
||||
__pycache__
|
||||
@ -4,14 +4,12 @@ import tensorflow as tf
|
||||
import random
|
||||
import os
|
||||
import numpy as np
|
||||
from scipy.fft import fft, ifft
|
||||
import soundfile as sf
|
||||
import math
|
||||
import pandas as pd
|
||||
import scipy as sp
|
||||
import glob
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
# generator function. It reads the csv file with pandas and loads the largest audio segments from each recording. If extend=False, it will only read the segments with length>length_seg, trim them and yield them with no further processing. Otherwise, if the segment length is inferior, it will extend the length using concatenative synthesis.
|
||||
def __noise_sample_generator(info_file, fs, length_seq, split):
|
||||
head = os.path.split(info_file)[0]
|
||||
@ -26,14 +24,25 @@ def __noise_sample_generator(info_file,fs, length_seq, split):
|
||||
for i in r:
|
||||
segments = ast.literal_eval(load_data_split.loc[i, "segments"])
|
||||
if split == "test":
|
||||
loaded_data, Fs=sf.read(os.path.join(head,load_data_split["recording"].loc[i],load_data_split["largest_segment"].loc[i]))
|
||||
loaded_data, Fs = sf.read(
|
||||
os.path.join(
|
||||
head,
|
||||
load_data_split["recording"].loc[i],
|
||||
load_data_split["largest_segment"].loc[i],
|
||||
)
|
||||
)
|
||||
else:
|
||||
num = np.random.randint(0, len(segments))
|
||||
loaded_data, Fs=sf.read(os.path.join(head,load_data_split["recording"].loc[i],segments[num]))
|
||||
assert(fs==Fs, "wrong sampling rate")
|
||||
loaded_data, Fs = sf.read(
|
||||
os.path.join(
|
||||
head, load_data_split["recording"].loc[i], segments[num]
|
||||
)
|
||||
)
|
||||
assert fs == Fs, "wrong sampling rate"
|
||||
|
||||
yield __extend_sample_by_repeating(loaded_data, fs, length_seq)
|
||||
|
||||
|
||||
def __extend_sample_by_repeating(data, fs, seq_len):
|
||||
rpm = 78
|
||||
target_samp = seq_len
|
||||
@ -57,12 +66,11 @@ def __extend_sample_by_repeating(data, fs,seq_len):
|
||||
|
||||
overhead = len(data) % period_sam
|
||||
|
||||
if(overhead>bls):
|
||||
if overhead > bls:
|
||||
complete_periods = (len(data) // period_sam) * period_sam
|
||||
else:
|
||||
complete_periods = (len(data) // period_sam - 1) * period_sam
|
||||
|
||||
|
||||
a = np.multiply(data[0:bls], window_left)
|
||||
b = np.multiply(data[complete_periods : complete_periods + bls], window_right)
|
||||
c_1 = np.concatenate((data[0:complete_periods, :], b))
|
||||
@ -71,10 +79,9 @@ def __extend_sample_by_repeating(data, fs,seq_len):
|
||||
|
||||
large_data[0 : complete_periods + bls, :] = c_1
|
||||
|
||||
|
||||
pointer = complete_periods
|
||||
not_finished = True
|
||||
while (not_finished):
|
||||
while not_finished:
|
||||
if target_samp > pointer + complete_periods + bls:
|
||||
large_data[pointer : pointer + complete_periods + bls] += c_2
|
||||
pointer += complete_periods
|
||||
@ -86,8 +93,9 @@ def __extend_sample_by_repeating(data, fs,seq_len):
|
||||
return large_data
|
||||
|
||||
|
||||
def generate_real_recordings_data(path_recordings, fs=44100, seg_len_s=15, stereo=False):
|
||||
|
||||
def generate_real_recordings_data(
|
||||
path_recordings, fs=44100, seg_len_s=15, stereo=False
|
||||
):
|
||||
records_info = os.path.join(path_recordings, "audio_files.txt")
|
||||
num_lines = sum(1 for line in open(records_info))
|
||||
f = open(records_info, "r")
|
||||
@ -112,8 +120,18 @@ def generate_real_recordings_data(path_recordings, fs=44100, seg_len_s=15, stere
|
||||
|
||||
return records
|
||||
|
||||
def generate_paired_data_test_formal(path_pianos, path_noises, noise_amount="low_snr",num_samples=-1, fs=44100, seg_len_s=5 , extend=True, stereo=False, prenoise=False):
|
||||
|
||||
def generate_paired_data_test_formal(
|
||||
path_pianos,
|
||||
path_noises,
|
||||
noise_amount="low_snr",
|
||||
num_samples=-1,
|
||||
fs=44100,
|
||||
seg_len_s=5,
|
||||
extend=True,
|
||||
stereo=False,
|
||||
prenoise=False,
|
||||
):
|
||||
print(num_samples)
|
||||
segments_clean = []
|
||||
segments_noisy = []
|
||||
@ -134,9 +152,13 @@ def generate_paired_data_test_formal(path_pianos, path_noises, noise_amount="low
|
||||
train_samples = sorted(train_samples)
|
||||
|
||||
if prenoise:
|
||||
noise_generator=__noise_sample_generator(noises_info,fs, seg_len+fs, extend, "test") #Adds 1s of silence add the begiing, longer noise
|
||||
noise_generator = __noise_sample_generator(
|
||||
noises_info, fs, seg_len + fs, extend, "test"
|
||||
) # Adds 1s of silence add the begiing, longer noise
|
||||
else:
|
||||
noise_generator=__noise_sample_generator(noises_info,fs, seg_len, extend, "test") #this will take care of everything
|
||||
noise_generator = __noise_sample_generator(
|
||||
noises_info, fs, seg_len, extend, "test"
|
||||
) # this will take care of everything
|
||||
# load data clean files
|
||||
for file in tqdm(train_samples): # add [1:5] for testing
|
||||
data_clean, samplerate = sf.read(file)
|
||||
@ -159,19 +181,22 @@ def generate_paired_data_test_formal(path_pianos, path_noises, noise_amount="low
|
||||
num_frames = np.floor(len(data_clean) / hop_size - seg_len / hop_size + 1)
|
||||
print(num_frames)
|
||||
if num_frames == 0:
|
||||
data_clean=np.concatenate((data_clean, np.zeros(shape=(int(2*seg_len-len(data_clean)),))), axis=0)
|
||||
data_clean = np.concatenate(
|
||||
(data_clean, np.zeros(shape=(int(2 * seg_len - len(data_clean)),))),
|
||||
axis=0,
|
||||
)
|
||||
num_frames = 1
|
||||
|
||||
data_not_finished = True
|
||||
pointer = 0
|
||||
while(data_not_finished):
|
||||
while data_not_finished:
|
||||
if i >= num_samples:
|
||||
break
|
||||
segment = data_clean[pointer : pointer + seg_len]
|
||||
pointer = pointer + hop_size
|
||||
if pointer + seg_len > len(data_clean):
|
||||
data_not_finished = False
|
||||
segment=segment.astype('float32')
|
||||
segment = segment.astype("float32")
|
||||
|
||||
# SNRs=np.random.uniform(2,20)
|
||||
snr = SNRs[i]
|
||||
@ -195,23 +220,26 @@ def generate_paired_data_test_formal(path_pianos, path_noises, noise_amount="low
|
||||
|
||||
# sum both signals according to snr
|
||||
if prenoise:
|
||||
segment=np.concatenate((np.zeros(shape=(fs,)),segment),axis=0) #add one second of silence
|
||||
summed=segment+np.sqrt(power_clean/(snr*power_noise))*new_noise #not sure if this is correct, maybe revisit later!!
|
||||
segment = np.concatenate(
|
||||
(np.zeros(shape=(fs,)), segment), axis=0
|
||||
) # add one second of silence
|
||||
summed = (
|
||||
segment + np.sqrt(power_clean / (snr * power_noise)) * new_noise
|
||||
) # not sure if this is correct, maybe revisit later!!
|
||||
|
||||
summed=summed.astype('float32')
|
||||
summed = summed.astype("float32")
|
||||
# yield tf.convert_to_tensor(summed), tf.convert_to_tensor(segment)
|
||||
|
||||
|
||||
summed = 10.0 ** (scale / 10.0) * summed
|
||||
segment = 10.0 ** (scale / 10.0) * segment
|
||||
segments_noisy.append(summed.astype('float32'))
|
||||
segments_clean.append(segment.astype('float32'))
|
||||
segments_noisy.append(summed.astype("float32"))
|
||||
segments_clean.append(segment.astype("float32"))
|
||||
i = i + 1
|
||||
|
||||
return segments_noisy, segments_clean
|
||||
|
||||
def generate_test_data(path_music, path_noises,num_samples=-1, fs=44100, seg_len_s=5):
|
||||
|
||||
def generate_test_data(path_music, path_noises, num_samples=-1, fs=44100, seg_len_s=5):
|
||||
segments_clean = []
|
||||
segments_noisy = []
|
||||
seg_len = fs * seg_len_s
|
||||
@ -222,9 +250,10 @@ def generate_test_data(path_music, path_noises,num_samples=-1, fs=44100, seg_len
|
||||
train_samples = glob.glob(os.path.join(path, "*.wav"))
|
||||
train_samples = sorted(train_samples)
|
||||
|
||||
noise_generator=__noise_sample_generator(noises_info,fs, seg_len, "test") #this will take care of everything
|
||||
noise_generator = __noise_sample_generator(
|
||||
noises_info, fs, seg_len, "test"
|
||||
) # this will take care of everything
|
||||
# load data clean files
|
||||
jj=0
|
||||
for file in tqdm(train_samples): # add [1:5] for testing
|
||||
data_clean, samplerate = sf.read(file)
|
||||
if samplerate != fs:
|
||||
@ -243,13 +272,18 @@ def generate_test_data(path_music, path_noises,num_samples=-1, fs=44100, seg_len
|
||||
|
||||
num_frames = np.floor(len(data_clean) / hop_size - seg_len / hop_size + 1)
|
||||
if num_frames == 0:
|
||||
data_clean=np.concatenate((data_clean, np.zeros(shape=(int(2*seg_len-len(data_clean)),))), axis=0)
|
||||
data_clean = np.concatenate(
|
||||
(data_clean, np.zeros(shape=(int(2 * seg_len - len(data_clean)),))),
|
||||
axis=0,
|
||||
)
|
||||
num_frames = 1
|
||||
|
||||
pointer = 0
|
||||
segment = data_clean[pointer : pointer + (seg_len - 2 * fs)]
|
||||
segment=segment.astype('float32')
|
||||
segment=np.concatenate(( np.zeros(shape=(2*fs,)), segment), axis=0) #I hope its ok
|
||||
segment = segment.astype("float32")
|
||||
segment = np.concatenate(
|
||||
(np.zeros(shape=(2 * fs,)), segment), axis=0
|
||||
) # I hope its ok
|
||||
# segments_clean.append(segment)
|
||||
|
||||
for snr in SNRs:
|
||||
@ -269,17 +303,21 @@ def generate_test_data(path_music, path_noises,num_samples=-1, fs=44100, seg_len
|
||||
snr = 10.0 ** (snr / 10.0)
|
||||
|
||||
# sum both signals according to snr
|
||||
summed=segment+np.sqrt(power_clean/(snr*power_noise))*new_noise #not sure if this is correct, maybe revisit later!!
|
||||
summed=summed.astype('float32')
|
||||
summed = (
|
||||
segment + np.sqrt(power_clean / (snr * power_noise)) * new_noise
|
||||
) # not sure if this is correct, maybe revisit later!!
|
||||
summed = summed.astype("float32")
|
||||
# yield tf.convert_to_tensor(summed), tf.convert_to_tensor(segment)
|
||||
|
||||
segments_noisy.append(summed.astype('float32'))
|
||||
segments_clean.append(segment.astype('float32'))
|
||||
segments_noisy.append(summed.astype("float32"))
|
||||
segments_clean.append(segment.astype("float32"))
|
||||
|
||||
return segments_noisy, segments_clean
|
||||
|
||||
def generate_val_data(path_music, path_noises,split,num_samples=-1, fs=44100, seg_len_s=5):
|
||||
|
||||
def generate_val_data(
|
||||
path_music, path_noises, split, num_samples=-1, fs=44100, seg_len_s=5
|
||||
):
|
||||
val_samples = []
|
||||
for path in path_music:
|
||||
val_samples.extend(glob.glob(os.path.join(path, "*.wav")))
|
||||
@ -304,7 +342,6 @@ def generate_val_data(path_music, path_noises,split,num_samples=-1, fs=44100, se
|
||||
seg_len = fs * seg_len_s
|
||||
segments_clean = []
|
||||
for file in tqdm(data_clean_loaded):
|
||||
|
||||
# framify arguments: seg_len, hop_size
|
||||
hop_size = int(seg_len) # no overlap
|
||||
|
||||
@ -313,7 +350,7 @@ def generate_val_data(path_music, path_noises,split,num_samples=-1, fs=44100, se
|
||||
for i in range(0, int(num_frames)):
|
||||
segment = file[pointer : pointer + seg_len]
|
||||
pointer = pointer + hop_size
|
||||
segment=segment.astype('float32')
|
||||
segment = segment.astype("float32")
|
||||
segments_clean.append(segment)
|
||||
|
||||
del data_clean_loaded
|
||||
@ -323,8 +360,9 @@ def generate_val_data(path_music, path_noises,split,num_samples=-1, fs=44100, se
|
||||
# noise_shapes=np.random.randint(0,len(noise_samples), len(segments_clean))
|
||||
noises_info = os.path.join(path_noises, "info.csv")
|
||||
|
||||
noise_generator=__noise_sample_generator(noises_info,fs, seg_len, split) #this will take care of everything
|
||||
|
||||
noise_generator = __noise_sample_generator(
|
||||
noises_info, fs, seg_len, split
|
||||
) # this will take care of everything
|
||||
|
||||
# generate noisy segments
|
||||
# load noise samples using pandas dataframe. Each split (train, val, test) should have its unique csv info file
|
||||
@ -345,7 +383,6 @@ def generate_val_data(path_music, path_noises,split,num_samples=-1, fs=44100, se
|
||||
data_clean = segments_clean[i]
|
||||
# configure sizes
|
||||
|
||||
|
||||
# estimate clean signal power
|
||||
power_clean = np.var(data_clean)
|
||||
# estimate noise power
|
||||
@ -354,33 +391,37 @@ def generate_val_data(path_music, path_noises,split,num_samples=-1, fs=44100, se
|
||||
snr = 10.0 ** (SNRs[i] / 10.0)
|
||||
|
||||
# sum both signals according to snr
|
||||
summed=data_clean+np.sqrt(power_clean/(snr*power_noise))*new_noise #not sure if this is correct, maybe revisit later!!
|
||||
summed = (
|
||||
data_clean + np.sqrt(power_clean / (snr * power_noise)) * new_noise
|
||||
) # not sure if this is correct, maybe revisit later!!
|
||||
# the rest is normal
|
||||
|
||||
summed = 10.0 ** (scales[i] / 10.0) * summed
|
||||
segments_clean[i] = 10.0 ** (scales[i] / 10.0) * segments_clean[i]
|
||||
|
||||
segments_noisy.append(summed.astype('float32'))
|
||||
segments_noisy.append(summed.astype("float32"))
|
||||
|
||||
return segments_noisy, segments_clean
|
||||
|
||||
|
||||
|
||||
def generator_train(path_music, path_noises,split, fs=44100, seg_len_s=5, extend=True, stereo=False):
|
||||
|
||||
def generator_train(
|
||||
path_music, path_noises, split, fs=44100, seg_len_s=5, extend=True, stereo=False
|
||||
):
|
||||
train_samples = []
|
||||
for path in path_music:
|
||||
train_samples.extend(glob.glob(os.path.join(path.decode("utf-8"), "*.wav")))
|
||||
|
||||
seg_len = fs * seg_len_s
|
||||
noises_info = os.path.join(path_noises.decode("utf-8"), "info.csv")
|
||||
noise_generator=__noise_sample_generator(noises_info,fs, seg_len, split.decode("utf-8")) #this will take care of everything
|
||||
noise_generator = __noise_sample_generator(
|
||||
noises_info, fs, seg_len, split.decode("utf-8")
|
||||
) # this will take care of everything
|
||||
# load data clean files
|
||||
while True:
|
||||
random.shuffle(train_samples)
|
||||
for file in train_samples:
|
||||
data, samplerate = sf.read(file)
|
||||
assert(samplerate==fs, "wrong sampling rate")
|
||||
assert samplerate == fs, "wrong sampling rate"
|
||||
data_clean = data
|
||||
# Stereo to mono
|
||||
if len(data.shape) > 1:
|
||||
@ -396,27 +437,33 @@ def generator_train(path_music, path_noises,split, fs=44100, seg_len_s=5, extend
|
||||
|
||||
num_frames = np.floor(len(data_clean) / seg_len)
|
||||
if num_frames == 0:
|
||||
data_clean=np.concatenate((data_clean, np.zeros(shape=(int(2*seg_len-len(data_clean)),))), axis=0)
|
||||
data_clean = np.concatenate(
|
||||
(data_clean, np.zeros(shape=(int(2 * seg_len - len(data_clean)),))),
|
||||
axis=0,
|
||||
)
|
||||
num_frames = 1
|
||||
pointer = 0
|
||||
data_clean=np.roll(data_clean, np.random.randint(0,seg_len)) #if only one frame, roll it for augmentation
|
||||
data_clean = np.roll(
|
||||
data_clean, np.random.randint(0, seg_len)
|
||||
) # if only one frame, roll it for augmentation
|
||||
elif num_frames > 1:
|
||||
pointer=np.random.randint(0,hop_size) #initial shifting, graeat for augmentation, better than overlap as we get different frames at each "while" iteration
|
||||
pointer = np.random.randint(
|
||||
0, hop_size
|
||||
) # initial shifting, graeat for augmentation, better than overlap as we get different frames at each "while" iteration
|
||||
else:
|
||||
pointer = 0
|
||||
|
||||
data_not_finished = True
|
||||
while(data_not_finished):
|
||||
while data_not_finished:
|
||||
segment = data_clean[pointer : pointer + seg_len]
|
||||
pointer = pointer + hop_size
|
||||
if pointer + seg_len > len(data_clean):
|
||||
data_not_finished = False
|
||||
segment=segment.astype('float32')
|
||||
segment = segment.astype("float32")
|
||||
|
||||
SNRs = np.random.uniform(2, 20)
|
||||
scale = np.random.uniform(-6, 4)
|
||||
|
||||
|
||||
# load noise signal
|
||||
data_noise = next(noise_generator)
|
||||
data_noise = np.mean(data_noise, axis=1)
|
||||
@ -427,9 +474,13 @@ def generator_train(path_music, path_noises,split, fs=44100, seg_len_s=5, extend
|
||||
# configure sizes
|
||||
if stereo:
|
||||
# estimate clean signal power
|
||||
power_clean=0.5*np.var(segment[:,0])+0.5*np.var(segment[:,1])
|
||||
power_clean = 0.5 * np.var(segment[:, 0]) + 0.5 * np.var(
|
||||
segment[:, 1]
|
||||
)
|
||||
# estimate noise power
|
||||
power_noise=0.5*np.var(new_noise[:,0])+0.5*np.var(new_noise[:,1])
|
||||
power_noise = 0.5 * np.var(new_noise[:, 0]) + 0.5 * np.var(
|
||||
new_noise[:, 1]
|
||||
)
|
||||
else:
|
||||
# estimate clean signal power
|
||||
power_clean = np.var(segment)
|
||||
@ -438,39 +489,63 @@ def generator_train(path_music, path_noises,split, fs=44100, seg_len_s=5, extend
|
||||
|
||||
snr = 10.0 ** (SNRs / 10.0)
|
||||
|
||||
|
||||
# sum both signals according to snr
|
||||
summed=segment+np.sqrt(power_clean/(snr*power_noise))*new_noise #not sure if this is correct, maybe revisit later!!
|
||||
summed = (
|
||||
segment + np.sqrt(power_clean / (snr * power_noise)) * new_noise
|
||||
) # not sure if this is correct, maybe revisit later!!
|
||||
summed = 10.0 ** (scale / 10.0) * summed
|
||||
segment = 10.0 ** (scale / 10.0) * segment
|
||||
|
||||
summed=summed.astype('float32')
|
||||
summed = summed.astype("float32")
|
||||
yield tf.convert_to_tensor(summed), tf.convert_to_tensor(segment)
|
||||
|
||||
def load_data(buffer_size, path_music_train, path_music_val, path_noises, fs=44100, seg_len_s=5, extend=True, stereo=False) :
|
||||
|
||||
def load_data(
|
||||
buffer_size,
|
||||
path_music_train,
|
||||
path_music_val,
|
||||
path_noises,
|
||||
fs=44100,
|
||||
seg_len_s=5,
|
||||
extend=True,
|
||||
stereo=False,
|
||||
):
|
||||
print("Generating train dataset")
|
||||
trainshape = int(fs * seg_len_s)
|
||||
|
||||
dataset_train = tf.data.Dataset.from_generator(generator_train,args=(path_music_train, path_noises,"train", fs, seg_len_s, extend, stereo), output_shapes=(tf.TensorShape((trainshape,)),tf.TensorShape((trainshape,))), output_types=(tf.float32, tf.float32) )
|
||||
|
||||
dataset_train = tf.data.Dataset.from_generator(
|
||||
generator_train,
|
||||
args=(path_music_train, path_noises, "train", fs, seg_len_s, extend, stereo),
|
||||
output_shapes=(tf.TensorShape((trainshape,)), tf.TensorShape((trainshape,))),
|
||||
output_types=(tf.float32, tf.float32),
|
||||
)
|
||||
|
||||
print("Generating validation dataset")
|
||||
segments_noisy, segments_clean=generate_val_data(path_music_val, path_noises,"validation",fs=fs, seg_len_s=seg_len_s)
|
||||
segments_noisy, segments_clean = generate_val_data(
|
||||
path_music_val, path_noises, "validation", fs=fs, seg_len_s=seg_len_s
|
||||
)
|
||||
|
||||
dataset_val = tf.data.Dataset.from_tensor_slices((segments_noisy, segments_clean))
|
||||
|
||||
return dataset_train.shuffle(buffer_size), dataset_val
|
||||
|
||||
|
||||
def load_data_test(buffer_size, path_pianos_test, path_noises, **kwargs):
|
||||
print("Generating test dataset")
|
||||
segments_noisy, segments_clean=generate_test_data(path_pianos_test, path_noises, extend=True, **kwargs)
|
||||
segments_noisy, segments_clean = generate_test_data(
|
||||
path_pianos_test, path_noises, extend=True, **kwargs
|
||||
)
|
||||
dataset_test = tf.data.Dataset.from_tensor_slices((segments_noisy, segments_clean))
|
||||
# dataset_test=tf.data.Dataset.from_tensor_slices((segments_noisy[1:3], segments_clean[1:3]))
|
||||
# train_dataset = train.cache().shuffle(buffer_size).take(info.splits["train"].num_examples)
|
||||
return dataset_test
|
||||
|
||||
|
||||
def load_data_formal(path_pianos_test, path_noises, **kwargs):
|
||||
print("Generating test dataset")
|
||||
segments_noisy, segments_clean=generate_paired_data_test_formal(path_pianos_test, path_noises, extend=True, **kwargs)
|
||||
segments_noisy, segments_clean = generate_paired_data_test_formal(
|
||||
path_pianos_test, path_noises, extend=True, **kwargs
|
||||
)
|
||||
print("segments::")
|
||||
print(len(segments_noisy))
|
||||
dataset_test = tf.data.Dataset.from_tensor_slices((segments_noisy, segments_clean))
|
||||
@ -478,6 +553,7 @@ def load_data_formal( path_pianos_test, path_noises, **kwargs) :
|
||||
# train_dataset = train.cache().shuffle(buffer_size).take(info.splits["train"].num_examples)
|
||||
return dataset_test
|
||||
|
||||
|
||||
def load_real_test_recordings(buffer_size, path_recordings, **kwargs):
|
||||
print("Generating real test dataset")
|
||||
|
||||
|
||||
130
inference.py
130
inference.py
@ -4,6 +4,7 @@ import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def run(args):
|
||||
import unet
|
||||
import tensorflow as tf
|
||||
@ -16,33 +17,45 @@ def run(args):
|
||||
|
||||
unet_model = unet.build_model_denoise(unet_args=args.unet)
|
||||
|
||||
ckpt=os.path.join(os.path.dirname(os.path.abspath(__file__)),path_experiment, 'checkpoint')
|
||||
ckpt = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)), path_experiment, "checkpoint"
|
||||
)
|
||||
unet_model.load_weights(ckpt)
|
||||
|
||||
def do_stft(noisy):
|
||||
|
||||
window_fn = tf.signal.hamming_window
|
||||
|
||||
win_size = args.stft.win_size
|
||||
hop_size = args.stft.hop_size
|
||||
|
||||
|
||||
stft_signal_noisy=tf.signal.stft(noisy,frame_length=win_size, window_fn=window_fn, frame_step=hop_size, pad_end=True)
|
||||
stft_noisy_stacked=tf.stack( values=[tf.math.real(stft_signal_noisy), tf.math.imag(stft_signal_noisy)], axis=-1)
|
||||
stft_signal_noisy = tf.signal.stft(
|
||||
noisy,
|
||||
frame_length=win_size,
|
||||
window_fn=window_fn,
|
||||
frame_step=hop_size,
|
||||
pad_end=True,
|
||||
)
|
||||
stft_noisy_stacked = tf.stack(
|
||||
values=[tf.math.real(stft_signal_noisy), tf.math.imag(stft_signal_noisy)],
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
return stft_noisy_stacked
|
||||
|
||||
def do_istft(data):
|
||||
|
||||
window_fn = tf.signal.hamming_window
|
||||
|
||||
win_size = args.stft.win_size
|
||||
hop_size = args.stft.hop_size
|
||||
|
||||
inv_window_fn=tf.signal.inverse_stft_window_fn(hop_size, forward_window_fn=window_fn)
|
||||
inv_window_fn = tf.signal.inverse_stft_window_fn(
|
||||
hop_size, forward_window_fn=window_fn
|
||||
)
|
||||
|
||||
pred_cpx = data[..., 0] + 1j * data[..., 1]
|
||||
pred_time=tf.signal.inverse_stft(pred_cpx, win_size, hop_size, window_fn=inv_window_fn)
|
||||
pred_time = tf.signal.inverse_stft(
|
||||
pred_cpx, win_size, hop_size, window_fn=inv_window_fn
|
||||
)
|
||||
return pred_time
|
||||
|
||||
audio = str(args.inference.audio)
|
||||
@ -57,8 +70,6 @@ def run(args):
|
||||
|
||||
data = scipy.signal.resample(data, int((44100 / samplerate) * len(data)) + 1)
|
||||
|
||||
|
||||
|
||||
segment_size = 44100 * 5 # 20s segments
|
||||
|
||||
length_data = len(data)
|
||||
@ -66,7 +77,6 @@ def run(args):
|
||||
window = np.hanning(2 * overlapsize)
|
||||
window_right = window[overlapsize::]
|
||||
window_left = window[0:overlapsize]
|
||||
audio_finished=False
|
||||
pointer = 0
|
||||
denoised_data = np.zeros(shape=(len(data),))
|
||||
residual_noise = np.zeros(shape=(len(data),))
|
||||
@ -87,21 +97,72 @@ def run(args):
|
||||
residual_time = np.array(residual_time)
|
||||
|
||||
if pointer == 0:
|
||||
pred_time=np.concatenate((pred_time[0:int(segment_size-overlapsize)], np.multiply(pred_time[int(segment_size-overlapsize):segment_size],window_right)), axis=0)
|
||||
residual_time=np.concatenate((residual_time[0:int(segment_size-overlapsize)], np.multiply(residual_time[int(segment_size-overlapsize):segment_size],window_right)), axis=0)
|
||||
pred_time = np.concatenate(
|
||||
(
|
||||
pred_time[0 : int(segment_size - overlapsize)],
|
||||
np.multiply(
|
||||
pred_time[int(segment_size - overlapsize) : segment_size],
|
||||
window_right,
|
||||
),
|
||||
),
|
||||
axis=0,
|
||||
)
|
||||
residual_time = np.concatenate(
|
||||
(
|
||||
residual_time[0 : int(segment_size - overlapsize)],
|
||||
np.multiply(
|
||||
residual_time[
|
||||
int(segment_size - overlapsize) : segment_size
|
||||
],
|
||||
window_right,
|
||||
),
|
||||
),
|
||||
axis=0,
|
||||
)
|
||||
else:
|
||||
pred_time=np.concatenate((np.multiply(pred_time[0:int(overlapsize)], window_left), pred_time[int(overlapsize):int(segment_size-overlapsize)], np.multiply(pred_time[int(segment_size-overlapsize):int(segment_size)],window_right)), axis=0)
|
||||
residual_time=np.concatenate((np.multiply(residual_time[0:int(overlapsize)], window_left), residual_time[int(overlapsize):int(segment_size-overlapsize)], np.multiply(residual_time[int(segment_size-overlapsize):int(segment_size)],window_right)), axis=0)
|
||||
pred_time = np.concatenate(
|
||||
(
|
||||
np.multiply(pred_time[0 : int(overlapsize)], window_left),
|
||||
pred_time[int(overlapsize) : int(segment_size - overlapsize)],
|
||||
np.multiply(
|
||||
pred_time[
|
||||
int(segment_size - overlapsize) : int(segment_size)
|
||||
],
|
||||
window_right,
|
||||
),
|
||||
),
|
||||
axis=0,
|
||||
)
|
||||
residual_time = np.concatenate(
|
||||
(
|
||||
np.multiply(residual_time[0 : int(overlapsize)], window_left),
|
||||
residual_time[
|
||||
int(overlapsize) : int(segment_size - overlapsize)
|
||||
],
|
||||
np.multiply(
|
||||
residual_time[
|
||||
int(segment_size - overlapsize) : int(segment_size)
|
||||
],
|
||||
window_right,
|
||||
),
|
||||
),
|
||||
axis=0,
|
||||
)
|
||||
|
||||
denoised_data[pointer:pointer+segment_size]=denoised_data[pointer:pointer+segment_size]+pred_time
|
||||
residual_noise[pointer:pointer+segment_size]=residual_noise[pointer:pointer+segment_size]+residual_time
|
||||
denoised_data[pointer : pointer + segment_size] = (
|
||||
denoised_data[pointer : pointer + segment_size] + pred_time
|
||||
)
|
||||
residual_noise[pointer : pointer + segment_size] = (
|
||||
residual_noise[pointer : pointer + segment_size] + residual_time
|
||||
)
|
||||
|
||||
pointer = pointer + segment_size - overlapsize
|
||||
else:
|
||||
segment = data[pointer::]
|
||||
lensegment = len(segment)
|
||||
segment=np.concatenate((segment, np.zeros(shape=(int(segment_size-len(segment)),))), axis=0)
|
||||
audio_finished=True
|
||||
segment = np.concatenate(
|
||||
(segment, np.zeros(shape=(int(segment_size - len(segment)),))), axis=0
|
||||
)
|
||||
# dostft
|
||||
segment_TF = do_stft(segment)
|
||||
|
||||
@ -121,11 +182,27 @@ def run(args):
|
||||
pred_time = pred_time
|
||||
residual_time = residual_time
|
||||
else:
|
||||
pred_time=np.concatenate((np.multiply(pred_time[0:int(overlapsize)], window_left), pred_time[int(overlapsize):int(segment_size)]),axis=0)
|
||||
residual_time=np.concatenate((np.multiply(residual_time[0:int(overlapsize)], window_left), residual_time[int(overlapsize):int(segment_size)]),axis=0)
|
||||
pred_time = np.concatenate(
|
||||
(
|
||||
np.multiply(pred_time[0 : int(overlapsize)], window_left),
|
||||
pred_time[int(overlapsize) : int(segment_size)],
|
||||
),
|
||||
axis=0,
|
||||
)
|
||||
residual_time = np.concatenate(
|
||||
(
|
||||
np.multiply(residual_time[0 : int(overlapsize)], window_left),
|
||||
residual_time[int(overlapsize) : int(segment_size)],
|
||||
),
|
||||
axis=0,
|
||||
)
|
||||
|
||||
denoised_data[pointer::]=denoised_data[pointer::]+pred_time[0:lensegment]
|
||||
residual_noise[pointer::]=residual_noise[pointer::]+residual_time[0:lensegment]
|
||||
denoised_data[pointer::] = (
|
||||
denoised_data[pointer::] + pred_time[0:lensegment]
|
||||
)
|
||||
residual_noise[pointer::] = (
|
||||
residual_noise[pointer::] + residual_time[0:lensegment]
|
||||
)
|
||||
|
||||
basename = os.path.splitext(audio)[0]
|
||||
wav_noisy_name = basename + "_noisy_input" + ".wav"
|
||||
@ -156,10 +233,3 @@ def main(args):
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
109
train.py
109
train.py
@ -4,15 +4,14 @@ import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def run(args):
|
||||
import unet
|
||||
import tensorflow as tf
|
||||
import dataset_loader
|
||||
from tensorflow.keras.optimizers import Adam
|
||||
import soundfile as sf
|
||||
import datetime
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
|
||||
path_experiment = str(args.path_experiment)
|
||||
|
||||
@ -20,55 +19,70 @@ def run(args):
|
||||
os.makedirs(path_experiment)
|
||||
|
||||
path_music_train = args.dset.path_music_train
|
||||
path_music_test=args.dset.path_music_test
|
||||
path_music_validation = args.dset.path_music_validation
|
||||
|
||||
path_noise = args.dset.path_noise
|
||||
path_recordings=args.dset.path_recordings
|
||||
|
||||
fs = args.fs
|
||||
overlap=args.overlap
|
||||
seg_len_s_train = args.seg_len_s_train
|
||||
|
||||
batch_size = args.batch_size
|
||||
epochs = args.epochs
|
||||
|
||||
num_real_test_segments=args.num_real_test_segments
|
||||
buffer_size = args.buffer_size # for shuffle
|
||||
|
||||
tensorboard_logs = args.tensorboard_logs
|
||||
|
||||
def do_stft(noisy, clean=None):
|
||||
|
||||
window_fn = tf.signal.hamming_window
|
||||
|
||||
win_size = args.stft.win_size
|
||||
hop_size = args.stft.hop_size
|
||||
|
||||
stft_signal_noisy = tf.signal.stft(
|
||||
noisy, frame_length=win_size, window_fn=window_fn, frame_step=hop_size
|
||||
)
|
||||
stft_noisy_stacked = tf.stack(
|
||||
values=[tf.math.real(stft_signal_noisy), tf.math.imag(stft_signal_noisy)],
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
stft_signal_noisy=tf.signal.stft(noisy,frame_length=win_size, window_fn=window_fn, frame_step=hop_size)
|
||||
stft_noisy_stacked=tf.stack( values=[tf.math.real(stft_signal_noisy), tf.math.imag(stft_signal_noisy)], axis=-1)
|
||||
|
||||
if clean!=None:
|
||||
|
||||
stft_signal_clean=tf.signal.stft(clean,frame_length=win_size, window_fn=window_fn, frame_step=hop_size)
|
||||
stft_clean_stacked=tf.stack( values=[tf.math.real(stft_signal_clean), tf.math.imag(stft_signal_clean)], axis=-1)
|
||||
|
||||
if clean is not None:
|
||||
stft_signal_clean = tf.signal.stft(
|
||||
clean, frame_length=win_size, window_fn=window_fn, frame_step=hop_size
|
||||
)
|
||||
stft_clean_stacked = tf.stack(
|
||||
values=[
|
||||
tf.math.real(stft_signal_clean),
|
||||
tf.math.imag(stft_signal_clean),
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
return stft_noisy_stacked, stft_clean_stacked
|
||||
else:
|
||||
|
||||
return stft_noisy_stacked
|
||||
|
||||
# Loading data. The train dataset object is a generator. The validation dataset is loaded in memory.
|
||||
|
||||
dataset_train, dataset_val=dataset_loader.load_data(buffer_size, path_music_train, path_music_validation, path_noise, fs=fs, seg_len_s=seg_len_s_train)
|
||||
dataset_train, dataset_val = dataset_loader.load_data(
|
||||
buffer_size,
|
||||
path_music_train,
|
||||
path_music_validation,
|
||||
path_noise,
|
||||
fs=fs,
|
||||
seg_len_s=seg_len_s_train,
|
||||
)
|
||||
|
||||
dataset_train=dataset_train.map(do_stft, num_parallel_calls=args.num_workers, deterministic=None)
|
||||
dataset_val=dataset_val.map(do_stft, num_parallel_calls=args.num_workers, deterministic=None)
|
||||
dataset_train = dataset_train.map(
|
||||
do_stft, num_parallel_calls=args.num_workers, deterministic=None
|
||||
)
|
||||
dataset_val = dataset_val.map(
|
||||
do_stft, num_parallel_calls=args.num_workers, deterministic=None
|
||||
)
|
||||
|
||||
strategy = tf.distribute.MirroredStrategy()
|
||||
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
|
||||
print("Number of devices: {}".format(strategy.num_replicas_in_sync))
|
||||
|
||||
with strategy.scope():
|
||||
# build the model
|
||||
@ -80,12 +94,17 @@ def run(args):
|
||||
loss = tf.keras.losses.MeanAbsoluteError()
|
||||
|
||||
if args.use_tensorboard:
|
||||
log_dir = os.path.join(tensorboard_logs, os.path.basename(path_experiment)+"_"+datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
|
||||
log_dir = os.path.join(
|
||||
tensorboard_logs,
|
||||
os.path.basename(path_experiment)
|
||||
+ "_"
|
||||
+ datetime.datetime.now().strftime("%Y%m%d-%H%M%S"),
|
||||
)
|
||||
train_summary_writer = tf.summary.create_file_writer(log_dir + "/train")
|
||||
val_summary_writer = tf.summary.create_file_writer(log_dir + "/validation")
|
||||
|
||||
# path where the checkpoints will be saved
|
||||
checkpoint_filepath=os.path.join(path_experiment, 'checkpoint')
|
||||
checkpoint_filepath = os.path.join(path_experiment, "checkpoint")
|
||||
|
||||
dataset_train = dataset_train.batch(batch_size)
|
||||
dataset_val = dataset_val.batch(batch_size)
|
||||
@ -106,27 +125,43 @@ def run(args):
|
||||
for epoch in range(epochs):
|
||||
total_loss = 0
|
||||
step_loss = 0
|
||||
for step in tqdm(range(args.steps_per_epoch), desc="Training epoch "+str(epoch)):
|
||||
for step in tqdm(
|
||||
range(args.steps_per_epoch), desc="Training epoch " + str(epoch)
|
||||
):
|
||||
step_loss = trainer.distributed_training_step(iterator.get_next())
|
||||
total_loss += step_loss
|
||||
with train_summary_writer.as_default():
|
||||
tf.summary.scalar('batch_loss', step_loss, step=step)
|
||||
tf.summary.scalar('batch_mean_absolute_error', trainer.train_mae.result(), step=step)
|
||||
tf.summary.scalar("batch_loss", step_loss, step=step)
|
||||
tf.summary.scalar(
|
||||
"batch_mean_absolute_error", trainer.train_mae.result(), step=step
|
||||
)
|
||||
|
||||
train_loss = total_loss / args.steps_per_epoch
|
||||
|
||||
for x in tqdm(dataset_val, desc="Validating epoch " + str(epoch)):
|
||||
trainer.distributed_test_step(x)
|
||||
|
||||
template = ("Epoch {}, Loss: {}, train_MAE: {}, val_Loss: {}, val_MAE: {}")
|
||||
print (template.format(epoch+1, train_loss, trainer.train_mae.result(), trainer.val_loss.result(), trainer.val_mae.result()))
|
||||
template = "Epoch {}, Loss: {}, train_MAE: {}, val_Loss: {}, val_MAE: {}"
|
||||
print(
|
||||
template.format(
|
||||
epoch + 1,
|
||||
train_loss,
|
||||
trainer.train_mae.result(),
|
||||
trainer.val_loss.result(),
|
||||
trainer.val_mae.result(),
|
||||
)
|
||||
)
|
||||
|
||||
with train_summary_writer.as_default():
|
||||
tf.summary.scalar('epoch_loss', train_loss, step=epoch)
|
||||
tf.summary.scalar('epoch_mean_absolute_error', trainer.train_mae.result(), step=epoch)
|
||||
tf.summary.scalar("epoch_loss", train_loss, step=epoch)
|
||||
tf.summary.scalar(
|
||||
"epoch_mean_absolute_error", trainer.train_mae.result(), step=epoch
|
||||
)
|
||||
with val_summary_writer.as_default():
|
||||
tf.summary.scalar('epoch_loss', trainer.val_loss.result(), step=epoch)
|
||||
tf.summary.scalar('epoch_mean_absolute_error', trainer.val_mae.result(), step=epoch)
|
||||
tf.summary.scalar("epoch_loss", trainer.val_loss.result(), step=epoch)
|
||||
tf.summary.scalar(
|
||||
"epoch_mean_absolute_error", trainer.val_mae.result(), step=epoch
|
||||
)
|
||||
|
||||
trainer.train_mae.reset_states()
|
||||
trainer.val_loss.reset_states()
|
||||
@ -137,10 +172,11 @@ def run(args):
|
||||
current_lr *= 1e-1
|
||||
trainer.optimizer.lr = current_lr
|
||||
try:
|
||||
unet_model.save_weights(checkpoint_filpath)
|
||||
except:
|
||||
unet_model.save_weights(checkpoint_filepath)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _main(args):
|
||||
global __file__
|
||||
|
||||
@ -161,10 +197,3 @@ def main(args):
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
35
trainer.py
35
trainer.py
@ -1,12 +1,7 @@
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import soundfile as sf
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
|
||||
class Trainer():
|
||||
|
||||
class Trainer:
|
||||
def __init__(self, model, optimizer, loss, strategy, path_experiment, args):
|
||||
self.model = model
|
||||
print(self.model.summary())
|
||||
@ -23,17 +18,20 @@ class Trainer():
|
||||
self.train_mae_s1 = tf.keras.metrics.MeanAbsoluteError(name="train_mae_s1")
|
||||
self.train_mae = tf.keras.metrics.MeanAbsoluteError(name="train_mae_s2")
|
||||
self.val_mae = tf.keras.metrics.MeanAbsoluteError(name="validation_mae")
|
||||
self.val_loss = tf.keras.metrics.Mean(name='test_loss')
|
||||
|
||||
self.val_loss = tf.keras.metrics.Mean(name="test_loss")
|
||||
|
||||
def train_step(self, inputs):
|
||||
noisy, clean = inputs
|
||||
|
||||
with tf.GradientTape() as tape:
|
||||
logits_2, logits_1 = self.model(
|
||||
noisy, training=True
|
||||
) # Logits for this minibatch
|
||||
|
||||
logits_2,logits_1 = self.model(noisy, training=True) # Logits for this minibatch
|
||||
|
||||
loss_value = tf.reduce_mean(self.loss_object(clean, logits_2) + tf.reduce_mean(self.loss_object(clean, logits_1)))
|
||||
loss_value = tf.reduce_mean(
|
||||
self.loss_object(clean, logits_2)
|
||||
+ tf.reduce_mean(self.loss_object(clean, logits_1))
|
||||
)
|
||||
|
||||
grads = tape.gradient(loss_value, self.model.trainable_weights)
|
||||
self.optimizer.apply_gradients(zip(grads, self.model.trainable_weights))
|
||||
@ -42,11 +40,12 @@ class Trainer():
|
||||
return loss_value
|
||||
|
||||
def test_step(self, inputs):
|
||||
|
||||
noisy, clean = inputs
|
||||
|
||||
predictions_s2, predictions_s1 = self.model(noisy, training=False)
|
||||
t_loss = self.loss_object(clean, predictions_s2)+self.loss_object(clean, predictions_s1)
|
||||
t_loss = self.loss_object(clean, predictions_s2) + self.loss_object(
|
||||
clean, predictions_s1
|
||||
)
|
||||
|
||||
self.val_mae.update_state(clean, predictions_s2)
|
||||
self.val_loss.update_state(t_loss)
|
||||
@ -54,13 +53,11 @@ class Trainer():
|
||||
@tf.function()
|
||||
def distributed_training_step(self, inputs):
|
||||
per_replica_losses = self.strategy.run(self.train_step, args=(inputs,))
|
||||
reduced_losses=self.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None)
|
||||
reduced_losses = self.strategy.reduce(
|
||||
tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None
|
||||
)
|
||||
return reduced_losses
|
||||
|
||||
@tf.function
|
||||
def distributed_test_step(self, inputs):
|
||||
return self.strategy.run(self.test_step, args=(inputs,))
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
266
unet.py
266
unet.py
@ -1,11 +1,11 @@
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras import Model, Input
|
||||
from tensorflow.keras import Input
|
||||
from tensorflow.keras import layers
|
||||
from tensorflow.keras.initializers import TruncatedNormal
|
||||
import math as m
|
||||
|
||||
def build_model_denoise(unet_args=None):
|
||||
|
||||
def build_model_denoise(unet_args=None):
|
||||
inputs = Input(shape=(None, None, 2))
|
||||
|
||||
outputs_stage_2, outputs_stage_1 = MultiStage_denoise(unet_args=unet_args)(inputs)
|
||||
@ -14,17 +14,20 @@ def build_model_denoise(unet_args=None):
|
||||
model = tf.keras.Model(inputs=inputs, outputs=[outputs_stage_2, outputs_stage_1])
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class DenseBlock(layers.Layer):
|
||||
'''
|
||||
"""
|
||||
[B, T, F, N] => [B, T, F, N]
|
||||
DenseNet Block consisting of "num_layers" densely connected convolutional layers
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self, num_layers, N, ksize, activation):
|
||||
'''
|
||||
"""
|
||||
num_layers: number of densely connected conv. layers
|
||||
N: Number of filters (same in each layer)
|
||||
ksize: Kernel size (same in each layer)
|
||||
'''
|
||||
"""
|
||||
super(DenseBlock, self).__init__()
|
||||
self.activation = activation
|
||||
|
||||
@ -33,52 +36,58 @@ class DenseBlock(layers.Layer):
|
||||
self.num_layers = num_layers
|
||||
|
||||
for i in range(num_layers):
|
||||
self.H.append(layers.Conv2D(filters=N,
|
||||
self.H.append(
|
||||
layers.Conv2D(
|
||||
filters=N,
|
||||
kernel_size=ksize,
|
||||
kernel_initializer=TruncatedNormal(),
|
||||
strides=1,
|
||||
padding='VALID',
|
||||
activation=self.activation))
|
||||
padding="VALID",
|
||||
activation=self.activation,
|
||||
)
|
||||
)
|
||||
|
||||
def call(self, x):
|
||||
|
||||
x_=tf.pad(x, self.paddings_1, mode='SYMMETRIC')
|
||||
x_ = tf.pad(x, self.paddings_1, mode="SYMMETRIC")
|
||||
x_ = self.H[0](x_)
|
||||
if self.num_layers > 1:
|
||||
for h in self.H[1:]:
|
||||
x = tf.concat([x_, x], axis=-1)
|
||||
x_=tf.pad(x, self.paddings_1, mode='SYMMETRIC')
|
||||
x_ = tf.pad(x, self.paddings_1, mode="SYMMETRIC")
|
||||
x_ = h(x_)
|
||||
|
||||
return x_
|
||||
|
||||
|
||||
class FinalBlock(layers.Layer):
|
||||
'''
|
||||
"""
|
||||
[B, T, F, N] => [B, T, F, 2]
|
||||
Final block. Basically, a 3x3 conv. layer to map the output features to the output complex spectrogram.
|
||||
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(FinalBlock, self).__init__()
|
||||
ksize = (3, 3)
|
||||
self.paddings_2 = get_paddings(ksize)
|
||||
self.conv2=layers.Conv2D(filters=2,
|
||||
self.conv2 = layers.Conv2D(
|
||||
filters=2,
|
||||
kernel_size=ksize,
|
||||
kernel_initializer=TruncatedNormal(),
|
||||
strides=1,
|
||||
padding='VALID',
|
||||
activation=None)
|
||||
|
||||
padding="VALID",
|
||||
activation=None,
|
||||
)
|
||||
|
||||
def call(self, inputs):
|
||||
|
||||
x=tf.pad(inputs, self.paddings_2, mode='SYMMETRIC')
|
||||
x = tf.pad(inputs, self.paddings_2, mode="SYMMETRIC")
|
||||
pred = self.conv2(x)
|
||||
|
||||
return pred
|
||||
|
||||
|
||||
class SAM(layers.Layer):
|
||||
'''
|
||||
"""
|
||||
[B, T, F, N] => [B, T, F, N] , [B, T, F, N]
|
||||
Supervised Attention Module:
|
||||
The purpose of SAM is to make the network only propagate the most relevant features to the second stage, discarding the less useful ones.
|
||||
@ -86,48 +95,55 @@ class SAM(layers.Layer):
|
||||
The first stage output is then calculated adding the original input spectrogram to the residual noise.
|
||||
The attention-guided features are computed using the attention masks M, which are directly calculated from the first stage output with a 1x1 convolution and a sigmoid function.
|
||||
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self, n_feat):
|
||||
super(SAM, self).__init__()
|
||||
|
||||
ksize = (3, 3)
|
||||
self.paddings_1 = get_paddings(ksize)
|
||||
self.conv1 = layers.Conv2D(filters=n_feat,
|
||||
self.conv1 = layers.Conv2D(
|
||||
filters=n_feat,
|
||||
kernel_size=ksize,
|
||||
kernel_initializer=TruncatedNormal(),
|
||||
strides=1,
|
||||
padding='VALID',
|
||||
activation=None)
|
||||
padding="VALID",
|
||||
activation=None,
|
||||
)
|
||||
ksize = (3, 3)
|
||||
self.paddings_2 = get_paddings(ksize)
|
||||
self.conv2=layers.Conv2D(filters=2,
|
||||
self.conv2 = layers.Conv2D(
|
||||
filters=2,
|
||||
kernel_size=ksize,
|
||||
kernel_initializer=TruncatedNormal(),
|
||||
strides=1,
|
||||
padding='VALID',
|
||||
activation=None)
|
||||
padding="VALID",
|
||||
activation=None,
|
||||
)
|
||||
|
||||
ksize = (3, 3)
|
||||
self.paddings_3 = get_paddings(ksize)
|
||||
self.conv3 = layers.Conv2D(filters=n_feat,
|
||||
self.conv3 = layers.Conv2D(
|
||||
filters=n_feat,
|
||||
kernel_size=ksize,
|
||||
kernel_initializer=TruncatedNormal(),
|
||||
strides=1,
|
||||
padding='VALID',
|
||||
activation=None)
|
||||
padding="VALID",
|
||||
activation=None,
|
||||
)
|
||||
self.cropadd = CropAddBlock()
|
||||
|
||||
def call(self, inputs, input_spectrogram):
|
||||
x1=tf.pad(inputs, self.paddings_1, mode='SYMMETRIC')
|
||||
x1 = tf.pad(inputs, self.paddings_1, mode="SYMMETRIC")
|
||||
x1 = self.conv1(x1)
|
||||
|
||||
x=tf.pad(inputs, self.paddings_2, mode='SYMMETRIC')
|
||||
x = tf.pad(inputs, self.paddings_2, mode="SYMMETRIC")
|
||||
x = self.conv2(x)
|
||||
|
||||
# residual prediction
|
||||
pred = layers.Add()([x, input_spectrogram]) # features to next stage
|
||||
|
||||
x3=tf.pad(pred, self.paddings_3, mode='SYMMETRIC')
|
||||
x3 = tf.pad(pred, self.paddings_3, mode="SYMMETRIC")
|
||||
M = self.conv3(x3)
|
||||
|
||||
M = tf.keras.activations.sigmoid(M)
|
||||
@ -138,17 +154,18 @@ class SAM(layers.Layer):
|
||||
|
||||
|
||||
class AddFreqEncoding(layers.Layer):
|
||||
'''
|
||||
"""
|
||||
[B, T, F, 2] => [B, T, F, 12]
|
||||
Generates frequency positional embeddings and concatenates them as 10 extra channels
|
||||
This function is optimized for F=1025
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self, f_dim):
|
||||
super(AddFreqEncoding, self).__init__()
|
||||
pi = tf.constant(m.pi)
|
||||
pi=tf.cast(pi,'float32')
|
||||
pi = tf.cast(pi, "float32")
|
||||
self.f_dim = f_dim # f_dim is fixed
|
||||
n=tf.cast(tf.range(f_dim)/(f_dim-1),'float32')
|
||||
n = tf.cast(tf.range(f_dim) / (f_dim - 1), "float32")
|
||||
coss = tf.math.cos(pi * n)
|
||||
f_channel = tf.expand_dims(coss, -1) # (1025,1)
|
||||
self.fembeddings = f_channel
|
||||
@ -156,28 +173,38 @@ class AddFreqEncoding(layers.Layer):
|
||||
for k in range(1, 10):
|
||||
coss = tf.math.cos(2**k * pi * n)
|
||||
f_channel = tf.expand_dims(coss, -1) # (1025,1)
|
||||
self.fembeddings=tf.concat([self.fembeddings,f_channel],axis=-1) #(1025,10)
|
||||
|
||||
self.fembeddings = tf.concat(
|
||||
[self.fembeddings, f_channel], axis=-1
|
||||
) # (1025,10)
|
||||
|
||||
def call(self, input_tensor):
|
||||
|
||||
batch_size_tensor = tf.shape(input_tensor)[0] # get batch size
|
||||
time_dim = tf.shape(input_tensor)[1] # get time dimension
|
||||
|
||||
fembeddings_2 = tf.broadcast_to(self.fembeddings, [batch_size_tensor, time_dim, self.f_dim, 10])
|
||||
|
||||
fembeddings_2 = tf.broadcast_to(
|
||||
self.fembeddings, [batch_size_tensor, time_dim, self.f_dim, 10]
|
||||
)
|
||||
|
||||
return tf.concat([input_tensor, fembeddings_2], axis=-1) # (batch,427,1025,12)
|
||||
|
||||
|
||||
def get_paddings(K):
|
||||
return tf.constant([[0,0],[K[0]//2, K[0]//2 -(1- K[0]%2) ], [ K[1]//2, K[1]//2 -(1- K[1]%2) ],[0,0]])
|
||||
return tf.constant(
|
||||
[
|
||||
[0, 0],
|
||||
[K[0] // 2, K[0] // 2 - (1 - K[0] % 2)],
|
||||
[K[1] // 2, K[1] // 2 - (1 - K[1] % 2)],
|
||||
[0, 0],
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class Decoder(layers.Layer):
|
||||
'''
|
||||
"""
|
||||
[B, T, F, N] , skip connections => [B, T, F, N]
|
||||
Decoder side of the U-Net subnetwork.
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self, Ns, Ss, unet_args):
|
||||
super(Decoder, self).__init__()
|
||||
|
||||
@ -186,21 +213,30 @@ class Decoder(layers.Layer):
|
||||
self.activation = unet_args.activation
|
||||
self.depth = unet_args.depth
|
||||
|
||||
|
||||
ksize = (3, 3)
|
||||
self.paddings_3 = get_paddings(ksize)
|
||||
self.conv2d_3=layers.Conv2D(filters=self.Ns[self.depth],
|
||||
self.conv2d_3 = layers.Conv2D(
|
||||
filters=self.Ns[self.depth],
|
||||
kernel_size=ksize,
|
||||
kernel_initializer=TruncatedNormal(),
|
||||
strides=1,
|
||||
padding='VALID',
|
||||
activation=self.activation)
|
||||
padding="VALID",
|
||||
activation=self.activation,
|
||||
)
|
||||
|
||||
self.cropadd = CropAddBlock()
|
||||
|
||||
self.dblocks = []
|
||||
for i in range(self.depth):
|
||||
self.dblocks.append(D_Block(layer_idx=i,N=self.Ns[i], S=self.Ss[i], activation=self.activation,num_tfc=unet_args.num_tfc))
|
||||
self.dblocks.append(
|
||||
D_Block(
|
||||
layer_idx=i,
|
||||
N=self.Ns[i],
|
||||
S=self.Ss[i],
|
||||
activation=self.activation,
|
||||
num_tfc=unet_args.num_tfc,
|
||||
)
|
||||
)
|
||||
|
||||
def call(self, inputs, contracting_layers):
|
||||
x = inputs
|
||||
@ -208,12 +244,13 @@ class Decoder(layers.Layer):
|
||||
x = self.dblocks[i - 1](x, contracting_layers[i - 1])
|
||||
return x
|
||||
|
||||
class Encoder(tf.keras.Model):
|
||||
|
||||
'''
|
||||
class Encoder(tf.keras.Model):
|
||||
"""
|
||||
[B, T, F, N] => skip connections , [B, T, F, N_4]
|
||||
Encoder side of the U-Net subnetwork.
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self, Ns, Ss, unet_args):
|
||||
super(Encoder, self).__init__()
|
||||
self.Ns = Ns
|
||||
@ -225,14 +262,22 @@ class Encoder(tf.keras.Model):
|
||||
|
||||
self.eblocks = []
|
||||
for i in range(self.depth):
|
||||
self.eblocks.append(E_Block(layer_idx=i,N0=self.Ns[i],N=self.Ns[i+1],S=self.Ss[i], activation=self.activation , num_tfc=unet_args.num_tfc))
|
||||
self.eblocks.append(
|
||||
E_Block(
|
||||
layer_idx=i,
|
||||
N0=self.Ns[i],
|
||||
N=self.Ns[i + 1],
|
||||
S=self.Ss[i],
|
||||
activation=self.activation,
|
||||
num_tfc=unet_args.num_tfc,
|
||||
)
|
||||
)
|
||||
|
||||
self.i_block = I_Block(self.Ns[self.depth], self.activation, unet_args.num_tfc)
|
||||
|
||||
def call(self, inputs):
|
||||
x = inputs
|
||||
for i in range(self.depth):
|
||||
|
||||
x, x_contract = self.eblocks[i](x)
|
||||
|
||||
self.contracting_layers[i] = x_contract # if remove 0, correct this
|
||||
@ -240,8 +285,8 @@ class Encoder(tf.keras.Model):
|
||||
|
||||
return x, self.contracting_layers
|
||||
|
||||
class MultiStage_denoise(tf.keras.Model):
|
||||
|
||||
class MultiStage_denoise(tf.keras.Model):
|
||||
def __init__(self, unet_args=None):
|
||||
super(MultiStage_denoise, self).__init__()
|
||||
|
||||
@ -259,13 +304,14 @@ class MultiStage_denoise(tf.keras.Model):
|
||||
# initial feature extractor
|
||||
ksize = (7, 7)
|
||||
self.paddings_1 = get_paddings(ksize)
|
||||
self.conv2d_1 = layers.Conv2D(filters=self.Ns[0],
|
||||
self.conv2d_1 = layers.Conv2D(
|
||||
filters=self.Ns[0],
|
||||
kernel_size=ksize,
|
||||
kernel_initializer=TruncatedNormal(),
|
||||
strides=1,
|
||||
padding='VALID',
|
||||
activation=self.activation)
|
||||
|
||||
padding="VALID",
|
||||
activation=self.activation,
|
||||
)
|
||||
|
||||
self.encoder_s1 = Encoder(self.Ns, self.Ss, unet_args)
|
||||
self.decoder_s1 = Decoder(self.Ns, self.Ss, unet_args)
|
||||
@ -281,39 +327,41 @@ class MultiStage_denoise(tf.keras.Model):
|
||||
# initial feature extractor
|
||||
ksize = (7, 7)
|
||||
self.paddings_2 = get_paddings(ksize)
|
||||
self.conv2d_2 = layers.Conv2D(filters=self.Ns[0],
|
||||
self.conv2d_2 = layers.Conv2D(
|
||||
filters=self.Ns[0],
|
||||
kernel_size=ksize,
|
||||
kernel_initializer=TruncatedNormal(),
|
||||
strides=1,
|
||||
padding='VALID',
|
||||
activation=self.activation)
|
||||
|
||||
padding="VALID",
|
||||
activation=self.activation,
|
||||
)
|
||||
|
||||
self.encoder_s2 = Encoder(self.Ns, self.Ss, unet_args)
|
||||
self.decoder_s2 = Decoder(self.Ns, self.Ss, unet_args)
|
||||
|
||||
@tf.function()
|
||||
def call(self, inputs):
|
||||
|
||||
if self.use_fencoding:
|
||||
x_w_freq = self.freq_encoding(inputs) # None, None, 1025, 12
|
||||
else:
|
||||
x_w_freq = inputs
|
||||
|
||||
# intitial feature extractor
|
||||
x=tf.pad(x_w_freq, self.paddings_1, mode='SYMMETRIC')
|
||||
x = tf.pad(x_w_freq, self.paddings_1, mode="SYMMETRIC")
|
||||
x = self.conv2d_1(x) # None, None, 1025, 32
|
||||
|
||||
x, contracting_layers_s1 = self.encoder_s1(x)
|
||||
# decoder
|
||||
feats_s1 =self.decoder_s1(x, contracting_layers_s1) #None, None, 1025, 32 features
|
||||
feats_s1 = self.decoder_s1(
|
||||
x, contracting_layers_s1
|
||||
) # None, None, 1025, 32 features
|
||||
|
||||
if self.num_stages > 1:
|
||||
# SAM module
|
||||
Fout, pred_stage_1 = self.sam_1(feats_s1, inputs)
|
||||
|
||||
# intitial feature extractor
|
||||
x=tf.pad(x_w_freq, self.paddings_2, mode='SYMMETRIC')
|
||||
x = tf.pad(x_w_freq, self.paddings_2, mode="SYMMETRIC")
|
||||
x = self.conv2d_2(x)
|
||||
|
||||
if self.use_sam:
|
||||
@ -323,7 +371,9 @@ class MultiStage_denoise(tf.keras.Model):
|
||||
|
||||
x, contracting_layers_s2 = self.encoder_s2(x)
|
||||
|
||||
feats_s2=self.decoder_s2(x, contracting_layers_s2) #None, None, 1025, 32 features
|
||||
feats_s2 = self.decoder_s2(
|
||||
x, contracting_layers_s2
|
||||
) # None, None, 1025, 32 features
|
||||
|
||||
# consider implementing a third stage?
|
||||
|
||||
@ -333,23 +383,27 @@ class MultiStage_denoise(tf.keras.Model):
|
||||
pred_stage_1 = self.finalblock(feats_s1)
|
||||
return pred_stage_1
|
||||
|
||||
|
||||
class I_Block(layers.Layer):
|
||||
'''
|
||||
"""
|
||||
[B, T, F, N] => [B, T, F, N]
|
||||
Intermediate block:
|
||||
Basically, a densenet block with a residual connection
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self, N, activation, num_tfc, **kwargs):
|
||||
super(I_Block, self).__init__(**kwargs)
|
||||
|
||||
ksize = (3, 3)
|
||||
self.tfc = DenseBlock(num_tfc, N, ksize, activation)
|
||||
|
||||
self.conv2d_res= layers.Conv2D(filters=N,
|
||||
self.conv2d_res = layers.Conv2D(
|
||||
filters=N,
|
||||
kernel_size=(1, 1),
|
||||
kernel_initializer=TruncatedNormal(),
|
||||
strides=1,
|
||||
padding='VALID')
|
||||
padding="VALID",
|
||||
)
|
||||
|
||||
def call(self, inputs):
|
||||
x = self.tfc(inputs)
|
||||
@ -359,7 +413,6 @@ class I_Block(layers.Layer):
|
||||
|
||||
|
||||
class E_Block(layers.Layer):
|
||||
|
||||
def __init__(self, layer_idx, N0, N, S, activation, num_tfc, **kwargs):
|
||||
super(E_Block, self).__init__(**kwargs)
|
||||
self.layer_idx = layer_idx
|
||||
@ -371,31 +424,33 @@ class E_Block(layers.Layer):
|
||||
|
||||
ksize = (S[0] + 2, S[1] + 2)
|
||||
self.paddings_2 = get_paddings(ksize)
|
||||
self.conv2d_2 = layers.Conv2D(filters=N,
|
||||
self.conv2d_2 = layers.Conv2D(
|
||||
filters=N,
|
||||
kernel_size=(S[0] + 2, S[1] + 2),
|
||||
kernel_initializer=TruncatedNormal(),
|
||||
strides=S,
|
||||
padding='VALID',
|
||||
activation=self.activation)
|
||||
|
||||
padding="VALID",
|
||||
activation=self.activation,
|
||||
)
|
||||
|
||||
def call(self, inputs, training=None, **kwargs):
|
||||
x = self.i_block(inputs)
|
||||
|
||||
x_down=tf.pad(x, self.paddings_2, mode='SYMMETRIC')
|
||||
x_down = tf.pad(x, self.paddings_2, mode="SYMMETRIC")
|
||||
x_down = self.conv2d_2(x_down)
|
||||
|
||||
return x_down, x
|
||||
|
||||
|
||||
def get_config(self):
|
||||
return dict(layer_idx=self.layer_idx,
|
||||
return dict(
|
||||
layer_idx=self.layer_idx,
|
||||
N=self.N,
|
||||
S=self.S,
|
||||
**super(E_Block, self).get_config()
|
||||
**super(E_Block, self).get_config(),
|
||||
)
|
||||
class D_Block(layers.Layer):
|
||||
|
||||
|
||||
class D_Block(layers.Layer):
|
||||
def __init__(self, layer_idx, N, S, activation, num_tfc, **kwargs):
|
||||
super(D_Block, self).__init__(**kwargs)
|
||||
self.layer_idx = layer_idx
|
||||
@ -405,29 +460,35 @@ class D_Block(layers.Layer):
|
||||
ksize = (S[0] + 2, S[1] + 2)
|
||||
self.paddings_1 = get_paddings(ksize)
|
||||
|
||||
self.tconv_1= layers.Conv2DTranspose(filters=N,
|
||||
self.tconv_1 = layers.Conv2DTranspose(
|
||||
filters=N,
|
||||
kernel_size=(S[0] + 2, S[1] + 2),
|
||||
kernel_initializer=TruncatedNormal(),
|
||||
strides=S,
|
||||
activation=self.activation,
|
||||
padding='VALID')
|
||||
padding="VALID",
|
||||
)
|
||||
|
||||
self.upsampling = layers.UpSampling2D(size=S, interpolation='nearest')
|
||||
self.upsampling = layers.UpSampling2D(size=S, interpolation="nearest")
|
||||
|
||||
self.projection = layers.Conv2D(filters=N,
|
||||
self.projection = layers.Conv2D(
|
||||
filters=N,
|
||||
kernel_size=(1, 1),
|
||||
kernel_initializer=TruncatedNormal(),
|
||||
strides=1,
|
||||
activation=self.activation,
|
||||
padding='VALID')
|
||||
padding="VALID",
|
||||
)
|
||||
self.cropadd = CropAddBlock()
|
||||
self.cropconcat = CropConcatBlock()
|
||||
|
||||
self.i_block = I_Block(N, activation, num_tfc)
|
||||
|
||||
def call(self, inputs, bridge, previous_encoder=None, previous_decoder=None,**kwargs):
|
||||
def call(
|
||||
self, inputs, bridge, previous_encoder=None, previous_decoder=None, **kwargs
|
||||
):
|
||||
x = inputs
|
||||
x=tf.pad(x, self.paddings_1, mode='SYMMETRIC')
|
||||
x = tf.pad(x, self.paddings_1, mode="SYMMETRIC")
|
||||
x = self.tconv_1(inputs)
|
||||
|
||||
x2 = self.upsampling(inputs)
|
||||
@ -437,39 +498,40 @@ class D_Block(layers.Layer):
|
||||
|
||||
x = self.cropadd(x, x2)
|
||||
|
||||
|
||||
x = self.cropconcat(x, bridge)
|
||||
|
||||
x = self.i_block(x)
|
||||
return x
|
||||
|
||||
def get_config(self):
|
||||
return dict(layer_idx=self.layer_idx,
|
||||
return dict(
|
||||
layer_idx=self.layer_idx,
|
||||
N=self.N,
|
||||
S=self.S,
|
||||
**super(D_Block, self).get_config()
|
||||
**super(D_Block, self).get_config(),
|
||||
)
|
||||
|
||||
class CropAddBlock(layers.Layer):
|
||||
|
||||
class CropAddBlock(layers.Layer):
|
||||
def call(self, down_layer, x, **kwargs):
|
||||
x1_shape = tf.shape(down_layer)
|
||||
x2_shape = tf.shape(x)
|
||||
|
||||
|
||||
height_diff = (x1_shape[1] - x2_shape[1]) // 2
|
||||
width_diff = (x1_shape[2] - x2_shape[2]) // 2
|
||||
|
||||
down_layer_cropped = down_layer[:,
|
||||
down_layer_cropped = down_layer[
|
||||
:,
|
||||
height_diff : (x2_shape[1] + height_diff),
|
||||
width_diff : (x2_shape[2] + width_diff),
|
||||
:]
|
||||
:,
|
||||
]
|
||||
|
||||
x = layers.Add()([down_layer_cropped, x])
|
||||
return x
|
||||
|
||||
class CropConcatBlock(layers.Layer):
|
||||
|
||||
class CropConcatBlock(layers.Layer):
|
||||
def call(self, down_layer, x, **kwargs):
|
||||
x1_shape = tf.shape(down_layer)
|
||||
x2_shape = tf.shape(x)
|
||||
@ -477,10 +539,12 @@ class CropConcatBlock(layers.Layer):
|
||||
height_diff = (x1_shape[1] - x2_shape[1]) // 2
|
||||
width_diff = (x1_shape[2] - x2_shape[2]) // 2
|
||||
|
||||
down_layer_cropped = down_layer[:,
|
||||
down_layer_cropped = down_layer[
|
||||
:,
|
||||
height_diff : (x2_shape[1] + height_diff),
|
||||
width_diff : (x2_shape[2] + width_diff),
|
||||
:]
|
||||
:,
|
||||
]
|
||||
|
||||
x = tf.concat([down_layer_cropped, x], axis=-1)
|
||||
return x
|
||||
|
||||
Loading…
Reference in New Issue
Block a user