Compare commits
No commits in common. "98702405729670a2ef501bda5945447646938451" and "ff2dff25d5c1db5a2c9870c5419e3bae49f82890" have entirely different histories.
9870240572
...
ff2dff25d5
3
.gitignore
vendored
3
.gitignore
vendored
@ -1,3 +0,0 @@
|
|||||||
experiments
|
|
||||||
outputs
|
|
||||||
__pycache__
|
|
||||||
15
README.md
15
README.md
@ -13,7 +13,7 @@ width="400px"></p>
|
|||||||
|
|
||||||
Listen to our [audio samples](http://research.spa.aalto.fi/publications/papers/icassp22-denoising/)
|
Listen to our [audio samples](http://research.spa.aalto.fi/publications/papers/icassp22-denoising/)
|
||||||
|
|
||||||
[](https://colab.research.google.com/github/eloimoliner/denoising-historical-recordings/blob/master/colab/demo.ipynb)
|
[](https://colab.research.google.com/github/eloimoliner/denoising-historical-recordings/blob/colab/colab/demo.ipynb]
|
||||||
|
|
||||||
## Requirements
|
## Requirements
|
||||||
You will need at least python 3.7 and CUDA 10.1 if you want to use GPU. See `requirements.txt` for the required package versions.
|
You will need at least python 3.7 and CUDA 10.1 if you want to use GPU. See `requirements.txt` for the required package versions.
|
||||||
@ -24,10 +24,7 @@ To install the environment through anaconda, follow the instructions:
|
|||||||
conda activate historical_denoiser
|
conda activate historical_denoiser
|
||||||
|
|
||||||
## Denoising Recordings
|
## Denoising Recordings
|
||||||
|
Run the following commands to clone the repository and install the pretrained weights of the two-stage U-Net model:
|
||||||
You can denoise your recordings in the cloud using the Colab notebook. [](https://colab.research.google.com/github/eloimoliner/denoising-historical-recordings/blob/master/colab/demo.ipynb)
|
|
||||||
|
|
||||||
Otherwise, run the following commands to clone the repository and install the pretrained weights of the two-stage U-Net model:
|
|
||||||
|
|
||||||
git clone https://github.com/eloimoliner/denoising-historical-recordings.git
|
git clone https://github.com/eloimoliner/denoising-historical-recordings.git
|
||||||
cd denoising-historical-recordings
|
cd denoising-historical-recordings
|
||||||
@ -40,13 +37,7 @@ If the environment is installed correctly, you can denoise an audio file by runn
|
|||||||
|
|
||||||
A ".wav" file with the denoised version, as well as the residual noise and the original signal in "mono", will be generated in the same directory as the input file.
|
A ".wav" file with the denoised version, as well as the residual noise and the original signal in "mono", will be generated in the same directory as the input file.
|
||||||
## Training
|
## Training
|
||||||
To retrain the model, follow the instructions:
|
TODO
|
||||||
|
|
||||||
Download the [Gramophone Noise Dataset](http://research.spa.aalto.fi/publications/papers/icassp22-denoising/media/datasets/Gramophone_Record_Noise_Dataset.zip), or any other dataset containing recording noises.
|
|
||||||
|
|
||||||
Prepare a dataset of clean music (e.g. [MusicNet](https://zenodo.org/record/5120004#.YnN-96IzbmE))
|
|
||||||
|
|
||||||
|
|
||||||
## Remarks
|
## Remarks
|
||||||
|
|
||||||
The trained model is specialized in denoising gramophone recordings, such as the ones included in this collection https://archive.org/details/georgeblood. It has shown to be robust to a wide range of different noises, but it may produce some artifacts if you try to inference in something completely different.
|
The trained model is specialized in denoising gramophone recordings, such as the ones included in this collection https://archive.org/details/georgeblood. It has shown to be robust to a wide range of different noises, but it may produce some artifacts if you try to inference in something completely different.
|
||||||
|
|||||||
@ -7,7 +7,7 @@
|
|||||||
"colab_type": "text"
|
"colab_type": "text"
|
||||||
},
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"<a href=\"https://colab.research.google.com/github/eloimoliner/denoising-historical-recordings/blob/master/colab/demo.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
"<a href=\"https://colab.research.google.com/github/eloimoliner/denoising-historical-recordings/blob/colab/colab/demo.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -40,8 +40,7 @@
|
|||||||
"* Make sure to use a GPU runtime, click: __Runtime >> Change Runtime Type >> GPU__\n",
|
"* Make sure to use a GPU runtime, click: __Runtime >> Change Runtime Type >> GPU__\n",
|
||||||
"* Press ▶️ on the left of each of the cells\n",
|
"* Press ▶️ on the left of each of the cells\n",
|
||||||
"* View the code: Double-click any of the cells\n",
|
"* View the code: Double-click any of the cells\n",
|
||||||
"* Hide the code: Double click the right side of the cell\n",
|
"* Hide the code: Double click the right side of the cell\n"
|
||||||
"* For some reason, this notebook does not work in Firefox, so please use another browser.\n"
|
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "8UON6ncSApA9"
|
"id": "8UON6ncSApA9"
|
||||||
@ -208,7 +207,7 @@
|
|||||||
"id": "TQBDTmO4mUBx"
|
"id": "TQBDTmO4mUBx"
|
||||||
},
|
},
|
||||||
"id": "TQBDTmO4mUBx",
|
"id": "TQBDTmO4mUBx",
|
||||||
"execution_count": null,
|
"execution_count": 4,
|
||||||
"outputs": []
|
"outputs": []
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -244,7 +243,7 @@
|
|||||||
"outputId": "2d05860c-536d-45f8-92b4-d2ba6f5a54c5"
|
"outputId": "2d05860c-536d-45f8-92b4-d2ba6f5a54c5"
|
||||||
},
|
},
|
||||||
"id": "50Kmdy6AtbhW",
|
"id": "50Kmdy6AtbhW",
|
||||||
"execution_count": null,
|
"execution_count": 5,
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"output_type": "display_data",
|
"output_type": "display_data",
|
||||||
@ -297,7 +296,7 @@
|
|||||||
"outputId": "173f5355-2939-41fe-c702-591aa752fc7e"
|
"outputId": "173f5355-2939-41fe-c702-591aa752fc7e"
|
||||||
},
|
},
|
||||||
"id": "0po6zpvrylc2",
|
"id": "0po6zpvrylc2",
|
||||||
"execution_count": null,
|
"execution_count": 6,
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
@ -334,7 +333,7 @@
|
|||||||
"outputId": "54588c26-0b3c-42bf-aca2-8316ab54603f"
|
"outputId": "54588c26-0b3c-42bf-aca2-8316ab54603f"
|
||||||
},
|
},
|
||||||
"id": "3tEshWBezYvf",
|
"id": "3tEshWBezYvf",
|
||||||
"execution_count": null,
|
"execution_count": 7,
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"output_type": "display_data",
|
"output_type": "display_data",
|
||||||
|
|||||||
@ -4,12 +4,14 @@ import tensorflow as tf
|
|||||||
import random
|
import random
|
||||||
import os
|
import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from scipy.fft import fft, ifft
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
|
import math
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
import scipy as sp
|
||||||
import glob
|
import glob
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
#generator function. It reads the csv file with pandas and loads the largest audio segments from each recording. If extend=False, it will only read the segments with length>length_seg, trim them and yield them with no further processing. Otherwise, if the segment length is inferior, it will extend the length using concatenative synthesis.
|
#generator function. It reads the csv file with pandas and loads the largest audio segments from each recording. If extend=False, it will only read the segments with length>length_seg, trim them and yield them with no further processing. Otherwise, if the segment length is inferior, it will extend the length using concatenative synthesis.
|
||||||
def __noise_sample_generator(info_file,fs, length_seq, split):
|
def __noise_sample_generator(info_file,fs, length_seq, split):
|
||||||
head=os.path.split(info_file)[0]
|
head=os.path.split(info_file)[0]
|
||||||
@ -24,25 +26,14 @@ def __noise_sample_generator(info_file, fs, length_seq, split):
|
|||||||
for i in r:
|
for i in r:
|
||||||
segments=ast.literal_eval(load_data_split.loc[i,"segments"])
|
segments=ast.literal_eval(load_data_split.loc[i,"segments"])
|
||||||
if split=="test":
|
if split=="test":
|
||||||
loaded_data, Fs = sf.read(
|
loaded_data, Fs=sf.read(os.path.join(head,load_data_split["recording"].loc[i],load_data_split["largest_segment"].loc[i]))
|
||||||
os.path.join(
|
|
||||||
head,
|
|
||||||
load_data_split["recording"].loc[i],
|
|
||||||
load_data_split["largest_segment"].loc[i],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
num=np.random.randint(0,len(segments))
|
num=np.random.randint(0,len(segments))
|
||||||
loaded_data, Fs = sf.read(
|
loaded_data, Fs=sf.read(os.path.join(head,load_data_split["recording"].loc[i],segments[num]))
|
||||||
os.path.join(
|
assert(fs==Fs, "wrong sampling rate")
|
||||||
head, load_data_split["recording"].loc[i], segments[num]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
assert fs == Fs, "wrong sampling rate"
|
|
||||||
|
|
||||||
yield __extend_sample_by_repeating(loaded_data,fs,length_seq)
|
yield __extend_sample_by_repeating(loaded_data,fs,length_seq)
|
||||||
|
|
||||||
|
|
||||||
def __extend_sample_by_repeating(data, fs,seq_len):
|
def __extend_sample_by_repeating(data, fs,seq_len):
|
||||||
rpm=78
|
rpm=78
|
||||||
target_samp=seq_len
|
target_samp=seq_len
|
||||||
@ -66,11 +57,12 @@ def __extend_sample_by_repeating(data, fs, seq_len):
|
|||||||
|
|
||||||
overhead=len(data)%period_sam
|
overhead=len(data)%period_sam
|
||||||
|
|
||||||
if overhead > bls:
|
if(overhead>bls):
|
||||||
complete_periods=(len(data)//period_sam)*period_sam
|
complete_periods=(len(data)//period_sam)*period_sam
|
||||||
else:
|
else:
|
||||||
complete_periods=(len(data)//period_sam -1)*period_sam
|
complete_periods=(len(data)//period_sam -1)*period_sam
|
||||||
|
|
||||||
|
|
||||||
a=np.multiply(data[0:bls], window_left)
|
a=np.multiply(data[0:bls], window_left)
|
||||||
b=np.multiply(data[complete_periods:complete_periods+bls], window_right)
|
b=np.multiply(data[complete_periods:complete_periods+bls], window_right)
|
||||||
c_1=np.concatenate((data[0:complete_periods,:],b))
|
c_1=np.concatenate((data[0:complete_periods,:],b))
|
||||||
@ -79,9 +71,10 @@ def __extend_sample_by_repeating(data, fs, seq_len):
|
|||||||
|
|
||||||
large_data[0:complete_periods+bls,:]=c_1
|
large_data[0:complete_periods+bls,:]=c_1
|
||||||
|
|
||||||
|
|
||||||
pointer=complete_periods
|
pointer=complete_periods
|
||||||
not_finished=True
|
not_finished=True
|
||||||
while not_finished:
|
while (not_finished):
|
||||||
if target_samp>pointer+complete_periods+bls:
|
if target_samp>pointer+complete_periods+bls:
|
||||||
large_data[pointer:pointer+complete_periods+bls] +=c_2
|
large_data[pointer:pointer+complete_periods+bls] +=c_2
|
||||||
pointer+=complete_periods
|
pointer+=complete_periods
|
||||||
@ -93,9 +86,8 @@ def __extend_sample_by_repeating(data, fs, seq_len):
|
|||||||
return large_data
|
return large_data
|
||||||
|
|
||||||
|
|
||||||
def generate_real_recordings_data(
|
def generate_real_recordings_data(path_recordings, fs=44100, seg_len_s=15, stereo=False):
|
||||||
path_recordings, fs=44100, seg_len_s=15, stereo=False
|
|
||||||
):
|
|
||||||
records_info=os.path.join(path_recordings,"audio_files.txt")
|
records_info=os.path.join(path_recordings,"audio_files.txt")
|
||||||
num_lines = sum(1 for line in open(records_info))
|
num_lines = sum(1 for line in open(records_info))
|
||||||
f = open(records_info,"r")
|
f = open(records_info,"r")
|
||||||
@ -120,18 +112,8 @@ def generate_real_recordings_data(
|
|||||||
|
|
||||||
return records
|
return records
|
||||||
|
|
||||||
|
def generate_paired_data_test_formal(path_pianos, path_noises, noise_amount="low_snr",num_samples=-1, fs=44100, seg_len_s=5 , extend=True, stereo=False, prenoise=False):
|
||||||
|
|
||||||
def generate_paired_data_test_formal(
|
|
||||||
path_pianos,
|
|
||||||
path_noises,
|
|
||||||
noise_amount="low_snr",
|
|
||||||
num_samples=-1,
|
|
||||||
fs=44100,
|
|
||||||
seg_len_s=5,
|
|
||||||
extend=True,
|
|
||||||
stereo=False,
|
|
||||||
prenoise=False,
|
|
||||||
):
|
|
||||||
print(num_samples)
|
print(num_samples)
|
||||||
segments_clean=[]
|
segments_clean=[]
|
||||||
segments_noisy=[]
|
segments_noisy=[]
|
||||||
@ -152,13 +134,9 @@ def generate_paired_data_test_formal(
|
|||||||
train_samples=sorted(train_samples)
|
train_samples=sorted(train_samples)
|
||||||
|
|
||||||
if prenoise:
|
if prenoise:
|
||||||
noise_generator = __noise_sample_generator(
|
noise_generator=__noise_sample_generator(noises_info,fs, seg_len+fs, extend, "test") #Adds 1s of silence add the begiing, longer noise
|
||||||
noises_info, fs, seg_len + fs, extend, "test"
|
|
||||||
) # Adds 1s of silence add the begiing, longer noise
|
|
||||||
else:
|
else:
|
||||||
noise_generator = __noise_sample_generator(
|
noise_generator=__noise_sample_generator(noises_info,fs, seg_len, extend, "test") #this will take care of everything
|
||||||
noises_info, fs, seg_len, extend, "test"
|
|
||||||
) # this will take care of everything
|
|
||||||
#load data clean files
|
#load data clean files
|
||||||
for file in tqdm(train_samples): #add [1:5] for testing
|
for file in tqdm(train_samples): #add [1:5] for testing
|
||||||
data_clean, samplerate = sf.read(file)
|
data_clean, samplerate = sf.read(file)
|
||||||
@ -181,22 +159,19 @@ def generate_paired_data_test_formal(
|
|||||||
num_frames=np.floor(len(data_clean)/hop_size - seg_len/hop_size +1)
|
num_frames=np.floor(len(data_clean)/hop_size - seg_len/hop_size +1)
|
||||||
print(num_frames)
|
print(num_frames)
|
||||||
if num_frames==0:
|
if num_frames==0:
|
||||||
data_clean = np.concatenate(
|
data_clean=np.concatenate((data_clean, np.zeros(shape=(int(2*seg_len-len(data_clean)),))), axis=0)
|
||||||
(data_clean, np.zeros(shape=(int(2 * seg_len - len(data_clean)),))),
|
|
||||||
axis=0,
|
|
||||||
)
|
|
||||||
num_frames=1
|
num_frames=1
|
||||||
|
|
||||||
data_not_finished=True
|
data_not_finished=True
|
||||||
pointer=0
|
pointer=0
|
||||||
while data_not_finished:
|
while(data_not_finished):
|
||||||
if i>=num_samples:
|
if i>=num_samples:
|
||||||
break
|
break
|
||||||
segment=data_clean[pointer:pointer+seg_len]
|
segment=data_clean[pointer:pointer+seg_len]
|
||||||
pointer=pointer+hop_size
|
pointer=pointer+hop_size
|
||||||
if pointer+seg_len>len(data_clean):
|
if pointer+seg_len>len(data_clean):
|
||||||
data_not_finished=False
|
data_not_finished=False
|
||||||
segment = segment.astype("float32")
|
segment=segment.astype('float32')
|
||||||
|
|
||||||
#SNRs=np.random.uniform(2,20)
|
#SNRs=np.random.uniform(2,20)
|
||||||
snr=SNRs[i]
|
snr=SNRs[i]
|
||||||
@ -220,26 +195,23 @@ def generate_paired_data_test_formal(
|
|||||||
|
|
||||||
#sum both signals according to snr
|
#sum both signals according to snr
|
||||||
if prenoise:
|
if prenoise:
|
||||||
segment = np.concatenate(
|
segment=np.concatenate((np.zeros(shape=(fs,)),segment),axis=0) #add one second of silence
|
||||||
(np.zeros(shape=(fs,)), segment), axis=0
|
summed=segment+np.sqrt(power_clean/(snr*power_noise))*new_noise #not sure if this is correct, maybe revisit later!!
|
||||||
) # add one second of silence
|
|
||||||
summed = (
|
|
||||||
segment + np.sqrt(power_clean / (snr * power_noise)) * new_noise
|
|
||||||
) # not sure if this is correct, maybe revisit later!!
|
|
||||||
|
|
||||||
summed = summed.astype("float32")
|
summed=summed.astype('float32')
|
||||||
#yield tf.convert_to_tensor(summed), tf.convert_to_tensor(segment)
|
#yield tf.convert_to_tensor(summed), tf.convert_to_tensor(segment)
|
||||||
|
|
||||||
|
|
||||||
summed=10.0**(scale/10.0) *summed
|
summed=10.0**(scale/10.0) *summed
|
||||||
segment=10.0**(scale/10.0) *segment
|
segment=10.0**(scale/10.0) *segment
|
||||||
segments_noisy.append(summed.astype("float32"))
|
segments_noisy.append(summed.astype('float32'))
|
||||||
segments_clean.append(segment.astype("float32"))
|
segments_clean.append(segment.astype('float32'))
|
||||||
i=i+1
|
i=i+1
|
||||||
|
|
||||||
return segments_noisy, segments_clean
|
return segments_noisy, segments_clean
|
||||||
|
|
||||||
|
|
||||||
def generate_test_data(path_music, path_noises,num_samples=-1, fs=44100, seg_len_s=5):
|
def generate_test_data(path_music, path_noises,num_samples=-1, fs=44100, seg_len_s=5):
|
||||||
|
|
||||||
segments_clean=[]
|
segments_clean=[]
|
||||||
segments_noisy=[]
|
segments_noisy=[]
|
||||||
seg_len=fs*seg_len_s
|
seg_len=fs*seg_len_s
|
||||||
@ -250,10 +222,9 @@ def generate_test_data(path_music, path_noises, num_samples=-1, fs=44100, seg_le
|
|||||||
train_samples=glob.glob(os.path.join(path,"*.wav"))
|
train_samples=glob.glob(os.path.join(path,"*.wav"))
|
||||||
train_samples=sorted(train_samples)
|
train_samples=sorted(train_samples)
|
||||||
|
|
||||||
noise_generator = __noise_sample_generator(
|
noise_generator=__noise_sample_generator(noises_info,fs, seg_len, "test") #this will take care of everything
|
||||||
noises_info, fs, seg_len, "test"
|
|
||||||
) # this will take care of everything
|
|
||||||
#load data clean files
|
#load data clean files
|
||||||
|
jj=0
|
||||||
for file in tqdm(train_samples): #add [1:5] for testing
|
for file in tqdm(train_samples): #add [1:5] for testing
|
||||||
data_clean, samplerate = sf.read(file)
|
data_clean, samplerate = sf.read(file)
|
||||||
if samplerate!=fs:
|
if samplerate!=fs:
|
||||||
@ -272,18 +243,13 @@ def generate_test_data(path_music, path_noises, num_samples=-1, fs=44100, seg_le
|
|||||||
|
|
||||||
num_frames=np.floor(len(data_clean)/hop_size - seg_len/hop_size +1)
|
num_frames=np.floor(len(data_clean)/hop_size - seg_len/hop_size +1)
|
||||||
if num_frames==0:
|
if num_frames==0:
|
||||||
data_clean = np.concatenate(
|
data_clean=np.concatenate((data_clean, np.zeros(shape=(int(2*seg_len-len(data_clean)),))), axis=0)
|
||||||
(data_clean, np.zeros(shape=(int(2 * seg_len - len(data_clean)),))),
|
|
||||||
axis=0,
|
|
||||||
)
|
|
||||||
num_frames=1
|
num_frames=1
|
||||||
|
|
||||||
pointer=0
|
pointer=0
|
||||||
segment=data_clean[pointer:pointer+(seg_len-2*fs)]
|
segment=data_clean[pointer:pointer+(seg_len-2*fs)]
|
||||||
segment = segment.astype("float32")
|
segment=segment.astype('float32')
|
||||||
segment = np.concatenate(
|
segment=np.concatenate(( np.zeros(shape=(2*fs,)), segment), axis=0) #I hope its ok
|
||||||
(np.zeros(shape=(2 * fs,)), segment), axis=0
|
|
||||||
) # I hope its ok
|
|
||||||
#segments_clean.append(segment)
|
#segments_clean.append(segment)
|
||||||
|
|
||||||
for snr in SNRs:
|
for snr in SNRs:
|
||||||
@ -303,21 +269,17 @@ def generate_test_data(path_music, path_noises, num_samples=-1, fs=44100, seg_le
|
|||||||
snr = 10.0**(snr/10.0)
|
snr = 10.0**(snr/10.0)
|
||||||
|
|
||||||
#sum both signals according to snr
|
#sum both signals according to snr
|
||||||
summed = (
|
summed=segment+np.sqrt(power_clean/(snr*power_noise))*new_noise #not sure if this is correct, maybe revisit later!!
|
||||||
segment + np.sqrt(power_clean / (snr * power_noise)) * new_noise
|
summed=summed.astype('float32')
|
||||||
) # not sure if this is correct, maybe revisit later!!
|
|
||||||
summed = summed.astype("float32")
|
|
||||||
#yield tf.convert_to_tensor(summed), tf.convert_to_tensor(segment)
|
#yield tf.convert_to_tensor(summed), tf.convert_to_tensor(segment)
|
||||||
|
|
||||||
segments_noisy.append(summed.astype("float32"))
|
segments_noisy.append(summed.astype('float32'))
|
||||||
segments_clean.append(segment.astype("float32"))
|
segments_clean.append(segment.astype('float32'))
|
||||||
|
|
||||||
return segments_noisy, segments_clean
|
return segments_noisy, segments_clean
|
||||||
|
|
||||||
|
def generate_val_data(path_music, path_noises,split,num_samples=-1, fs=44100, seg_len_s=5):
|
||||||
|
|
||||||
def generate_val_data(
|
|
||||||
path_music, path_noises, split, num_samples=-1, fs=44100, seg_len_s=5
|
|
||||||
):
|
|
||||||
val_samples=[]
|
val_samples=[]
|
||||||
for path in path_music:
|
for path in path_music:
|
||||||
val_samples.extend(glob.glob(os.path.join(path,"*.wav")))
|
val_samples.extend(glob.glob(os.path.join(path,"*.wav")))
|
||||||
@ -342,6 +304,7 @@ def generate_val_data(
|
|||||||
seg_len=fs*seg_len_s
|
seg_len=fs*seg_len_s
|
||||||
segments_clean=[]
|
segments_clean=[]
|
||||||
for file in tqdm(data_clean_loaded):
|
for file in tqdm(data_clean_loaded):
|
||||||
|
|
||||||
#framify arguments: seg_len, hop_size
|
#framify arguments: seg_len, hop_size
|
||||||
hop_size=int(seg_len)# no overlap
|
hop_size=int(seg_len)# no overlap
|
||||||
|
|
||||||
@ -350,7 +313,7 @@ def generate_val_data(
|
|||||||
for i in range(0,int(num_frames)):
|
for i in range(0,int(num_frames)):
|
||||||
segment=file[pointer:pointer+seg_len]
|
segment=file[pointer:pointer+seg_len]
|
||||||
pointer=pointer+hop_size
|
pointer=pointer+hop_size
|
||||||
segment = segment.astype("float32")
|
segment=segment.astype('float32')
|
||||||
segments_clean.append(segment)
|
segments_clean.append(segment)
|
||||||
|
|
||||||
del data_clean_loaded
|
del data_clean_loaded
|
||||||
@ -360,9 +323,8 @@ def generate_val_data(
|
|||||||
#noise_shapes=np.random.randint(0,len(noise_samples), len(segments_clean))
|
#noise_shapes=np.random.randint(0,len(noise_samples), len(segments_clean))
|
||||||
noises_info=os.path.join(path_noises,"info.csv")
|
noises_info=os.path.join(path_noises,"info.csv")
|
||||||
|
|
||||||
noise_generator = __noise_sample_generator(
|
noise_generator=__noise_sample_generator(noises_info,fs, seg_len, split) #this will take care of everything
|
||||||
noises_info, fs, seg_len, split
|
|
||||||
) # this will take care of everything
|
|
||||||
|
|
||||||
#generate noisy segments
|
#generate noisy segments
|
||||||
#load noise samples using pandas dataframe. Each split (train, val, test) should have its unique csv info file
|
#load noise samples using pandas dataframe. Each split (train, val, test) should have its unique csv info file
|
||||||
@ -383,6 +345,7 @@ def generate_val_data(
|
|||||||
data_clean=segments_clean[i]
|
data_clean=segments_clean[i]
|
||||||
#configure sizes
|
#configure sizes
|
||||||
|
|
||||||
|
|
||||||
#estimate clean signal power
|
#estimate clean signal power
|
||||||
power_clean=np.var(data_clean)
|
power_clean=np.var(data_clean)
|
||||||
#estimate noise power
|
#estimate noise power
|
||||||
@ -391,37 +354,33 @@ def generate_val_data(
|
|||||||
snr = 10.0**(SNRs[i]/10.0)
|
snr = 10.0**(SNRs[i]/10.0)
|
||||||
|
|
||||||
#sum both signals according to snr
|
#sum both signals according to snr
|
||||||
summed = (
|
summed=data_clean+np.sqrt(power_clean/(snr*power_noise))*new_noise #not sure if this is correct, maybe revisit later!!
|
||||||
data_clean + np.sqrt(power_clean / (snr * power_noise)) * new_noise
|
|
||||||
) # not sure if this is correct, maybe revisit later!!
|
|
||||||
#the rest is normal
|
#the rest is normal
|
||||||
|
|
||||||
summed=10.0**(scales[i]/10.0) *summed
|
summed=10.0**(scales[i]/10.0) *summed
|
||||||
segments_clean[i]=10.0**(scales[i]/10.0) *segments_clean[i]
|
segments_clean[i]=10.0**(scales[i]/10.0) *segments_clean[i]
|
||||||
|
|
||||||
segments_noisy.append(summed.astype("float32"))
|
segments_noisy.append(summed.astype('float32'))
|
||||||
|
|
||||||
return segments_noisy, segments_clean
|
return segments_noisy, segments_clean
|
||||||
|
|
||||||
|
|
||||||
def generator_train(
|
|
||||||
path_music, path_noises, split, fs=44100, seg_len_s=5, extend=True, stereo=False
|
def generator_train(path_music, path_noises,split, fs=44100, seg_len_s=5, extend=True, stereo=False):
|
||||||
):
|
|
||||||
train_samples=[]
|
train_samples=[]
|
||||||
for path in path_music:
|
for path in path_music:
|
||||||
train_samples.extend(glob.glob(os.path.join(path.decode("utf-8") ,"*.wav")))
|
train_samples.extend(glob.glob(os.path.join(path.decode("utf-8") ,"*.wav")))
|
||||||
|
|
||||||
seg_len=fs*seg_len_s
|
seg_len=fs*seg_len_s
|
||||||
noises_info=os.path.join(path_noises.decode("utf-8"),"info.csv")
|
noises_info=os.path.join(path_noises.decode("utf-8"),"info.csv")
|
||||||
noise_generator = __noise_sample_generator(
|
noise_generator=__noise_sample_generator(noises_info,fs, seg_len, split.decode("utf-8")) #this will take care of everything
|
||||||
noises_info, fs, seg_len, split.decode("utf-8")
|
|
||||||
) # this will take care of everything
|
|
||||||
#load data clean files
|
#load data clean files
|
||||||
while True:
|
while True:
|
||||||
random.shuffle(train_samples)
|
random.shuffle(train_samples)
|
||||||
for file in train_samples:
|
for file in train_samples:
|
||||||
data, samplerate = sf.read(file)
|
data, samplerate = sf.read(file)
|
||||||
assert samplerate == fs, "wrong sampling rate"
|
assert(samplerate==fs, "wrong sampling rate")
|
||||||
data_clean=data
|
data_clean=data
|
||||||
#Stereo to mono
|
#Stereo to mono
|
||||||
if len(data.shape)>1 :
|
if len(data.shape)>1 :
|
||||||
@ -437,33 +396,27 @@ def generator_train(
|
|||||||
|
|
||||||
num_frames=np.floor(len(data_clean)/seg_len)
|
num_frames=np.floor(len(data_clean)/seg_len)
|
||||||
if num_frames==0:
|
if num_frames==0:
|
||||||
data_clean = np.concatenate(
|
data_clean=np.concatenate((data_clean, np.zeros(shape=(int(2*seg_len-len(data_clean)),))), axis=0)
|
||||||
(data_clean, np.zeros(shape=(int(2 * seg_len - len(data_clean)),))),
|
|
||||||
axis=0,
|
|
||||||
)
|
|
||||||
num_frames=1
|
num_frames=1
|
||||||
pointer=0
|
pointer=0
|
||||||
data_clean = np.roll(
|
data_clean=np.roll(data_clean, np.random.randint(0,seg_len)) #if only one frame, roll it for augmentation
|
||||||
data_clean, np.random.randint(0, seg_len)
|
|
||||||
) # if only one frame, roll it for augmentation
|
|
||||||
elif num_frames>1:
|
elif num_frames>1:
|
||||||
pointer = np.random.randint(
|
pointer=np.random.randint(0,hop_size) #initial shifting, graeat for augmentation, better than overlap as we get different frames at each "while" iteration
|
||||||
0, hop_size
|
|
||||||
) # initial shifting, graeat for augmentation, better than overlap as we get different frames at each "while" iteration
|
|
||||||
else:
|
else:
|
||||||
pointer=0
|
pointer=0
|
||||||
|
|
||||||
data_not_finished=True
|
data_not_finished=True
|
||||||
while data_not_finished:
|
while(data_not_finished):
|
||||||
segment=data_clean[pointer:pointer+seg_len]
|
segment=data_clean[pointer:pointer+seg_len]
|
||||||
pointer=pointer+hop_size
|
pointer=pointer+hop_size
|
||||||
if pointer+seg_len>len(data_clean):
|
if pointer+seg_len>len(data_clean):
|
||||||
data_not_finished=False
|
data_not_finished=False
|
||||||
segment = segment.astype("float32")
|
segment=segment.astype('float32')
|
||||||
|
|
||||||
SNRs=np.random.uniform(2,20)
|
SNRs=np.random.uniform(2,20)
|
||||||
scale=np.random.uniform(-6,4)
|
scale=np.random.uniform(-6,4)
|
||||||
|
|
||||||
|
|
||||||
#load noise signal
|
#load noise signal
|
||||||
data_noise= next(noise_generator)
|
data_noise= next(noise_generator)
|
||||||
data_noise=np.mean(data_noise,axis=1)
|
data_noise=np.mean(data_noise,axis=1)
|
||||||
@ -474,13 +427,9 @@ def generator_train(
|
|||||||
#configure sizes
|
#configure sizes
|
||||||
if stereo:
|
if stereo:
|
||||||
#estimate clean signal power
|
#estimate clean signal power
|
||||||
power_clean = 0.5 * np.var(segment[:, 0]) + 0.5 * np.var(
|
power_clean=0.5*np.var(segment[:,0])+0.5*np.var(segment[:,1])
|
||||||
segment[:, 1]
|
|
||||||
)
|
|
||||||
#estimate noise power
|
#estimate noise power
|
||||||
power_noise = 0.5 * np.var(new_noise[:, 0]) + 0.5 * np.var(
|
power_noise=0.5*np.var(new_noise[:,0])+0.5*np.var(new_noise[:,1])
|
||||||
new_noise[:, 1]
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
#estimate clean signal power
|
#estimate clean signal power
|
||||||
power_clean=np.var(segment)
|
power_clean=np.var(segment)
|
||||||
@ -489,63 +438,39 @@ def generator_train(
|
|||||||
|
|
||||||
snr = 10.0**(SNRs/10.0)
|
snr = 10.0**(SNRs/10.0)
|
||||||
|
|
||||||
|
|
||||||
#sum both signals according to snr
|
#sum both signals according to snr
|
||||||
summed = (
|
summed=segment+np.sqrt(power_clean/(snr*power_noise))*new_noise #not sure if this is correct, maybe revisit later!!
|
||||||
segment + np.sqrt(power_clean / (snr * power_noise)) * new_noise
|
|
||||||
) # not sure if this is correct, maybe revisit later!!
|
|
||||||
summed=10.0**(scale/10.0) *summed
|
summed=10.0**(scale/10.0) *summed
|
||||||
segment=10.0**(scale/10.0) *segment
|
segment=10.0**(scale/10.0) *segment
|
||||||
|
|
||||||
summed = summed.astype("float32")
|
summed=summed.astype('float32')
|
||||||
yield tf.convert_to_tensor(summed), tf.convert_to_tensor(segment)
|
yield tf.convert_to_tensor(summed), tf.convert_to_tensor(segment)
|
||||||
|
|
||||||
|
def load_data(buffer_size, path_music_train, path_music_val, path_noises, fs=44100, seg_len_s=5, extend=True, stereo=False) :
|
||||||
def load_data(
|
|
||||||
buffer_size,
|
|
||||||
path_music_train,
|
|
||||||
path_music_val,
|
|
||||||
path_noises,
|
|
||||||
fs=44100,
|
|
||||||
seg_len_s=5,
|
|
||||||
extend=True,
|
|
||||||
stereo=False,
|
|
||||||
):
|
|
||||||
print("Generating train dataset")
|
print("Generating train dataset")
|
||||||
trainshape=int(fs*seg_len_s)
|
trainshape=int(fs*seg_len_s)
|
||||||
|
|
||||||
dataset_train = tf.data.Dataset.from_generator(
|
dataset_train = tf.data.Dataset.from_generator(generator_train,args=(path_music_train, path_noises,"train", fs, seg_len_s, extend, stereo), output_shapes=(tf.TensorShape((trainshape,)),tf.TensorShape((trainshape,))), output_types=(tf.float32, tf.float32) )
|
||||||
generator_train,
|
|
||||||
args=(path_music_train, path_noises, "train", fs, seg_len_s, extend, stereo),
|
|
||||||
output_shapes=(tf.TensorShape((trainshape,)), tf.TensorShape((trainshape,))),
|
|
||||||
output_types=(tf.float32, tf.float32),
|
|
||||||
)
|
|
||||||
|
|
||||||
print("Generating validation dataset")
|
print("Generating validation dataset")
|
||||||
segments_noisy, segments_clean = generate_val_data(
|
segments_noisy, segments_clean=generate_val_data(path_music_val, path_noises,"validation",fs=fs, seg_len_s=seg_len_s)
|
||||||
path_music_val, path_noises, "validation", fs=fs, seg_len_s=seg_len_s
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset_val=tf.data.Dataset.from_tensor_slices((segments_noisy, segments_clean))
|
dataset_val=tf.data.Dataset.from_tensor_slices((segments_noisy, segments_clean))
|
||||||
|
|
||||||
return dataset_train.shuffle(buffer_size), dataset_val
|
return dataset_train.shuffle(buffer_size), dataset_val
|
||||||
|
|
||||||
|
|
||||||
def load_data_test(buffer_size, path_pianos_test, path_noises, **kwargs):
|
def load_data_test(buffer_size, path_pianos_test, path_noises, **kwargs):
|
||||||
print("Generating test dataset")
|
print("Generating test dataset")
|
||||||
segments_noisy, segments_clean = generate_test_data(
|
segments_noisy, segments_clean=generate_test_data(path_pianos_test, path_noises, extend=True, **kwargs)
|
||||||
path_pianos_test, path_noises, extend=True, **kwargs
|
|
||||||
)
|
|
||||||
dataset_test=tf.data.Dataset.from_tensor_slices((segments_noisy, segments_clean))
|
dataset_test=tf.data.Dataset.from_tensor_slices((segments_noisy, segments_clean))
|
||||||
#dataset_test=tf.data.Dataset.from_tensor_slices((segments_noisy[1:3], segments_clean[1:3]))
|
#dataset_test=tf.data.Dataset.from_tensor_slices((segments_noisy[1:3], segments_clean[1:3]))
|
||||||
#train_dataset = train.cache().shuffle(buffer_size).take(info.splits["train"].num_examples)
|
#train_dataset = train.cache().shuffle(buffer_size).take(info.splits["train"].num_examples)
|
||||||
return dataset_test
|
return dataset_test
|
||||||
|
|
||||||
|
|
||||||
def load_data_formal( path_pianos_test, path_noises, **kwargs) :
|
def load_data_formal( path_pianos_test, path_noises, **kwargs) :
|
||||||
print("Generating test dataset")
|
print("Generating test dataset")
|
||||||
segments_noisy, segments_clean = generate_paired_data_test_formal(
|
segments_noisy, segments_clean=generate_paired_data_test_formal(path_pianos_test, path_noises, extend=True, **kwargs)
|
||||||
path_pianos_test, path_noises, extend=True, **kwargs
|
|
||||||
)
|
|
||||||
print("segments::")
|
print("segments::")
|
||||||
print(len(segments_noisy))
|
print(len(segments_noisy))
|
||||||
dataset_test=tf.data.Dataset.from_tensor_slices((segments_noisy, segments_clean))
|
dataset_test=tf.data.Dataset.from_tensor_slices((segments_noisy, segments_clean))
|
||||||
@ -553,7 +478,6 @@ def load_data_formal(path_pianos_test, path_noises, **kwargs):
|
|||||||
#train_dataset = train.cache().shuffle(buffer_size).take(info.splits["train"].num_examples)
|
#train_dataset = train.cache().shuffle(buffer_size).take(info.splits["train"].num_examples)
|
||||||
return dataset_test
|
return dataset_test
|
||||||
|
|
||||||
|
|
||||||
def load_real_test_recordings(buffer_size, path_recordings, **kwargs):
|
def load_real_test_recordings(buffer_size, path_recordings, **kwargs):
|
||||||
print("Generating real test dataset")
|
print("Generating real test dataset")
|
||||||
|
|
||||||
|
|||||||
130
inference.py
130
inference.py
@ -4,7 +4,6 @@ import logging
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def run(args):
|
def run(args):
|
||||||
import unet
|
import unet
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
@ -17,45 +16,33 @@ def run(args):
|
|||||||
|
|
||||||
unet_model = unet.build_model_denoise(unet_args=args.unet)
|
unet_model = unet.build_model_denoise(unet_args=args.unet)
|
||||||
|
|
||||||
ckpt = os.path.join(
|
ckpt=os.path.join(os.path.dirname(os.path.abspath(__file__)),path_experiment, 'checkpoint')
|
||||||
os.path.dirname(os.path.abspath(__file__)), path_experiment, "checkpoint"
|
|
||||||
)
|
|
||||||
unet_model.load_weights(ckpt)
|
unet_model.load_weights(ckpt)
|
||||||
|
|
||||||
def do_stft(noisy):
|
def do_stft(noisy):
|
||||||
|
|
||||||
window_fn = tf.signal.hamming_window
|
window_fn = tf.signal.hamming_window
|
||||||
|
|
||||||
win_size=args.stft.win_size
|
win_size=args.stft.win_size
|
||||||
hop_size=args.stft.hop_size
|
hop_size=args.stft.hop_size
|
||||||
|
|
||||||
stft_signal_noisy = tf.signal.stft(
|
|
||||||
noisy,
|
stft_signal_noisy=tf.signal.stft(noisy,frame_length=win_size, window_fn=window_fn, frame_step=hop_size, pad_end=True)
|
||||||
frame_length=win_size,
|
stft_noisy_stacked=tf.stack( values=[tf.math.real(stft_signal_noisy), tf.math.imag(stft_signal_noisy)], axis=-1)
|
||||||
window_fn=window_fn,
|
|
||||||
frame_step=hop_size,
|
|
||||||
pad_end=True,
|
|
||||||
)
|
|
||||||
stft_noisy_stacked = tf.stack(
|
|
||||||
values=[tf.math.real(stft_signal_noisy), tf.math.imag(stft_signal_noisy)],
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
|
|
||||||
return stft_noisy_stacked
|
return stft_noisy_stacked
|
||||||
|
|
||||||
def do_istft(data):
|
def do_istft(data):
|
||||||
|
|
||||||
window_fn = tf.signal.hamming_window
|
window_fn = tf.signal.hamming_window
|
||||||
|
|
||||||
win_size=args.stft.win_size
|
win_size=args.stft.win_size
|
||||||
hop_size=args.stft.hop_size
|
hop_size=args.stft.hop_size
|
||||||
|
|
||||||
inv_window_fn = tf.signal.inverse_stft_window_fn(
|
inv_window_fn=tf.signal.inverse_stft_window_fn(hop_size, forward_window_fn=window_fn)
|
||||||
hop_size, forward_window_fn=window_fn
|
|
||||||
)
|
|
||||||
|
|
||||||
pred_cpx=data[...,0] + 1j * data[...,1]
|
pred_cpx=data[...,0] + 1j * data[...,1]
|
||||||
pred_time = tf.signal.inverse_stft(
|
pred_time=tf.signal.inverse_stft(pred_cpx, win_size, hop_size, window_fn=inv_window_fn)
|
||||||
pred_cpx, win_size, hop_size, window_fn=inv_window_fn
|
|
||||||
)
|
|
||||||
return pred_time
|
return pred_time
|
||||||
|
|
||||||
audio=str(args.inference.audio)
|
audio=str(args.inference.audio)
|
||||||
@ -70,6 +57,8 @@ def run(args):
|
|||||||
|
|
||||||
data=scipy.signal.resample(data, int((44100 / samplerate )*len(data))+1)
|
data=scipy.signal.resample(data, int((44100 / samplerate )*len(data))+1)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
segment_size=44100*5 #20s segments
|
segment_size=44100*5 #20s segments
|
||||||
|
|
||||||
length_data=len(data)
|
length_data=len(data)
|
||||||
@ -77,6 +66,7 @@ def run(args):
|
|||||||
window=np.hanning(2*overlapsize)
|
window=np.hanning(2*overlapsize)
|
||||||
window_right=window[overlapsize::]
|
window_right=window[overlapsize::]
|
||||||
window_left=window[0:overlapsize]
|
window_left=window[0:overlapsize]
|
||||||
|
audio_finished=False
|
||||||
pointer=0
|
pointer=0
|
||||||
denoised_data=np.zeros(shape=(len(data),))
|
denoised_data=np.zeros(shape=(len(data),))
|
||||||
residual_noise=np.zeros(shape=(len(data),))
|
residual_noise=np.zeros(shape=(len(data),))
|
||||||
@ -97,72 +87,21 @@ def run(args):
|
|||||||
residual_time=np.array(residual_time)
|
residual_time=np.array(residual_time)
|
||||||
|
|
||||||
if pointer==0:
|
if pointer==0:
|
||||||
pred_time = np.concatenate(
|
pred_time=np.concatenate((pred_time[0:int(segment_size-overlapsize)], np.multiply(pred_time[int(segment_size-overlapsize):segment_size],window_right)), axis=0)
|
||||||
(
|
residual_time=np.concatenate((residual_time[0:int(segment_size-overlapsize)], np.multiply(residual_time[int(segment_size-overlapsize):segment_size],window_right)), axis=0)
|
||||||
pred_time[0 : int(segment_size - overlapsize)],
|
|
||||||
np.multiply(
|
|
||||||
pred_time[int(segment_size - overlapsize) : segment_size],
|
|
||||||
window_right,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
axis=0,
|
|
||||||
)
|
|
||||||
residual_time = np.concatenate(
|
|
||||||
(
|
|
||||||
residual_time[0 : int(segment_size - overlapsize)],
|
|
||||||
np.multiply(
|
|
||||||
residual_time[
|
|
||||||
int(segment_size - overlapsize) : segment_size
|
|
||||||
],
|
|
||||||
window_right,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
axis=0,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
pred_time = np.concatenate(
|
pred_time=np.concatenate((np.multiply(pred_time[0:int(overlapsize)], window_left), pred_time[int(overlapsize):int(segment_size-overlapsize)], np.multiply(pred_time[int(segment_size-overlapsize):int(segment_size)],window_right)), axis=0)
|
||||||
(
|
residual_time=np.concatenate((np.multiply(residual_time[0:int(overlapsize)], window_left), residual_time[int(overlapsize):int(segment_size-overlapsize)], np.multiply(residual_time[int(segment_size-overlapsize):int(segment_size)],window_right)), axis=0)
|
||||||
np.multiply(pred_time[0 : int(overlapsize)], window_left),
|
|
||||||
pred_time[int(overlapsize) : int(segment_size - overlapsize)],
|
|
||||||
np.multiply(
|
|
||||||
pred_time[
|
|
||||||
int(segment_size - overlapsize) : int(segment_size)
|
|
||||||
],
|
|
||||||
window_right,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
axis=0,
|
|
||||||
)
|
|
||||||
residual_time = np.concatenate(
|
|
||||||
(
|
|
||||||
np.multiply(residual_time[0 : int(overlapsize)], window_left),
|
|
||||||
residual_time[
|
|
||||||
int(overlapsize) : int(segment_size - overlapsize)
|
|
||||||
],
|
|
||||||
np.multiply(
|
|
||||||
residual_time[
|
|
||||||
int(segment_size - overlapsize) : int(segment_size)
|
|
||||||
],
|
|
||||||
window_right,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
axis=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
denoised_data[pointer : pointer + segment_size] = (
|
denoised_data[pointer:pointer+segment_size]=denoised_data[pointer:pointer+segment_size]+pred_time
|
||||||
denoised_data[pointer : pointer + segment_size] + pred_time
|
residual_noise[pointer:pointer+segment_size]=residual_noise[pointer:pointer+segment_size]+residual_time
|
||||||
)
|
|
||||||
residual_noise[pointer : pointer + segment_size] = (
|
|
||||||
residual_noise[pointer : pointer + segment_size] + residual_time
|
|
||||||
)
|
|
||||||
|
|
||||||
pointer=pointer+segment_size-overlapsize
|
pointer=pointer+segment_size-overlapsize
|
||||||
else:
|
else:
|
||||||
segment=data[pointer::]
|
segment=data[pointer::]
|
||||||
lensegment=len(segment)
|
lensegment=len(segment)
|
||||||
segment = np.concatenate(
|
segment=np.concatenate((segment, np.zeros(shape=(int(segment_size-len(segment)),))), axis=0)
|
||||||
(segment, np.zeros(shape=(int(segment_size - len(segment)),))), axis=0
|
audio_finished=True
|
||||||
)
|
|
||||||
#dostft
|
#dostft
|
||||||
segment_TF=do_stft(segment)
|
segment_TF=do_stft(segment)
|
||||||
|
|
||||||
@ -182,27 +121,11 @@ def run(args):
|
|||||||
pred_time=pred_time
|
pred_time=pred_time
|
||||||
residual_time=residual_time
|
residual_time=residual_time
|
||||||
else:
|
else:
|
||||||
pred_time = np.concatenate(
|
pred_time=np.concatenate((np.multiply(pred_time[0:int(overlapsize)], window_left), pred_time[int(overlapsize):int(segment_size)]),axis=0)
|
||||||
(
|
residual_time=np.concatenate((np.multiply(residual_time[0:int(overlapsize)], window_left), residual_time[int(overlapsize):int(segment_size)]),axis=0)
|
||||||
np.multiply(pred_time[0 : int(overlapsize)], window_left),
|
|
||||||
pred_time[int(overlapsize) : int(segment_size)],
|
|
||||||
),
|
|
||||||
axis=0,
|
|
||||||
)
|
|
||||||
residual_time = np.concatenate(
|
|
||||||
(
|
|
||||||
np.multiply(residual_time[0 : int(overlapsize)], window_left),
|
|
||||||
residual_time[int(overlapsize) : int(segment_size)],
|
|
||||||
),
|
|
||||||
axis=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
denoised_data[pointer::] = (
|
denoised_data[pointer::]=denoised_data[pointer::]+pred_time[0:lensegment]
|
||||||
denoised_data[pointer::] + pred_time[0:lensegment]
|
residual_noise[pointer::]=residual_noise[pointer::]+residual_time[0:lensegment]
|
||||||
)
|
|
||||||
residual_noise[pointer::] = (
|
|
||||||
residual_noise[pointer::] + residual_time[0:lensegment]
|
|
||||||
)
|
|
||||||
|
|
||||||
basename=os.path.splitext(audio)[0]
|
basename=os.path.splitext(audio)[0]
|
||||||
wav_noisy_name=basename+"_noisy_input"+".wav"
|
wav_noisy_name=basename+"_noisy_input"+".wav"
|
||||||
@ -233,3 +156,10 @@ def main(args):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
|
|
||||||
python inference.py inference.audio="$1"
|
python inference.py inference.audio=$1
|
||||||
|
|
||||||
|
|||||||
109
train.py
109
train.py
@ -4,14 +4,15 @@ import logging
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def run(args):
|
def run(args):
|
||||||
import unet
|
import unet
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import dataset_loader
|
import dataset_loader
|
||||||
from tensorflow.keras.optimizers import Adam
|
from tensorflow.keras.optimizers import Adam
|
||||||
|
import soundfile as sf
|
||||||
import datetime
|
import datetime
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
path_experiment=str(args.path_experiment)
|
path_experiment=str(args.path_experiment)
|
||||||
|
|
||||||
@ -19,70 +20,55 @@ def run(args):
|
|||||||
os.makedirs(path_experiment)
|
os.makedirs(path_experiment)
|
||||||
|
|
||||||
path_music_train=args.dset.path_music_train
|
path_music_train=args.dset.path_music_train
|
||||||
|
path_music_test=args.dset.path_music_test
|
||||||
path_music_validation=args.dset.path_music_validation
|
path_music_validation=args.dset.path_music_validation
|
||||||
|
|
||||||
path_noise=args.dset.path_noise
|
path_noise=args.dset.path_noise
|
||||||
|
path_recordings=args.dset.path_recordings
|
||||||
|
|
||||||
fs=args.fs
|
fs=args.fs
|
||||||
|
overlap=args.overlap
|
||||||
seg_len_s_train=args.seg_len_s_train
|
seg_len_s_train=args.seg_len_s_train
|
||||||
|
|
||||||
batch_size=args.batch_size
|
batch_size=args.batch_size
|
||||||
epochs=args.epochs
|
epochs=args.epochs
|
||||||
|
|
||||||
|
num_real_test_segments=args.num_real_test_segments
|
||||||
buffer_size=args.buffer_size #for shuffle
|
buffer_size=args.buffer_size #for shuffle
|
||||||
|
|
||||||
tensorboard_logs=args.tensorboard_logs
|
tensorboard_logs=args.tensorboard_logs
|
||||||
|
|
||||||
def do_stft(noisy, clean=None):
|
def do_stft(noisy, clean=None):
|
||||||
|
|
||||||
window_fn = tf.signal.hamming_window
|
window_fn = tf.signal.hamming_window
|
||||||
|
|
||||||
win_size=args.stft.win_size
|
win_size=args.stft.win_size
|
||||||
hop_size=args.stft.hop_size
|
hop_size=args.stft.hop_size
|
||||||
|
|
||||||
stft_signal_noisy = tf.signal.stft(
|
|
||||||
noisy, frame_length=win_size, window_fn=window_fn, frame_step=hop_size
|
|
||||||
)
|
|
||||||
stft_noisy_stacked = tf.stack(
|
|
||||||
values=[tf.math.real(stft_signal_noisy), tf.math.imag(stft_signal_noisy)],
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
|
|
||||||
if clean is not None:
|
stft_signal_noisy=tf.signal.stft(noisy,frame_length=win_size, window_fn=window_fn, frame_step=hop_size)
|
||||||
stft_signal_clean = tf.signal.stft(
|
stft_noisy_stacked=tf.stack( values=[tf.math.real(stft_signal_noisy), tf.math.imag(stft_signal_noisy)], axis=-1)
|
||||||
clean, frame_length=win_size, window_fn=window_fn, frame_step=hop_size
|
|
||||||
)
|
if clean!=None:
|
||||||
stft_clean_stacked = tf.stack(
|
|
||||||
values=[
|
stft_signal_clean=tf.signal.stft(clean,frame_length=win_size, window_fn=window_fn, frame_step=hop_size)
|
||||||
tf.math.real(stft_signal_clean),
|
stft_clean_stacked=tf.stack( values=[tf.math.real(stft_signal_clean), tf.math.imag(stft_signal_clean)], axis=-1)
|
||||||
tf.math.imag(stft_signal_clean),
|
|
||||||
],
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
|
|
||||||
return stft_noisy_stacked, stft_clean_stacked
|
return stft_noisy_stacked, stft_clean_stacked
|
||||||
else:
|
else:
|
||||||
|
|
||||||
return stft_noisy_stacked
|
return stft_noisy_stacked
|
||||||
|
|
||||||
#Loading data. The train dataset object is a generator. The validation dataset is loaded in memory.
|
#Loading data. The train dataset object is a generator. The validation dataset is loaded in memory.
|
||||||
|
|
||||||
dataset_train, dataset_val = dataset_loader.load_data(
|
dataset_train, dataset_val=dataset_loader.load_data(buffer_size, path_music_train, path_music_validation, path_noise, fs=fs, seg_len_s=seg_len_s_train)
|
||||||
buffer_size,
|
|
||||||
path_music_train,
|
|
||||||
path_music_validation,
|
|
||||||
path_noise,
|
|
||||||
fs=fs,
|
|
||||||
seg_len_s=seg_len_s_train,
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset_train = dataset_train.map(
|
dataset_train=dataset_train.map(do_stft, num_parallel_calls=args.num_workers, deterministic=None)
|
||||||
do_stft, num_parallel_calls=args.num_workers, deterministic=None
|
dataset_val=dataset_val.map(do_stft, num_parallel_calls=args.num_workers, deterministic=None)
|
||||||
)
|
|
||||||
dataset_val = dataset_val.map(
|
|
||||||
do_stft, num_parallel_calls=args.num_workers, deterministic=None
|
|
||||||
)
|
|
||||||
|
|
||||||
strategy = tf.distribute.MirroredStrategy()
|
strategy = tf.distribute.MirroredStrategy()
|
||||||
print("Number of devices: {}".format(strategy.num_replicas_in_sync))
|
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
|
||||||
|
|
||||||
with strategy.scope():
|
with strategy.scope():
|
||||||
#build the model
|
#build the model
|
||||||
@ -94,17 +80,12 @@ def run(args):
|
|||||||
loss=tf.keras.losses.MeanAbsoluteError()
|
loss=tf.keras.losses.MeanAbsoluteError()
|
||||||
|
|
||||||
if args.use_tensorboard:
|
if args.use_tensorboard:
|
||||||
log_dir = os.path.join(
|
log_dir = os.path.join(tensorboard_logs, os.path.basename(path_experiment)+"_"+datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
|
||||||
tensorboard_logs,
|
|
||||||
os.path.basename(path_experiment)
|
|
||||||
+ "_"
|
|
||||||
+ datetime.datetime.now().strftime("%Y%m%d-%H%M%S"),
|
|
||||||
)
|
|
||||||
train_summary_writer = tf.summary.create_file_writer(log_dir+"/train")
|
train_summary_writer = tf.summary.create_file_writer(log_dir+"/train")
|
||||||
val_summary_writer = tf.summary.create_file_writer(log_dir+"/validation")
|
val_summary_writer = tf.summary.create_file_writer(log_dir+"/validation")
|
||||||
|
|
||||||
#path where the checkpoints will be saved
|
#path where the checkpoints will be saved
|
||||||
checkpoint_filepath = os.path.join(path_experiment, "checkpoint")
|
checkpoint_filepath=os.path.join(path_experiment, 'checkpoint')
|
||||||
|
|
||||||
dataset_train=dataset_train.batch(batch_size)
|
dataset_train=dataset_train.batch(batch_size)
|
||||||
dataset_val=dataset_val.batch(batch_size)
|
dataset_val=dataset_val.batch(batch_size)
|
||||||
@ -125,43 +106,27 @@ def run(args):
|
|||||||
for epoch in range(epochs):
|
for epoch in range(epochs):
|
||||||
total_loss=0
|
total_loss=0
|
||||||
step_loss=0
|
step_loss=0
|
||||||
for step in tqdm(
|
for step in tqdm(range(args.steps_per_epoch), desc="Training epoch "+str(epoch)):
|
||||||
range(args.steps_per_epoch), desc="Training epoch " + str(epoch)
|
|
||||||
):
|
|
||||||
step_loss=trainer.distributed_training_step(iterator.get_next())
|
step_loss=trainer.distributed_training_step(iterator.get_next())
|
||||||
total_loss+=step_loss
|
total_loss+=step_loss
|
||||||
with train_summary_writer.as_default():
|
with train_summary_writer.as_default():
|
||||||
tf.summary.scalar("batch_loss", step_loss, step=step)
|
tf.summary.scalar('batch_loss', step_loss, step=step)
|
||||||
tf.summary.scalar(
|
tf.summary.scalar('batch_mean_absolute_error', trainer.train_mae.result(), step=step)
|
||||||
"batch_mean_absolute_error", trainer.train_mae.result(), step=step
|
|
||||||
)
|
|
||||||
|
|
||||||
train_loss=total_loss/args.steps_per_epoch
|
train_loss=total_loss/args.steps_per_epoch
|
||||||
|
|
||||||
for x in tqdm(dataset_val, desc="Validating epoch "+str(epoch)):
|
for x in tqdm(dataset_val, desc="Validating epoch "+str(epoch)):
|
||||||
trainer.distributed_test_step(x)
|
trainer.distributed_test_step(x)
|
||||||
|
|
||||||
template = "Epoch {}, Loss: {}, train_MAE: {}, val_Loss: {}, val_MAE: {}"
|
template = ("Epoch {}, Loss: {}, train_MAE: {}, val_Loss: {}, val_MAE: {}")
|
||||||
print(
|
print (template.format(epoch+1, train_loss, trainer.train_mae.result(), trainer.val_loss.result(), trainer.val_mae.result()))
|
||||||
template.format(
|
|
||||||
epoch + 1,
|
|
||||||
train_loss,
|
|
||||||
trainer.train_mae.result(),
|
|
||||||
trainer.val_loss.result(),
|
|
||||||
trainer.val_mae.result(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
with train_summary_writer.as_default():
|
with train_summary_writer.as_default():
|
||||||
tf.summary.scalar("epoch_loss", train_loss, step=epoch)
|
tf.summary.scalar('epoch_loss', train_loss, step=epoch)
|
||||||
tf.summary.scalar(
|
tf.summary.scalar('epoch_mean_absolute_error', trainer.train_mae.result(), step=epoch)
|
||||||
"epoch_mean_absolute_error", trainer.train_mae.result(), step=epoch
|
|
||||||
)
|
|
||||||
with val_summary_writer.as_default():
|
with val_summary_writer.as_default():
|
||||||
tf.summary.scalar("epoch_loss", trainer.val_loss.result(), step=epoch)
|
tf.summary.scalar('epoch_loss', trainer.val_loss.result(), step=epoch)
|
||||||
tf.summary.scalar(
|
tf.summary.scalar('epoch_mean_absolute_error', trainer.val_mae.result(), step=epoch)
|
||||||
"epoch_mean_absolute_error", trainer.val_mae.result(), step=epoch
|
|
||||||
)
|
|
||||||
|
|
||||||
trainer.train_mae.reset_states()
|
trainer.train_mae.reset_states()
|
||||||
trainer.val_loss.reset_states()
|
trainer.val_loss.reset_states()
|
||||||
@ -172,11 +137,10 @@ def run(args):
|
|||||||
current_lr*=1e-1
|
current_lr*=1e-1
|
||||||
trainer.optimizer.lr=current_lr
|
trainer.optimizer.lr=current_lr
|
||||||
try:
|
try:
|
||||||
unet_model.save_weights(checkpoint_filepath)
|
unet_model.save_weights(checkpoint_filpath)
|
||||||
except Exception:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _main(args):
|
def _main(args):
|
||||||
global __file__
|
global __file__
|
||||||
|
|
||||||
@ -197,3 +161,10 @@ def main(args):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
35
trainer.py
35
trainer.py
@ -1,7 +1,12 @@
|
|||||||
|
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
import soundfile as sf
|
||||||
|
from tqdm import tqdm
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
class Trainer():
|
||||||
class Trainer:
|
|
||||||
def __init__(self, model, optimizer,loss, strategy, path_experiment, args):
|
def __init__(self, model, optimizer,loss, strategy, path_experiment, args):
|
||||||
self.model=model
|
self.model=model
|
||||||
print(self.model.summary())
|
print(self.model.summary())
|
||||||
@ -18,20 +23,17 @@ class Trainer:
|
|||||||
self.train_mae_s1=tf.keras.metrics.MeanAbsoluteError(name="train_mae_s1")
|
self.train_mae_s1=tf.keras.metrics.MeanAbsoluteError(name="train_mae_s1")
|
||||||
self.train_mae=tf.keras.metrics.MeanAbsoluteError(name="train_mae_s2")
|
self.train_mae=tf.keras.metrics.MeanAbsoluteError(name="train_mae_s2")
|
||||||
self.val_mae=tf.keras.metrics.MeanAbsoluteError(name="validation_mae")
|
self.val_mae=tf.keras.metrics.MeanAbsoluteError(name="validation_mae")
|
||||||
self.val_loss = tf.keras.metrics.Mean(name="test_loss")
|
self.val_loss = tf.keras.metrics.Mean(name='test_loss')
|
||||||
|
|
||||||
|
|
||||||
def train_step(self,inputs):
|
def train_step(self,inputs):
|
||||||
noisy, clean= inputs
|
noisy, clean= inputs
|
||||||
|
|
||||||
with tf.GradientTape() as tape:
|
with tf.GradientTape() as tape:
|
||||||
logits_2, logits_1 = self.model(
|
|
||||||
noisy, training=True
|
|
||||||
) # Logits for this minibatch
|
|
||||||
|
|
||||||
loss_value = tf.reduce_mean(
|
logits_2,logits_1 = self.model(noisy, training=True) # Logits for this minibatch
|
||||||
self.loss_object(clean, logits_2)
|
|
||||||
+ tf.reduce_mean(self.loss_object(clean, logits_1))
|
loss_value = tf.reduce_mean(self.loss_object(clean, logits_2) + tf.reduce_mean(self.loss_object(clean, logits_1)))
|
||||||
)
|
|
||||||
|
|
||||||
grads = tape.gradient(loss_value, self.model.trainable_weights)
|
grads = tape.gradient(loss_value, self.model.trainable_weights)
|
||||||
self.optimizer.apply_gradients(zip(grads, self.model.trainable_weights))
|
self.optimizer.apply_gradients(zip(grads, self.model.trainable_weights))
|
||||||
@ -40,12 +42,11 @@ class Trainer:
|
|||||||
return loss_value
|
return loss_value
|
||||||
|
|
||||||
def test_step(self,inputs):
|
def test_step(self,inputs):
|
||||||
|
|
||||||
noisy,clean = inputs
|
noisy,clean = inputs
|
||||||
|
|
||||||
predictions_s2, predictions_s1 = self.model(noisy, training=False)
|
predictions_s2, predictions_s1 = self.model(noisy, training=False)
|
||||||
t_loss = self.loss_object(clean, predictions_s2) + self.loss_object(
|
t_loss = self.loss_object(clean, predictions_s2)+self.loss_object(clean, predictions_s1)
|
||||||
clean, predictions_s1
|
|
||||||
)
|
|
||||||
|
|
||||||
self.val_mae.update_state(clean,predictions_s2)
|
self.val_mae.update_state(clean,predictions_s2)
|
||||||
self.val_loss.update_state(t_loss)
|
self.val_loss.update_state(t_loss)
|
||||||
@ -53,11 +54,13 @@ class Trainer:
|
|||||||
@tf.function()
|
@tf.function()
|
||||||
def distributed_training_step(self,inputs):
|
def distributed_training_step(self,inputs):
|
||||||
per_replica_losses=self.strategy.run(self.train_step, args=(inputs,))
|
per_replica_losses=self.strategy.run(self.train_step, args=(inputs,))
|
||||||
reduced_losses = self.strategy.reduce(
|
reduced_losses=self.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None)
|
||||||
tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None
|
|
||||||
)
|
|
||||||
return reduced_losses
|
return reduced_losses
|
||||||
|
|
||||||
@tf.function
|
@tf.function
|
||||||
def distributed_test_step(self,inputs):
|
def distributed_test_step(self,inputs):
|
||||||
return self.strategy.run(self.test_step, args=(inputs,))
|
return self.strategy.run(self.test_step, args=(inputs,))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
266
unet.py
266
unet.py
@ -1,11 +1,11 @@
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow.keras import Input
|
from tensorflow.keras import Model, Input
|
||||||
from tensorflow.keras import layers
|
from tensorflow.keras import layers
|
||||||
from tensorflow.keras.initializers import TruncatedNormal
|
from tensorflow.keras.initializers import TruncatedNormal
|
||||||
import math as m
|
import math as m
|
||||||
|
|
||||||
|
|
||||||
def build_model_denoise(unet_args=None):
|
def build_model_denoise(unet_args=None):
|
||||||
|
|
||||||
inputs=Input(shape=(None, None,2))
|
inputs=Input(shape=(None, None,2))
|
||||||
|
|
||||||
outputs_stage_2,outputs_stage_1=MultiStage_denoise(unet_args=unet_args)(inputs)
|
outputs_stage_2,outputs_stage_1=MultiStage_denoise(unet_args=unet_args)(inputs)
|
||||||
@ -14,20 +14,17 @@ def build_model_denoise(unet_args=None):
|
|||||||
model= tf.keras.Model(inputs=inputs,outputs=[outputs_stage_2, outputs_stage_1])
|
model= tf.keras.Model(inputs=inputs,outputs=[outputs_stage_2, outputs_stage_1])
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
class DenseBlock(layers.Layer):
|
class DenseBlock(layers.Layer):
|
||||||
"""
|
'''
|
||||||
[B, T, F, N] => [B, T, F, N]
|
[B, T, F, N] => [B, T, F, N]
|
||||||
DenseNet Block consisting of "num_layers" densely connected convolutional layers
|
DenseNet Block consisting of "num_layers" densely connected convolutional layers
|
||||||
"""
|
'''
|
||||||
|
|
||||||
def __init__(self, num_layers, N, ksize,activation):
|
def __init__(self, num_layers, N, ksize,activation):
|
||||||
"""
|
'''
|
||||||
num_layers: number of densely connected conv. layers
|
num_layers: number of densely connected conv. layers
|
||||||
N: Number of filters (same in each layer)
|
N: Number of filters (same in each layer)
|
||||||
ksize: Kernel size (same in each layer)
|
ksize: Kernel size (same in each layer)
|
||||||
"""
|
'''
|
||||||
super(DenseBlock, self).__init__()
|
super(DenseBlock, self).__init__()
|
||||||
self.activation=activation
|
self.activation=activation
|
||||||
|
|
||||||
@ -36,58 +33,52 @@ class DenseBlock(layers.Layer):
|
|||||||
self.num_layers=num_layers
|
self.num_layers=num_layers
|
||||||
|
|
||||||
for i in range(num_layers):
|
for i in range(num_layers):
|
||||||
self.H.append(
|
self.H.append(layers.Conv2D(filters=N,
|
||||||
layers.Conv2D(
|
|
||||||
filters=N,
|
|
||||||
kernel_size=ksize,
|
kernel_size=ksize,
|
||||||
kernel_initializer=TruncatedNormal(),
|
kernel_initializer=TruncatedNormal(),
|
||||||
strides=1,
|
strides=1,
|
||||||
padding="VALID",
|
padding='VALID',
|
||||||
activation=self.activation,
|
activation=self.activation))
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def call(self, x):
|
def call(self, x):
|
||||||
x_ = tf.pad(x, self.paddings_1, mode="SYMMETRIC")
|
|
||||||
|
x_=tf.pad(x, self.paddings_1, mode='SYMMETRIC')
|
||||||
x_ = self.H[0](x_)
|
x_ = self.H[0](x_)
|
||||||
if self.num_layers>1:
|
if self.num_layers>1:
|
||||||
for h in self.H[1:]:
|
for h in self.H[1:]:
|
||||||
x = tf.concat([x_, x], axis=-1)
|
x = tf.concat([x_, x], axis=-1)
|
||||||
x_ = tf.pad(x, self.paddings_1, mode="SYMMETRIC")
|
x_=tf.pad(x, self.paddings_1, mode='SYMMETRIC')
|
||||||
x_ = h(x_)
|
x_ = h(x_)
|
||||||
|
|
||||||
return x_
|
return x_
|
||||||
|
|
||||||
|
|
||||||
class FinalBlock(layers.Layer):
|
class FinalBlock(layers.Layer):
|
||||||
"""
|
'''
|
||||||
[B, T, F, N] => [B, T, F, 2]
|
[B, T, F, N] => [B, T, F, 2]
|
||||||
Final block. Basically, a 3x3 conv. layer to map the output features to the output complex spectrogram.
|
Final block. Basically, a 3x3 conv. layer to map the output features to the output complex spectrogram.
|
||||||
|
|
||||||
"""
|
'''
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(FinalBlock, self).__init__()
|
super(FinalBlock, self).__init__()
|
||||||
ksize=(3,3)
|
ksize=(3,3)
|
||||||
self.paddings_2=get_paddings(ksize)
|
self.paddings_2=get_paddings(ksize)
|
||||||
self.conv2 = layers.Conv2D(
|
self.conv2=layers.Conv2D(filters=2,
|
||||||
filters=2,
|
|
||||||
kernel_size=ksize,
|
kernel_size=ksize,
|
||||||
kernel_initializer=TruncatedNormal(),
|
kernel_initializer=TruncatedNormal(),
|
||||||
strides=1,
|
strides=1,
|
||||||
padding="VALID",
|
padding='VALID',
|
||||||
activation=None,
|
activation=None)
|
||||||
)
|
|
||||||
|
|
||||||
def call(self, inputs ):
|
def call(self, inputs ):
|
||||||
x = tf.pad(inputs, self.paddings_2, mode="SYMMETRIC")
|
|
||||||
|
x=tf.pad(inputs, self.paddings_2, mode='SYMMETRIC')
|
||||||
pred=self.conv2(x)
|
pred=self.conv2(x)
|
||||||
|
|
||||||
return pred
|
return pred
|
||||||
|
|
||||||
|
|
||||||
class SAM(layers.Layer):
|
class SAM(layers.Layer):
|
||||||
"""
|
'''
|
||||||
[B, T, F, N] => [B, T, F, N] , [B, T, F, N]
|
[B, T, F, N] => [B, T, F, N] , [B, T, F, N]
|
||||||
Supervised Attention Module:
|
Supervised Attention Module:
|
||||||
The purpose of SAM is to make the network only propagate the most relevant features to the second stage, discarding the less useful ones.
|
The purpose of SAM is to make the network only propagate the most relevant features to the second stage, discarding the less useful ones.
|
||||||
@ -95,55 +86,48 @@ class SAM(layers.Layer):
|
|||||||
The first stage output is then calculated adding the original input spectrogram to the residual noise.
|
The first stage output is then calculated adding the original input spectrogram to the residual noise.
|
||||||
The attention-guided features are computed using the attention masks M, which are directly calculated from the first stage output with a 1x1 convolution and a sigmoid function.
|
The attention-guided features are computed using the attention masks M, which are directly calculated from the first stage output with a 1x1 convolution and a sigmoid function.
|
||||||
|
|
||||||
"""
|
'''
|
||||||
|
|
||||||
def __init__(self, n_feat):
|
def __init__(self, n_feat):
|
||||||
super(SAM, self).__init__()
|
super(SAM, self).__init__()
|
||||||
|
|
||||||
ksize=(3,3)
|
ksize=(3,3)
|
||||||
self.paddings_1=get_paddings(ksize)
|
self.paddings_1=get_paddings(ksize)
|
||||||
self.conv1 = layers.Conv2D(
|
self.conv1 = layers.Conv2D(filters=n_feat,
|
||||||
filters=n_feat,
|
|
||||||
kernel_size=ksize,
|
kernel_size=ksize,
|
||||||
kernel_initializer=TruncatedNormal(),
|
kernel_initializer=TruncatedNormal(),
|
||||||
strides=1,
|
strides=1,
|
||||||
padding="VALID",
|
padding='VALID',
|
||||||
activation=None,
|
activation=None)
|
||||||
)
|
|
||||||
ksize=(3,3)
|
ksize=(3,3)
|
||||||
self.paddings_2=get_paddings(ksize)
|
self.paddings_2=get_paddings(ksize)
|
||||||
self.conv2 = layers.Conv2D(
|
self.conv2=layers.Conv2D(filters=2,
|
||||||
filters=2,
|
|
||||||
kernel_size=ksize,
|
kernel_size=ksize,
|
||||||
kernel_initializer=TruncatedNormal(),
|
kernel_initializer=TruncatedNormal(),
|
||||||
strides=1,
|
strides=1,
|
||||||
padding="VALID",
|
padding='VALID',
|
||||||
activation=None,
|
activation=None)
|
||||||
)
|
|
||||||
|
|
||||||
ksize=(3,3)
|
ksize=(3,3)
|
||||||
self.paddings_3=get_paddings(ksize)
|
self.paddings_3=get_paddings(ksize)
|
||||||
self.conv3 = layers.Conv2D(
|
self.conv3 = layers.Conv2D(filters=n_feat,
|
||||||
filters=n_feat,
|
|
||||||
kernel_size=ksize,
|
kernel_size=ksize,
|
||||||
kernel_initializer=TruncatedNormal(),
|
kernel_initializer=TruncatedNormal(),
|
||||||
strides=1,
|
strides=1,
|
||||||
padding="VALID",
|
padding='VALID',
|
||||||
activation=None,
|
activation=None)
|
||||||
)
|
|
||||||
self.cropadd=CropAddBlock()
|
self.cropadd=CropAddBlock()
|
||||||
|
|
||||||
def call(self, inputs, input_spectrogram):
|
def call(self, inputs, input_spectrogram):
|
||||||
x1 = tf.pad(inputs, self.paddings_1, mode="SYMMETRIC")
|
x1=tf.pad(inputs, self.paddings_1, mode='SYMMETRIC')
|
||||||
x1 = self.conv1(x1)
|
x1 = self.conv1(x1)
|
||||||
|
|
||||||
x = tf.pad(inputs, self.paddings_2, mode="SYMMETRIC")
|
x=tf.pad(inputs, self.paddings_2, mode='SYMMETRIC')
|
||||||
x=self.conv2(x)
|
x=self.conv2(x)
|
||||||
|
|
||||||
#residual prediction
|
#residual prediction
|
||||||
pred = layers.Add()([x, input_spectrogram]) #features to next stage
|
pred = layers.Add()([x, input_spectrogram]) #features to next stage
|
||||||
|
|
||||||
x3 = tf.pad(pred, self.paddings_3, mode="SYMMETRIC")
|
x3=tf.pad(pred, self.paddings_3, mode='SYMMETRIC')
|
||||||
M=self.conv3(x3)
|
M=self.conv3(x3)
|
||||||
|
|
||||||
M= tf.keras.activations.sigmoid(M)
|
M= tf.keras.activations.sigmoid(M)
|
||||||
@ -154,18 +138,17 @@ class SAM(layers.Layer):
|
|||||||
|
|
||||||
|
|
||||||
class AddFreqEncoding(layers.Layer):
|
class AddFreqEncoding(layers.Layer):
|
||||||
"""
|
'''
|
||||||
[B, T, F, 2] => [B, T, F, 12]
|
[B, T, F, 2] => [B, T, F, 12]
|
||||||
Generates frequency positional embeddings and concatenates them as 10 extra channels
|
Generates frequency positional embeddings and concatenates them as 10 extra channels
|
||||||
This function is optimized for F=1025
|
This function is optimized for F=1025
|
||||||
"""
|
'''
|
||||||
|
|
||||||
def __init__(self, f_dim):
|
def __init__(self, f_dim):
|
||||||
super(AddFreqEncoding, self).__init__()
|
super(AddFreqEncoding, self).__init__()
|
||||||
pi = tf.constant(m.pi)
|
pi = tf.constant(m.pi)
|
||||||
pi = tf.cast(pi, "float32")
|
pi=tf.cast(pi,'float32')
|
||||||
self.f_dim=f_dim #f_dim is fixed
|
self.f_dim=f_dim #f_dim is fixed
|
||||||
n = tf.cast(tf.range(f_dim) / (f_dim - 1), "float32")
|
n=tf.cast(tf.range(f_dim)/(f_dim-1),'float32')
|
||||||
coss=tf.math.cos(pi*n)
|
coss=tf.math.cos(pi*n)
|
||||||
f_channel = tf.expand_dims(coss, -1) #(1025,1)
|
f_channel = tf.expand_dims(coss, -1) #(1025,1)
|
||||||
self.fembeddings= f_channel
|
self.fembeddings= f_channel
|
||||||
@ -173,38 +156,28 @@ class AddFreqEncoding(layers.Layer):
|
|||||||
for k in range(1,10):
|
for k in range(1,10):
|
||||||
coss=tf.math.cos(2**k*pi*n)
|
coss=tf.math.cos(2**k*pi*n)
|
||||||
f_channel = tf.expand_dims(coss, -1) #(1025,1)
|
f_channel = tf.expand_dims(coss, -1) #(1025,1)
|
||||||
self.fembeddings = tf.concat(
|
self.fembeddings=tf.concat([self.fembeddings,f_channel],axis=-1) #(1025,10)
|
||||||
[self.fembeddings, f_channel], axis=-1
|
|
||||||
) # (1025,10)
|
|
||||||
|
|
||||||
def call(self, input_tensor):
|
def call(self, input_tensor):
|
||||||
|
|
||||||
batch_size_tensor = tf.shape(input_tensor)[0] # get batch size
|
batch_size_tensor = tf.shape(input_tensor)[0] # get batch size
|
||||||
time_dim = tf.shape(input_tensor)[1] # get time dimension
|
time_dim = tf.shape(input_tensor)[1] # get time dimension
|
||||||
|
|
||||||
fembeddings_2 = tf.broadcast_to(
|
fembeddings_2 = tf.broadcast_to(self.fembeddings, [batch_size_tensor, time_dim, self.f_dim, 10])
|
||||||
self.fembeddings, [batch_size_tensor, time_dim, self.f_dim, 10]
|
|
||||||
)
|
|
||||||
|
|
||||||
return tf.concat([input_tensor,fembeddings_2],axis=-1) #(batch,427,1025,12)
|
return tf.concat([input_tensor,fembeddings_2],axis=-1) #(batch,427,1025,12)
|
||||||
|
|
||||||
|
|
||||||
def get_paddings(K):
|
def get_paddings(K):
|
||||||
return tf.constant(
|
return tf.constant([[0,0],[K[0]//2, K[0]//2 -(1- K[0]%2) ], [ K[1]//2, K[1]//2 -(1- K[1]%2) ],[0,0]])
|
||||||
[
|
|
||||||
[0, 0],
|
|
||||||
[K[0] // 2, K[0] // 2 - (1 - K[0] % 2)],
|
|
||||||
[K[1] // 2, K[1] // 2 - (1 - K[1] % 2)],
|
|
||||||
[0, 0],
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Decoder(layers.Layer):
|
class Decoder(layers.Layer):
|
||||||
"""
|
'''
|
||||||
[B, T, F, N] , skip connections => [B, T, F, N]
|
[B, T, F, N] , skip connections => [B, T, F, N]
|
||||||
Decoder side of the U-Net subnetwork.
|
Decoder side of the U-Net subnetwork.
|
||||||
"""
|
'''
|
||||||
|
|
||||||
def __init__(self, Ns, Ss, unet_args):
|
def __init__(self, Ns, Ss, unet_args):
|
||||||
super(Decoder, self).__init__()
|
super(Decoder, self).__init__()
|
||||||
|
|
||||||
@ -213,30 +186,21 @@ class Decoder(layers.Layer):
|
|||||||
self.activation=unet_args.activation
|
self.activation=unet_args.activation
|
||||||
self.depth=unet_args.depth
|
self.depth=unet_args.depth
|
||||||
|
|
||||||
|
|
||||||
ksize=(3,3)
|
ksize=(3,3)
|
||||||
self.paddings_3=get_paddings(ksize)
|
self.paddings_3=get_paddings(ksize)
|
||||||
self.conv2d_3 = layers.Conv2D(
|
self.conv2d_3=layers.Conv2D(filters=self.Ns[self.depth],
|
||||||
filters=self.Ns[self.depth],
|
|
||||||
kernel_size=ksize,
|
kernel_size=ksize,
|
||||||
kernel_initializer=TruncatedNormal(),
|
kernel_initializer=TruncatedNormal(),
|
||||||
strides=1,
|
strides=1,
|
||||||
padding="VALID",
|
padding='VALID',
|
||||||
activation=self.activation,
|
activation=self.activation)
|
||||||
)
|
|
||||||
|
|
||||||
self.cropadd=CropAddBlock()
|
self.cropadd=CropAddBlock()
|
||||||
|
|
||||||
self.dblocks=[]
|
self.dblocks=[]
|
||||||
for i in range(self.depth):
|
for i in range(self.depth):
|
||||||
self.dblocks.append(
|
self.dblocks.append(D_Block(layer_idx=i,N=self.Ns[i], S=self.Ss[i], activation=self.activation,num_tfc=unet_args.num_tfc))
|
||||||
D_Block(
|
|
||||||
layer_idx=i,
|
|
||||||
N=self.Ns[i],
|
|
||||||
S=self.Ss[i],
|
|
||||||
activation=self.activation,
|
|
||||||
num_tfc=unet_args.num_tfc,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def call(self,inputs, contracting_layers):
|
def call(self,inputs, contracting_layers):
|
||||||
x=inputs
|
x=inputs
|
||||||
@ -244,13 +208,12 @@ class Decoder(layers.Layer):
|
|||||||
x=self.dblocks[i-1](x, contracting_layers[i-1])
|
x=self.dblocks[i-1](x, contracting_layers[i-1])
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class Encoder(tf.keras.Model):
|
class Encoder(tf.keras.Model):
|
||||||
"""
|
|
||||||
|
'''
|
||||||
[B, T, F, N] => skip connections , [B, T, F, N_4]
|
[B, T, F, N] => skip connections , [B, T, F, N_4]
|
||||||
Encoder side of the U-Net subnetwork.
|
Encoder side of the U-Net subnetwork.
|
||||||
"""
|
'''
|
||||||
|
|
||||||
def __init__(self, Ns, Ss, unet_args):
|
def __init__(self, Ns, Ss, unet_args):
|
||||||
super(Encoder, self).__init__()
|
super(Encoder, self).__init__()
|
||||||
self.Ns=Ns
|
self.Ns=Ns
|
||||||
@ -262,22 +225,14 @@ class Encoder(tf.keras.Model):
|
|||||||
|
|
||||||
self.eblocks=[]
|
self.eblocks=[]
|
||||||
for i in range(self.depth):
|
for i in range(self.depth):
|
||||||
self.eblocks.append(
|
self.eblocks.append(E_Block(layer_idx=i,N0=self.Ns[i],N=self.Ns[i+1],S=self.Ss[i], activation=self.activation , num_tfc=unet_args.num_tfc))
|
||||||
E_Block(
|
|
||||||
layer_idx=i,
|
|
||||||
N0=self.Ns[i],
|
|
||||||
N=self.Ns[i + 1],
|
|
||||||
S=self.Ss[i],
|
|
||||||
activation=self.activation,
|
|
||||||
num_tfc=unet_args.num_tfc,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.i_block=I_Block(self.Ns[self.depth],self.activation,unet_args.num_tfc)
|
self.i_block=I_Block(self.Ns[self.depth],self.activation,unet_args.num_tfc)
|
||||||
|
|
||||||
def call(self, inputs):
|
def call(self, inputs):
|
||||||
x=inputs
|
x=inputs
|
||||||
for i in range(self.depth):
|
for i in range(self.depth):
|
||||||
|
|
||||||
x, x_contract=self.eblocks[i](x)
|
x, x_contract=self.eblocks[i](x)
|
||||||
|
|
||||||
self.contracting_layers[i] = x_contract #if remove 0, correct this
|
self.contracting_layers[i] = x_contract #if remove 0, correct this
|
||||||
@ -285,8 +240,8 @@ class Encoder(tf.keras.Model):
|
|||||||
|
|
||||||
return x, self.contracting_layers
|
return x, self.contracting_layers
|
||||||
|
|
||||||
|
|
||||||
class MultiStage_denoise(tf.keras.Model):
|
class MultiStage_denoise(tf.keras.Model):
|
||||||
|
|
||||||
def __init__(self, unet_args=None):
|
def __init__(self, unet_args=None):
|
||||||
super(MultiStage_denoise, self).__init__()
|
super(MultiStage_denoise, self).__init__()
|
||||||
|
|
||||||
@ -304,14 +259,13 @@ class MultiStage_denoise(tf.keras.Model):
|
|||||||
#initial feature extractor
|
#initial feature extractor
|
||||||
ksize=(7,7)
|
ksize=(7,7)
|
||||||
self.paddings_1=get_paddings(ksize)
|
self.paddings_1=get_paddings(ksize)
|
||||||
self.conv2d_1 = layers.Conv2D(
|
self.conv2d_1 = layers.Conv2D(filters=self.Ns[0],
|
||||||
filters=self.Ns[0],
|
|
||||||
kernel_size=ksize,
|
kernel_size=ksize,
|
||||||
kernel_initializer=TruncatedNormal(),
|
kernel_initializer=TruncatedNormal(),
|
||||||
strides=1,
|
strides=1,
|
||||||
padding="VALID",
|
padding='VALID',
|
||||||
activation=self.activation,
|
activation=self.activation)
|
||||||
)
|
|
||||||
|
|
||||||
self.encoder_s1=Encoder(self.Ns, self.Ss, unet_args)
|
self.encoder_s1=Encoder(self.Ns, self.Ss, unet_args)
|
||||||
self.decoder_s1=Decoder(self.Ns, self.Ss, unet_args)
|
self.decoder_s1=Decoder(self.Ns, self.Ss, unet_args)
|
||||||
@ -327,41 +281,39 @@ class MultiStage_denoise(tf.keras.Model):
|
|||||||
#initial feature extractor
|
#initial feature extractor
|
||||||
ksize=(7,7)
|
ksize=(7,7)
|
||||||
self.paddings_2=get_paddings(ksize)
|
self.paddings_2=get_paddings(ksize)
|
||||||
self.conv2d_2 = layers.Conv2D(
|
self.conv2d_2 = layers.Conv2D(filters=self.Ns[0],
|
||||||
filters=self.Ns[0],
|
|
||||||
kernel_size=ksize,
|
kernel_size=ksize,
|
||||||
kernel_initializer=TruncatedNormal(),
|
kernel_initializer=TruncatedNormal(),
|
||||||
strides=1,
|
strides=1,
|
||||||
padding="VALID",
|
padding='VALID',
|
||||||
activation=self.activation,
|
activation=self.activation)
|
||||||
)
|
|
||||||
|
|
||||||
self.encoder_s2=Encoder(self.Ns, self.Ss, unet_args)
|
self.encoder_s2=Encoder(self.Ns, self.Ss, unet_args)
|
||||||
self.decoder_s2=Decoder(self.Ns, self.Ss, unet_args)
|
self.decoder_s2=Decoder(self.Ns, self.Ss, unet_args)
|
||||||
|
|
||||||
@tf.function()
|
@tf.function()
|
||||||
def call(self, inputs):
|
def call(self, inputs):
|
||||||
|
|
||||||
if self.use_fencoding:
|
if self.use_fencoding:
|
||||||
x_w_freq=self.freq_encoding(inputs) #None, None, 1025, 12
|
x_w_freq=self.freq_encoding(inputs) #None, None, 1025, 12
|
||||||
else:
|
else:
|
||||||
x_w_freq=inputs
|
x_w_freq=inputs
|
||||||
|
|
||||||
#intitial feature extractor
|
#intitial feature extractor
|
||||||
x = tf.pad(x_w_freq, self.paddings_1, mode="SYMMETRIC")
|
x=tf.pad(x_w_freq, self.paddings_1, mode='SYMMETRIC')
|
||||||
x=self.conv2d_1(x) #None, None, 1025, 32
|
x=self.conv2d_1(x) #None, None, 1025, 32
|
||||||
|
|
||||||
x, contracting_layers_s1= self.encoder_s1(x)
|
x, contracting_layers_s1= self.encoder_s1(x)
|
||||||
#decoder
|
#decoder
|
||||||
feats_s1 = self.decoder_s1(
|
feats_s1 =self.decoder_s1(x, contracting_layers_s1) #None, None, 1025, 32 features
|
||||||
x, contracting_layers_s1
|
|
||||||
) # None, None, 1025, 32 features
|
|
||||||
|
|
||||||
if self.num_stages>1:
|
if self.num_stages>1:
|
||||||
#SAM module
|
#SAM module
|
||||||
Fout, pred_stage_1=self.sam_1(feats_s1,inputs)
|
Fout, pred_stage_1=self.sam_1(feats_s1,inputs)
|
||||||
|
|
||||||
#intitial feature extractor
|
#intitial feature extractor
|
||||||
x = tf.pad(x_w_freq, self.paddings_2, mode="SYMMETRIC")
|
x=tf.pad(x_w_freq, self.paddings_2, mode='SYMMETRIC')
|
||||||
x=self.conv2d_2(x)
|
x=self.conv2d_2(x)
|
||||||
|
|
||||||
if self.use_sam:
|
if self.use_sam:
|
||||||
@ -371,9 +323,7 @@ class MultiStage_denoise(tf.keras.Model):
|
|||||||
|
|
||||||
x, contracting_layers_s2= self.encoder_s2(x)
|
x, contracting_layers_s2= self.encoder_s2(x)
|
||||||
|
|
||||||
feats_s2 = self.decoder_s2(
|
feats_s2=self.decoder_s2(x, contracting_layers_s2) #None, None, 1025, 32 features
|
||||||
x, contracting_layers_s2
|
|
||||||
) # None, None, 1025, 32 features
|
|
||||||
|
|
||||||
#consider implementing a third stage?
|
#consider implementing a third stage?
|
||||||
|
|
||||||
@ -383,27 +333,23 @@ class MultiStage_denoise(tf.keras.Model):
|
|||||||
pred_stage_1=self.finalblock(feats_s1)
|
pred_stage_1=self.finalblock(feats_s1)
|
||||||
return pred_stage_1
|
return pred_stage_1
|
||||||
|
|
||||||
|
|
||||||
class I_Block(layers.Layer):
|
class I_Block(layers.Layer):
|
||||||
"""
|
'''
|
||||||
[B, T, F, N] => [B, T, F, N]
|
[B, T, F, N] => [B, T, F, N]
|
||||||
Intermediate block:
|
Intermediate block:
|
||||||
Basically, a densenet block with a residual connection
|
Basically, a densenet block with a residual connection
|
||||||
"""
|
'''
|
||||||
|
|
||||||
def __init__(self,N,activation, num_tfc, **kwargs):
|
def __init__(self,N,activation, num_tfc, **kwargs):
|
||||||
super(I_Block, self).__init__(**kwargs)
|
super(I_Block, self).__init__(**kwargs)
|
||||||
|
|
||||||
ksize=(3,3)
|
ksize=(3,3)
|
||||||
self.tfc=DenseBlock(num_tfc,N,ksize, activation)
|
self.tfc=DenseBlock(num_tfc,N,ksize, activation)
|
||||||
|
|
||||||
self.conv2d_res = layers.Conv2D(
|
self.conv2d_res= layers.Conv2D(filters=N,
|
||||||
filters=N,
|
|
||||||
kernel_size=(1,1),
|
kernel_size=(1,1),
|
||||||
kernel_initializer=TruncatedNormal(),
|
kernel_initializer=TruncatedNormal(),
|
||||||
strides=1,
|
strides=1,
|
||||||
padding="VALID",
|
padding='VALID')
|
||||||
)
|
|
||||||
|
|
||||||
def call(self,inputs):
|
def call(self,inputs):
|
||||||
x=self.tfc(inputs)
|
x=self.tfc(inputs)
|
||||||
@ -413,6 +359,7 @@ class I_Block(layers.Layer):
|
|||||||
|
|
||||||
|
|
||||||
class E_Block(layers.Layer):
|
class E_Block(layers.Layer):
|
||||||
|
|
||||||
def __init__(self, layer_idx,N0, N, S,activation, num_tfc, **kwargs):
|
def __init__(self, layer_idx,N0, N, S,activation, num_tfc, **kwargs):
|
||||||
super(E_Block, self).__init__(**kwargs)
|
super(E_Block, self).__init__(**kwargs)
|
||||||
self.layer_idx=layer_idx
|
self.layer_idx=layer_idx
|
||||||
@ -424,33 +371,31 @@ class E_Block(layers.Layer):
|
|||||||
|
|
||||||
ksize=(S[0]+2,S[1]+2)
|
ksize=(S[0]+2,S[1]+2)
|
||||||
self.paddings_2=get_paddings(ksize)
|
self.paddings_2=get_paddings(ksize)
|
||||||
self.conv2d_2 = layers.Conv2D(
|
self.conv2d_2 = layers.Conv2D(filters=N,
|
||||||
filters=N,
|
|
||||||
kernel_size=(S[0]+2,S[1]+2),
|
kernel_size=(S[0]+2,S[1]+2),
|
||||||
kernel_initializer=TruncatedNormal(),
|
kernel_initializer=TruncatedNormal(),
|
||||||
strides=S,
|
strides=S,
|
||||||
padding="VALID",
|
padding='VALID',
|
||||||
activation=self.activation,
|
activation=self.activation)
|
||||||
)
|
|
||||||
|
|
||||||
def call(self, inputs, training=None, **kwargs):
|
def call(self, inputs, training=None, **kwargs):
|
||||||
x=self.i_block(inputs)
|
x=self.i_block(inputs)
|
||||||
|
|
||||||
x_down = tf.pad(x, self.paddings_2, mode="SYMMETRIC")
|
x_down=tf.pad(x, self.paddings_2, mode='SYMMETRIC')
|
||||||
x_down = self.conv2d_2(x_down)
|
x_down = self.conv2d_2(x_down)
|
||||||
|
|
||||||
return x_down, x
|
return x_down, x
|
||||||
|
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
return dict(
|
return dict(layer_idx=self.layer_idx,
|
||||||
layer_idx=self.layer_idx,
|
|
||||||
N=self.N,
|
N=self.N,
|
||||||
S=self.S,
|
S=self.S,
|
||||||
**super(E_Block, self).get_config(),
|
**super(E_Block, self).get_config()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class D_Block(layers.Layer):
|
class D_Block(layers.Layer):
|
||||||
|
|
||||||
def __init__(self, layer_idx, N, S,activation, num_tfc, **kwargs):
|
def __init__(self, layer_idx, N, S,activation, num_tfc, **kwargs):
|
||||||
super(D_Block, self).__init__(**kwargs)
|
super(D_Block, self).__init__(**kwargs)
|
||||||
self.layer_idx=layer_idx
|
self.layer_idx=layer_idx
|
||||||
@ -460,35 +405,29 @@ class D_Block(layers.Layer):
|
|||||||
ksize=(S[0]+2, S[1]+2)
|
ksize=(S[0]+2, S[1]+2)
|
||||||
self.paddings_1=get_paddings(ksize)
|
self.paddings_1=get_paddings(ksize)
|
||||||
|
|
||||||
self.tconv_1 = layers.Conv2DTranspose(
|
self.tconv_1= layers.Conv2DTranspose(filters=N,
|
||||||
filters=N,
|
|
||||||
kernel_size=(S[0]+2, S[1]+2),
|
kernel_size=(S[0]+2, S[1]+2),
|
||||||
kernel_initializer=TruncatedNormal(),
|
kernel_initializer=TruncatedNormal(),
|
||||||
strides=S,
|
strides=S,
|
||||||
activation=self.activation,
|
activation=self.activation,
|
||||||
padding="VALID",
|
padding='VALID')
|
||||||
)
|
|
||||||
|
|
||||||
self.upsampling = layers.UpSampling2D(size=S, interpolation="nearest")
|
self.upsampling = layers.UpSampling2D(size=S, interpolation='nearest')
|
||||||
|
|
||||||
self.projection = layers.Conv2D(
|
self.projection = layers.Conv2D(filters=N,
|
||||||
filters=N,
|
|
||||||
kernel_size=(1,1),
|
kernel_size=(1,1),
|
||||||
kernel_initializer=TruncatedNormal(),
|
kernel_initializer=TruncatedNormal(),
|
||||||
strides=1,
|
strides=1,
|
||||||
activation=self.activation,
|
activation=self.activation,
|
||||||
padding="VALID",
|
padding='VALID')
|
||||||
)
|
|
||||||
self.cropadd=CropAddBlock()
|
self.cropadd=CropAddBlock()
|
||||||
self.cropconcat=CropConcatBlock()
|
self.cropconcat=CropConcatBlock()
|
||||||
|
|
||||||
self.i_block=I_Block(N,activation,num_tfc)
|
self.i_block=I_Block(N,activation,num_tfc)
|
||||||
|
|
||||||
def call(
|
def call(self, inputs, bridge, previous_encoder=None, previous_decoder=None,**kwargs):
|
||||||
self, inputs, bridge, previous_encoder=None, previous_decoder=None, **kwargs
|
|
||||||
):
|
|
||||||
x = inputs
|
x = inputs
|
||||||
x = tf.pad(x, self.paddings_1, mode="SYMMETRIC")
|
x=tf.pad(x, self.paddings_1, mode='SYMMETRIC')
|
||||||
x = self.tconv_1(inputs)
|
x = self.tconv_1(inputs)
|
||||||
|
|
||||||
x2= self.upsampling(inputs)
|
x2= self.upsampling(inputs)
|
||||||
@ -498,40 +437,39 @@ class D_Block(layers.Layer):
|
|||||||
|
|
||||||
x= self.cropadd(x,x2)
|
x= self.cropadd(x,x2)
|
||||||
|
|
||||||
|
|
||||||
x=self.cropconcat(x,bridge)
|
x=self.cropconcat(x,bridge)
|
||||||
|
|
||||||
x=self.i_block(x)
|
x=self.i_block(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
return dict(
|
return dict(layer_idx=self.layer_idx,
|
||||||
layer_idx=self.layer_idx,
|
|
||||||
N=self.N,
|
N=self.N,
|
||||||
S=self.S,
|
S=self.S,
|
||||||
**super(D_Block, self).get_config(),
|
**super(D_Block, self).get_config()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class CropAddBlock(layers.Layer):
|
class CropAddBlock(layers.Layer):
|
||||||
|
|
||||||
def call(self,down_layer, x, **kwargs):
|
def call(self,down_layer, x, **kwargs):
|
||||||
x1_shape = tf.shape(down_layer)
|
x1_shape = tf.shape(down_layer)
|
||||||
x2_shape = tf.shape(x)
|
x2_shape = tf.shape(x)
|
||||||
|
|
||||||
|
|
||||||
height_diff = (x1_shape[1] - x2_shape[1]) // 2
|
height_diff = (x1_shape[1] - x2_shape[1]) // 2
|
||||||
width_diff = (x1_shape[2] - x2_shape[2]) // 2
|
width_diff = (x1_shape[2] - x2_shape[2]) // 2
|
||||||
|
|
||||||
down_layer_cropped = down_layer[
|
down_layer_cropped = down_layer[:,
|
||||||
:,
|
|
||||||
height_diff: (x2_shape[1] + height_diff),
|
height_diff: (x2_shape[1] + height_diff),
|
||||||
width_diff: (x2_shape[2] + width_diff),
|
width_diff: (x2_shape[2] + width_diff),
|
||||||
:,
|
:]
|
||||||
]
|
|
||||||
|
|
||||||
x = layers.Add()([down_layer_cropped, x])
|
x = layers.Add()([down_layer_cropped, x])
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class CropConcatBlock(layers.Layer):
|
class CropConcatBlock(layers.Layer):
|
||||||
|
|
||||||
def call(self, down_layer, x, **kwargs):
|
def call(self, down_layer, x, **kwargs):
|
||||||
x1_shape = tf.shape(down_layer)
|
x1_shape = tf.shape(down_layer)
|
||||||
x2_shape = tf.shape(x)
|
x2_shape = tf.shape(x)
|
||||||
@ -539,12 +477,10 @@ class CropConcatBlock(layers.Layer):
|
|||||||
height_diff = (x1_shape[1] - x2_shape[1]) // 2
|
height_diff = (x1_shape[1] - x2_shape[1]) // 2
|
||||||
width_diff = (x1_shape[2] - x2_shape[2]) // 2
|
width_diff = (x1_shape[2] - x2_shape[2]) // 2
|
||||||
|
|
||||||
down_layer_cropped = down_layer[
|
down_layer_cropped = down_layer[:,
|
||||||
:,
|
|
||||||
height_diff: (x2_shape[1] + height_diff),
|
height_diff: (x2_shape[1] + height_diff),
|
||||||
width_diff: (x2_shape[2] + width_diff),
|
width_diff: (x2_shape[2] + width_diff),
|
||||||
:,
|
:]
|
||||||
]
|
|
||||||
|
|
||||||
x = tf.concat([down_layer_cropped, x], axis=-1)
|
x = tf.concat([down_layer_cropped, x], axis=-1)
|
||||||
return x
|
return x
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user