fix formatting & linters

This commit is contained in:
festinuz 2025-05-08 12:56:32 +03:00
parent 114fce7c84
commit 9870240572
6 changed files with 1143 additions and 904 deletions

3
.gitignore vendored Normal file
View File

@ -0,0 +1,3 @@
experiments
outputs
__pycache__

View File

@ -4,14 +4,12 @@ import tensorflow as tf
import random
import os
import numpy as np
from scipy.fft import fft, ifft
import soundfile as sf
import math
import pandas as pd
import scipy as sp
import glob
from tqdm import tqdm
# generator function. It reads the csv file with pandas and loads the largest audio segments from each recording. If extend=False, it will only read the segments with length>length_seg, trim them and yield them with no further processing. Otherwise, if the segment length is inferior, it will extend the length using concatenative synthesis.
def __noise_sample_generator(info_file, fs, length_seq, split):
head = os.path.split(info_file)[0]
@ -26,14 +24,25 @@ def __noise_sample_generator(info_file,fs, length_seq, split):
for i in r:
segments = ast.literal_eval(load_data_split.loc[i, "segments"])
if split == "test":
loaded_data, Fs=sf.read(os.path.join(head,load_data_split["recording"].loc[i],load_data_split["largest_segment"].loc[i]))
loaded_data, Fs = sf.read(
os.path.join(
head,
load_data_split["recording"].loc[i],
load_data_split["largest_segment"].loc[i],
)
)
else:
num = np.random.randint(0, len(segments))
loaded_data, Fs=sf.read(os.path.join(head,load_data_split["recording"].loc[i],segments[num]))
assert(fs==Fs, "wrong sampling rate")
loaded_data, Fs = sf.read(
os.path.join(
head, load_data_split["recording"].loc[i], segments[num]
)
)
assert fs == Fs, "wrong sampling rate"
yield __extend_sample_by_repeating(loaded_data, fs, length_seq)
def __extend_sample_by_repeating(data, fs, seq_len):
rpm = 78
target_samp = seq_len
@ -57,12 +66,11 @@ def __extend_sample_by_repeating(data, fs,seq_len):
overhead = len(data) % period_sam
if(overhead>bls):
if overhead > bls:
complete_periods = (len(data) // period_sam) * period_sam
else:
complete_periods = (len(data) // period_sam - 1) * period_sam
a = np.multiply(data[0:bls], window_left)
b = np.multiply(data[complete_periods : complete_periods + bls], window_right)
c_1 = np.concatenate((data[0:complete_periods, :], b))
@ -71,10 +79,9 @@ def __extend_sample_by_repeating(data, fs,seq_len):
large_data[0 : complete_periods + bls, :] = c_1
pointer = complete_periods
not_finished = True
while (not_finished):
while not_finished:
if target_samp > pointer + complete_periods + bls:
large_data[pointer : pointer + complete_periods + bls] += c_2
pointer += complete_periods
@ -86,8 +93,9 @@ def __extend_sample_by_repeating(data, fs,seq_len):
return large_data
def generate_real_recordings_data(path_recordings, fs=44100, seg_len_s=15, stereo=False):
def generate_real_recordings_data(
path_recordings, fs=44100, seg_len_s=15, stereo=False
):
records_info = os.path.join(path_recordings, "audio_files.txt")
num_lines = sum(1 for line in open(records_info))
f = open(records_info, "r")
@ -112,8 +120,18 @@ def generate_real_recordings_data(path_recordings, fs=44100, seg_len_s=15, stere
return records
def generate_paired_data_test_formal(path_pianos, path_noises, noise_amount="low_snr",num_samples=-1, fs=44100, seg_len_s=5 , extend=True, stereo=False, prenoise=False):
def generate_paired_data_test_formal(
path_pianos,
path_noises,
noise_amount="low_snr",
num_samples=-1,
fs=44100,
seg_len_s=5,
extend=True,
stereo=False,
prenoise=False,
):
print(num_samples)
segments_clean = []
segments_noisy = []
@ -134,9 +152,13 @@ def generate_paired_data_test_formal(path_pianos, path_noises, noise_amount="low
train_samples = sorted(train_samples)
if prenoise:
noise_generator=__noise_sample_generator(noises_info,fs, seg_len+fs, extend, "test") #Adds 1s of silence add the begiing, longer noise
noise_generator = __noise_sample_generator(
noises_info, fs, seg_len + fs, extend, "test"
) # Adds 1s of silence add the begiing, longer noise
else:
noise_generator=__noise_sample_generator(noises_info,fs, seg_len, extend, "test") #this will take care of everything
noise_generator = __noise_sample_generator(
noises_info, fs, seg_len, extend, "test"
) # this will take care of everything
# load data clean files
for file in tqdm(train_samples): # add [1:5] for testing
data_clean, samplerate = sf.read(file)
@ -159,19 +181,22 @@ def generate_paired_data_test_formal(path_pianos, path_noises, noise_amount="low
num_frames = np.floor(len(data_clean) / hop_size - seg_len / hop_size + 1)
print(num_frames)
if num_frames == 0:
data_clean=np.concatenate((data_clean, np.zeros(shape=(int(2*seg_len-len(data_clean)),))), axis=0)
data_clean = np.concatenate(
(data_clean, np.zeros(shape=(int(2 * seg_len - len(data_clean)),))),
axis=0,
)
num_frames = 1
data_not_finished = True
pointer = 0
while(data_not_finished):
while data_not_finished:
if i >= num_samples:
break
segment = data_clean[pointer : pointer + seg_len]
pointer = pointer + hop_size
if pointer + seg_len > len(data_clean):
data_not_finished = False
segment=segment.astype('float32')
segment = segment.astype("float32")
# SNRs=np.random.uniform(2,20)
snr = SNRs[i]
@ -195,23 +220,26 @@ def generate_paired_data_test_formal(path_pianos, path_noises, noise_amount="low
# sum both signals according to snr
if prenoise:
segment=np.concatenate((np.zeros(shape=(fs,)),segment),axis=0) #add one second of silence
summed=segment+np.sqrt(power_clean/(snr*power_noise))*new_noise #not sure if this is correct, maybe revisit later!!
segment = np.concatenate(
(np.zeros(shape=(fs,)), segment), axis=0
) # add one second of silence
summed = (
segment + np.sqrt(power_clean / (snr * power_noise)) * new_noise
) # not sure if this is correct, maybe revisit later!!
summed=summed.astype('float32')
summed = summed.astype("float32")
# yield tf.convert_to_tensor(summed), tf.convert_to_tensor(segment)
summed = 10.0 ** (scale / 10.0) * summed
segment = 10.0 ** (scale / 10.0) * segment
segments_noisy.append(summed.astype('float32'))
segments_clean.append(segment.astype('float32'))
segments_noisy.append(summed.astype("float32"))
segments_clean.append(segment.astype("float32"))
i = i + 1
return segments_noisy, segments_clean
def generate_test_data(path_music, path_noises,num_samples=-1, fs=44100, seg_len_s=5):
def generate_test_data(path_music, path_noises, num_samples=-1, fs=44100, seg_len_s=5):
segments_clean = []
segments_noisy = []
seg_len = fs * seg_len_s
@ -222,9 +250,10 @@ def generate_test_data(path_music, path_noises,num_samples=-1, fs=44100, seg_len
train_samples = glob.glob(os.path.join(path, "*.wav"))
train_samples = sorted(train_samples)
noise_generator=__noise_sample_generator(noises_info,fs, seg_len, "test") #this will take care of everything
noise_generator = __noise_sample_generator(
noises_info, fs, seg_len, "test"
) # this will take care of everything
# load data clean files
jj=0
for file in tqdm(train_samples): # add [1:5] for testing
data_clean, samplerate = sf.read(file)
if samplerate != fs:
@ -243,13 +272,18 @@ def generate_test_data(path_music, path_noises,num_samples=-1, fs=44100, seg_len
num_frames = np.floor(len(data_clean) / hop_size - seg_len / hop_size + 1)
if num_frames == 0:
data_clean=np.concatenate((data_clean, np.zeros(shape=(int(2*seg_len-len(data_clean)),))), axis=0)
data_clean = np.concatenate(
(data_clean, np.zeros(shape=(int(2 * seg_len - len(data_clean)),))),
axis=0,
)
num_frames = 1
pointer = 0
segment = data_clean[pointer : pointer + (seg_len - 2 * fs)]
segment=segment.astype('float32')
segment=np.concatenate(( np.zeros(shape=(2*fs,)), segment), axis=0) #I hope its ok
segment = segment.astype("float32")
segment = np.concatenate(
(np.zeros(shape=(2 * fs,)), segment), axis=0
) # I hope its ok
# segments_clean.append(segment)
for snr in SNRs:
@ -269,17 +303,21 @@ def generate_test_data(path_music, path_noises,num_samples=-1, fs=44100, seg_len
snr = 10.0 ** (snr / 10.0)
# sum both signals according to snr
summed=segment+np.sqrt(power_clean/(snr*power_noise))*new_noise #not sure if this is correct, maybe revisit later!!
summed=summed.astype('float32')
summed = (
segment + np.sqrt(power_clean / (snr * power_noise)) * new_noise
) # not sure if this is correct, maybe revisit later!!
summed = summed.astype("float32")
# yield tf.convert_to_tensor(summed), tf.convert_to_tensor(segment)
segments_noisy.append(summed.astype('float32'))
segments_clean.append(segment.astype('float32'))
segments_noisy.append(summed.astype("float32"))
segments_clean.append(segment.astype("float32"))
return segments_noisy, segments_clean
def generate_val_data(path_music, path_noises,split,num_samples=-1, fs=44100, seg_len_s=5):
def generate_val_data(
path_music, path_noises, split, num_samples=-1, fs=44100, seg_len_s=5
):
val_samples = []
for path in path_music:
val_samples.extend(glob.glob(os.path.join(path, "*.wav")))
@ -304,7 +342,6 @@ def generate_val_data(path_music, path_noises,split,num_samples=-1, fs=44100, se
seg_len = fs * seg_len_s
segments_clean = []
for file in tqdm(data_clean_loaded):
# framify arguments: seg_len, hop_size
hop_size = int(seg_len) # no overlap
@ -313,7 +350,7 @@ def generate_val_data(path_music, path_noises,split,num_samples=-1, fs=44100, se
for i in range(0, int(num_frames)):
segment = file[pointer : pointer + seg_len]
pointer = pointer + hop_size
segment=segment.astype('float32')
segment = segment.astype("float32")
segments_clean.append(segment)
del data_clean_loaded
@ -323,8 +360,9 @@ def generate_val_data(path_music, path_noises,split,num_samples=-1, fs=44100, se
# noise_shapes=np.random.randint(0,len(noise_samples), len(segments_clean))
noises_info = os.path.join(path_noises, "info.csv")
noise_generator=__noise_sample_generator(noises_info,fs, seg_len, split) #this will take care of everything
noise_generator = __noise_sample_generator(
noises_info, fs, seg_len, split
) # this will take care of everything
# generate noisy segments
# load noise samples using pandas dataframe. Each split (train, val, test) should have its unique csv info file
@ -345,7 +383,6 @@ def generate_val_data(path_music, path_noises,split,num_samples=-1, fs=44100, se
data_clean = segments_clean[i]
# configure sizes
# estimate clean signal power
power_clean = np.var(data_clean)
# estimate noise power
@ -354,33 +391,37 @@ def generate_val_data(path_music, path_noises,split,num_samples=-1, fs=44100, se
snr = 10.0 ** (SNRs[i] / 10.0)
# sum both signals according to snr
summed=data_clean+np.sqrt(power_clean/(snr*power_noise))*new_noise #not sure if this is correct, maybe revisit later!!
summed = (
data_clean + np.sqrt(power_clean / (snr * power_noise)) * new_noise
) # not sure if this is correct, maybe revisit later!!
# the rest is normal
summed = 10.0 ** (scales[i] / 10.0) * summed
segments_clean[i] = 10.0 ** (scales[i] / 10.0) * segments_clean[i]
segments_noisy.append(summed.astype('float32'))
segments_noisy.append(summed.astype("float32"))
return segments_noisy, segments_clean
def generator_train(path_music, path_noises,split, fs=44100, seg_len_s=5, extend=True, stereo=False):
def generator_train(
path_music, path_noises, split, fs=44100, seg_len_s=5, extend=True, stereo=False
):
train_samples = []
for path in path_music:
train_samples.extend(glob.glob(os.path.join(path.decode("utf-8"), "*.wav")))
seg_len = fs * seg_len_s
noises_info = os.path.join(path_noises.decode("utf-8"), "info.csv")
noise_generator=__noise_sample_generator(noises_info,fs, seg_len, split.decode("utf-8")) #this will take care of everything
noise_generator = __noise_sample_generator(
noises_info, fs, seg_len, split.decode("utf-8")
) # this will take care of everything
# load data clean files
while True:
random.shuffle(train_samples)
for file in train_samples:
data, samplerate = sf.read(file)
assert(samplerate==fs, "wrong sampling rate")
assert samplerate == fs, "wrong sampling rate"
data_clean = data
# Stereo to mono
if len(data.shape) > 1:
@ -396,27 +437,33 @@ def generator_train(path_music, path_noises,split, fs=44100, seg_len_s=5, extend
num_frames = np.floor(len(data_clean) / seg_len)
if num_frames == 0:
data_clean=np.concatenate((data_clean, np.zeros(shape=(int(2*seg_len-len(data_clean)),))), axis=0)
data_clean = np.concatenate(
(data_clean, np.zeros(shape=(int(2 * seg_len - len(data_clean)),))),
axis=0,
)
num_frames = 1
pointer = 0
data_clean=np.roll(data_clean, np.random.randint(0,seg_len)) #if only one frame, roll it for augmentation
data_clean = np.roll(
data_clean, np.random.randint(0, seg_len)
) # if only one frame, roll it for augmentation
elif num_frames > 1:
pointer=np.random.randint(0,hop_size) #initial shifting, graeat for augmentation, better than overlap as we get different frames at each "while" iteration
pointer = np.random.randint(
0, hop_size
) # initial shifting, graeat for augmentation, better than overlap as we get different frames at each "while" iteration
else:
pointer = 0
data_not_finished = True
while(data_not_finished):
while data_not_finished:
segment = data_clean[pointer : pointer + seg_len]
pointer = pointer + hop_size
if pointer + seg_len > len(data_clean):
data_not_finished = False
segment=segment.astype('float32')
segment = segment.astype("float32")
SNRs = np.random.uniform(2, 20)
scale = np.random.uniform(-6, 4)
# load noise signal
data_noise = next(noise_generator)
data_noise = np.mean(data_noise, axis=1)
@ -427,9 +474,13 @@ def generator_train(path_music, path_noises,split, fs=44100, seg_len_s=5, extend
# configure sizes
if stereo:
# estimate clean signal power
power_clean=0.5*np.var(segment[:,0])+0.5*np.var(segment[:,1])
power_clean = 0.5 * np.var(segment[:, 0]) + 0.5 * np.var(
segment[:, 1]
)
# estimate noise power
power_noise=0.5*np.var(new_noise[:,0])+0.5*np.var(new_noise[:,1])
power_noise = 0.5 * np.var(new_noise[:, 0]) + 0.5 * np.var(
new_noise[:, 1]
)
else:
# estimate clean signal power
power_clean = np.var(segment)
@ -438,39 +489,63 @@ def generator_train(path_music, path_noises,split, fs=44100, seg_len_s=5, extend
snr = 10.0 ** (SNRs / 10.0)
# sum both signals according to snr
summed=segment+np.sqrt(power_clean/(snr*power_noise))*new_noise #not sure if this is correct, maybe revisit later!!
summed = (
segment + np.sqrt(power_clean / (snr * power_noise)) * new_noise
) # not sure if this is correct, maybe revisit later!!
summed = 10.0 ** (scale / 10.0) * summed
segment = 10.0 ** (scale / 10.0) * segment
summed=summed.astype('float32')
summed = summed.astype("float32")
yield tf.convert_to_tensor(summed), tf.convert_to_tensor(segment)
def load_data(buffer_size, path_music_train, path_music_val, path_noises, fs=44100, seg_len_s=5, extend=True, stereo=False) :
def load_data(
buffer_size,
path_music_train,
path_music_val,
path_noises,
fs=44100,
seg_len_s=5,
extend=True,
stereo=False,
):
print("Generating train dataset")
trainshape = int(fs * seg_len_s)
dataset_train = tf.data.Dataset.from_generator(generator_train,args=(path_music_train, path_noises,"train", fs, seg_len_s, extend, stereo), output_shapes=(tf.TensorShape((trainshape,)),tf.TensorShape((trainshape,))), output_types=(tf.float32, tf.float32) )
dataset_train = tf.data.Dataset.from_generator(
generator_train,
args=(path_music_train, path_noises, "train", fs, seg_len_s, extend, stereo),
output_shapes=(tf.TensorShape((trainshape,)), tf.TensorShape((trainshape,))),
output_types=(tf.float32, tf.float32),
)
print("Generating validation dataset")
segments_noisy, segments_clean=generate_val_data(path_music_val, path_noises,"validation",fs=fs, seg_len_s=seg_len_s)
segments_noisy, segments_clean = generate_val_data(
path_music_val, path_noises, "validation", fs=fs, seg_len_s=seg_len_s
)
dataset_val = tf.data.Dataset.from_tensor_slices((segments_noisy, segments_clean))
return dataset_train.shuffle(buffer_size), dataset_val
def load_data_test(buffer_size, path_pianos_test, path_noises, **kwargs):
print("Generating test dataset")
segments_noisy, segments_clean=generate_test_data(path_pianos_test, path_noises, extend=True, **kwargs)
segments_noisy, segments_clean = generate_test_data(
path_pianos_test, path_noises, extend=True, **kwargs
)
dataset_test = tf.data.Dataset.from_tensor_slices((segments_noisy, segments_clean))
# dataset_test=tf.data.Dataset.from_tensor_slices((segments_noisy[1:3], segments_clean[1:3]))
# train_dataset = train.cache().shuffle(buffer_size).take(info.splits["train"].num_examples)
return dataset_test
def load_data_formal(path_pianos_test, path_noises, **kwargs):
print("Generating test dataset")
segments_noisy, segments_clean=generate_paired_data_test_formal(path_pianos_test, path_noises, extend=True, **kwargs)
segments_noisy, segments_clean = generate_paired_data_test_formal(
path_pianos_test, path_noises, extend=True, **kwargs
)
print("segments::")
print(len(segments_noisy))
dataset_test = tf.data.Dataset.from_tensor_slices((segments_noisy, segments_clean))
@ -478,6 +553,7 @@ def load_data_formal( path_pianos_test, path_noises, **kwargs) :
# train_dataset = train.cache().shuffle(buffer_size).take(info.splits["train"].num_examples)
return dataset_test
def load_real_test_recordings(buffer_size, path_recordings, **kwargs):
print("Generating real test dataset")

View File

@ -4,6 +4,7 @@ import logging
logger = logging.getLogger(__name__)
def run(args):
import unet
import tensorflow as tf
@ -16,33 +17,45 @@ def run(args):
unet_model = unet.build_model_denoise(unet_args=args.unet)
ckpt=os.path.join(os.path.dirname(os.path.abspath(__file__)),path_experiment, 'checkpoint')
ckpt = os.path.join(
os.path.dirname(os.path.abspath(__file__)), path_experiment, "checkpoint"
)
unet_model.load_weights(ckpt)
def do_stft(noisy):
window_fn = tf.signal.hamming_window
win_size = args.stft.win_size
hop_size = args.stft.hop_size
stft_signal_noisy=tf.signal.stft(noisy,frame_length=win_size, window_fn=window_fn, frame_step=hop_size, pad_end=True)
stft_noisy_stacked=tf.stack( values=[tf.math.real(stft_signal_noisy), tf.math.imag(stft_signal_noisy)], axis=-1)
stft_signal_noisy = tf.signal.stft(
noisy,
frame_length=win_size,
window_fn=window_fn,
frame_step=hop_size,
pad_end=True,
)
stft_noisy_stacked = tf.stack(
values=[tf.math.real(stft_signal_noisy), tf.math.imag(stft_signal_noisy)],
axis=-1,
)
return stft_noisy_stacked
def do_istft(data):
window_fn = tf.signal.hamming_window
win_size = args.stft.win_size
hop_size = args.stft.hop_size
inv_window_fn=tf.signal.inverse_stft_window_fn(hop_size, forward_window_fn=window_fn)
inv_window_fn = tf.signal.inverse_stft_window_fn(
hop_size, forward_window_fn=window_fn
)
pred_cpx = data[..., 0] + 1j * data[..., 1]
pred_time=tf.signal.inverse_stft(pred_cpx, win_size, hop_size, window_fn=inv_window_fn)
pred_time = tf.signal.inverse_stft(
pred_cpx, win_size, hop_size, window_fn=inv_window_fn
)
return pred_time
audio = str(args.inference.audio)
@ -57,8 +70,6 @@ def run(args):
data = scipy.signal.resample(data, int((44100 / samplerate) * len(data)) + 1)
segment_size = 44100 * 5 # 20s segments
length_data = len(data)
@ -66,7 +77,6 @@ def run(args):
window = np.hanning(2 * overlapsize)
window_right = window[overlapsize::]
window_left = window[0:overlapsize]
audio_finished=False
pointer = 0
denoised_data = np.zeros(shape=(len(data),))
residual_noise = np.zeros(shape=(len(data),))
@ -87,21 +97,72 @@ def run(args):
residual_time = np.array(residual_time)
if pointer == 0:
pred_time=np.concatenate((pred_time[0:int(segment_size-overlapsize)], np.multiply(pred_time[int(segment_size-overlapsize):segment_size],window_right)), axis=0)
residual_time=np.concatenate((residual_time[0:int(segment_size-overlapsize)], np.multiply(residual_time[int(segment_size-overlapsize):segment_size],window_right)), axis=0)
pred_time = np.concatenate(
(
pred_time[0 : int(segment_size - overlapsize)],
np.multiply(
pred_time[int(segment_size - overlapsize) : segment_size],
window_right,
),
),
axis=0,
)
residual_time = np.concatenate(
(
residual_time[0 : int(segment_size - overlapsize)],
np.multiply(
residual_time[
int(segment_size - overlapsize) : segment_size
],
window_right,
),
),
axis=0,
)
else:
pred_time=np.concatenate((np.multiply(pred_time[0:int(overlapsize)], window_left), pred_time[int(overlapsize):int(segment_size-overlapsize)], np.multiply(pred_time[int(segment_size-overlapsize):int(segment_size)],window_right)), axis=0)
residual_time=np.concatenate((np.multiply(residual_time[0:int(overlapsize)], window_left), residual_time[int(overlapsize):int(segment_size-overlapsize)], np.multiply(residual_time[int(segment_size-overlapsize):int(segment_size)],window_right)), axis=0)
pred_time = np.concatenate(
(
np.multiply(pred_time[0 : int(overlapsize)], window_left),
pred_time[int(overlapsize) : int(segment_size - overlapsize)],
np.multiply(
pred_time[
int(segment_size - overlapsize) : int(segment_size)
],
window_right,
),
),
axis=0,
)
residual_time = np.concatenate(
(
np.multiply(residual_time[0 : int(overlapsize)], window_left),
residual_time[
int(overlapsize) : int(segment_size - overlapsize)
],
np.multiply(
residual_time[
int(segment_size - overlapsize) : int(segment_size)
],
window_right,
),
),
axis=0,
)
denoised_data[pointer:pointer+segment_size]=denoised_data[pointer:pointer+segment_size]+pred_time
residual_noise[pointer:pointer+segment_size]=residual_noise[pointer:pointer+segment_size]+residual_time
denoised_data[pointer : pointer + segment_size] = (
denoised_data[pointer : pointer + segment_size] + pred_time
)
residual_noise[pointer : pointer + segment_size] = (
residual_noise[pointer : pointer + segment_size] + residual_time
)
pointer = pointer + segment_size - overlapsize
else:
segment = data[pointer::]
lensegment = len(segment)
segment=np.concatenate((segment, np.zeros(shape=(int(segment_size-len(segment)),))), axis=0)
audio_finished=True
segment = np.concatenate(
(segment, np.zeros(shape=(int(segment_size - len(segment)),))), axis=0
)
# dostft
segment_TF = do_stft(segment)
@ -121,11 +182,27 @@ def run(args):
pred_time = pred_time
residual_time = residual_time
else:
pred_time=np.concatenate((np.multiply(pred_time[0:int(overlapsize)], window_left), pred_time[int(overlapsize):int(segment_size)]),axis=0)
residual_time=np.concatenate((np.multiply(residual_time[0:int(overlapsize)], window_left), residual_time[int(overlapsize):int(segment_size)]),axis=0)
pred_time = np.concatenate(
(
np.multiply(pred_time[0 : int(overlapsize)], window_left),
pred_time[int(overlapsize) : int(segment_size)],
),
axis=0,
)
residual_time = np.concatenate(
(
np.multiply(residual_time[0 : int(overlapsize)], window_left),
residual_time[int(overlapsize) : int(segment_size)],
),
axis=0,
)
denoised_data[pointer::]=denoised_data[pointer::]+pred_time[0:lensegment]
residual_noise[pointer::]=residual_noise[pointer::]+residual_time[0:lensegment]
denoised_data[pointer::] = (
denoised_data[pointer::] + pred_time[0:lensegment]
)
residual_noise[pointer::] = (
residual_noise[pointer::] + residual_time[0:lensegment]
)
basename = os.path.splitext(audio)[0]
wav_noisy_name = basename + "_noisy_input" + ".wav"
@ -156,10 +233,3 @@ def main(args):
if __name__ == "__main__":
main()

109
train.py
View File

@ -4,15 +4,14 @@ import logging
logger = logging.getLogger(__name__)
def run(args):
import unet
import tensorflow as tf
import dataset_loader
from tensorflow.keras.optimizers import Adam
import soundfile as sf
import datetime
from tqdm import tqdm
import numpy as np
path_experiment = str(args.path_experiment)
@ -20,55 +19,70 @@ def run(args):
os.makedirs(path_experiment)
path_music_train = args.dset.path_music_train
path_music_test=args.dset.path_music_test
path_music_validation = args.dset.path_music_validation
path_noise = args.dset.path_noise
path_recordings=args.dset.path_recordings
fs = args.fs
overlap=args.overlap
seg_len_s_train = args.seg_len_s_train
batch_size = args.batch_size
epochs = args.epochs
num_real_test_segments=args.num_real_test_segments
buffer_size = args.buffer_size # for shuffle
tensorboard_logs = args.tensorboard_logs
def do_stft(noisy, clean=None):
window_fn = tf.signal.hamming_window
win_size = args.stft.win_size
hop_size = args.stft.hop_size
stft_signal_noisy = tf.signal.stft(
noisy, frame_length=win_size, window_fn=window_fn, frame_step=hop_size
)
stft_noisy_stacked = tf.stack(
values=[tf.math.real(stft_signal_noisy), tf.math.imag(stft_signal_noisy)],
axis=-1,
)
stft_signal_noisy=tf.signal.stft(noisy,frame_length=win_size, window_fn=window_fn, frame_step=hop_size)
stft_noisy_stacked=tf.stack( values=[tf.math.real(stft_signal_noisy), tf.math.imag(stft_signal_noisy)], axis=-1)
if clean!=None:
stft_signal_clean=tf.signal.stft(clean,frame_length=win_size, window_fn=window_fn, frame_step=hop_size)
stft_clean_stacked=tf.stack( values=[tf.math.real(stft_signal_clean), tf.math.imag(stft_signal_clean)], axis=-1)
if clean is not None:
stft_signal_clean = tf.signal.stft(
clean, frame_length=win_size, window_fn=window_fn, frame_step=hop_size
)
stft_clean_stacked = tf.stack(
values=[
tf.math.real(stft_signal_clean),
tf.math.imag(stft_signal_clean),
],
axis=-1,
)
return stft_noisy_stacked, stft_clean_stacked
else:
return stft_noisy_stacked
# Loading data. The train dataset object is a generator. The validation dataset is loaded in memory.
dataset_train, dataset_val=dataset_loader.load_data(buffer_size, path_music_train, path_music_validation, path_noise, fs=fs, seg_len_s=seg_len_s_train)
dataset_train, dataset_val = dataset_loader.load_data(
buffer_size,
path_music_train,
path_music_validation,
path_noise,
fs=fs,
seg_len_s=seg_len_s_train,
)
dataset_train=dataset_train.map(do_stft, num_parallel_calls=args.num_workers, deterministic=None)
dataset_val=dataset_val.map(do_stft, num_parallel_calls=args.num_workers, deterministic=None)
dataset_train = dataset_train.map(
do_stft, num_parallel_calls=args.num_workers, deterministic=None
)
dataset_val = dataset_val.map(
do_stft, num_parallel_calls=args.num_workers, deterministic=None
)
strategy = tf.distribute.MirroredStrategy()
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
print("Number of devices: {}".format(strategy.num_replicas_in_sync))
with strategy.scope():
# build the model
@ -80,12 +94,17 @@ def run(args):
loss = tf.keras.losses.MeanAbsoluteError()
if args.use_tensorboard:
log_dir = os.path.join(tensorboard_logs, os.path.basename(path_experiment)+"_"+datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
log_dir = os.path.join(
tensorboard_logs,
os.path.basename(path_experiment)
+ "_"
+ datetime.datetime.now().strftime("%Y%m%d-%H%M%S"),
)
train_summary_writer = tf.summary.create_file_writer(log_dir + "/train")
val_summary_writer = tf.summary.create_file_writer(log_dir + "/validation")
# path where the checkpoints will be saved
checkpoint_filepath=os.path.join(path_experiment, 'checkpoint')
checkpoint_filepath = os.path.join(path_experiment, "checkpoint")
dataset_train = dataset_train.batch(batch_size)
dataset_val = dataset_val.batch(batch_size)
@ -106,27 +125,43 @@ def run(args):
for epoch in range(epochs):
total_loss = 0
step_loss = 0
for step in tqdm(range(args.steps_per_epoch), desc="Training epoch "+str(epoch)):
for step in tqdm(
range(args.steps_per_epoch), desc="Training epoch " + str(epoch)
):
step_loss = trainer.distributed_training_step(iterator.get_next())
total_loss += step_loss
with train_summary_writer.as_default():
tf.summary.scalar('batch_loss', step_loss, step=step)
tf.summary.scalar('batch_mean_absolute_error', trainer.train_mae.result(), step=step)
tf.summary.scalar("batch_loss", step_loss, step=step)
tf.summary.scalar(
"batch_mean_absolute_error", trainer.train_mae.result(), step=step
)
train_loss = total_loss / args.steps_per_epoch
for x in tqdm(dataset_val, desc="Validating epoch " + str(epoch)):
trainer.distributed_test_step(x)
template = ("Epoch {}, Loss: {}, train_MAE: {}, val_Loss: {}, val_MAE: {}")
print (template.format(epoch+1, train_loss, trainer.train_mae.result(), trainer.val_loss.result(), trainer.val_mae.result()))
template = "Epoch {}, Loss: {}, train_MAE: {}, val_Loss: {}, val_MAE: {}"
print(
template.format(
epoch + 1,
train_loss,
trainer.train_mae.result(),
trainer.val_loss.result(),
trainer.val_mae.result(),
)
)
with train_summary_writer.as_default():
tf.summary.scalar('epoch_loss', train_loss, step=epoch)
tf.summary.scalar('epoch_mean_absolute_error', trainer.train_mae.result(), step=epoch)
tf.summary.scalar("epoch_loss", train_loss, step=epoch)
tf.summary.scalar(
"epoch_mean_absolute_error", trainer.train_mae.result(), step=epoch
)
with val_summary_writer.as_default():
tf.summary.scalar('epoch_loss', trainer.val_loss.result(), step=epoch)
tf.summary.scalar('epoch_mean_absolute_error', trainer.val_mae.result(), step=epoch)
tf.summary.scalar("epoch_loss", trainer.val_loss.result(), step=epoch)
tf.summary.scalar(
"epoch_mean_absolute_error", trainer.val_mae.result(), step=epoch
)
trainer.train_mae.reset_states()
trainer.val_loss.reset_states()
@ -137,10 +172,11 @@ def run(args):
current_lr *= 1e-1
trainer.optimizer.lr = current_lr
try:
unet_model.save_weights(checkpoint_filpath)
except:
unet_model.save_weights(checkpoint_filepath)
except Exception:
pass
def _main(args):
global __file__
@ -161,10 +197,3 @@ def main(args):
if __name__ == "__main__":
main()

View File

@ -1,12 +1,7 @@
import os
import numpy as np
import tensorflow as tf
import soundfile as sf
from tqdm import tqdm
import pandas as pd
class Trainer():
class Trainer:
def __init__(self, model, optimizer, loss, strategy, path_experiment, args):
self.model = model
print(self.model.summary())
@ -23,17 +18,20 @@ class Trainer():
self.train_mae_s1 = tf.keras.metrics.MeanAbsoluteError(name="train_mae_s1")
self.train_mae = tf.keras.metrics.MeanAbsoluteError(name="train_mae_s2")
self.val_mae = tf.keras.metrics.MeanAbsoluteError(name="validation_mae")
self.val_loss = tf.keras.metrics.Mean(name='test_loss')
self.val_loss = tf.keras.metrics.Mean(name="test_loss")
def train_step(self, inputs):
noisy, clean = inputs
with tf.GradientTape() as tape:
logits_2, logits_1 = self.model(
noisy, training=True
) # Logits for this minibatch
logits_2,logits_1 = self.model(noisy, training=True) # Logits for this minibatch
loss_value = tf.reduce_mean(self.loss_object(clean, logits_2) + tf.reduce_mean(self.loss_object(clean, logits_1)))
loss_value = tf.reduce_mean(
self.loss_object(clean, logits_2)
+ tf.reduce_mean(self.loss_object(clean, logits_1))
)
grads = tape.gradient(loss_value, self.model.trainable_weights)
self.optimizer.apply_gradients(zip(grads, self.model.trainable_weights))
@ -42,11 +40,12 @@ class Trainer():
return loss_value
def test_step(self, inputs):
noisy, clean = inputs
predictions_s2, predictions_s1 = self.model(noisy, training=False)
t_loss = self.loss_object(clean, predictions_s2)+self.loss_object(clean, predictions_s1)
t_loss = self.loss_object(clean, predictions_s2) + self.loss_object(
clean, predictions_s1
)
self.val_mae.update_state(clean, predictions_s2)
self.val_loss.update_state(t_loss)
@ -54,13 +53,11 @@ class Trainer():
@tf.function()
def distributed_training_step(self, inputs):
per_replica_losses = self.strategy.run(self.train_step, args=(inputs,))
reduced_losses=self.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None)
reduced_losses = self.strategy.reduce(
tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None
)
return reduced_losses
@tf.function
def distributed_test_step(self, inputs):
return self.strategy.run(self.test_step, args=(inputs,))

266
unet.py
View File

@ -1,11 +1,11 @@
import tensorflow as tf
from tensorflow.keras import Model, Input
from tensorflow.keras import Input
from tensorflow.keras import layers
from tensorflow.keras.initializers import TruncatedNormal
import math as m
def build_model_denoise(unet_args=None):
def build_model_denoise(unet_args=None):
inputs = Input(shape=(None, None, 2))
outputs_stage_2, outputs_stage_1 = MultiStage_denoise(unet_args=unet_args)(inputs)
@ -14,17 +14,20 @@ def build_model_denoise(unet_args=None):
model = tf.keras.Model(inputs=inputs, outputs=[outputs_stage_2, outputs_stage_1])
return model
class DenseBlock(layers.Layer):
'''
"""
[B, T, F, N] => [B, T, F, N]
DenseNet Block consisting of "num_layers" densely connected convolutional layers
'''
"""
def __init__(self, num_layers, N, ksize, activation):
'''
"""
num_layers: number of densely connected conv. layers
N: Number of filters (same in each layer)
ksize: Kernel size (same in each layer)
'''
"""
super(DenseBlock, self).__init__()
self.activation = activation
@ -33,52 +36,58 @@ class DenseBlock(layers.Layer):
self.num_layers = num_layers
for i in range(num_layers):
self.H.append(layers.Conv2D(filters=N,
self.H.append(
layers.Conv2D(
filters=N,
kernel_size=ksize,
kernel_initializer=TruncatedNormal(),
strides=1,
padding='VALID',
activation=self.activation))
padding="VALID",
activation=self.activation,
)
)
def call(self, x):
x_=tf.pad(x, self.paddings_1, mode='SYMMETRIC')
x_ = tf.pad(x, self.paddings_1, mode="SYMMETRIC")
x_ = self.H[0](x_)
if self.num_layers > 1:
for h in self.H[1:]:
x = tf.concat([x_, x], axis=-1)
x_=tf.pad(x, self.paddings_1, mode='SYMMETRIC')
x_ = tf.pad(x, self.paddings_1, mode="SYMMETRIC")
x_ = h(x_)
return x_
class FinalBlock(layers.Layer):
'''
"""
[B, T, F, N] => [B, T, F, 2]
Final block. Basically, a 3x3 conv. layer to map the output features to the output complex spectrogram.
'''
"""
def __init__(self):
super(FinalBlock, self).__init__()
ksize = (3, 3)
self.paddings_2 = get_paddings(ksize)
self.conv2=layers.Conv2D(filters=2,
self.conv2 = layers.Conv2D(
filters=2,
kernel_size=ksize,
kernel_initializer=TruncatedNormal(),
strides=1,
padding='VALID',
activation=None)
padding="VALID",
activation=None,
)
def call(self, inputs):
x=tf.pad(inputs, self.paddings_2, mode='SYMMETRIC')
x = tf.pad(inputs, self.paddings_2, mode="SYMMETRIC")
pred = self.conv2(x)
return pred
class SAM(layers.Layer):
'''
"""
[B, T, F, N] => [B, T, F, N] , [B, T, F, N]
Supervised Attention Module:
The purpose of SAM is to make the network only propagate the most relevant features to the second stage, discarding the less useful ones.
@ -86,48 +95,55 @@ class SAM(layers.Layer):
The first stage output is then calculated adding the original input spectrogram to the residual noise.
The attention-guided features are computed using the attention masks M, which are directly calculated from the first stage output with a 1x1 convolution and a sigmoid function.
'''
"""
def __init__(self, n_feat):
super(SAM, self).__init__()
ksize = (3, 3)
self.paddings_1 = get_paddings(ksize)
self.conv1 = layers.Conv2D(filters=n_feat,
self.conv1 = layers.Conv2D(
filters=n_feat,
kernel_size=ksize,
kernel_initializer=TruncatedNormal(),
strides=1,
padding='VALID',
activation=None)
padding="VALID",
activation=None,
)
ksize = (3, 3)
self.paddings_2 = get_paddings(ksize)
self.conv2=layers.Conv2D(filters=2,
self.conv2 = layers.Conv2D(
filters=2,
kernel_size=ksize,
kernel_initializer=TruncatedNormal(),
strides=1,
padding='VALID',
activation=None)
padding="VALID",
activation=None,
)
ksize = (3, 3)
self.paddings_3 = get_paddings(ksize)
self.conv3 = layers.Conv2D(filters=n_feat,
self.conv3 = layers.Conv2D(
filters=n_feat,
kernel_size=ksize,
kernel_initializer=TruncatedNormal(),
strides=1,
padding='VALID',
activation=None)
padding="VALID",
activation=None,
)
self.cropadd = CropAddBlock()
def call(self, inputs, input_spectrogram):
x1=tf.pad(inputs, self.paddings_1, mode='SYMMETRIC')
x1 = tf.pad(inputs, self.paddings_1, mode="SYMMETRIC")
x1 = self.conv1(x1)
x=tf.pad(inputs, self.paddings_2, mode='SYMMETRIC')
x = tf.pad(inputs, self.paddings_2, mode="SYMMETRIC")
x = self.conv2(x)
# residual prediction
pred = layers.Add()([x, input_spectrogram]) # features to next stage
x3=tf.pad(pred, self.paddings_3, mode='SYMMETRIC')
x3 = tf.pad(pred, self.paddings_3, mode="SYMMETRIC")
M = self.conv3(x3)
M = tf.keras.activations.sigmoid(M)
@ -138,17 +154,18 @@ class SAM(layers.Layer):
class AddFreqEncoding(layers.Layer):
'''
"""
[B, T, F, 2] => [B, T, F, 12]
Generates frequency positional embeddings and concatenates them as 10 extra channels
This function is optimized for F=1025
'''
"""
def __init__(self, f_dim):
super(AddFreqEncoding, self).__init__()
pi = tf.constant(m.pi)
pi=tf.cast(pi,'float32')
pi = tf.cast(pi, "float32")
self.f_dim = f_dim # f_dim is fixed
n=tf.cast(tf.range(f_dim)/(f_dim-1),'float32')
n = tf.cast(tf.range(f_dim) / (f_dim - 1), "float32")
coss = tf.math.cos(pi * n)
f_channel = tf.expand_dims(coss, -1) # (1025,1)
self.fembeddings = f_channel
@ -156,28 +173,38 @@ class AddFreqEncoding(layers.Layer):
for k in range(1, 10):
coss = tf.math.cos(2**k * pi * n)
f_channel = tf.expand_dims(coss, -1) # (1025,1)
self.fembeddings=tf.concat([self.fembeddings,f_channel],axis=-1) #(1025,10)
self.fembeddings = tf.concat(
[self.fembeddings, f_channel], axis=-1
) # (1025,10)
def call(self, input_tensor):
batch_size_tensor = tf.shape(input_tensor)[0] # get batch size
time_dim = tf.shape(input_tensor)[1] # get time dimension
fembeddings_2 = tf.broadcast_to(self.fembeddings, [batch_size_tensor, time_dim, self.f_dim, 10])
fembeddings_2 = tf.broadcast_to(
self.fembeddings, [batch_size_tensor, time_dim, self.f_dim, 10]
)
return tf.concat([input_tensor, fembeddings_2], axis=-1) # (batch,427,1025,12)
def get_paddings(K):
return tf.constant([[0,0],[K[0]//2, K[0]//2 -(1- K[0]%2) ], [ K[1]//2, K[1]//2 -(1- K[1]%2) ],[0,0]])
return tf.constant(
[
[0, 0],
[K[0] // 2, K[0] // 2 - (1 - K[0] % 2)],
[K[1] // 2, K[1] // 2 - (1 - K[1] % 2)],
[0, 0],
]
)
class Decoder(layers.Layer):
'''
"""
[B, T, F, N] , skip connections => [B, T, F, N]
Decoder side of the U-Net subnetwork.
'''
"""
def __init__(self, Ns, Ss, unet_args):
super(Decoder, self).__init__()
@ -186,21 +213,30 @@ class Decoder(layers.Layer):
self.activation = unet_args.activation
self.depth = unet_args.depth
ksize = (3, 3)
self.paddings_3 = get_paddings(ksize)
self.conv2d_3=layers.Conv2D(filters=self.Ns[self.depth],
self.conv2d_3 = layers.Conv2D(
filters=self.Ns[self.depth],
kernel_size=ksize,
kernel_initializer=TruncatedNormal(),
strides=1,
padding='VALID',
activation=self.activation)
padding="VALID",
activation=self.activation,
)
self.cropadd = CropAddBlock()
self.dblocks = []
for i in range(self.depth):
self.dblocks.append(D_Block(layer_idx=i,N=self.Ns[i], S=self.Ss[i], activation=self.activation,num_tfc=unet_args.num_tfc))
self.dblocks.append(
D_Block(
layer_idx=i,
N=self.Ns[i],
S=self.Ss[i],
activation=self.activation,
num_tfc=unet_args.num_tfc,
)
)
def call(self, inputs, contracting_layers):
x = inputs
@ -208,12 +244,13 @@ class Decoder(layers.Layer):
x = self.dblocks[i - 1](x, contracting_layers[i - 1])
return x
class Encoder(tf.keras.Model):
'''
class Encoder(tf.keras.Model):
"""
[B, T, F, N] => skip connections , [B, T, F, N_4]
Encoder side of the U-Net subnetwork.
'''
"""
def __init__(self, Ns, Ss, unet_args):
super(Encoder, self).__init__()
self.Ns = Ns
@ -225,14 +262,22 @@ class Encoder(tf.keras.Model):
self.eblocks = []
for i in range(self.depth):
self.eblocks.append(E_Block(layer_idx=i,N0=self.Ns[i],N=self.Ns[i+1],S=self.Ss[i], activation=self.activation , num_tfc=unet_args.num_tfc))
self.eblocks.append(
E_Block(
layer_idx=i,
N0=self.Ns[i],
N=self.Ns[i + 1],
S=self.Ss[i],
activation=self.activation,
num_tfc=unet_args.num_tfc,
)
)
self.i_block = I_Block(self.Ns[self.depth], self.activation, unet_args.num_tfc)
def call(self, inputs):
x = inputs
for i in range(self.depth):
x, x_contract = self.eblocks[i](x)
self.contracting_layers[i] = x_contract # if remove 0, correct this
@ -240,8 +285,8 @@ class Encoder(tf.keras.Model):
return x, self.contracting_layers
class MultiStage_denoise(tf.keras.Model):
class MultiStage_denoise(tf.keras.Model):
def __init__(self, unet_args=None):
super(MultiStage_denoise, self).__init__()
@ -259,13 +304,14 @@ class MultiStage_denoise(tf.keras.Model):
# initial feature extractor
ksize = (7, 7)
self.paddings_1 = get_paddings(ksize)
self.conv2d_1 = layers.Conv2D(filters=self.Ns[0],
self.conv2d_1 = layers.Conv2D(
filters=self.Ns[0],
kernel_size=ksize,
kernel_initializer=TruncatedNormal(),
strides=1,
padding='VALID',
activation=self.activation)
padding="VALID",
activation=self.activation,
)
self.encoder_s1 = Encoder(self.Ns, self.Ss, unet_args)
self.decoder_s1 = Decoder(self.Ns, self.Ss, unet_args)
@ -281,39 +327,41 @@ class MultiStage_denoise(tf.keras.Model):
# initial feature extractor
ksize = (7, 7)
self.paddings_2 = get_paddings(ksize)
self.conv2d_2 = layers.Conv2D(filters=self.Ns[0],
self.conv2d_2 = layers.Conv2D(
filters=self.Ns[0],
kernel_size=ksize,
kernel_initializer=TruncatedNormal(),
strides=1,
padding='VALID',
activation=self.activation)
padding="VALID",
activation=self.activation,
)
self.encoder_s2 = Encoder(self.Ns, self.Ss, unet_args)
self.decoder_s2 = Decoder(self.Ns, self.Ss, unet_args)
@tf.function()
def call(self, inputs):
if self.use_fencoding:
x_w_freq = self.freq_encoding(inputs) # None, None, 1025, 12
else:
x_w_freq = inputs
# intitial feature extractor
x=tf.pad(x_w_freq, self.paddings_1, mode='SYMMETRIC')
x = tf.pad(x_w_freq, self.paddings_1, mode="SYMMETRIC")
x = self.conv2d_1(x) # None, None, 1025, 32
x, contracting_layers_s1 = self.encoder_s1(x)
# decoder
feats_s1 =self.decoder_s1(x, contracting_layers_s1) #None, None, 1025, 32 features
feats_s1 = self.decoder_s1(
x, contracting_layers_s1
) # None, None, 1025, 32 features
if self.num_stages > 1:
# SAM module
Fout, pred_stage_1 = self.sam_1(feats_s1, inputs)
# intitial feature extractor
x=tf.pad(x_w_freq, self.paddings_2, mode='SYMMETRIC')
x = tf.pad(x_w_freq, self.paddings_2, mode="SYMMETRIC")
x = self.conv2d_2(x)
if self.use_sam:
@ -323,7 +371,9 @@ class MultiStage_denoise(tf.keras.Model):
x, contracting_layers_s2 = self.encoder_s2(x)
feats_s2=self.decoder_s2(x, contracting_layers_s2) #None, None, 1025, 32 features
feats_s2 = self.decoder_s2(
x, contracting_layers_s2
) # None, None, 1025, 32 features
# consider implementing a third stage?
@ -333,23 +383,27 @@ class MultiStage_denoise(tf.keras.Model):
pred_stage_1 = self.finalblock(feats_s1)
return pred_stage_1
class I_Block(layers.Layer):
'''
"""
[B, T, F, N] => [B, T, F, N]
Intermediate block:
Basically, a densenet block with a residual connection
'''
"""
def __init__(self, N, activation, num_tfc, **kwargs):
super(I_Block, self).__init__(**kwargs)
ksize = (3, 3)
self.tfc = DenseBlock(num_tfc, N, ksize, activation)
self.conv2d_res= layers.Conv2D(filters=N,
self.conv2d_res = layers.Conv2D(
filters=N,
kernel_size=(1, 1),
kernel_initializer=TruncatedNormal(),
strides=1,
padding='VALID')
padding="VALID",
)
def call(self, inputs):
x = self.tfc(inputs)
@ -359,7 +413,6 @@ class I_Block(layers.Layer):
class E_Block(layers.Layer):
def __init__(self, layer_idx, N0, N, S, activation, num_tfc, **kwargs):
super(E_Block, self).__init__(**kwargs)
self.layer_idx = layer_idx
@ -371,31 +424,33 @@ class E_Block(layers.Layer):
ksize = (S[0] + 2, S[1] + 2)
self.paddings_2 = get_paddings(ksize)
self.conv2d_2 = layers.Conv2D(filters=N,
self.conv2d_2 = layers.Conv2D(
filters=N,
kernel_size=(S[0] + 2, S[1] + 2),
kernel_initializer=TruncatedNormal(),
strides=S,
padding='VALID',
activation=self.activation)
padding="VALID",
activation=self.activation,
)
def call(self, inputs, training=None, **kwargs):
x = self.i_block(inputs)
x_down=tf.pad(x, self.paddings_2, mode='SYMMETRIC')
x_down = tf.pad(x, self.paddings_2, mode="SYMMETRIC")
x_down = self.conv2d_2(x_down)
return x_down, x
def get_config(self):
return dict(layer_idx=self.layer_idx,
return dict(
layer_idx=self.layer_idx,
N=self.N,
S=self.S,
**super(E_Block, self).get_config()
**super(E_Block, self).get_config(),
)
class D_Block(layers.Layer):
class D_Block(layers.Layer):
def __init__(self, layer_idx, N, S, activation, num_tfc, **kwargs):
super(D_Block, self).__init__(**kwargs)
self.layer_idx = layer_idx
@ -405,29 +460,35 @@ class D_Block(layers.Layer):
ksize = (S[0] + 2, S[1] + 2)
self.paddings_1 = get_paddings(ksize)
self.tconv_1= layers.Conv2DTranspose(filters=N,
self.tconv_1 = layers.Conv2DTranspose(
filters=N,
kernel_size=(S[0] + 2, S[1] + 2),
kernel_initializer=TruncatedNormal(),
strides=S,
activation=self.activation,
padding='VALID')
padding="VALID",
)
self.upsampling = layers.UpSampling2D(size=S, interpolation='nearest')
self.upsampling = layers.UpSampling2D(size=S, interpolation="nearest")
self.projection = layers.Conv2D(filters=N,
self.projection = layers.Conv2D(
filters=N,
kernel_size=(1, 1),
kernel_initializer=TruncatedNormal(),
strides=1,
activation=self.activation,
padding='VALID')
padding="VALID",
)
self.cropadd = CropAddBlock()
self.cropconcat = CropConcatBlock()
self.i_block = I_Block(N, activation, num_tfc)
def call(self, inputs, bridge, previous_encoder=None, previous_decoder=None,**kwargs):
def call(
self, inputs, bridge, previous_encoder=None, previous_decoder=None, **kwargs
):
x = inputs
x=tf.pad(x, self.paddings_1, mode='SYMMETRIC')
x = tf.pad(x, self.paddings_1, mode="SYMMETRIC")
x = self.tconv_1(inputs)
x2 = self.upsampling(inputs)
@ -437,39 +498,40 @@ class D_Block(layers.Layer):
x = self.cropadd(x, x2)
x = self.cropconcat(x, bridge)
x = self.i_block(x)
return x
def get_config(self):
return dict(layer_idx=self.layer_idx,
return dict(
layer_idx=self.layer_idx,
N=self.N,
S=self.S,
**super(D_Block, self).get_config()
**super(D_Block, self).get_config(),
)
class CropAddBlock(layers.Layer):
class CropAddBlock(layers.Layer):
def call(self, down_layer, x, **kwargs):
x1_shape = tf.shape(down_layer)
x2_shape = tf.shape(x)
height_diff = (x1_shape[1] - x2_shape[1]) // 2
width_diff = (x1_shape[2] - x2_shape[2]) // 2
down_layer_cropped = down_layer[:,
down_layer_cropped = down_layer[
:,
height_diff : (x2_shape[1] + height_diff),
width_diff : (x2_shape[2] + width_diff),
:]
:,
]
x = layers.Add()([down_layer_cropped, x])
return x
class CropConcatBlock(layers.Layer):
class CropConcatBlock(layers.Layer):
def call(self, down_layer, x, **kwargs):
x1_shape = tf.shape(down_layer)
x2_shape = tf.shape(x)
@ -477,10 +539,12 @@ class CropConcatBlock(layers.Layer):
height_diff = (x1_shape[1] - x2_shape[1]) // 2
width_diff = (x1_shape[2] - x2_shape[2]) // 2
down_layer_cropped = down_layer[:,
down_layer_cropped = down_layer[
:,
height_diff : (x2_shape[1] + height_diff),
width_diff : (x2_shape[2] + width_diff),
:]
:,
]
x = tf.concat([down_layer_cropped, x], axis=-1)
return x