Compare commits
10 Commits
ff2dff25d5
...
9870240572
| Author | SHA1 | Date | |
|---|---|---|---|
| 9870240572 | |||
|
|
114fce7c84 | ||
|
|
1fe1988ed5 | ||
|
|
fb7a32a1ff | ||
|
|
b7d071a54c | ||
|
|
6eb46ba2fc | ||
|
|
018f4418e6 | ||
|
|
a1a92afefd | ||
|
|
214c872c51 | ||
|
|
210cd0edd8 |
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
experiments
|
||||
outputs
|
||||
__pycache__
|
||||
15
README.md
15
README.md
@ -13,7 +13,7 @@ width="400px"></p>
|
||||
|
||||
Listen to our [audio samples](http://research.spa.aalto.fi/publications/papers/icassp22-denoising/)
|
||||
|
||||
[](https://colab.research.google.com/github/eloimoliner/denoising-historical-recordings/blob/colab/colab/demo.ipynb]
|
||||
[](https://colab.research.google.com/github/eloimoliner/denoising-historical-recordings/blob/master/colab/demo.ipynb)
|
||||
|
||||
## Requirements
|
||||
You will need at least python 3.7 and CUDA 10.1 if you want to use GPU. See `requirements.txt` for the required package versions.
|
||||
@ -24,7 +24,10 @@ To install the environment through anaconda, follow the instructions:
|
||||
conda activate historical_denoiser
|
||||
|
||||
## Denoising Recordings
|
||||
Run the following commands to clone the repository and install the pretrained weights of the two-stage U-Net model:
|
||||
|
||||
You can denoise your recordings in the cloud using the Colab notebook. [](https://colab.research.google.com/github/eloimoliner/denoising-historical-recordings/blob/master/colab/demo.ipynb)
|
||||
|
||||
Otherwise, run the following commands to clone the repository and install the pretrained weights of the two-stage U-Net model:
|
||||
|
||||
git clone https://github.com/eloimoliner/denoising-historical-recordings.git
|
||||
cd denoising-historical-recordings
|
||||
@ -37,7 +40,13 @@ If the environment is installed correctly, you can denoise an audio file by runn
|
||||
|
||||
A ".wav" file with the denoised version, as well as the residual noise and the original signal in "mono", will be generated in the same directory as the input file.
|
||||
## Training
|
||||
TODO
|
||||
To retrain the model, follow the instructions:
|
||||
|
||||
Download the [Gramophone Noise Dataset](http://research.spa.aalto.fi/publications/papers/icassp22-denoising/media/datasets/Gramophone_Record_Noise_Dataset.zip), or any other dataset containing recording noises.
|
||||
|
||||
Prepare a dataset of clean music (e.g. [MusicNet](https://zenodo.org/record/5120004#.YnN-96IzbmE))
|
||||
|
||||
|
||||
## Remarks
|
||||
|
||||
The trained model is specialized in denoising gramophone recordings, such as the ones included in this collection https://archive.org/details/georgeblood. It has shown to be robust to a wide range of different noises, but it may produce some artifacts if you try to inference in something completely different.
|
||||
|
||||
@ -7,7 +7,7 @@
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"<a href=\"https://colab.research.google.com/github/eloimoliner/denoising-historical-recordings/blob/colab/colab/demo.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
||||
"<a href=\"https://colab.research.google.com/github/eloimoliner/denoising-historical-recordings/blob/master/colab/demo.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -40,7 +40,8 @@
|
||||
"* Make sure to use a GPU runtime, click: __Runtime >> Change Runtime Type >> GPU__\n",
|
||||
"* Press ▶️ on the left of each of the cells\n",
|
||||
"* View the code: Double-click any of the cells\n",
|
||||
"* Hide the code: Double click the right side of the cell\n"
|
||||
"* Hide the code: Double click the right side of the cell\n",
|
||||
"* For some reason, this notebook does not work in Firefox, so please use another browser.\n"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "8UON6ncSApA9"
|
||||
@ -207,7 +208,7 @@
|
||||
"id": "TQBDTmO4mUBx"
|
||||
},
|
||||
"id": "TQBDTmO4mUBx",
|
||||
"execution_count": 4,
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
@ -243,7 +244,7 @@
|
||||
"outputId": "2d05860c-536d-45f8-92b4-d2ba6f5a54c5"
|
||||
},
|
||||
"id": "50Kmdy6AtbhW",
|
||||
"execution_count": 5,
|
||||
"execution_count": null,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "display_data",
|
||||
@ -296,7 +297,7 @@
|
||||
"outputId": "173f5355-2939-41fe-c702-591aa752fc7e"
|
||||
},
|
||||
"id": "0po6zpvrylc2",
|
||||
"execution_count": 6,
|
||||
"execution_count": null,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
@ -333,7 +334,7 @@
|
||||
"outputId": "54588c26-0b3c-42bf-aca2-8316ab54603f"
|
||||
},
|
||||
"id": "3tEshWBezYvf",
|
||||
"execution_count": 7,
|
||||
"execution_count": null,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "display_data",
|
||||
|
||||
@ -4,485 +4,561 @@ import tensorflow as tf
|
||||
import random
|
||||
import os
|
||||
import numpy as np
|
||||
from scipy.fft import fft, ifft
|
||||
import soundfile as sf
|
||||
import math
|
||||
import pandas as pd
|
||||
import scipy as sp
|
||||
import glob
|
||||
from tqdm import tqdm
|
||||
|
||||
#generator function. It reads the csv file with pandas and loads the largest audio segments from each recording. If extend=False, it will only read the segments with length>length_seg, trim them and yield them with no further processing. Otherwise, if the segment length is inferior, it will extend the length using concatenative synthesis.
|
||||
def __noise_sample_generator(info_file,fs, length_seq, split):
|
||||
head=os.path.split(info_file)[0]
|
||||
load_data=pd.read_csv(info_file)
|
||||
#split= train, validation, test
|
||||
load_data_split=load_data.loc[load_data["split"]==split]
|
||||
load_data_split=load_data_split.reset_index(drop=True)
|
||||
|
||||
# generator function. It reads the csv file with pandas and loads the largest audio segments from each recording. If extend=False, it will only read the segments with length>length_seg, trim them and yield them with no further processing. Otherwise, if the segment length is inferior, it will extend the length using concatenative synthesis.
|
||||
def __noise_sample_generator(info_file, fs, length_seq, split):
|
||||
head = os.path.split(info_file)[0]
|
||||
load_data = pd.read_csv(info_file)
|
||||
# split= train, validation, test
|
||||
load_data_split = load_data.loc[load_data["split"] == split]
|
||||
load_data_split = load_data_split.reset_index(drop=True)
|
||||
while True:
|
||||
r = list(range(len(load_data_split)))
|
||||
if split!="test":
|
||||
if split != "test":
|
||||
random.shuffle(r)
|
||||
for i in r:
|
||||
segments=ast.literal_eval(load_data_split.loc[i,"segments"])
|
||||
if split=="test":
|
||||
loaded_data, Fs=sf.read(os.path.join(head,load_data_split["recording"].loc[i],load_data_split["largest_segment"].loc[i]))
|
||||
segments = ast.literal_eval(load_data_split.loc[i, "segments"])
|
||||
if split == "test":
|
||||
loaded_data, Fs = sf.read(
|
||||
os.path.join(
|
||||
head,
|
||||
load_data_split["recording"].loc[i],
|
||||
load_data_split["largest_segment"].loc[i],
|
||||
)
|
||||
)
|
||||
else:
|
||||
num=np.random.randint(0,len(segments))
|
||||
loaded_data, Fs=sf.read(os.path.join(head,load_data_split["recording"].loc[i],segments[num]))
|
||||
assert(fs==Fs, "wrong sampling rate")
|
||||
num = np.random.randint(0, len(segments))
|
||||
loaded_data, Fs = sf.read(
|
||||
os.path.join(
|
||||
head, load_data_split["recording"].loc[i], segments[num]
|
||||
)
|
||||
)
|
||||
assert fs == Fs, "wrong sampling rate"
|
||||
|
||||
yield __extend_sample_by_repeating(loaded_data,fs,length_seq)
|
||||
yield __extend_sample_by_repeating(loaded_data, fs, length_seq)
|
||||
|
||||
def __extend_sample_by_repeating(data, fs,seq_len):
|
||||
rpm=78
|
||||
target_samp=seq_len
|
||||
large_data=np.zeros(shape=(target_samp,2))
|
||||
|
||||
if len(data)>=target_samp:
|
||||
large_data=data[0:target_samp]
|
||||
def __extend_sample_by_repeating(data, fs, seq_len):
|
||||
rpm = 78
|
||||
target_samp = seq_len
|
||||
large_data = np.zeros(shape=(target_samp, 2))
|
||||
|
||||
if len(data) >= target_samp:
|
||||
large_data = data[0:target_samp]
|
||||
return large_data
|
||||
|
||||
bls=(1000*44100)/1000 #hardcoded
|
||||
bls = (1000 * 44100) / 1000 # hardcoded
|
||||
|
||||
window=np.stack((np.hanning(bls) ,np.hanning(bls)), axis=1)
|
||||
window_left=window[0:int(bls/2),:]
|
||||
window_right=window[int(bls/2)::,:]
|
||||
bls=int(bls/2)
|
||||
window = np.stack((np.hanning(bls), np.hanning(bls)), axis=1)
|
||||
window_left = window[0 : int(bls / 2), :]
|
||||
window_right = window[int(bls / 2) : :, :]
|
||||
bls = int(bls / 2)
|
||||
|
||||
rps=rpm/60
|
||||
period=1/rps
|
||||
rps = rpm / 60
|
||||
period = 1 / rps
|
||||
|
||||
period_sam=int(period*fs)
|
||||
period_sam = int(period * fs)
|
||||
|
||||
overhead=len(data)%period_sam
|
||||
overhead = len(data) % period_sam
|
||||
|
||||
if(overhead>bls):
|
||||
complete_periods=(len(data)//period_sam)*period_sam
|
||||
if overhead > bls:
|
||||
complete_periods = (len(data) // period_sam) * period_sam
|
||||
else:
|
||||
complete_periods=(len(data)//period_sam -1)*period_sam
|
||||
complete_periods = (len(data) // period_sam - 1) * period_sam
|
||||
|
||||
a = np.multiply(data[0:bls], window_left)
|
||||
b = np.multiply(data[complete_periods : complete_periods + bls], window_right)
|
||||
c_1 = np.concatenate((data[0:complete_periods, :], b))
|
||||
c_2 = np.concatenate((a, data[bls:complete_periods, :], b))
|
||||
c_3 = np.concatenate((a, data[bls::, :]))
|
||||
|
||||
a=np.multiply(data[0:bls], window_left)
|
||||
b=np.multiply(data[complete_periods:complete_periods+bls], window_right)
|
||||
c_1=np.concatenate((data[0:complete_periods,:],b))
|
||||
c_2=np.concatenate((a,data[bls:complete_periods,:],b))
|
||||
c_3=np.concatenate((a,data[bls::,:]))
|
||||
large_data[0 : complete_periods + bls, :] = c_1
|
||||
|
||||
large_data[0:complete_periods+bls,:]=c_1
|
||||
|
||||
|
||||
pointer=complete_periods
|
||||
not_finished=True
|
||||
while (not_finished):
|
||||
if target_samp>pointer+complete_periods+bls:
|
||||
large_data[pointer:pointer+complete_periods+bls] +=c_2
|
||||
pointer+=complete_periods
|
||||
pointer = complete_periods
|
||||
not_finished = True
|
||||
while not_finished:
|
||||
if target_samp > pointer + complete_periods + bls:
|
||||
large_data[pointer : pointer + complete_periods + bls] += c_2
|
||||
pointer += complete_periods
|
||||
else:
|
||||
large_data[pointer::]+=c_3[0:(target_samp-pointer)]
|
||||
#finish
|
||||
not_finished=False
|
||||
large_data[pointer::] += c_3[0 : (target_samp - pointer)]
|
||||
# finish
|
||||
not_finished = False
|
||||
|
||||
return large_data
|
||||
|
||||
|
||||
def generate_real_recordings_data(path_recordings, fs=44100, seg_len_s=15, stereo=False):
|
||||
|
||||
records_info=os.path.join(path_recordings,"audio_files.txt")
|
||||
def generate_real_recordings_data(
|
||||
path_recordings, fs=44100, seg_len_s=15, stereo=False
|
||||
):
|
||||
records_info = os.path.join(path_recordings, "audio_files.txt")
|
||||
num_lines = sum(1 for line in open(records_info))
|
||||
f = open(records_info,"r")
|
||||
#load data record files
|
||||
f = open(records_info, "r")
|
||||
# load data record files
|
||||
print("Loading record files")
|
||||
records=[]
|
||||
seg_len=fs*seg_len_s
|
||||
pointer=int(fs*5) #starting at second 5 by default
|
||||
records = []
|
||||
seg_len = fs * seg_len_s
|
||||
pointer = int(fs * 5) # starting at second 5 by default
|
||||
for i in tqdm(range(num_lines)):
|
||||
audio=f.readline()
|
||||
audio=audio[:-1]
|
||||
data, fs=sf.read(os.path.join(path_recordings,audio))
|
||||
if len(data.shape)>1 and not(stereo):
|
||||
data=np.mean(data,axis=1)
|
||||
#elif stereo and len(data.shape)==1:
|
||||
audio = f.readline()
|
||||
audio = audio[:-1]
|
||||
data, fs = sf.read(os.path.join(path_recordings, audio))
|
||||
if len(data.shape) > 1 and not (stereo):
|
||||
data = np.mean(data, axis=1)
|
||||
# elif stereo and len(data.shape)==1:
|
||||
# data=np.stack((data, data), axis=1)
|
||||
|
||||
#normalize
|
||||
data=data/np.max(np.abs(data))
|
||||
segment=data[pointer:pointer+seg_len]
|
||||
# normalize
|
||||
data = data / np.max(np.abs(data))
|
||||
segment = data[pointer : pointer + seg_len]
|
||||
records.append(segment.astype("float32"))
|
||||
|
||||
return records
|
||||
|
||||
def generate_paired_data_test_formal(path_pianos, path_noises, noise_amount="low_snr",num_samples=-1, fs=44100, seg_len_s=5 , extend=True, stereo=False, prenoise=False):
|
||||
|
||||
def generate_paired_data_test_formal(
|
||||
path_pianos,
|
||||
path_noises,
|
||||
noise_amount="low_snr",
|
||||
num_samples=-1,
|
||||
fs=44100,
|
||||
seg_len_s=5,
|
||||
extend=True,
|
||||
stereo=False,
|
||||
prenoise=False,
|
||||
):
|
||||
print(num_samples)
|
||||
segments_clean=[]
|
||||
segments_noisy=[]
|
||||
seg_len=fs*seg_len_s
|
||||
noises_info=os.path.join(path_noises,"info.csv")
|
||||
segments_clean = []
|
||||
segments_noisy = []
|
||||
seg_len = fs * seg_len_s
|
||||
noises_info = os.path.join(path_noises, "info.csv")
|
||||
np.random.seed(42)
|
||||
if noise_amount=="low_snr":
|
||||
SNRs=np.random.uniform(2,6,num_samples)
|
||||
elif noise_amount=="mid_snr":
|
||||
SNRs=np.random.uniform(6,12,num_samples)
|
||||
if noise_amount == "low_snr":
|
||||
SNRs = np.random.uniform(2, 6, num_samples)
|
||||
elif noise_amount == "mid_snr":
|
||||
SNRs = np.random.uniform(6, 12, num_samples)
|
||||
|
||||
scales=np.random.uniform(-4,0,num_samples)
|
||||
#SNRs=[2,6,12] #HARDCODED!!!!
|
||||
i=0
|
||||
scales = np.random.uniform(-4, 0, num_samples)
|
||||
# SNRs=[2,6,12] #HARDCODED!!!!
|
||||
i = 0
|
||||
print(path_pianos[0])
|
||||
print(seg_len)
|
||||
train_samples=glob.glob(os.path.join(path_pianos[0],"*.wav"))
|
||||
train_samples=sorted(train_samples)
|
||||
train_samples = glob.glob(os.path.join(path_pianos[0], "*.wav"))
|
||||
train_samples = sorted(train_samples)
|
||||
|
||||
if prenoise:
|
||||
noise_generator=__noise_sample_generator(noises_info,fs, seg_len+fs, extend, "test") #Adds 1s of silence add the begiing, longer noise
|
||||
noise_generator = __noise_sample_generator(
|
||||
noises_info, fs, seg_len + fs, extend, "test"
|
||||
) # Adds 1s of silence add the begiing, longer noise
|
||||
else:
|
||||
noise_generator=__noise_sample_generator(noises_info,fs, seg_len, extend, "test") #this will take care of everything
|
||||
#load data clean files
|
||||
for file in tqdm(train_samples): #add [1:5] for testing
|
||||
noise_generator = __noise_sample_generator(
|
||||
noises_info, fs, seg_len, extend, "test"
|
||||
) # this will take care of everything
|
||||
# load data clean files
|
||||
for file in tqdm(train_samples): # add [1:5] for testing
|
||||
data_clean, samplerate = sf.read(file)
|
||||
if samplerate!=fs:
|
||||
if samplerate != fs:
|
||||
print("!!!!WRONG SAMPLE RATe!!!")
|
||||
#Stereo to mono
|
||||
if len(data_clean.shape)>1 and not(stereo):
|
||||
data_clean=np.mean(data_clean,axis=1)
|
||||
#elif stereo and len(data_clean.shape)==1:
|
||||
# Stereo to mono
|
||||
if len(data_clean.shape) > 1 and not (stereo):
|
||||
data_clean = np.mean(data_clean, axis=1)
|
||||
# elif stereo and len(data_clean.shape)==1:
|
||||
# data_clean=np.stack((data_clean, data_clean), axis=1)
|
||||
#normalize
|
||||
data_clean=data_clean/np.max(np.abs(data_clean))
|
||||
#data_clean_loaded.append(data_clean)
|
||||
# normalize
|
||||
data_clean = data_clean / np.max(np.abs(data_clean))
|
||||
# data_clean_loaded.append(data_clean)
|
||||
|
||||
#framify data clean files
|
||||
# framify data clean files
|
||||
|
||||
#framify arguments: seg_len, hop_size
|
||||
hop_size=int(seg_len)# no overlap
|
||||
# framify arguments: seg_len, hop_size
|
||||
hop_size = int(seg_len) # no overlap
|
||||
|
||||
num_frames=np.floor(len(data_clean)/hop_size - seg_len/hop_size +1)
|
||||
num_frames = np.floor(len(data_clean) / hop_size - seg_len / hop_size + 1)
|
||||
print(num_frames)
|
||||
if num_frames==0:
|
||||
data_clean=np.concatenate((data_clean, np.zeros(shape=(int(2*seg_len-len(data_clean)),))), axis=0)
|
||||
num_frames=1
|
||||
if num_frames == 0:
|
||||
data_clean = np.concatenate(
|
||||
(data_clean, np.zeros(shape=(int(2 * seg_len - len(data_clean)),))),
|
||||
axis=0,
|
||||
)
|
||||
num_frames = 1
|
||||
|
||||
data_not_finished=True
|
||||
pointer=0
|
||||
while(data_not_finished):
|
||||
if i>=num_samples:
|
||||
data_not_finished = True
|
||||
pointer = 0
|
||||
while data_not_finished:
|
||||
if i >= num_samples:
|
||||
break
|
||||
segment=data_clean[pointer:pointer+seg_len]
|
||||
pointer=pointer+hop_size
|
||||
if pointer+seg_len>len(data_clean):
|
||||
data_not_finished=False
|
||||
segment=segment.astype('float32')
|
||||
segment = data_clean[pointer : pointer + seg_len]
|
||||
pointer = pointer + hop_size
|
||||
if pointer + seg_len > len(data_clean):
|
||||
data_not_finished = False
|
||||
segment = segment.astype("float32")
|
||||
|
||||
#SNRs=np.random.uniform(2,20)
|
||||
snr=SNRs[i]
|
||||
scale=scales[i]
|
||||
#load noise signal
|
||||
data_noise= next(noise_generator)
|
||||
data_noise=np.mean(data_noise,axis=1)
|
||||
#normalize
|
||||
data_noise=data_noise/np.max(np.abs(data_noise))
|
||||
new_noise=data_noise #if more processing needed, add here
|
||||
#load clean data
|
||||
#configure sizes
|
||||
power_clean=np.var(segment)
|
||||
#estimate noise power
|
||||
# SNRs=np.random.uniform(2,20)
|
||||
snr = SNRs[i]
|
||||
scale = scales[i]
|
||||
# load noise signal
|
||||
data_noise = next(noise_generator)
|
||||
data_noise = np.mean(data_noise, axis=1)
|
||||
# normalize
|
||||
data_noise = data_noise / np.max(np.abs(data_noise))
|
||||
new_noise = data_noise # if more processing needed, add here
|
||||
# load clean data
|
||||
# configure sizes
|
||||
power_clean = np.var(segment)
|
||||
# estimate noise power
|
||||
if prenoise:
|
||||
power_noise=np.var(new_noise[fs::])
|
||||
power_noise = np.var(new_noise[fs::])
|
||||
else:
|
||||
power_noise=np.var(new_noise)
|
||||
power_noise = np.var(new_noise)
|
||||
|
||||
snr = 10.0**(snr/10.0)
|
||||
snr = 10.0 ** (snr / 10.0)
|
||||
|
||||
#sum both signals according to snr
|
||||
# sum both signals according to snr
|
||||
if prenoise:
|
||||
segment=np.concatenate((np.zeros(shape=(fs,)),segment),axis=0) #add one second of silence
|
||||
summed=segment+np.sqrt(power_clean/(snr*power_noise))*new_noise #not sure if this is correct, maybe revisit later!!
|
||||
segment = np.concatenate(
|
||||
(np.zeros(shape=(fs,)), segment), axis=0
|
||||
) # add one second of silence
|
||||
summed = (
|
||||
segment + np.sqrt(power_clean / (snr * power_noise)) * new_noise
|
||||
) # not sure if this is correct, maybe revisit later!!
|
||||
|
||||
summed=summed.astype('float32')
|
||||
#yield tf.convert_to_tensor(summed), tf.convert_to_tensor(segment)
|
||||
summed = summed.astype("float32")
|
||||
# yield tf.convert_to_tensor(summed), tf.convert_to_tensor(segment)
|
||||
|
||||
|
||||
summed=10.0**(scale/10.0) *summed
|
||||
segment=10.0**(scale/10.0) *segment
|
||||
segments_noisy.append(summed.astype('float32'))
|
||||
segments_clean.append(segment.astype('float32'))
|
||||
i=i+1
|
||||
summed = 10.0 ** (scale / 10.0) * summed
|
||||
segment = 10.0 ** (scale / 10.0) * segment
|
||||
segments_noisy.append(summed.astype("float32"))
|
||||
segments_clean.append(segment.astype("float32"))
|
||||
i = i + 1
|
||||
|
||||
return segments_noisy, segments_clean
|
||||
|
||||
def generate_test_data(path_music, path_noises,num_samples=-1, fs=44100, seg_len_s=5):
|
||||
|
||||
segments_clean=[]
|
||||
segments_noisy=[]
|
||||
seg_len=fs*seg_len_s
|
||||
noises_info=os.path.join(path_noises,"info.csv")
|
||||
SNRs=[2,6,12] #HARDCODED!!!!
|
||||
def generate_test_data(path_music, path_noises, num_samples=-1, fs=44100, seg_len_s=5):
|
||||
segments_clean = []
|
||||
segments_noisy = []
|
||||
seg_len = fs * seg_len_s
|
||||
noises_info = os.path.join(path_noises, "info.csv")
|
||||
SNRs = [2, 6, 12] # HARDCODED!!!!
|
||||
for path in path_music:
|
||||
print(path)
|
||||
train_samples=glob.glob(os.path.join(path,"*.wav"))
|
||||
train_samples=sorted(train_samples)
|
||||
train_samples = glob.glob(os.path.join(path, "*.wav"))
|
||||
train_samples = sorted(train_samples)
|
||||
|
||||
noise_generator=__noise_sample_generator(noises_info,fs, seg_len, "test") #this will take care of everything
|
||||
#load data clean files
|
||||
jj=0
|
||||
for file in tqdm(train_samples): #add [1:5] for testing
|
||||
noise_generator = __noise_sample_generator(
|
||||
noises_info, fs, seg_len, "test"
|
||||
) # this will take care of everything
|
||||
# load data clean files
|
||||
for file in tqdm(train_samples): # add [1:5] for testing
|
||||
data_clean, samplerate = sf.read(file)
|
||||
if samplerate!=fs:
|
||||
if samplerate != fs:
|
||||
print("!!!!WRONG SAMPLE RATe!!!")
|
||||
#Stereo to mono
|
||||
if len(data_clean.shape)>1:
|
||||
data_clean=np.mean(data_clean,axis=1)
|
||||
#normalize
|
||||
data_clean=data_clean/np.max(np.abs(data_clean))
|
||||
#data_clean_loaded.append(data_clean)
|
||||
# Stereo to mono
|
||||
if len(data_clean.shape) > 1:
|
||||
data_clean = np.mean(data_clean, axis=1)
|
||||
# normalize
|
||||
data_clean = data_clean / np.max(np.abs(data_clean))
|
||||
# data_clean_loaded.append(data_clean)
|
||||
|
||||
#framify data clean files
|
||||
# framify data clean files
|
||||
|
||||
#framify arguments: seg_len, hop_size
|
||||
hop_size=int(seg_len)# no overlap
|
||||
# framify arguments: seg_len, hop_size
|
||||
hop_size = int(seg_len) # no overlap
|
||||
|
||||
num_frames=np.floor(len(data_clean)/hop_size - seg_len/hop_size +1)
|
||||
if num_frames==0:
|
||||
data_clean=np.concatenate((data_clean, np.zeros(shape=(int(2*seg_len-len(data_clean)),))), axis=0)
|
||||
num_frames=1
|
||||
num_frames = np.floor(len(data_clean) / hop_size - seg_len / hop_size + 1)
|
||||
if num_frames == 0:
|
||||
data_clean = np.concatenate(
|
||||
(data_clean, np.zeros(shape=(int(2 * seg_len - len(data_clean)),))),
|
||||
axis=0,
|
||||
)
|
||||
num_frames = 1
|
||||
|
||||
pointer=0
|
||||
segment=data_clean[pointer:pointer+(seg_len-2*fs)]
|
||||
segment=segment.astype('float32')
|
||||
segment=np.concatenate(( np.zeros(shape=(2*fs,)), segment), axis=0) #I hope its ok
|
||||
#segments_clean.append(segment)
|
||||
pointer = 0
|
||||
segment = data_clean[pointer : pointer + (seg_len - 2 * fs)]
|
||||
segment = segment.astype("float32")
|
||||
segment = np.concatenate(
|
||||
(np.zeros(shape=(2 * fs,)), segment), axis=0
|
||||
) # I hope its ok
|
||||
# segments_clean.append(segment)
|
||||
|
||||
for snr in SNRs:
|
||||
#load noise signal
|
||||
data_noise= next(noise_generator)
|
||||
data_noise=np.mean(data_noise,axis=1)
|
||||
#normalize
|
||||
data_noise=data_noise/np.max(np.abs(data_noise))
|
||||
new_noise=data_noise #if more processing needed, add here
|
||||
#load clean data
|
||||
#configure sizes
|
||||
#estimate clean signal power
|
||||
power_clean=np.var(segment)
|
||||
#estimate noise power
|
||||
power_noise=np.var(new_noise)
|
||||
# load noise signal
|
||||
data_noise = next(noise_generator)
|
||||
data_noise = np.mean(data_noise, axis=1)
|
||||
# normalize
|
||||
data_noise = data_noise / np.max(np.abs(data_noise))
|
||||
new_noise = data_noise # if more processing needed, add here
|
||||
# load clean data
|
||||
# configure sizes
|
||||
# estimate clean signal power
|
||||
power_clean = np.var(segment)
|
||||
# estimate noise power
|
||||
power_noise = np.var(new_noise)
|
||||
|
||||
snr = 10.0**(snr/10.0)
|
||||
snr = 10.0 ** (snr / 10.0)
|
||||
|
||||
#sum both signals according to snr
|
||||
summed=segment+np.sqrt(power_clean/(snr*power_noise))*new_noise #not sure if this is correct, maybe revisit later!!
|
||||
summed=summed.astype('float32')
|
||||
#yield tf.convert_to_tensor(summed), tf.convert_to_tensor(segment)
|
||||
# sum both signals according to snr
|
||||
summed = (
|
||||
segment + np.sqrt(power_clean / (snr * power_noise)) * new_noise
|
||||
) # not sure if this is correct, maybe revisit later!!
|
||||
summed = summed.astype("float32")
|
||||
# yield tf.convert_to_tensor(summed), tf.convert_to_tensor(segment)
|
||||
|
||||
segments_noisy.append(summed.astype('float32'))
|
||||
segments_clean.append(segment.astype('float32'))
|
||||
segments_noisy.append(summed.astype("float32"))
|
||||
segments_clean.append(segment.astype("float32"))
|
||||
|
||||
return segments_noisy, segments_clean
|
||||
|
||||
def generate_val_data(path_music, path_noises,split,num_samples=-1, fs=44100, seg_len_s=5):
|
||||
|
||||
val_samples=[]
|
||||
def generate_val_data(
|
||||
path_music, path_noises, split, num_samples=-1, fs=44100, seg_len_s=5
|
||||
):
|
||||
val_samples = []
|
||||
for path in path_music:
|
||||
val_samples.extend(glob.glob(os.path.join(path,"*.wav")))
|
||||
val_samples.extend(glob.glob(os.path.join(path, "*.wav")))
|
||||
|
||||
#load data clean files
|
||||
# load data clean files
|
||||
print("Loading clean files")
|
||||
data_clean_loaded=[]
|
||||
for ff in tqdm(range(0,len(val_samples))): #add [1:5] for testing
|
||||
data_clean_loaded = []
|
||||
for ff in tqdm(range(0, len(val_samples))): # add [1:5] for testing
|
||||
data_clean, samplerate = sf.read(val_samples[ff])
|
||||
if samplerate!=fs:
|
||||
if samplerate != fs:
|
||||
print("!!!!WRONG SAMPLE RATe!!!")
|
||||
#Stereo to mono
|
||||
if len(data_clean.shape)>1 :
|
||||
data_clean=np.mean(data_clean,axis=1)
|
||||
#normalize
|
||||
data_clean=data_clean/np.max(np.abs(data_clean))
|
||||
# Stereo to mono
|
||||
if len(data_clean.shape) > 1:
|
||||
data_clean = np.mean(data_clean, axis=1)
|
||||
# normalize
|
||||
data_clean = data_clean / np.max(np.abs(data_clean))
|
||||
data_clean_loaded.append(data_clean)
|
||||
del data_clean
|
||||
|
||||
#framify data clean files
|
||||
# framify data clean files
|
||||
print("Framifying clean files")
|
||||
seg_len=fs*seg_len_s
|
||||
segments_clean=[]
|
||||
seg_len = fs * seg_len_s
|
||||
segments_clean = []
|
||||
for file in tqdm(data_clean_loaded):
|
||||
# framify arguments: seg_len, hop_size
|
||||
hop_size = int(seg_len) # no overlap
|
||||
|
||||
#framify arguments: seg_len, hop_size
|
||||
hop_size=int(seg_len)# no overlap
|
||||
|
||||
num_frames=np.floor(len(file)/hop_size - seg_len/hop_size +1)
|
||||
pointer=0
|
||||
for i in range(0,int(num_frames)):
|
||||
segment=file[pointer:pointer+seg_len]
|
||||
pointer=pointer+hop_size
|
||||
segment=segment.astype('float32')
|
||||
num_frames = np.floor(len(file) / hop_size - seg_len / hop_size + 1)
|
||||
pointer = 0
|
||||
for i in range(0, int(num_frames)):
|
||||
segment = file[pointer : pointer + seg_len]
|
||||
pointer = pointer + hop_size
|
||||
segment = segment.astype("float32")
|
||||
segments_clean.append(segment)
|
||||
|
||||
del data_clean_loaded
|
||||
|
||||
SNRs=np.random.uniform(2,20,len(segments_clean))
|
||||
scales=np.random.uniform(-6,4,len(segments_clean))
|
||||
#noise_shapes=np.random.randint(0,len(noise_samples), len(segments_clean))
|
||||
noises_info=os.path.join(path_noises,"info.csv")
|
||||
SNRs = np.random.uniform(2, 20, len(segments_clean))
|
||||
scales = np.random.uniform(-6, 4, len(segments_clean))
|
||||
# noise_shapes=np.random.randint(0,len(noise_samples), len(segments_clean))
|
||||
noises_info = os.path.join(path_noises, "info.csv")
|
||||
|
||||
noise_generator=__noise_sample_generator(noises_info,fs, seg_len, split) #this will take care of everything
|
||||
noise_generator = __noise_sample_generator(
|
||||
noises_info, fs, seg_len, split
|
||||
) # this will take care of everything
|
||||
|
||||
# generate noisy segments
|
||||
# load noise samples using pandas dataframe. Each split (train, val, test) should have its unique csv info file
|
||||
|
||||
#generate noisy segments
|
||||
#load noise samples using pandas dataframe. Each split (train, val, test) should have its unique csv info file
|
||||
|
||||
#noise_samples=glob.glob(os.path.join(path_noises,"*.wav"))
|
||||
segments_noisy=[]
|
||||
# noise_samples=glob.glob(os.path.join(path_noises,"*.wav"))
|
||||
segments_noisy = []
|
||||
print("Processing noisy segments")
|
||||
|
||||
for i in tqdm(range(0,len(segments_clean))):
|
||||
#load noise signal
|
||||
data_noise= next(noise_generator)
|
||||
#Stereo to mono
|
||||
data_noise=np.mean(data_noise,axis=1)
|
||||
#normalize
|
||||
data_noise=data_noise/np.max(np.abs(data_noise))
|
||||
new_noise=data_noise #if more processing needed, add here
|
||||
#load clean data
|
||||
data_clean=segments_clean[i]
|
||||
#configure sizes
|
||||
for i in tqdm(range(0, len(segments_clean))):
|
||||
# load noise signal
|
||||
data_noise = next(noise_generator)
|
||||
# Stereo to mono
|
||||
data_noise = np.mean(data_noise, axis=1)
|
||||
# normalize
|
||||
data_noise = data_noise / np.max(np.abs(data_noise))
|
||||
new_noise = data_noise # if more processing needed, add here
|
||||
# load clean data
|
||||
data_clean = segments_clean[i]
|
||||
# configure sizes
|
||||
|
||||
# estimate clean signal power
|
||||
power_clean = np.var(data_clean)
|
||||
# estimate noise power
|
||||
power_noise = np.var(new_noise)
|
||||
|
||||
#estimate clean signal power
|
||||
power_clean=np.var(data_clean)
|
||||
#estimate noise power
|
||||
power_noise=np.var(new_noise)
|
||||
snr = 10.0 ** (SNRs[i] / 10.0)
|
||||
|
||||
snr = 10.0**(SNRs[i]/10.0)
|
||||
# sum both signals according to snr
|
||||
summed = (
|
||||
data_clean + np.sqrt(power_clean / (snr * power_noise)) * new_noise
|
||||
) # not sure if this is correct, maybe revisit later!!
|
||||
# the rest is normal
|
||||
|
||||
#sum both signals according to snr
|
||||
summed=data_clean+np.sqrt(power_clean/(snr*power_noise))*new_noise #not sure if this is correct, maybe revisit later!!
|
||||
#the rest is normal
|
||||
summed = 10.0 ** (scales[i] / 10.0) * summed
|
||||
segments_clean[i] = 10.0 ** (scales[i] / 10.0) * segments_clean[i]
|
||||
|
||||
summed=10.0**(scales[i]/10.0) *summed
|
||||
segments_clean[i]=10.0**(scales[i]/10.0) *segments_clean[i]
|
||||
|
||||
segments_noisy.append(summed.astype('float32'))
|
||||
segments_noisy.append(summed.astype("float32"))
|
||||
|
||||
return segments_noisy, segments_clean
|
||||
|
||||
|
||||
|
||||
def generator_train(path_music, path_noises,split, fs=44100, seg_len_s=5, extend=True, stereo=False):
|
||||
|
||||
train_samples=[]
|
||||
def generator_train(
|
||||
path_music, path_noises, split, fs=44100, seg_len_s=5, extend=True, stereo=False
|
||||
):
|
||||
train_samples = []
|
||||
for path in path_music:
|
||||
train_samples.extend(glob.glob(os.path.join(path.decode("utf-8") ,"*.wav")))
|
||||
train_samples.extend(glob.glob(os.path.join(path.decode("utf-8"), "*.wav")))
|
||||
|
||||
seg_len=fs*seg_len_s
|
||||
noises_info=os.path.join(path_noises.decode("utf-8"),"info.csv")
|
||||
noise_generator=__noise_sample_generator(noises_info,fs, seg_len, split.decode("utf-8")) #this will take care of everything
|
||||
#load data clean files
|
||||
seg_len = fs * seg_len_s
|
||||
noises_info = os.path.join(path_noises.decode("utf-8"), "info.csv")
|
||||
noise_generator = __noise_sample_generator(
|
||||
noises_info, fs, seg_len, split.decode("utf-8")
|
||||
) # this will take care of everything
|
||||
# load data clean files
|
||||
while True:
|
||||
random.shuffle(train_samples)
|
||||
for file in train_samples:
|
||||
data, samplerate = sf.read(file)
|
||||
assert(samplerate==fs, "wrong sampling rate")
|
||||
data_clean=data
|
||||
#Stereo to mono
|
||||
if len(data.shape)>1 :
|
||||
data_clean=np.mean(data_clean,axis=1)
|
||||
assert samplerate == fs, "wrong sampling rate"
|
||||
data_clean = data
|
||||
# Stereo to mono
|
||||
if len(data.shape) > 1:
|
||||
data_clean = np.mean(data_clean, axis=1)
|
||||
|
||||
#normalize
|
||||
data_clean=data_clean/np.max(np.abs(data_clean))
|
||||
# normalize
|
||||
data_clean = data_clean / np.max(np.abs(data_clean))
|
||||
|
||||
#framify data clean files
|
||||
# framify data clean files
|
||||
|
||||
#framify arguments: seg_len, hop_size
|
||||
hop_size=int(seg_len)
|
||||
# framify arguments: seg_len, hop_size
|
||||
hop_size = int(seg_len)
|
||||
|
||||
num_frames=np.floor(len(data_clean)/seg_len)
|
||||
if num_frames==0:
|
||||
data_clean=np.concatenate((data_clean, np.zeros(shape=(int(2*seg_len-len(data_clean)),))), axis=0)
|
||||
num_frames=1
|
||||
pointer=0
|
||||
data_clean=np.roll(data_clean, np.random.randint(0,seg_len)) #if only one frame, roll it for augmentation
|
||||
elif num_frames>1:
|
||||
pointer=np.random.randint(0,hop_size) #initial shifting, graeat for augmentation, better than overlap as we get different frames at each "while" iteration
|
||||
num_frames = np.floor(len(data_clean) / seg_len)
|
||||
if num_frames == 0:
|
||||
data_clean = np.concatenate(
|
||||
(data_clean, np.zeros(shape=(int(2 * seg_len - len(data_clean)),))),
|
||||
axis=0,
|
||||
)
|
||||
num_frames = 1
|
||||
pointer = 0
|
||||
data_clean = np.roll(
|
||||
data_clean, np.random.randint(0, seg_len)
|
||||
) # if only one frame, roll it for augmentation
|
||||
elif num_frames > 1:
|
||||
pointer = np.random.randint(
|
||||
0, hop_size
|
||||
) # initial shifting, graeat for augmentation, better than overlap as we get different frames at each "while" iteration
|
||||
else:
|
||||
pointer=0
|
||||
pointer = 0
|
||||
|
||||
data_not_finished=True
|
||||
while(data_not_finished):
|
||||
segment=data_clean[pointer:pointer+seg_len]
|
||||
pointer=pointer+hop_size
|
||||
if pointer+seg_len>len(data_clean):
|
||||
data_not_finished=False
|
||||
segment=segment.astype('float32')
|
||||
data_not_finished = True
|
||||
while data_not_finished:
|
||||
segment = data_clean[pointer : pointer + seg_len]
|
||||
pointer = pointer + hop_size
|
||||
if pointer + seg_len > len(data_clean):
|
||||
data_not_finished = False
|
||||
segment = segment.astype("float32")
|
||||
|
||||
SNRs=np.random.uniform(2,20)
|
||||
scale=np.random.uniform(-6,4)
|
||||
SNRs = np.random.uniform(2, 20)
|
||||
scale = np.random.uniform(-6, 4)
|
||||
|
||||
|
||||
#load noise signal
|
||||
data_noise= next(noise_generator)
|
||||
data_noise=np.mean(data_noise,axis=1)
|
||||
#normalize
|
||||
data_noise=data_noise/np.max(np.abs(data_noise))
|
||||
new_noise=data_noise #if more processing needed, add here
|
||||
#load clean data
|
||||
#configure sizes
|
||||
# load noise signal
|
||||
data_noise = next(noise_generator)
|
||||
data_noise = np.mean(data_noise, axis=1)
|
||||
# normalize
|
||||
data_noise = data_noise / np.max(np.abs(data_noise))
|
||||
new_noise = data_noise # if more processing needed, add here
|
||||
# load clean data
|
||||
# configure sizes
|
||||
if stereo:
|
||||
#estimate clean signal power
|
||||
power_clean=0.5*np.var(segment[:,0])+0.5*np.var(segment[:,1])
|
||||
#estimate noise power
|
||||
power_noise=0.5*np.var(new_noise[:,0])+0.5*np.var(new_noise[:,1])
|
||||
# estimate clean signal power
|
||||
power_clean = 0.5 * np.var(segment[:, 0]) + 0.5 * np.var(
|
||||
segment[:, 1]
|
||||
)
|
||||
# estimate noise power
|
||||
power_noise = 0.5 * np.var(new_noise[:, 0]) + 0.5 * np.var(
|
||||
new_noise[:, 1]
|
||||
)
|
||||
else:
|
||||
#estimate clean signal power
|
||||
power_clean=np.var(segment)
|
||||
#estimate noise power
|
||||
power_noise=np.var(new_noise)
|
||||
# estimate clean signal power
|
||||
power_clean = np.var(segment)
|
||||
# estimate noise power
|
||||
power_noise = np.var(new_noise)
|
||||
|
||||
snr = 10.0**(SNRs/10.0)
|
||||
snr = 10.0 ** (SNRs / 10.0)
|
||||
|
||||
# sum both signals according to snr
|
||||
summed = (
|
||||
segment + np.sqrt(power_clean / (snr * power_noise)) * new_noise
|
||||
) # not sure if this is correct, maybe revisit later!!
|
||||
summed = 10.0 ** (scale / 10.0) * summed
|
||||
segment = 10.0 ** (scale / 10.0) * segment
|
||||
|
||||
#sum both signals according to snr
|
||||
summed=segment+np.sqrt(power_clean/(snr*power_noise))*new_noise #not sure if this is correct, maybe revisit later!!
|
||||
summed=10.0**(scale/10.0) *summed
|
||||
segment=10.0**(scale/10.0) *segment
|
||||
|
||||
summed=summed.astype('float32')
|
||||
summed = summed.astype("float32")
|
||||
yield tf.convert_to_tensor(summed), tf.convert_to_tensor(segment)
|
||||
|
||||
def load_data(buffer_size, path_music_train, path_music_val, path_noises, fs=44100, seg_len_s=5, extend=True, stereo=False) :
|
||||
|
||||
def load_data(
|
||||
buffer_size,
|
||||
path_music_train,
|
||||
path_music_val,
|
||||
path_noises,
|
||||
fs=44100,
|
||||
seg_len_s=5,
|
||||
extend=True,
|
||||
stereo=False,
|
||||
):
|
||||
print("Generating train dataset")
|
||||
trainshape=int(fs*seg_len_s)
|
||||
|
||||
dataset_train = tf.data.Dataset.from_generator(generator_train,args=(path_music_train, path_noises,"train", fs, seg_len_s, extend, stereo), output_shapes=(tf.TensorShape((trainshape,)),tf.TensorShape((trainshape,))), output_types=(tf.float32, tf.float32) )
|
||||
trainshape = int(fs * seg_len_s)
|
||||
|
||||
dataset_train = tf.data.Dataset.from_generator(
|
||||
generator_train,
|
||||
args=(path_music_train, path_noises, "train", fs, seg_len_s, extend, stereo),
|
||||
output_shapes=(tf.TensorShape((trainshape,)), tf.TensorShape((trainshape,))),
|
||||
output_types=(tf.float32, tf.float32),
|
||||
)
|
||||
|
||||
print("Generating validation dataset")
|
||||
segments_noisy, segments_clean=generate_val_data(path_music_val, path_noises,"validation",fs=fs, seg_len_s=seg_len_s)
|
||||
segments_noisy, segments_clean = generate_val_data(
|
||||
path_music_val, path_noises, "validation", fs=fs, seg_len_s=seg_len_s
|
||||
)
|
||||
|
||||
dataset_val=tf.data.Dataset.from_tensor_slices((segments_noisy, segments_clean))
|
||||
dataset_val = tf.data.Dataset.from_tensor_slices((segments_noisy, segments_clean))
|
||||
|
||||
return dataset_train.shuffle(buffer_size), dataset_val
|
||||
|
||||
|
||||
def load_data_test(buffer_size, path_pianos_test, path_noises, **kwargs):
|
||||
print("Generating test dataset")
|
||||
segments_noisy, segments_clean=generate_test_data(path_pianos_test, path_noises, extend=True, **kwargs)
|
||||
dataset_test=tf.data.Dataset.from_tensor_slices((segments_noisy, segments_clean))
|
||||
#dataset_test=tf.data.Dataset.from_tensor_slices((segments_noisy[1:3], segments_clean[1:3]))
|
||||
#train_dataset = train.cache().shuffle(buffer_size).take(info.splits["train"].num_examples)
|
||||
segments_noisy, segments_clean = generate_test_data(
|
||||
path_pianos_test, path_noises, extend=True, **kwargs
|
||||
)
|
||||
dataset_test = tf.data.Dataset.from_tensor_slices((segments_noisy, segments_clean))
|
||||
# dataset_test=tf.data.Dataset.from_tensor_slices((segments_noisy[1:3], segments_clean[1:3]))
|
||||
# train_dataset = train.cache().shuffle(buffer_size).take(info.splits["train"].num_examples)
|
||||
return dataset_test
|
||||
def load_data_formal( path_pianos_test, path_noises, **kwargs) :
|
||||
|
||||
|
||||
def load_data_formal(path_pianos_test, path_noises, **kwargs):
|
||||
print("Generating test dataset")
|
||||
segments_noisy, segments_clean=generate_paired_data_test_formal(path_pianos_test, path_noises, extend=True, **kwargs)
|
||||
segments_noisy, segments_clean = generate_paired_data_test_formal(
|
||||
path_pianos_test, path_noises, extend=True, **kwargs
|
||||
)
|
||||
print("segments::")
|
||||
print(len(segments_noisy))
|
||||
dataset_test=tf.data.Dataset.from_tensor_slices((segments_noisy, segments_clean))
|
||||
#dataset_test=tf.data.Dataset.from_tensor_slices((segments_noisy[1:3], segments_clean[1:3]))
|
||||
#train_dataset = train.cache().shuffle(buffer_size).take(info.splits["train"].num_examples)
|
||||
dataset_test = tf.data.Dataset.from_tensor_slices((segments_noisy, segments_clean))
|
||||
# dataset_test=tf.data.Dataset.from_tensor_slices((segments_noisy[1:3], segments_clean[1:3]))
|
||||
# train_dataset = train.cache().shuffle(buffer_size).take(info.splits["train"].num_examples)
|
||||
return dataset_test
|
||||
|
||||
|
||||
def load_real_test_recordings(buffer_size, path_recordings, **kwargs):
|
||||
print("Generating real test dataset")
|
||||
|
||||
segments_noisy=generate_real_recordings_data(path_recordings, **kwargs)
|
||||
segments_noisy = generate_real_recordings_data(path_recordings, **kwargs)
|
||||
|
||||
dataset_test=tf.data.Dataset.from_tensor_slices(segments_noisy)
|
||||
#train_dataset = train.cache().shuffle(buffer_size).take(info.splits["train"].num_examples)
|
||||
dataset_test = tf.data.Dataset.from_tensor_slices(segments_noisy)
|
||||
# train_dataset = train.cache().shuffle(buffer_size).take(info.splits["train"].num_examples)
|
||||
return dataset_test
|
||||
|
||||
242
inference.py
242
inference.py
@ -4,6 +4,7 @@ import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def run(args):
|
||||
import unet
|
||||
import tensorflow as tf
|
||||
@ -12,127 +13,203 @@ def run(args):
|
||||
from tqdm import tqdm
|
||||
import scipy.signal
|
||||
|
||||
path_experiment=str(args.path_experiment)
|
||||
path_experiment = str(args.path_experiment)
|
||||
|
||||
unet_model = unet.build_model_denoise(unet_args=args.unet)
|
||||
|
||||
ckpt=os.path.join(os.path.dirname(os.path.abspath(__file__)),path_experiment, 'checkpoint')
|
||||
ckpt = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)), path_experiment, "checkpoint"
|
||||
)
|
||||
unet_model.load_weights(ckpt)
|
||||
|
||||
def do_stft(noisy):
|
||||
|
||||
window_fn = tf.signal.hamming_window
|
||||
|
||||
win_size=args.stft.win_size
|
||||
hop_size=args.stft.hop_size
|
||||
win_size = args.stft.win_size
|
||||
hop_size = args.stft.hop_size
|
||||
|
||||
|
||||
stft_signal_noisy=tf.signal.stft(noisy,frame_length=win_size, window_fn=window_fn, frame_step=hop_size, pad_end=True)
|
||||
stft_noisy_stacked=tf.stack( values=[tf.math.real(stft_signal_noisy), tf.math.imag(stft_signal_noisy)], axis=-1)
|
||||
stft_signal_noisy = tf.signal.stft(
|
||||
noisy,
|
||||
frame_length=win_size,
|
||||
window_fn=window_fn,
|
||||
frame_step=hop_size,
|
||||
pad_end=True,
|
||||
)
|
||||
stft_noisy_stacked = tf.stack(
|
||||
values=[tf.math.real(stft_signal_noisy), tf.math.imag(stft_signal_noisy)],
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
return stft_noisy_stacked
|
||||
|
||||
def do_istft(data):
|
||||
|
||||
window_fn = tf.signal.hamming_window
|
||||
|
||||
win_size=args.stft.win_size
|
||||
hop_size=args.stft.hop_size
|
||||
win_size = args.stft.win_size
|
||||
hop_size = args.stft.hop_size
|
||||
|
||||
inv_window_fn=tf.signal.inverse_stft_window_fn(hop_size, forward_window_fn=window_fn)
|
||||
inv_window_fn = tf.signal.inverse_stft_window_fn(
|
||||
hop_size, forward_window_fn=window_fn
|
||||
)
|
||||
|
||||
pred_cpx=data[...,0] + 1j * data[...,1]
|
||||
pred_time=tf.signal.inverse_stft(pred_cpx, win_size, hop_size, window_fn=inv_window_fn)
|
||||
pred_cpx = data[..., 0] + 1j * data[..., 1]
|
||||
pred_time = tf.signal.inverse_stft(
|
||||
pred_cpx, win_size, hop_size, window_fn=inv_window_fn
|
||||
)
|
||||
return pred_time
|
||||
|
||||
audio=str(args.inference.audio)
|
||||
audio = str(args.inference.audio)
|
||||
data, samplerate = sf.read(audio)
|
||||
print(data.dtype)
|
||||
#Stereo to mono
|
||||
if len(data.shape)>1:
|
||||
data=np.mean(data,axis=1)
|
||||
# Stereo to mono
|
||||
if len(data.shape) > 1:
|
||||
data = np.mean(data, axis=1)
|
||||
|
||||
if samplerate!=44100:
|
||||
if samplerate != 44100:
|
||||
print("Resampling")
|
||||
|
||||
data=scipy.signal.resample(data, int((44100 / samplerate )*len(data))+1)
|
||||
data = scipy.signal.resample(data, int((44100 / samplerate) * len(data)) + 1)
|
||||
|
||||
segment_size = 44100 * 5 # 20s segments
|
||||
|
||||
|
||||
segment_size=44100*5 #20s segments
|
||||
|
||||
length_data=len(data)
|
||||
overlapsize=2048 #samples (46 ms)
|
||||
window=np.hanning(2*overlapsize)
|
||||
window_right=window[overlapsize::]
|
||||
window_left=window[0:overlapsize]
|
||||
audio_finished=False
|
||||
pointer=0
|
||||
denoised_data=np.zeros(shape=(len(data),))
|
||||
residual_noise=np.zeros(shape=(len(data),))
|
||||
numchunks=int(np.ceil(length_data/segment_size))
|
||||
length_data = len(data)
|
||||
overlapsize = 2048 # samples (46 ms)
|
||||
window = np.hanning(2 * overlapsize)
|
||||
window_right = window[overlapsize::]
|
||||
window_left = window[0:overlapsize]
|
||||
pointer = 0
|
||||
denoised_data = np.zeros(shape=(len(data),))
|
||||
residual_noise = np.zeros(shape=(len(data),))
|
||||
numchunks = int(np.ceil(length_data / segment_size))
|
||||
|
||||
for i in tqdm(range(numchunks)):
|
||||
if pointer+segment_size<length_data:
|
||||
segment=data[pointer:pointer+segment_size]
|
||||
#dostft
|
||||
segment_TF=do_stft(segment)
|
||||
segment_TF_ds=tf.data.Dataset.from_tensors(segment_TF)
|
||||
if pointer + segment_size < length_data:
|
||||
segment = data[pointer : pointer + segment_size]
|
||||
# dostft
|
||||
segment_TF = do_stft(segment)
|
||||
segment_TF_ds = tf.data.Dataset.from_tensors(segment_TF)
|
||||
pred = unet_model.predict(segment_TF_ds.batch(1))
|
||||
pred=pred[0]
|
||||
residual=segment_TF-pred[0]
|
||||
residual=np.array(residual)
|
||||
pred_time=do_istft(pred[0])
|
||||
residual_time=do_istft(residual)
|
||||
residual_time=np.array(residual_time)
|
||||
pred = pred[0]
|
||||
residual = segment_TF - pred[0]
|
||||
residual = np.array(residual)
|
||||
pred_time = do_istft(pred[0])
|
||||
residual_time = do_istft(residual)
|
||||
residual_time = np.array(residual_time)
|
||||
|
||||
if pointer==0:
|
||||
pred_time=np.concatenate((pred_time[0:int(segment_size-overlapsize)], np.multiply(pred_time[int(segment_size-overlapsize):segment_size],window_right)), axis=0)
|
||||
residual_time=np.concatenate((residual_time[0:int(segment_size-overlapsize)], np.multiply(residual_time[int(segment_size-overlapsize):segment_size],window_right)), axis=0)
|
||||
if pointer == 0:
|
||||
pred_time = np.concatenate(
|
||||
(
|
||||
pred_time[0 : int(segment_size - overlapsize)],
|
||||
np.multiply(
|
||||
pred_time[int(segment_size - overlapsize) : segment_size],
|
||||
window_right,
|
||||
),
|
||||
),
|
||||
axis=0,
|
||||
)
|
||||
residual_time = np.concatenate(
|
||||
(
|
||||
residual_time[0 : int(segment_size - overlapsize)],
|
||||
np.multiply(
|
||||
residual_time[
|
||||
int(segment_size - overlapsize) : segment_size
|
||||
],
|
||||
window_right,
|
||||
),
|
||||
),
|
||||
axis=0,
|
||||
)
|
||||
else:
|
||||
pred_time=np.concatenate((np.multiply(pred_time[0:int(overlapsize)], window_left), pred_time[int(overlapsize):int(segment_size-overlapsize)], np.multiply(pred_time[int(segment_size-overlapsize):int(segment_size)],window_right)), axis=0)
|
||||
residual_time=np.concatenate((np.multiply(residual_time[0:int(overlapsize)], window_left), residual_time[int(overlapsize):int(segment_size-overlapsize)], np.multiply(residual_time[int(segment_size-overlapsize):int(segment_size)],window_right)), axis=0)
|
||||
pred_time = np.concatenate(
|
||||
(
|
||||
np.multiply(pred_time[0 : int(overlapsize)], window_left),
|
||||
pred_time[int(overlapsize) : int(segment_size - overlapsize)],
|
||||
np.multiply(
|
||||
pred_time[
|
||||
int(segment_size - overlapsize) : int(segment_size)
|
||||
],
|
||||
window_right,
|
||||
),
|
||||
),
|
||||
axis=0,
|
||||
)
|
||||
residual_time = np.concatenate(
|
||||
(
|
||||
np.multiply(residual_time[0 : int(overlapsize)], window_left),
|
||||
residual_time[
|
||||
int(overlapsize) : int(segment_size - overlapsize)
|
||||
],
|
||||
np.multiply(
|
||||
residual_time[
|
||||
int(segment_size - overlapsize) : int(segment_size)
|
||||
],
|
||||
window_right,
|
||||
),
|
||||
),
|
||||
axis=0,
|
||||
)
|
||||
|
||||
denoised_data[pointer:pointer+segment_size]=denoised_data[pointer:pointer+segment_size]+pred_time
|
||||
residual_noise[pointer:pointer+segment_size]=residual_noise[pointer:pointer+segment_size]+residual_time
|
||||
denoised_data[pointer : pointer + segment_size] = (
|
||||
denoised_data[pointer : pointer + segment_size] + pred_time
|
||||
)
|
||||
residual_noise[pointer : pointer + segment_size] = (
|
||||
residual_noise[pointer : pointer + segment_size] + residual_time
|
||||
)
|
||||
|
||||
pointer=pointer+segment_size-overlapsize
|
||||
pointer = pointer + segment_size - overlapsize
|
||||
else:
|
||||
segment=data[pointer::]
|
||||
lensegment=len(segment)
|
||||
segment=np.concatenate((segment, np.zeros(shape=(int(segment_size-len(segment)),))), axis=0)
|
||||
audio_finished=True
|
||||
#dostft
|
||||
segment_TF=do_stft(segment)
|
||||
segment = data[pointer::]
|
||||
lensegment = len(segment)
|
||||
segment = np.concatenate(
|
||||
(segment, np.zeros(shape=(int(segment_size - len(segment)),))), axis=0
|
||||
)
|
||||
# dostft
|
||||
segment_TF = do_stft(segment)
|
||||
|
||||
segment_TF_ds=tf.data.Dataset.from_tensors(segment_TF)
|
||||
segment_TF_ds = tf.data.Dataset.from_tensors(segment_TF)
|
||||
|
||||
pred = unet_model.predict(segment_TF_ds.batch(1))
|
||||
pred=pred[0]
|
||||
residual=segment_TF-pred[0]
|
||||
residual=np.array(residual)
|
||||
pred_time=do_istft(pred[0])
|
||||
pred_time=np.array(pred_time)
|
||||
pred_time=pred_time[0:segment_size]
|
||||
residual_time=do_istft(residual)
|
||||
residual_time=np.array(residual_time)
|
||||
residual_time=residual_time[0:segment_size]
|
||||
if pointer==0:
|
||||
pred_time=pred_time
|
||||
residual_time=residual_time
|
||||
pred = pred[0]
|
||||
residual = segment_TF - pred[0]
|
||||
residual = np.array(residual)
|
||||
pred_time = do_istft(pred[0])
|
||||
pred_time = np.array(pred_time)
|
||||
pred_time = pred_time[0:segment_size]
|
||||
residual_time = do_istft(residual)
|
||||
residual_time = np.array(residual_time)
|
||||
residual_time = residual_time[0:segment_size]
|
||||
if pointer == 0:
|
||||
pred_time = pred_time
|
||||
residual_time = residual_time
|
||||
else:
|
||||
pred_time=np.concatenate((np.multiply(pred_time[0:int(overlapsize)], window_left), pred_time[int(overlapsize):int(segment_size)]),axis=0)
|
||||
residual_time=np.concatenate((np.multiply(residual_time[0:int(overlapsize)], window_left), residual_time[int(overlapsize):int(segment_size)]),axis=0)
|
||||
pred_time = np.concatenate(
|
||||
(
|
||||
np.multiply(pred_time[0 : int(overlapsize)], window_left),
|
||||
pred_time[int(overlapsize) : int(segment_size)],
|
||||
),
|
||||
axis=0,
|
||||
)
|
||||
residual_time = np.concatenate(
|
||||
(
|
||||
np.multiply(residual_time[0 : int(overlapsize)], window_left),
|
||||
residual_time[int(overlapsize) : int(segment_size)],
|
||||
),
|
||||
axis=0,
|
||||
)
|
||||
|
||||
denoised_data[pointer::]=denoised_data[pointer::]+pred_time[0:lensegment]
|
||||
residual_noise[pointer::]=residual_noise[pointer::]+residual_time[0:lensegment]
|
||||
denoised_data[pointer::] = (
|
||||
denoised_data[pointer::] + pred_time[0:lensegment]
|
||||
)
|
||||
residual_noise[pointer::] = (
|
||||
residual_noise[pointer::] + residual_time[0:lensegment]
|
||||
)
|
||||
|
||||
basename=os.path.splitext(audio)[0]
|
||||
wav_noisy_name=basename+"_noisy_input"+".wav"
|
||||
basename = os.path.splitext(audio)[0]
|
||||
wav_noisy_name = basename + "_noisy_input" + ".wav"
|
||||
sf.write(wav_noisy_name, data, 44100)
|
||||
wav_output_name=basename+"_denoised"+".wav"
|
||||
wav_output_name = basename + "_denoised" + ".wav"
|
||||
sf.write(wav_output_name, denoised_data, 44100)
|
||||
wav_output_name=basename+"_residual"+".wav"
|
||||
wav_output_name = basename + "_residual" + ".wav"
|
||||
sf.write(wav_output_name, residual_noise, 44100)
|
||||
|
||||
|
||||
@ -156,10 +233,3 @@ def main(args):
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
#!/bin/bash
|
||||
|
||||
|
||||
python inference.py inference.audio=$1
|
||||
python inference.py inference.audio="$1"
|
||||
|
||||
|
||||
181
train.py
181
train.py
@ -4,143 +4,179 @@ import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def run(args):
|
||||
import unet
|
||||
import tensorflow as tf
|
||||
import dataset_loader
|
||||
from tensorflow.keras.optimizers import Adam
|
||||
import soundfile as sf
|
||||
import datetime
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
|
||||
path_experiment=str(args.path_experiment)
|
||||
path_experiment = str(args.path_experiment)
|
||||
|
||||
if not os.path.exists(path_experiment):
|
||||
os.makedirs(path_experiment)
|
||||
|
||||
path_music_train=args.dset.path_music_train
|
||||
path_music_test=args.dset.path_music_test
|
||||
path_music_validation=args.dset.path_music_validation
|
||||
path_music_train = args.dset.path_music_train
|
||||
path_music_validation = args.dset.path_music_validation
|
||||
|
||||
path_noise=args.dset.path_noise
|
||||
path_recordings=args.dset.path_recordings
|
||||
path_noise = args.dset.path_noise
|
||||
|
||||
fs=args.fs
|
||||
overlap=args.overlap
|
||||
seg_len_s_train=args.seg_len_s_train
|
||||
fs = args.fs
|
||||
seg_len_s_train = args.seg_len_s_train
|
||||
|
||||
batch_size=args.batch_size
|
||||
epochs=args.epochs
|
||||
batch_size = args.batch_size
|
||||
epochs = args.epochs
|
||||
|
||||
num_real_test_segments=args.num_real_test_segments
|
||||
buffer_size=args.buffer_size #for shuffle
|
||||
buffer_size = args.buffer_size # for shuffle
|
||||
|
||||
tensorboard_logs=args.tensorboard_logs
|
||||
tensorboard_logs = args.tensorboard_logs
|
||||
|
||||
def do_stft(noisy, clean=None):
|
||||
|
||||
window_fn = tf.signal.hamming_window
|
||||
|
||||
win_size=args.stft.win_size
|
||||
hop_size=args.stft.hop_size
|
||||
win_size = args.stft.win_size
|
||||
hop_size = args.stft.hop_size
|
||||
|
||||
stft_signal_noisy = tf.signal.stft(
|
||||
noisy, frame_length=win_size, window_fn=window_fn, frame_step=hop_size
|
||||
)
|
||||
stft_noisy_stacked = tf.stack(
|
||||
values=[tf.math.real(stft_signal_noisy), tf.math.imag(stft_signal_noisy)],
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
stft_signal_noisy=tf.signal.stft(noisy,frame_length=win_size, window_fn=window_fn, frame_step=hop_size)
|
||||
stft_noisy_stacked=tf.stack( values=[tf.math.real(stft_signal_noisy), tf.math.imag(stft_signal_noisy)], axis=-1)
|
||||
|
||||
if clean!=None:
|
||||
|
||||
stft_signal_clean=tf.signal.stft(clean,frame_length=win_size, window_fn=window_fn, frame_step=hop_size)
|
||||
stft_clean_stacked=tf.stack( values=[tf.math.real(stft_signal_clean), tf.math.imag(stft_signal_clean)], axis=-1)
|
||||
|
||||
if clean is not None:
|
||||
stft_signal_clean = tf.signal.stft(
|
||||
clean, frame_length=win_size, window_fn=window_fn, frame_step=hop_size
|
||||
)
|
||||
stft_clean_stacked = tf.stack(
|
||||
values=[
|
||||
tf.math.real(stft_signal_clean),
|
||||
tf.math.imag(stft_signal_clean),
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
return stft_noisy_stacked, stft_clean_stacked
|
||||
else:
|
||||
|
||||
return stft_noisy_stacked
|
||||
|
||||
#Loading data. The train dataset object is a generator. The validation dataset is loaded in memory.
|
||||
# Loading data. The train dataset object is a generator. The validation dataset is loaded in memory.
|
||||
|
||||
dataset_train, dataset_val=dataset_loader.load_data(buffer_size, path_music_train, path_music_validation, path_noise, fs=fs, seg_len_s=seg_len_s_train)
|
||||
dataset_train, dataset_val = dataset_loader.load_data(
|
||||
buffer_size,
|
||||
path_music_train,
|
||||
path_music_validation,
|
||||
path_noise,
|
||||
fs=fs,
|
||||
seg_len_s=seg_len_s_train,
|
||||
)
|
||||
|
||||
dataset_train=dataset_train.map(do_stft, num_parallel_calls=args.num_workers, deterministic=None)
|
||||
dataset_val=dataset_val.map(do_stft, num_parallel_calls=args.num_workers, deterministic=None)
|
||||
dataset_train = dataset_train.map(
|
||||
do_stft, num_parallel_calls=args.num_workers, deterministic=None
|
||||
)
|
||||
dataset_val = dataset_val.map(
|
||||
do_stft, num_parallel_calls=args.num_workers, deterministic=None
|
||||
)
|
||||
|
||||
strategy = tf.distribute.MirroredStrategy()
|
||||
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
|
||||
print("Number of devices: {}".format(strategy.num_replicas_in_sync))
|
||||
|
||||
with strategy.scope():
|
||||
#build the model
|
||||
# build the model
|
||||
unet_model = unet.build_model_denoise(unet_args=args.unet)
|
||||
|
||||
current_lr=args.lr
|
||||
current_lr = args.lr
|
||||
optimizer = Adam(learning_rate=current_lr, beta_1=args.beta1, beta_2=args.beta2)
|
||||
|
||||
loss=tf.keras.losses.MeanAbsoluteError()
|
||||
loss = tf.keras.losses.MeanAbsoluteError()
|
||||
|
||||
if args.use_tensorboard:
|
||||
log_dir = os.path.join(tensorboard_logs, os.path.basename(path_experiment)+"_"+datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
|
||||
train_summary_writer = tf.summary.create_file_writer(log_dir+"/train")
|
||||
val_summary_writer = tf.summary.create_file_writer(log_dir+"/validation")
|
||||
log_dir = os.path.join(
|
||||
tensorboard_logs,
|
||||
os.path.basename(path_experiment)
|
||||
+ "_"
|
||||
+ datetime.datetime.now().strftime("%Y%m%d-%H%M%S"),
|
||||
)
|
||||
train_summary_writer = tf.summary.create_file_writer(log_dir + "/train")
|
||||
val_summary_writer = tf.summary.create_file_writer(log_dir + "/validation")
|
||||
|
||||
#path where the checkpoints will be saved
|
||||
checkpoint_filepath=os.path.join(path_experiment, 'checkpoint')
|
||||
# path where the checkpoints will be saved
|
||||
checkpoint_filepath = os.path.join(path_experiment, "checkpoint")
|
||||
|
||||
dataset_train=dataset_train.batch(batch_size)
|
||||
dataset_val=dataset_val.batch(batch_size)
|
||||
dataset_train = dataset_train.batch(batch_size)
|
||||
dataset_val = dataset_val.batch(batch_size)
|
||||
|
||||
#prefetching the dataset for better performance
|
||||
dataset_train=dataset_train.prefetch(batch_size*20)
|
||||
dataset_val=dataset_val.prefetch(batch_size*20)
|
||||
# prefetching the dataset for better performance
|
||||
dataset_train = dataset_train.prefetch(batch_size * 20)
|
||||
dataset_val = dataset_val.prefetch(batch_size * 20)
|
||||
|
||||
dataset_train=strategy.experimental_distribute_dataset(dataset_train)
|
||||
dataset_val=strategy.experimental_distribute_dataset(dataset_val)
|
||||
dataset_train = strategy.experimental_distribute_dataset(dataset_train)
|
||||
dataset_val = strategy.experimental_distribute_dataset(dataset_val)
|
||||
|
||||
iterator = iter(dataset_train)
|
||||
|
||||
from trainer import Trainer
|
||||
|
||||
trainer=Trainer(unet_model,optimizer,loss,strategy, path_experiment, args)
|
||||
trainer = Trainer(unet_model, optimizer, loss, strategy, path_experiment, args)
|
||||
|
||||
for epoch in range(epochs):
|
||||
total_loss=0
|
||||
step_loss=0
|
||||
for step in tqdm(range(args.steps_per_epoch), desc="Training epoch "+str(epoch)):
|
||||
step_loss=trainer.distributed_training_step(iterator.get_next())
|
||||
total_loss+=step_loss
|
||||
total_loss = 0
|
||||
step_loss = 0
|
||||
for step in tqdm(
|
||||
range(args.steps_per_epoch), desc="Training epoch " + str(epoch)
|
||||
):
|
||||
step_loss = trainer.distributed_training_step(iterator.get_next())
|
||||
total_loss += step_loss
|
||||
with train_summary_writer.as_default():
|
||||
tf.summary.scalar('batch_loss', step_loss, step=step)
|
||||
tf.summary.scalar('batch_mean_absolute_error', trainer.train_mae.result(), step=step)
|
||||
tf.summary.scalar("batch_loss", step_loss, step=step)
|
||||
tf.summary.scalar(
|
||||
"batch_mean_absolute_error", trainer.train_mae.result(), step=step
|
||||
)
|
||||
|
||||
train_loss=total_loss/args.steps_per_epoch
|
||||
train_loss = total_loss / args.steps_per_epoch
|
||||
|
||||
for x in tqdm(dataset_val, desc="Validating epoch "+str(epoch)):
|
||||
for x in tqdm(dataset_val, desc="Validating epoch " + str(epoch)):
|
||||
trainer.distributed_test_step(x)
|
||||
|
||||
template = ("Epoch {}, Loss: {}, train_MAE: {}, val_Loss: {}, val_MAE: {}")
|
||||
print (template.format(epoch+1, train_loss, trainer.train_mae.result(), trainer.val_loss.result(), trainer.val_mae.result()))
|
||||
template = "Epoch {}, Loss: {}, train_MAE: {}, val_Loss: {}, val_MAE: {}"
|
||||
print(
|
||||
template.format(
|
||||
epoch + 1,
|
||||
train_loss,
|
||||
trainer.train_mae.result(),
|
||||
trainer.val_loss.result(),
|
||||
trainer.val_mae.result(),
|
||||
)
|
||||
)
|
||||
|
||||
with train_summary_writer.as_default():
|
||||
tf.summary.scalar('epoch_loss', train_loss, step=epoch)
|
||||
tf.summary.scalar('epoch_mean_absolute_error', trainer.train_mae.result(), step=epoch)
|
||||
tf.summary.scalar("epoch_loss", train_loss, step=epoch)
|
||||
tf.summary.scalar(
|
||||
"epoch_mean_absolute_error", trainer.train_mae.result(), step=epoch
|
||||
)
|
||||
with val_summary_writer.as_default():
|
||||
tf.summary.scalar('epoch_loss', trainer.val_loss.result(), step=epoch)
|
||||
tf.summary.scalar('epoch_mean_absolute_error', trainer.val_mae.result(), step=epoch)
|
||||
tf.summary.scalar("epoch_loss", trainer.val_loss.result(), step=epoch)
|
||||
tf.summary.scalar(
|
||||
"epoch_mean_absolute_error", trainer.val_mae.result(), step=epoch
|
||||
)
|
||||
|
||||
trainer.train_mae.reset_states()
|
||||
trainer.val_loss.reset_states()
|
||||
trainer.val_mae.reset_states()
|
||||
|
||||
if (epoch+1) % 50 == 0:
|
||||
if (epoch + 1) % 50 == 0:
|
||||
if args.variable_lr:
|
||||
current_lr*=1e-1
|
||||
trainer.optimizer.lr=current_lr
|
||||
current_lr *= 1e-1
|
||||
trainer.optimizer.lr = current_lr
|
||||
try:
|
||||
unet_model.save_weights(checkpoint_filpath)
|
||||
except:
|
||||
unet_model.save_weights(checkpoint_filepath)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _main(args):
|
||||
global __file__
|
||||
|
||||
@ -161,10 +197,3 @@ def main(args):
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
77
trainer.py
77
trainer.py
@ -1,39 +1,37 @@
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import soundfile as sf
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
|
||||
class Trainer():
|
||||
def __init__(self, model, optimizer,loss, strategy, path_experiment, args):
|
||||
self.model=model
|
||||
|
||||
class Trainer:
|
||||
def __init__(self, model, optimizer, loss, strategy, path_experiment, args):
|
||||
self.model = model
|
||||
print(self.model.summary())
|
||||
self.strategy=strategy
|
||||
self.optimizer=optimizer
|
||||
self.path_experiment=path_experiment
|
||||
self.args=args
|
||||
#self.metrics=[]
|
||||
self.strategy = strategy
|
||||
self.optimizer = optimizer
|
||||
self.path_experiment = path_experiment
|
||||
self.args = args
|
||||
# self.metrics=[]
|
||||
|
||||
with self.strategy.scope():
|
||||
#loss_fn=tf.keras.losses.mean_absolute_error
|
||||
loss.reduction=tf.keras.losses.Reduction.NONE
|
||||
self.loss_object=loss
|
||||
self.train_mae_s1=tf.keras.metrics.MeanAbsoluteError(name="train_mae_s1")
|
||||
self.train_mae=tf.keras.metrics.MeanAbsoluteError(name="train_mae_s2")
|
||||
self.val_mae=tf.keras.metrics.MeanAbsoluteError(name="validation_mae")
|
||||
self.val_loss = tf.keras.metrics.Mean(name='test_loss')
|
||||
# loss_fn=tf.keras.losses.mean_absolute_error
|
||||
loss.reduction = tf.keras.losses.Reduction.NONE
|
||||
self.loss_object = loss
|
||||
self.train_mae_s1 = tf.keras.metrics.MeanAbsoluteError(name="train_mae_s1")
|
||||
self.train_mae = tf.keras.metrics.MeanAbsoluteError(name="train_mae_s2")
|
||||
self.val_mae = tf.keras.metrics.MeanAbsoluteError(name="validation_mae")
|
||||
self.val_loss = tf.keras.metrics.Mean(name="test_loss")
|
||||
|
||||
|
||||
def train_step(self,inputs):
|
||||
noisy, clean= inputs
|
||||
def train_step(self, inputs):
|
||||
noisy, clean = inputs
|
||||
|
||||
with tf.GradientTape() as tape:
|
||||
logits_2, logits_1 = self.model(
|
||||
noisy, training=True
|
||||
) # Logits for this minibatch
|
||||
|
||||
logits_2,logits_1 = self.model(noisy, training=True) # Logits for this minibatch
|
||||
|
||||
loss_value = tf.reduce_mean(self.loss_object(clean, logits_2) + tf.reduce_mean(self.loss_object(clean, logits_1)))
|
||||
loss_value = tf.reduce_mean(
|
||||
self.loss_object(clean, logits_2)
|
||||
+ tf.reduce_mean(self.loss_object(clean, logits_1))
|
||||
)
|
||||
|
||||
grads = tape.gradient(loss_value, self.model.trainable_weights)
|
||||
self.optimizer.apply_gradients(zip(grads, self.model.trainable_weights))
|
||||
@ -41,26 +39,25 @@ class Trainer():
|
||||
self.train_mae_s1.update_state(clean, logits_1)
|
||||
return loss_value
|
||||
|
||||
def test_step(self,inputs):
|
||||
|
||||
noisy,clean = inputs
|
||||
def test_step(self, inputs):
|
||||
noisy, clean = inputs
|
||||
|
||||
predictions_s2, predictions_s1 = self.model(noisy, training=False)
|
||||
t_loss = self.loss_object(clean, predictions_s2)+self.loss_object(clean, predictions_s1)
|
||||
t_loss = self.loss_object(clean, predictions_s2) + self.loss_object(
|
||||
clean, predictions_s1
|
||||
)
|
||||
|
||||
self.val_mae.update_state(clean,predictions_s2)
|
||||
self.val_mae.update_state(clean, predictions_s2)
|
||||
self.val_loss.update_state(t_loss)
|
||||
|
||||
@tf.function()
|
||||
def distributed_training_step(self,inputs):
|
||||
per_replica_losses=self.strategy.run(self.train_step, args=(inputs,))
|
||||
reduced_losses=self.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None)
|
||||
def distributed_training_step(self, inputs):
|
||||
per_replica_losses = self.strategy.run(self.train_step, args=(inputs,))
|
||||
reduced_losses = self.strategy.reduce(
|
||||
tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None
|
||||
)
|
||||
return reduced_losses
|
||||
|
||||
@tf.function
|
||||
def distributed_test_step(self,inputs):
|
||||
def distributed_test_step(self, inputs):
|
||||
return self.strategy.run(self.test_step, args=(inputs,))
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
564
unet.py
564
unet.py
@ -1,84 +1,93 @@
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras import Model, Input
|
||||
from tensorflow.keras import Input
|
||||
from tensorflow.keras import layers
|
||||
from tensorflow.keras.initializers import TruncatedNormal
|
||||
import math as m
|
||||
|
||||
|
||||
def build_model_denoise(unet_args=None):
|
||||
inputs = Input(shape=(None, None, 2))
|
||||
|
||||
inputs=Input(shape=(None, None,2))
|
||||
outputs_stage_2, outputs_stage_1 = MultiStage_denoise(unet_args=unet_args)(inputs)
|
||||
|
||||
outputs_stage_2,outputs_stage_1=MultiStage_denoise(unet_args=unet_args)(inputs)
|
||||
|
||||
#Encapsulating MultiStage_denoise in a keras.Model object
|
||||
model= tf.keras.Model(inputs=inputs,outputs=[outputs_stage_2, outputs_stage_1])
|
||||
# Encapsulating MultiStage_denoise in a keras.Model object
|
||||
model = tf.keras.Model(inputs=inputs, outputs=[outputs_stage_2, outputs_stage_1])
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class DenseBlock(layers.Layer):
|
||||
'''
|
||||
"""
|
||||
[B, T, F, N] => [B, T, F, N]
|
||||
DenseNet Block consisting of "num_layers" densely connected convolutional layers
|
||||
'''
|
||||
def __init__(self, num_layers, N, ksize,activation):
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self, num_layers, N, ksize, activation):
|
||||
"""
|
||||
num_layers: number of densely connected conv. layers
|
||||
N: Number of filters (same in each layer)
|
||||
ksize: Kernel size (same in each layer)
|
||||
'''
|
||||
"""
|
||||
super(DenseBlock, self).__init__()
|
||||
self.activation=activation
|
||||
self.activation = activation
|
||||
|
||||
self.paddings_1=get_paddings(ksize)
|
||||
self.H=[]
|
||||
self.num_layers=num_layers
|
||||
self.paddings_1 = get_paddings(ksize)
|
||||
self.H = []
|
||||
self.num_layers = num_layers
|
||||
|
||||
for i in range(num_layers):
|
||||
self.H.append(layers.Conv2D(filters=N,
|
||||
self.H.append(
|
||||
layers.Conv2D(
|
||||
filters=N,
|
||||
kernel_size=ksize,
|
||||
kernel_initializer=TruncatedNormal(),
|
||||
strides=1,
|
||||
padding='VALID',
|
||||
activation=self.activation))
|
||||
padding="VALID",
|
||||
activation=self.activation,
|
||||
)
|
||||
)
|
||||
|
||||
def call(self, x):
|
||||
|
||||
x_=tf.pad(x, self.paddings_1, mode='SYMMETRIC')
|
||||
x_ = tf.pad(x, self.paddings_1, mode="SYMMETRIC")
|
||||
x_ = self.H[0](x_)
|
||||
if self.num_layers>1:
|
||||
if self.num_layers > 1:
|
||||
for h in self.H[1:]:
|
||||
x = tf.concat([x_, x], axis=-1)
|
||||
x_=tf.pad(x, self.paddings_1, mode='SYMMETRIC')
|
||||
x_ = tf.pad(x, self.paddings_1, mode="SYMMETRIC")
|
||||
x_ = h(x_)
|
||||
|
||||
return x_
|
||||
|
||||
|
||||
class FinalBlock(layers.Layer):
|
||||
'''
|
||||
"""
|
||||
[B, T, F, N] => [B, T, F, 2]
|
||||
Final block. Basically, a 3x3 conv. layer to map the output features to the output complex spectrogram.
|
||||
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(FinalBlock, self).__init__()
|
||||
ksize=(3,3)
|
||||
self.paddings_2=get_paddings(ksize)
|
||||
self.conv2=layers.Conv2D(filters=2,
|
||||
ksize = (3, 3)
|
||||
self.paddings_2 = get_paddings(ksize)
|
||||
self.conv2 = layers.Conv2D(
|
||||
filters=2,
|
||||
kernel_size=ksize,
|
||||
kernel_initializer=TruncatedNormal(),
|
||||
strides=1,
|
||||
padding='VALID',
|
||||
activation=None)
|
||||
padding="VALID",
|
||||
activation=None,
|
||||
)
|
||||
|
||||
|
||||
def call(self, inputs ):
|
||||
|
||||
x=tf.pad(inputs, self.paddings_2, mode='SYMMETRIC')
|
||||
pred=self.conv2(x)
|
||||
def call(self, inputs):
|
||||
x = tf.pad(inputs, self.paddings_2, mode="SYMMETRIC")
|
||||
pred = self.conv2(x)
|
||||
|
||||
return pred
|
||||
|
||||
|
||||
class SAM(layers.Layer):
|
||||
'''
|
||||
"""
|
||||
[B, T, F, N] => [B, T, F, N] , [B, T, F, N]
|
||||
Supervised Attention Module:
|
||||
The purpose of SAM is to make the network only propagate the most relevant features to the second stage, discarding the less useful ones.
|
||||
@ -86,390 +95,424 @@ class SAM(layers.Layer):
|
||||
The first stage output is then calculated adding the original input spectrogram to the residual noise.
|
||||
The attention-guided features are computed using the attention masks M, which are directly calculated from the first stage output with a 1x1 convolution and a sigmoid function.
|
||||
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self, n_feat):
|
||||
super(SAM, self).__init__()
|
||||
|
||||
ksize=(3,3)
|
||||
self.paddings_1=get_paddings(ksize)
|
||||
self.conv1 = layers.Conv2D(filters=n_feat,
|
||||
ksize = (3, 3)
|
||||
self.paddings_1 = get_paddings(ksize)
|
||||
self.conv1 = layers.Conv2D(
|
||||
filters=n_feat,
|
||||
kernel_size=ksize,
|
||||
kernel_initializer=TruncatedNormal(),
|
||||
strides=1,
|
||||
padding='VALID',
|
||||
activation=None)
|
||||
ksize=(3,3)
|
||||
self.paddings_2=get_paddings(ksize)
|
||||
self.conv2=layers.Conv2D(filters=2,
|
||||
padding="VALID",
|
||||
activation=None,
|
||||
)
|
||||
ksize = (3, 3)
|
||||
self.paddings_2 = get_paddings(ksize)
|
||||
self.conv2 = layers.Conv2D(
|
||||
filters=2,
|
||||
kernel_size=ksize,
|
||||
kernel_initializer=TruncatedNormal(),
|
||||
strides=1,
|
||||
padding='VALID',
|
||||
activation=None)
|
||||
padding="VALID",
|
||||
activation=None,
|
||||
)
|
||||
|
||||
ksize=(3,3)
|
||||
self.paddings_3=get_paddings(ksize)
|
||||
self.conv3 = layers.Conv2D(filters=n_feat,
|
||||
ksize = (3, 3)
|
||||
self.paddings_3 = get_paddings(ksize)
|
||||
self.conv3 = layers.Conv2D(
|
||||
filters=n_feat,
|
||||
kernel_size=ksize,
|
||||
kernel_initializer=TruncatedNormal(),
|
||||
strides=1,
|
||||
padding='VALID',
|
||||
activation=None)
|
||||
self.cropadd=CropAddBlock()
|
||||
padding="VALID",
|
||||
activation=None,
|
||||
)
|
||||
self.cropadd = CropAddBlock()
|
||||
|
||||
def call(self, inputs, input_spectrogram):
|
||||
x1=tf.pad(inputs, self.paddings_1, mode='SYMMETRIC')
|
||||
x1 = tf.pad(inputs, self.paddings_1, mode="SYMMETRIC")
|
||||
x1 = self.conv1(x1)
|
||||
|
||||
x=tf.pad(inputs, self.paddings_2, mode='SYMMETRIC')
|
||||
x=self.conv2(x)
|
||||
x = tf.pad(inputs, self.paddings_2, mode="SYMMETRIC")
|
||||
x = self.conv2(x)
|
||||
|
||||
#residual prediction
|
||||
pred = layers.Add()([x, input_spectrogram]) #features to next stage
|
||||
# residual prediction
|
||||
pred = layers.Add()([x, input_spectrogram]) # features to next stage
|
||||
|
||||
x3=tf.pad(pred, self.paddings_3, mode='SYMMETRIC')
|
||||
M=self.conv3(x3)
|
||||
x3 = tf.pad(pred, self.paddings_3, mode="SYMMETRIC")
|
||||
M = self.conv3(x3)
|
||||
|
||||
M= tf.keras.activations.sigmoid(M)
|
||||
x1=layers.Multiply()([x1, M])
|
||||
x1 = layers.Add()([x1, inputs]) #features to next stage
|
||||
M = tf.keras.activations.sigmoid(M)
|
||||
x1 = layers.Multiply()([x1, M])
|
||||
x1 = layers.Add()([x1, inputs]) # features to next stage
|
||||
|
||||
return x1, pred
|
||||
|
||||
|
||||
class AddFreqEncoding(layers.Layer):
|
||||
'''
|
||||
"""
|
||||
[B, T, F, 2] => [B, T, F, 12]
|
||||
Generates frequency positional embeddings and concatenates them as 10 extra channels
|
||||
This function is optimized for F=1025
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self, f_dim):
|
||||
super(AddFreqEncoding, self).__init__()
|
||||
pi = tf.constant(m.pi)
|
||||
pi=tf.cast(pi,'float32')
|
||||
self.f_dim=f_dim #f_dim is fixed
|
||||
n=tf.cast(tf.range(f_dim)/(f_dim-1),'float32')
|
||||
coss=tf.math.cos(pi*n)
|
||||
f_channel = tf.expand_dims(coss, -1) #(1025,1)
|
||||
self.fembeddings= f_channel
|
||||
|
||||
for k in range(1,10):
|
||||
coss=tf.math.cos(2**k*pi*n)
|
||||
f_channel = tf.expand_dims(coss, -1) #(1025,1)
|
||||
self.fembeddings=tf.concat([self.fembeddings,f_channel],axis=-1) #(1025,10)
|
||||
pi = tf.cast(pi, "float32")
|
||||
self.f_dim = f_dim # f_dim is fixed
|
||||
n = tf.cast(tf.range(f_dim) / (f_dim - 1), "float32")
|
||||
coss = tf.math.cos(pi * n)
|
||||
f_channel = tf.expand_dims(coss, -1) # (1025,1)
|
||||
self.fembeddings = f_channel
|
||||
|
||||
for k in range(1, 10):
|
||||
coss = tf.math.cos(2**k * pi * n)
|
||||
f_channel = tf.expand_dims(coss, -1) # (1025,1)
|
||||
self.fembeddings = tf.concat(
|
||||
[self.fembeddings, f_channel], axis=-1
|
||||
) # (1025,10)
|
||||
|
||||
def call(self, input_tensor):
|
||||
|
||||
batch_size_tensor = tf.shape(input_tensor)[0] # get batch size
|
||||
time_dim = tf.shape(input_tensor)[1] # get time dimension
|
||||
|
||||
fembeddings_2 = tf.broadcast_to(self.fembeddings, [batch_size_tensor, time_dim, self.f_dim, 10])
|
||||
fembeddings_2 = tf.broadcast_to(
|
||||
self.fembeddings, [batch_size_tensor, time_dim, self.f_dim, 10]
|
||||
)
|
||||
|
||||
|
||||
return tf.concat([input_tensor,fembeddings_2],axis=-1) #(batch,427,1025,12)
|
||||
return tf.concat([input_tensor, fembeddings_2], axis=-1) # (batch,427,1025,12)
|
||||
|
||||
|
||||
def get_paddings(K):
|
||||
return tf.constant([[0,0],[K[0]//2, K[0]//2 -(1- K[0]%2) ], [ K[1]//2, K[1]//2 -(1- K[1]%2) ],[0,0]])
|
||||
return tf.constant(
|
||||
[
|
||||
[0, 0],
|
||||
[K[0] // 2, K[0] // 2 - (1 - K[0] % 2)],
|
||||
[K[1] // 2, K[1] // 2 - (1 - K[1] % 2)],
|
||||
[0, 0],
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class Decoder(layers.Layer):
|
||||
'''
|
||||
"""
|
||||
[B, T, F, N] , skip connections => [B, T, F, N]
|
||||
Decoder side of the U-Net subnetwork.
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self, Ns, Ss, unet_args):
|
||||
super(Decoder, self).__init__()
|
||||
|
||||
self.Ns=Ns
|
||||
self.Ss=Ss
|
||||
self.activation=unet_args.activation
|
||||
self.depth=unet_args.depth
|
||||
self.Ns = Ns
|
||||
self.Ss = Ss
|
||||
self.activation = unet_args.activation
|
||||
self.depth = unet_args.depth
|
||||
|
||||
|
||||
ksize=(3,3)
|
||||
self.paddings_3=get_paddings(ksize)
|
||||
self.conv2d_3=layers.Conv2D(filters=self.Ns[self.depth],
|
||||
ksize = (3, 3)
|
||||
self.paddings_3 = get_paddings(ksize)
|
||||
self.conv2d_3 = layers.Conv2D(
|
||||
filters=self.Ns[self.depth],
|
||||
kernel_size=ksize,
|
||||
kernel_initializer=TruncatedNormal(),
|
||||
strides=1,
|
||||
padding='VALID',
|
||||
activation=self.activation)
|
||||
padding="VALID",
|
||||
activation=self.activation,
|
||||
)
|
||||
|
||||
self.cropadd=CropAddBlock()
|
||||
self.cropadd = CropAddBlock()
|
||||
|
||||
self.dblocks=[]
|
||||
self.dblocks = []
|
||||
for i in range(self.depth):
|
||||
self.dblocks.append(D_Block(layer_idx=i,N=self.Ns[i], S=self.Ss[i], activation=self.activation,num_tfc=unet_args.num_tfc))
|
||||
self.dblocks.append(
|
||||
D_Block(
|
||||
layer_idx=i,
|
||||
N=self.Ns[i],
|
||||
S=self.Ss[i],
|
||||
activation=self.activation,
|
||||
num_tfc=unet_args.num_tfc,
|
||||
)
|
||||
)
|
||||
|
||||
def call(self,inputs, contracting_layers):
|
||||
x=inputs
|
||||
for i in range(self.depth,0,-1):
|
||||
x=self.dblocks[i-1](x, contracting_layers[i-1])
|
||||
def call(self, inputs, contracting_layers):
|
||||
x = inputs
|
||||
for i in range(self.depth, 0, -1):
|
||||
x = self.dblocks[i - 1](x, contracting_layers[i - 1])
|
||||
return x
|
||||
|
||||
class Encoder(tf.keras.Model):
|
||||
|
||||
'''
|
||||
class Encoder(tf.keras.Model):
|
||||
"""
|
||||
[B, T, F, N] => skip connections , [B, T, F, N_4]
|
||||
Encoder side of the U-Net subnetwork.
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self, Ns, Ss, unet_args):
|
||||
super(Encoder, self).__init__()
|
||||
self.Ns=Ns
|
||||
self.Ss=Ss
|
||||
self.activation=unet_args.activation
|
||||
self.depth=unet_args.depth
|
||||
self.Ns = Ns
|
||||
self.Ss = Ss
|
||||
self.activation = unet_args.activation
|
||||
self.depth = unet_args.depth
|
||||
|
||||
self.contracting_layers = {}
|
||||
|
||||
self.eblocks=[]
|
||||
self.eblocks = []
|
||||
for i in range(self.depth):
|
||||
self.eblocks.append(E_Block(layer_idx=i,N0=self.Ns[i],N=self.Ns[i+1],S=self.Ss[i], activation=self.activation , num_tfc=unet_args.num_tfc))
|
||||
self.eblocks.append(
|
||||
E_Block(
|
||||
layer_idx=i,
|
||||
N0=self.Ns[i],
|
||||
N=self.Ns[i + 1],
|
||||
S=self.Ss[i],
|
||||
activation=self.activation,
|
||||
num_tfc=unet_args.num_tfc,
|
||||
)
|
||||
)
|
||||
|
||||
self.i_block=I_Block(self.Ns[self.depth],self.activation,unet_args.num_tfc)
|
||||
self.i_block = I_Block(self.Ns[self.depth], self.activation, unet_args.num_tfc)
|
||||
|
||||
def call(self, inputs):
|
||||
x=inputs
|
||||
x = inputs
|
||||
for i in range(self.depth):
|
||||
x, x_contract = self.eblocks[i](x)
|
||||
|
||||
x, x_contract=self.eblocks[i](x)
|
||||
|
||||
self.contracting_layers[i] = x_contract #if remove 0, correct this
|
||||
x=self.i_block(x)
|
||||
self.contracting_layers[i] = x_contract # if remove 0, correct this
|
||||
x = self.i_block(x)
|
||||
|
||||
return x, self.contracting_layers
|
||||
|
||||
class MultiStage_denoise(tf.keras.Model):
|
||||
|
||||
class MultiStage_denoise(tf.keras.Model):
|
||||
def __init__(self, unet_args=None):
|
||||
super(MultiStage_denoise, self).__init__()
|
||||
|
||||
self.activation=unet_args.activation
|
||||
self.depth=unet_args.depth
|
||||
self.activation = unet_args.activation
|
||||
self.depth = unet_args.depth
|
||||
if unet_args.use_fencoding:
|
||||
self.freq_encoding=AddFreqEncoding(unet_args.f_dim)
|
||||
self.use_sam=unet_args.use_SAM
|
||||
self.use_fencoding=unet_args.use_fencoding
|
||||
self.num_stages=unet_args.num_stages
|
||||
#Encoder
|
||||
self.Ns= [32,64,64,128,128,256,512]
|
||||
self.Ss= [(2,2),(2,2),(2,2),(2,2),(2,2),(2,2)]
|
||||
self.freq_encoding = AddFreqEncoding(unet_args.f_dim)
|
||||
self.use_sam = unet_args.use_SAM
|
||||
self.use_fencoding = unet_args.use_fencoding
|
||||
self.num_stages = unet_args.num_stages
|
||||
# Encoder
|
||||
self.Ns = [32, 64, 64, 128, 128, 256, 512]
|
||||
self.Ss = [(2, 2), (2, 2), (2, 2), (2, 2), (2, 2), (2, 2)]
|
||||
|
||||
#initial feature extractor
|
||||
ksize=(7,7)
|
||||
self.paddings_1=get_paddings(ksize)
|
||||
self.conv2d_1 = layers.Conv2D(filters=self.Ns[0],
|
||||
# initial feature extractor
|
||||
ksize = (7, 7)
|
||||
self.paddings_1 = get_paddings(ksize)
|
||||
self.conv2d_1 = layers.Conv2D(
|
||||
filters=self.Ns[0],
|
||||
kernel_size=ksize,
|
||||
kernel_initializer=TruncatedNormal(),
|
||||
strides=1,
|
||||
padding='VALID',
|
||||
activation=self.activation)
|
||||
padding="VALID",
|
||||
activation=self.activation,
|
||||
)
|
||||
|
||||
|
||||
self.encoder_s1=Encoder(self.Ns, self.Ss, unet_args)
|
||||
self.decoder_s1=Decoder(self.Ns, self.Ss, unet_args)
|
||||
self.encoder_s1 = Encoder(self.Ns, self.Ss, unet_args)
|
||||
self.decoder_s1 = Decoder(self.Ns, self.Ss, unet_args)
|
||||
|
||||
self.cropconcat = CropConcatBlock()
|
||||
self.cropadd = CropAddBlock()
|
||||
|
||||
self.finalblock=FinalBlock()
|
||||
self.finalblock = FinalBlock()
|
||||
|
||||
if self.num_stages>1:
|
||||
self.sam_1=SAM(self.Ns[0])
|
||||
if self.num_stages > 1:
|
||||
self.sam_1 = SAM(self.Ns[0])
|
||||
|
||||
#initial feature extractor
|
||||
ksize=(7,7)
|
||||
self.paddings_2=get_paddings(ksize)
|
||||
self.conv2d_2 = layers.Conv2D(filters=self.Ns[0],
|
||||
# initial feature extractor
|
||||
ksize = (7, 7)
|
||||
self.paddings_2 = get_paddings(ksize)
|
||||
self.conv2d_2 = layers.Conv2D(
|
||||
filters=self.Ns[0],
|
||||
kernel_size=ksize,
|
||||
kernel_initializer=TruncatedNormal(),
|
||||
strides=1,
|
||||
padding='VALID',
|
||||
activation=self.activation)
|
||||
padding="VALID",
|
||||
activation=self.activation,
|
||||
)
|
||||
|
||||
|
||||
self.encoder_s2=Encoder(self.Ns, self.Ss, unet_args)
|
||||
self.decoder_s2=Decoder(self.Ns, self.Ss, unet_args)
|
||||
self.encoder_s2 = Encoder(self.Ns, self.Ss, unet_args)
|
||||
self.decoder_s2 = Decoder(self.Ns, self.Ss, unet_args)
|
||||
|
||||
@tf.function()
|
||||
def call(self, inputs):
|
||||
|
||||
if self.use_fencoding:
|
||||
x_w_freq=self.freq_encoding(inputs) #None, None, 1025, 12
|
||||
x_w_freq = self.freq_encoding(inputs) # None, None, 1025, 12
|
||||
else:
|
||||
x_w_freq=inputs
|
||||
x_w_freq = inputs
|
||||
|
||||
#intitial feature extractor
|
||||
x=tf.pad(x_w_freq, self.paddings_1, mode='SYMMETRIC')
|
||||
x=self.conv2d_1(x) #None, None, 1025, 32
|
||||
# intitial feature extractor
|
||||
x = tf.pad(x_w_freq, self.paddings_1, mode="SYMMETRIC")
|
||||
x = self.conv2d_1(x) # None, None, 1025, 32
|
||||
|
||||
x, contracting_layers_s1= self.encoder_s1(x)
|
||||
#decoder
|
||||
feats_s1 =self.decoder_s1(x, contracting_layers_s1) #None, None, 1025, 32 features
|
||||
x, contracting_layers_s1 = self.encoder_s1(x)
|
||||
# decoder
|
||||
feats_s1 = self.decoder_s1(
|
||||
x, contracting_layers_s1
|
||||
) # None, None, 1025, 32 features
|
||||
|
||||
if self.num_stages>1:
|
||||
#SAM module
|
||||
Fout, pred_stage_1=self.sam_1(feats_s1,inputs)
|
||||
if self.num_stages > 1:
|
||||
# SAM module
|
||||
Fout, pred_stage_1 = self.sam_1(feats_s1, inputs)
|
||||
|
||||
#intitial feature extractor
|
||||
x=tf.pad(x_w_freq, self.paddings_2, mode='SYMMETRIC')
|
||||
x=self.conv2d_2(x)
|
||||
# intitial feature extractor
|
||||
x = tf.pad(x_w_freq, self.paddings_2, mode="SYMMETRIC")
|
||||
x = self.conv2d_2(x)
|
||||
|
||||
if self.use_sam:
|
||||
x = tf.concat([x, Fout], axis=-1)
|
||||
else:
|
||||
x = tf.concat([x,feats_s1], axis=-1)
|
||||
x = tf.concat([x, feats_s1], axis=-1)
|
||||
|
||||
x, contracting_layers_s2= self.encoder_s2(x)
|
||||
x, contracting_layers_s2 = self.encoder_s2(x)
|
||||
|
||||
feats_s2=self.decoder_s2(x, contracting_layers_s2) #None, None, 1025, 32 features
|
||||
feats_s2 = self.decoder_s2(
|
||||
x, contracting_layers_s2
|
||||
) # None, None, 1025, 32 features
|
||||
|
||||
#consider implementing a third stage?
|
||||
# consider implementing a third stage?
|
||||
|
||||
pred_stage_2=self.finalblock(feats_s2)
|
||||
pred_stage_2 = self.finalblock(feats_s2)
|
||||
return pred_stage_2, pred_stage_1
|
||||
else:
|
||||
pred_stage_1=self.finalblock(feats_s1)
|
||||
pred_stage_1 = self.finalblock(feats_s1)
|
||||
return pred_stage_1
|
||||
|
||||
|
||||
class I_Block(layers.Layer):
|
||||
'''
|
||||
"""
|
||||
[B, T, F, N] => [B, T, F, N]
|
||||
Intermediate block:
|
||||
Basically, a densenet block with a residual connection
|
||||
'''
|
||||
def __init__(self,N,activation, num_tfc, **kwargs):
|
||||
"""
|
||||
|
||||
def __init__(self, N, activation, num_tfc, **kwargs):
|
||||
super(I_Block, self).__init__(**kwargs)
|
||||
|
||||
ksize=(3,3)
|
||||
self.tfc=DenseBlock(num_tfc,N,ksize, activation)
|
||||
ksize = (3, 3)
|
||||
self.tfc = DenseBlock(num_tfc, N, ksize, activation)
|
||||
|
||||
self.conv2d_res= layers.Conv2D(filters=N,
|
||||
kernel_size=(1,1),
|
||||
self.conv2d_res = layers.Conv2D(
|
||||
filters=N,
|
||||
kernel_size=(1, 1),
|
||||
kernel_initializer=TruncatedNormal(),
|
||||
strides=1,
|
||||
padding='VALID')
|
||||
padding="VALID",
|
||||
)
|
||||
|
||||
def call(self,inputs):
|
||||
x=self.tfc(inputs)
|
||||
def call(self, inputs):
|
||||
x = self.tfc(inputs)
|
||||
|
||||
inputs_proj=self.conv2d_res(inputs)
|
||||
return layers.Add()([x,inputs_proj])
|
||||
inputs_proj = self.conv2d_res(inputs)
|
||||
return layers.Add()([x, inputs_proj])
|
||||
|
||||
|
||||
class E_Block(layers.Layer):
|
||||
|
||||
def __init__(self, layer_idx,N0, N, S,activation, num_tfc, **kwargs):
|
||||
def __init__(self, layer_idx, N0, N, S, activation, num_tfc, **kwargs):
|
||||
super(E_Block, self).__init__(**kwargs)
|
||||
self.layer_idx=layer_idx
|
||||
self.N0=N0
|
||||
self.N=N
|
||||
self.S=S
|
||||
self.activation=activation
|
||||
self.i_block=I_Block(N0,activation,num_tfc)
|
||||
self.layer_idx = layer_idx
|
||||
self.N0 = N0
|
||||
self.N = N
|
||||
self.S = S
|
||||
self.activation = activation
|
||||
self.i_block = I_Block(N0, activation, num_tfc)
|
||||
|
||||
ksize=(S[0]+2,S[1]+2)
|
||||
self.paddings_2=get_paddings(ksize)
|
||||
self.conv2d_2 = layers.Conv2D(filters=N,
|
||||
kernel_size=(S[0]+2,S[1]+2),
|
||||
ksize = (S[0] + 2, S[1] + 2)
|
||||
self.paddings_2 = get_paddings(ksize)
|
||||
self.conv2d_2 = layers.Conv2D(
|
||||
filters=N,
|
||||
kernel_size=(S[0] + 2, S[1] + 2),
|
||||
kernel_initializer=TruncatedNormal(),
|
||||
strides=S,
|
||||
padding='VALID',
|
||||
activation=self.activation)
|
||||
|
||||
padding="VALID",
|
||||
activation=self.activation,
|
||||
)
|
||||
|
||||
def call(self, inputs, training=None, **kwargs):
|
||||
x=self.i_block(inputs)
|
||||
x = self.i_block(inputs)
|
||||
|
||||
x_down=tf.pad(x, self.paddings_2, mode='SYMMETRIC')
|
||||
x_down = tf.pad(x, self.paddings_2, mode="SYMMETRIC")
|
||||
x_down = self.conv2d_2(x_down)
|
||||
|
||||
return x_down, x
|
||||
|
||||
|
||||
def get_config(self):
|
||||
return dict(layer_idx=self.layer_idx,
|
||||
return dict(
|
||||
layer_idx=self.layer_idx,
|
||||
N=self.N,
|
||||
S=self.S,
|
||||
**super(E_Block, self).get_config()
|
||||
**super(E_Block, self).get_config(),
|
||||
)
|
||||
|
||||
|
||||
class D_Block(layers.Layer):
|
||||
|
||||
def __init__(self, layer_idx, N, S,activation, num_tfc, **kwargs):
|
||||
def __init__(self, layer_idx, N, S, activation, num_tfc, **kwargs):
|
||||
super(D_Block, self).__init__(**kwargs)
|
||||
self.layer_idx=layer_idx
|
||||
self.N=N
|
||||
self.S=S
|
||||
self.activation=activation
|
||||
ksize=(S[0]+2, S[1]+2)
|
||||
self.paddings_1=get_paddings(ksize)
|
||||
self.layer_idx = layer_idx
|
||||
self.N = N
|
||||
self.S = S
|
||||
self.activation = activation
|
||||
ksize = (S[0] + 2, S[1] + 2)
|
||||
self.paddings_1 = get_paddings(ksize)
|
||||
|
||||
self.tconv_1= layers.Conv2DTranspose(filters=N,
|
||||
kernel_size=(S[0]+2, S[1]+2),
|
||||
self.tconv_1 = layers.Conv2DTranspose(
|
||||
filters=N,
|
||||
kernel_size=(S[0] + 2, S[1] + 2),
|
||||
kernel_initializer=TruncatedNormal(),
|
||||
strides=S,
|
||||
activation=self.activation,
|
||||
padding='VALID')
|
||||
padding="VALID",
|
||||
)
|
||||
|
||||
self.upsampling = layers.UpSampling2D(size=S, interpolation='nearest')
|
||||
self.upsampling = layers.UpSampling2D(size=S, interpolation="nearest")
|
||||
|
||||
self.projection = layers.Conv2D(filters=N,
|
||||
kernel_size=(1,1),
|
||||
self.projection = layers.Conv2D(
|
||||
filters=N,
|
||||
kernel_size=(1, 1),
|
||||
kernel_initializer=TruncatedNormal(),
|
||||
strides=1,
|
||||
activation=self.activation,
|
||||
padding='VALID')
|
||||
self.cropadd=CropAddBlock()
|
||||
self.cropconcat=CropConcatBlock()
|
||||
padding="VALID",
|
||||
)
|
||||
self.cropadd = CropAddBlock()
|
||||
self.cropconcat = CropConcatBlock()
|
||||
|
||||
self.i_block=I_Block(N,activation,num_tfc)
|
||||
self.i_block = I_Block(N, activation, num_tfc)
|
||||
|
||||
def call(self, inputs, bridge, previous_encoder=None, previous_decoder=None,**kwargs):
|
||||
def call(
|
||||
self, inputs, bridge, previous_encoder=None, previous_decoder=None, **kwargs
|
||||
):
|
||||
x = inputs
|
||||
x=tf.pad(x, self.paddings_1, mode='SYMMETRIC')
|
||||
x = tf.pad(x, self.paddings_1, mode="SYMMETRIC")
|
||||
x = self.tconv_1(inputs)
|
||||
|
||||
x2= self.upsampling(inputs)
|
||||
x2 = self.upsampling(inputs)
|
||||
|
||||
if x2.shape[-1]!=x.shape[-1]:
|
||||
x2= self.projection(x2)
|
||||
if x2.shape[-1] != x.shape[-1]:
|
||||
x2 = self.projection(x2)
|
||||
|
||||
x= self.cropadd(x,x2)
|
||||
x = self.cropadd(x, x2)
|
||||
|
||||
x = self.cropconcat(x, bridge)
|
||||
|
||||
x=self.cropconcat(x,bridge)
|
||||
|
||||
x=self.i_block(x)
|
||||
x = self.i_block(x)
|
||||
return x
|
||||
|
||||
def get_config(self):
|
||||
return dict(layer_idx=self.layer_idx,
|
||||
return dict(
|
||||
layer_idx=self.layer_idx,
|
||||
N=self.N,
|
||||
S=self.S,
|
||||
**super(D_Block, self).get_config()
|
||||
**super(D_Block, self).get_config(),
|
||||
)
|
||||
|
||||
|
||||
class CropAddBlock(layers.Layer):
|
||||
|
||||
def call(self,down_layer, x, **kwargs):
|
||||
x1_shape = tf.shape(down_layer)
|
||||
x2_shape = tf.shape(x)
|
||||
|
||||
|
||||
height_diff = (x1_shape[1] - x2_shape[1]) // 2
|
||||
width_diff = (x1_shape[2] - x2_shape[2]) // 2
|
||||
|
||||
down_layer_cropped = down_layer[:,
|
||||
height_diff: (x2_shape[1] + height_diff),
|
||||
width_diff: (x2_shape[2] + width_diff),
|
||||
:]
|
||||
|
||||
x = layers.Add()([down_layer_cropped, x])
|
||||
return x
|
||||
|
||||
class CropConcatBlock(layers.Layer):
|
||||
|
||||
def call(self, down_layer, x, **kwargs):
|
||||
x1_shape = tf.shape(down_layer)
|
||||
x2_shape = tf.shape(x)
|
||||
@ -477,10 +520,31 @@ class CropConcatBlock(layers.Layer):
|
||||
height_diff = (x1_shape[1] - x2_shape[1]) // 2
|
||||
width_diff = (x1_shape[2] - x2_shape[2]) // 2
|
||||
|
||||
down_layer_cropped = down_layer[:,
|
||||
height_diff: (x2_shape[1] + height_diff),
|
||||
width_diff: (x2_shape[2] + width_diff),
|
||||
:]
|
||||
down_layer_cropped = down_layer[
|
||||
:,
|
||||
height_diff : (x2_shape[1] + height_diff),
|
||||
width_diff : (x2_shape[2] + width_diff),
|
||||
:,
|
||||
]
|
||||
|
||||
x = layers.Add()([down_layer_cropped, x])
|
||||
return x
|
||||
|
||||
|
||||
class CropConcatBlock(layers.Layer):
|
||||
def call(self, down_layer, x, **kwargs):
|
||||
x1_shape = tf.shape(down_layer)
|
||||
x2_shape = tf.shape(x)
|
||||
|
||||
height_diff = (x1_shape[1] - x2_shape[1]) // 2
|
||||
width_diff = (x1_shape[2] - x2_shape[2]) // 2
|
||||
|
||||
down_layer_cropped = down_layer[
|
||||
:,
|
||||
height_diff : (x2_shape[1] + height_diff),
|
||||
width_diff : (x2_shape[2] + width_diff),
|
||||
:,
|
||||
]
|
||||
|
||||
x = tf.concat([down_layer_cropped, x], axis=-1)
|
||||
return x
|
||||
|
||||
Loading…
Reference in New Issue
Block a user