Compare commits

..

10 Commits

Author SHA1 Message Date
9870240572 fix formatting & linters 2025-05-08 12:56:32 +03:00
Eloi Moliner Juanpere
114fce7c84
Merge pull request #6 from JorenSix/patch-1
Update inference.sh
2022-06-29 17:50:42 +03:00
Joren Six
1fe1988ed5
Update inference.sh
Small change to allow spaces in file names. Bash expands the variable $1 correctly even if it is in double quotes, python receives a single argument and not (if there are spaces) multiple arguments.
2022-06-29 09:58:53 +02:00
Eloi Moliner Juanpere
fb7a32a1ff
Update README.md 2022-05-05 10:47:26 +03:00
Eloi Moliner Juanpere
b7d071a54c
Update README.md 2022-01-24 10:08:01 +02:00
Eloi Moliner Juanpere
6eb46ba2fc
Update README.md 2022-01-24 10:07:24 +02:00
Eloi Moliner Juanpere
018f4418e6
Update README.md 2022-01-24 10:06:15 +02:00
Eloi Moliner Juanpere
a1a92afefd Created using Colaboratory 2022-01-24 10:00:03 +02:00
Eloi Moliner Juanpere
214c872c51
Update README.md 2022-01-22 12:14:07 +02:00
Eloi Moliner Juanpere
210cd0edd8
Update README.md 2022-01-22 12:11:40 +02:00
9 changed files with 1163 additions and 914 deletions

3
.gitignore vendored Normal file
View File

@ -0,0 +1,3 @@
experiments
outputs
__pycache__

View File

@ -13,7 +13,7 @@ width="400px"></p>
Listen to our [audio samples](http://research.spa.aalto.fi/publications/papers/icassp22-denoising/) Listen to our [audio samples](http://research.spa.aalto.fi/publications/papers/icassp22-denoising/)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/eloimoliner/denoising-historical-recordings/blob/colab/colab/demo.ipynb] [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/eloimoliner/denoising-historical-recordings/blob/master/colab/demo.ipynb)
## Requirements ## Requirements
You will need at least python 3.7 and CUDA 10.1 if you want to use GPU. See `requirements.txt` for the required package versions. You will need at least python 3.7 and CUDA 10.1 if you want to use GPU. See `requirements.txt` for the required package versions.
@ -24,7 +24,10 @@ To install the environment through anaconda, follow the instructions:
conda activate historical_denoiser conda activate historical_denoiser
## Denoising Recordings ## Denoising Recordings
Run the following commands to clone the repository and install the pretrained weights of the two-stage U-Net model:
You can denoise your recordings in the cloud using the Colab notebook. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/eloimoliner/denoising-historical-recordings/blob/master/colab/demo.ipynb)
Otherwise, run the following commands to clone the repository and install the pretrained weights of the two-stage U-Net model:
git clone https://github.com/eloimoliner/denoising-historical-recordings.git git clone https://github.com/eloimoliner/denoising-historical-recordings.git
cd denoising-historical-recordings cd denoising-historical-recordings
@ -37,7 +40,13 @@ If the environment is installed correctly, you can denoise an audio file by runn
A ".wav" file with the denoised version, as well as the residual noise and the original signal in "mono", will be generated in the same directory as the input file. A ".wav" file with the denoised version, as well as the residual noise and the original signal in "mono", will be generated in the same directory as the input file.
## Training ## Training
TODO To retrain the model, follow the instructions:
Download the [Gramophone Noise Dataset](http://research.spa.aalto.fi/publications/papers/icassp22-denoising/media/datasets/Gramophone_Record_Noise_Dataset.zip), or any other dataset containing recording noises.
Prepare a dataset of clean music (e.g. [MusicNet](https://zenodo.org/record/5120004#.YnN-96IzbmE))
## Remarks ## Remarks
The trained model is specialized in denoising gramophone recordings, such as the ones included in this collection https://archive.org/details/georgeblood. It has shown to be robust to a wide range of different noises, but it may produce some artifacts if you try to inference in something completely different. The trained model is specialized in denoising gramophone recordings, such as the ones included in this collection https://archive.org/details/georgeblood. It has shown to be robust to a wide range of different noises, but it may produce some artifacts if you try to inference in something completely different.

View File

@ -7,7 +7,7 @@
"colab_type": "text" "colab_type": "text"
}, },
"source": [ "source": [
"<a href=\"https://colab.research.google.com/github/eloimoliner/denoising-historical-recordings/blob/colab/colab/demo.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" "<a href=\"https://colab.research.google.com/github/eloimoliner/denoising-historical-recordings/blob/master/colab/demo.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
] ]
}, },
{ {
@ -40,7 +40,8 @@
"* Make sure to use a GPU runtime, click: __Runtime >> Change Runtime Type >> GPU__\n", "* Make sure to use a GPU runtime, click: __Runtime >> Change Runtime Type >> GPU__\n",
"* Press ▶️ on the left of each of the cells\n", "* Press ▶️ on the left of each of the cells\n",
"* View the code: Double-click any of the cells\n", "* View the code: Double-click any of the cells\n",
"* Hide the code: Double click the right side of the cell\n" "* Hide the code: Double click the right side of the cell\n",
"* For some reason, this notebook does not work in Firefox, so please use another browser.\n"
], ],
"metadata": { "metadata": {
"id": "8UON6ncSApA9" "id": "8UON6ncSApA9"
@ -207,7 +208,7 @@
"id": "TQBDTmO4mUBx" "id": "TQBDTmO4mUBx"
}, },
"id": "TQBDTmO4mUBx", "id": "TQBDTmO4mUBx",
"execution_count": 4, "execution_count": null,
"outputs": [] "outputs": []
}, },
{ {
@ -243,7 +244,7 @@
"outputId": "2d05860c-536d-45f8-92b4-d2ba6f5a54c5" "outputId": "2d05860c-536d-45f8-92b4-d2ba6f5a54c5"
}, },
"id": "50Kmdy6AtbhW", "id": "50Kmdy6AtbhW",
"execution_count": 5, "execution_count": null,
"outputs": [ "outputs": [
{ {
"output_type": "display_data", "output_type": "display_data",
@ -296,7 +297,7 @@
"outputId": "173f5355-2939-41fe-c702-591aa752fc7e" "outputId": "173f5355-2939-41fe-c702-591aa752fc7e"
}, },
"id": "0po6zpvrylc2", "id": "0po6zpvrylc2",
"execution_count": 6, "execution_count": null,
"outputs": [ "outputs": [
{ {
"output_type": "stream", "output_type": "stream",
@ -333,7 +334,7 @@
"outputId": "54588c26-0b3c-42bf-aca2-8316ab54603f" "outputId": "54588c26-0b3c-42bf-aca2-8316ab54603f"
}, },
"id": "3tEshWBezYvf", "id": "3tEshWBezYvf",
"execution_count": 7, "execution_count": null,
"outputs": [ "outputs": [
{ {
"output_type": "display_data", "output_type": "display_data",

View File

@ -4,485 +4,561 @@ import tensorflow as tf
import random import random
import os import os
import numpy as np import numpy as np
from scipy.fft import fft, ifft
import soundfile as sf import soundfile as sf
import math
import pandas as pd import pandas as pd
import scipy as sp
import glob import glob
from tqdm import tqdm from tqdm import tqdm
#generator function. It reads the csv file with pandas and loads the largest audio segments from each recording. If extend=False, it will only read the segments with length>length_seg, trim them and yield them with no further processing. Otherwise, if the segment length is inferior, it will extend the length using concatenative synthesis.
def __noise_sample_generator(info_file,fs, length_seq, split): # generator function. It reads the csv file with pandas and loads the largest audio segments from each recording. If extend=False, it will only read the segments with length>length_seg, trim them and yield them with no further processing. Otherwise, if the segment length is inferior, it will extend the length using concatenative synthesis.
head=os.path.split(info_file)[0] def __noise_sample_generator(info_file, fs, length_seq, split):
load_data=pd.read_csv(info_file) head = os.path.split(info_file)[0]
#split= train, validation, test load_data = pd.read_csv(info_file)
load_data_split=load_data.loc[load_data["split"]==split] # split= train, validation, test
load_data_split=load_data_split.reset_index(drop=True) load_data_split = load_data.loc[load_data["split"] == split]
load_data_split = load_data_split.reset_index(drop=True)
while True: while True:
r = list(range(len(load_data_split))) r = list(range(len(load_data_split)))
if split!="test": if split != "test":
random.shuffle(r) random.shuffle(r)
for i in r: for i in r:
segments=ast.literal_eval(load_data_split.loc[i,"segments"]) segments = ast.literal_eval(load_data_split.loc[i, "segments"])
if split=="test": if split == "test":
loaded_data, Fs=sf.read(os.path.join(head,load_data_split["recording"].loc[i],load_data_split["largest_segment"].loc[i])) loaded_data, Fs = sf.read(
os.path.join(
head,
load_data_split["recording"].loc[i],
load_data_split["largest_segment"].loc[i],
)
)
else: else:
num=np.random.randint(0,len(segments)) num = np.random.randint(0, len(segments))
loaded_data, Fs=sf.read(os.path.join(head,load_data_split["recording"].loc[i],segments[num])) loaded_data, Fs = sf.read(
assert(fs==Fs, "wrong sampling rate") os.path.join(
head, load_data_split["recording"].loc[i], segments[num]
)
)
assert fs == Fs, "wrong sampling rate"
yield __extend_sample_by_repeating(loaded_data,fs,length_seq) yield __extend_sample_by_repeating(loaded_data, fs, length_seq)
def __extend_sample_by_repeating(data, fs,seq_len):
rpm=78
target_samp=seq_len
large_data=np.zeros(shape=(target_samp,2))
if len(data)>=target_samp: def __extend_sample_by_repeating(data, fs, seq_len):
large_data=data[0:target_samp] rpm = 78
target_samp = seq_len
large_data = np.zeros(shape=(target_samp, 2))
if len(data) >= target_samp:
large_data = data[0:target_samp]
return large_data return large_data
bls=(1000*44100)/1000 #hardcoded bls = (1000 * 44100) / 1000 # hardcoded
window=np.stack((np.hanning(bls) ,np.hanning(bls)), axis=1) window = np.stack((np.hanning(bls), np.hanning(bls)), axis=1)
window_left=window[0:int(bls/2),:] window_left = window[0 : int(bls / 2), :]
window_right=window[int(bls/2)::,:] window_right = window[int(bls / 2) : :, :]
bls=int(bls/2) bls = int(bls / 2)
rps=rpm/60 rps = rpm / 60
period=1/rps period = 1 / rps
period_sam=int(period*fs) period_sam = int(period * fs)
overhead=len(data)%period_sam overhead = len(data) % period_sam
if(overhead>bls): if overhead > bls:
complete_periods=(len(data)//period_sam)*period_sam complete_periods = (len(data) // period_sam) * period_sam
else: else:
complete_periods=(len(data)//period_sam -1)*period_sam complete_periods = (len(data) // period_sam - 1) * period_sam
a = np.multiply(data[0:bls], window_left)
b = np.multiply(data[complete_periods : complete_periods + bls], window_right)
c_1 = np.concatenate((data[0:complete_periods, :], b))
c_2 = np.concatenate((a, data[bls:complete_periods, :], b))
c_3 = np.concatenate((a, data[bls::, :]))
a=np.multiply(data[0:bls], window_left) large_data[0 : complete_periods + bls, :] = c_1
b=np.multiply(data[complete_periods:complete_periods+bls], window_right)
c_1=np.concatenate((data[0:complete_periods,:],b))
c_2=np.concatenate((a,data[bls:complete_periods,:],b))
c_3=np.concatenate((a,data[bls::,:]))
large_data[0:complete_periods+bls,:]=c_1 pointer = complete_periods
not_finished = True
while not_finished:
pointer=complete_periods if target_samp > pointer + complete_periods + bls:
not_finished=True large_data[pointer : pointer + complete_periods + bls] += c_2
while (not_finished): pointer += complete_periods
if target_samp>pointer+complete_periods+bls:
large_data[pointer:pointer+complete_periods+bls] +=c_2
pointer+=complete_periods
else: else:
large_data[pointer::]+=c_3[0:(target_samp-pointer)] large_data[pointer::] += c_3[0 : (target_samp - pointer)]
#finish # finish
not_finished=False not_finished = False
return large_data return large_data
def generate_real_recordings_data(path_recordings, fs=44100, seg_len_s=15, stereo=False): def generate_real_recordings_data(
path_recordings, fs=44100, seg_len_s=15, stereo=False
records_info=os.path.join(path_recordings,"audio_files.txt") ):
records_info = os.path.join(path_recordings, "audio_files.txt")
num_lines = sum(1 for line in open(records_info)) num_lines = sum(1 for line in open(records_info))
f = open(records_info,"r") f = open(records_info, "r")
#load data record files # load data record files
print("Loading record files") print("Loading record files")
records=[] records = []
seg_len=fs*seg_len_s seg_len = fs * seg_len_s
pointer=int(fs*5) #starting at second 5 by default pointer = int(fs * 5) # starting at second 5 by default
for i in tqdm(range(num_lines)): for i in tqdm(range(num_lines)):
audio=f.readline() audio = f.readline()
audio=audio[:-1] audio = audio[:-1]
data, fs=sf.read(os.path.join(path_recordings,audio)) data, fs = sf.read(os.path.join(path_recordings, audio))
if len(data.shape)>1 and not(stereo): if len(data.shape) > 1 and not (stereo):
data=np.mean(data,axis=1) data = np.mean(data, axis=1)
#elif stereo and len(data.shape)==1: # elif stereo and len(data.shape)==1:
# data=np.stack((data, data), axis=1) # data=np.stack((data, data), axis=1)
#normalize # normalize
data=data/np.max(np.abs(data)) data = data / np.max(np.abs(data))
segment=data[pointer:pointer+seg_len] segment = data[pointer : pointer + seg_len]
records.append(segment.astype("float32")) records.append(segment.astype("float32"))
return records return records
def generate_paired_data_test_formal(path_pianos, path_noises, noise_amount="low_snr",num_samples=-1, fs=44100, seg_len_s=5 , extend=True, stereo=False, prenoise=False):
def generate_paired_data_test_formal(
path_pianos,
path_noises,
noise_amount="low_snr",
num_samples=-1,
fs=44100,
seg_len_s=5,
extend=True,
stereo=False,
prenoise=False,
):
print(num_samples) print(num_samples)
segments_clean=[] segments_clean = []
segments_noisy=[] segments_noisy = []
seg_len=fs*seg_len_s seg_len = fs * seg_len_s
noises_info=os.path.join(path_noises,"info.csv") noises_info = os.path.join(path_noises, "info.csv")
np.random.seed(42) np.random.seed(42)
if noise_amount=="low_snr": if noise_amount == "low_snr":
SNRs=np.random.uniform(2,6,num_samples) SNRs = np.random.uniform(2, 6, num_samples)
elif noise_amount=="mid_snr": elif noise_amount == "mid_snr":
SNRs=np.random.uniform(6,12,num_samples) SNRs = np.random.uniform(6, 12, num_samples)
scales=np.random.uniform(-4,0,num_samples) scales = np.random.uniform(-4, 0, num_samples)
#SNRs=[2,6,12] #HARDCODED!!!! # SNRs=[2,6,12] #HARDCODED!!!!
i=0 i = 0
print(path_pianos[0]) print(path_pianos[0])
print(seg_len) print(seg_len)
train_samples=glob.glob(os.path.join(path_pianos[0],"*.wav")) train_samples = glob.glob(os.path.join(path_pianos[0], "*.wav"))
train_samples=sorted(train_samples) train_samples = sorted(train_samples)
if prenoise: if prenoise:
noise_generator=__noise_sample_generator(noises_info,fs, seg_len+fs, extend, "test") #Adds 1s of silence add the begiing, longer noise noise_generator = __noise_sample_generator(
noises_info, fs, seg_len + fs, extend, "test"
) # Adds 1s of silence add the begiing, longer noise
else: else:
noise_generator=__noise_sample_generator(noises_info,fs, seg_len, extend, "test") #this will take care of everything noise_generator = __noise_sample_generator(
#load data clean files noises_info, fs, seg_len, extend, "test"
for file in tqdm(train_samples): #add [1:5] for testing ) # this will take care of everything
# load data clean files
for file in tqdm(train_samples): # add [1:5] for testing
data_clean, samplerate = sf.read(file) data_clean, samplerate = sf.read(file)
if samplerate!=fs: if samplerate != fs:
print("!!!!WRONG SAMPLE RATe!!!") print("!!!!WRONG SAMPLE RATe!!!")
#Stereo to mono # Stereo to mono
if len(data_clean.shape)>1 and not(stereo): if len(data_clean.shape) > 1 and not (stereo):
data_clean=np.mean(data_clean,axis=1) data_clean = np.mean(data_clean, axis=1)
#elif stereo and len(data_clean.shape)==1: # elif stereo and len(data_clean.shape)==1:
# data_clean=np.stack((data_clean, data_clean), axis=1) # data_clean=np.stack((data_clean, data_clean), axis=1)
#normalize # normalize
data_clean=data_clean/np.max(np.abs(data_clean)) data_clean = data_clean / np.max(np.abs(data_clean))
#data_clean_loaded.append(data_clean) # data_clean_loaded.append(data_clean)
#framify data clean files # framify data clean files
#framify arguments: seg_len, hop_size # framify arguments: seg_len, hop_size
hop_size=int(seg_len)# no overlap hop_size = int(seg_len) # no overlap
num_frames=np.floor(len(data_clean)/hop_size - seg_len/hop_size +1) num_frames = np.floor(len(data_clean) / hop_size - seg_len / hop_size + 1)
print(num_frames) print(num_frames)
if num_frames==0: if num_frames == 0:
data_clean=np.concatenate((data_clean, np.zeros(shape=(int(2*seg_len-len(data_clean)),))), axis=0) data_clean = np.concatenate(
num_frames=1 (data_clean, np.zeros(shape=(int(2 * seg_len - len(data_clean)),))),
axis=0,
)
num_frames = 1
data_not_finished=True data_not_finished = True
pointer=0 pointer = 0
while(data_not_finished): while data_not_finished:
if i>=num_samples: if i >= num_samples:
break break
segment=data_clean[pointer:pointer+seg_len] segment = data_clean[pointer : pointer + seg_len]
pointer=pointer+hop_size pointer = pointer + hop_size
if pointer+seg_len>len(data_clean): if pointer + seg_len > len(data_clean):
data_not_finished=False data_not_finished = False
segment=segment.astype('float32') segment = segment.astype("float32")
#SNRs=np.random.uniform(2,20) # SNRs=np.random.uniform(2,20)
snr=SNRs[i] snr = SNRs[i]
scale=scales[i] scale = scales[i]
#load noise signal # load noise signal
data_noise= next(noise_generator) data_noise = next(noise_generator)
data_noise=np.mean(data_noise,axis=1) data_noise = np.mean(data_noise, axis=1)
#normalize # normalize
data_noise=data_noise/np.max(np.abs(data_noise)) data_noise = data_noise / np.max(np.abs(data_noise))
new_noise=data_noise #if more processing needed, add here new_noise = data_noise # if more processing needed, add here
#load clean data # load clean data
#configure sizes # configure sizes
power_clean=np.var(segment) power_clean = np.var(segment)
#estimate noise power # estimate noise power
if prenoise: if prenoise:
power_noise=np.var(new_noise[fs::]) power_noise = np.var(new_noise[fs::])
else: else:
power_noise=np.var(new_noise) power_noise = np.var(new_noise)
snr = 10.0**(snr/10.0) snr = 10.0 ** (snr / 10.0)
#sum both signals according to snr # sum both signals according to snr
if prenoise: if prenoise:
segment=np.concatenate((np.zeros(shape=(fs,)),segment),axis=0) #add one second of silence segment = np.concatenate(
summed=segment+np.sqrt(power_clean/(snr*power_noise))*new_noise #not sure if this is correct, maybe revisit later!! (np.zeros(shape=(fs,)), segment), axis=0
) # add one second of silence
summed = (
segment + np.sqrt(power_clean / (snr * power_noise)) * new_noise
) # not sure if this is correct, maybe revisit later!!
summed=summed.astype('float32') summed = summed.astype("float32")
#yield tf.convert_to_tensor(summed), tf.convert_to_tensor(segment) # yield tf.convert_to_tensor(summed), tf.convert_to_tensor(segment)
summed = 10.0 ** (scale / 10.0) * summed
summed=10.0**(scale/10.0) *summed segment = 10.0 ** (scale / 10.0) * segment
segment=10.0**(scale/10.0) *segment segments_noisy.append(summed.astype("float32"))
segments_noisy.append(summed.astype('float32')) segments_clean.append(segment.astype("float32"))
segments_clean.append(segment.astype('float32')) i = i + 1
i=i+1
return segments_noisy, segments_clean return segments_noisy, segments_clean
def generate_test_data(path_music, path_noises,num_samples=-1, fs=44100, seg_len_s=5):
segments_clean=[] def generate_test_data(path_music, path_noises, num_samples=-1, fs=44100, seg_len_s=5):
segments_noisy=[] segments_clean = []
seg_len=fs*seg_len_s segments_noisy = []
noises_info=os.path.join(path_noises,"info.csv") seg_len = fs * seg_len_s
SNRs=[2,6,12] #HARDCODED!!!! noises_info = os.path.join(path_noises, "info.csv")
SNRs = [2, 6, 12] # HARDCODED!!!!
for path in path_music: for path in path_music:
print(path) print(path)
train_samples=glob.glob(os.path.join(path,"*.wav")) train_samples = glob.glob(os.path.join(path, "*.wav"))
train_samples=sorted(train_samples) train_samples = sorted(train_samples)
noise_generator=__noise_sample_generator(noises_info,fs, seg_len, "test") #this will take care of everything noise_generator = __noise_sample_generator(
#load data clean files noises_info, fs, seg_len, "test"
jj=0 ) # this will take care of everything
for file in tqdm(train_samples): #add [1:5] for testing # load data clean files
for file in tqdm(train_samples): # add [1:5] for testing
data_clean, samplerate = sf.read(file) data_clean, samplerate = sf.read(file)
if samplerate!=fs: if samplerate != fs:
print("!!!!WRONG SAMPLE RATe!!!") print("!!!!WRONG SAMPLE RATe!!!")
#Stereo to mono # Stereo to mono
if len(data_clean.shape)>1: if len(data_clean.shape) > 1:
data_clean=np.mean(data_clean,axis=1) data_clean = np.mean(data_clean, axis=1)
#normalize # normalize
data_clean=data_clean/np.max(np.abs(data_clean)) data_clean = data_clean / np.max(np.abs(data_clean))
#data_clean_loaded.append(data_clean) # data_clean_loaded.append(data_clean)
#framify data clean files # framify data clean files
#framify arguments: seg_len, hop_size # framify arguments: seg_len, hop_size
hop_size=int(seg_len)# no overlap hop_size = int(seg_len) # no overlap
num_frames=np.floor(len(data_clean)/hop_size - seg_len/hop_size +1) num_frames = np.floor(len(data_clean) / hop_size - seg_len / hop_size + 1)
if num_frames==0: if num_frames == 0:
data_clean=np.concatenate((data_clean, np.zeros(shape=(int(2*seg_len-len(data_clean)),))), axis=0) data_clean = np.concatenate(
num_frames=1 (data_clean, np.zeros(shape=(int(2 * seg_len - len(data_clean)),))),
axis=0,
)
num_frames = 1
pointer=0 pointer = 0
segment=data_clean[pointer:pointer+(seg_len-2*fs)] segment = data_clean[pointer : pointer + (seg_len - 2 * fs)]
segment=segment.astype('float32') segment = segment.astype("float32")
segment=np.concatenate(( np.zeros(shape=(2*fs,)), segment), axis=0) #I hope its ok segment = np.concatenate(
#segments_clean.append(segment) (np.zeros(shape=(2 * fs,)), segment), axis=0
) # I hope its ok
# segments_clean.append(segment)
for snr in SNRs: for snr in SNRs:
#load noise signal # load noise signal
data_noise= next(noise_generator) data_noise = next(noise_generator)
data_noise=np.mean(data_noise,axis=1) data_noise = np.mean(data_noise, axis=1)
#normalize # normalize
data_noise=data_noise/np.max(np.abs(data_noise)) data_noise = data_noise / np.max(np.abs(data_noise))
new_noise=data_noise #if more processing needed, add here new_noise = data_noise # if more processing needed, add here
#load clean data # load clean data
#configure sizes # configure sizes
#estimate clean signal power # estimate clean signal power
power_clean=np.var(segment) power_clean = np.var(segment)
#estimate noise power # estimate noise power
power_noise=np.var(new_noise) power_noise = np.var(new_noise)
snr = 10.0**(snr/10.0) snr = 10.0 ** (snr / 10.0)
#sum both signals according to snr # sum both signals according to snr
summed=segment+np.sqrt(power_clean/(snr*power_noise))*new_noise #not sure if this is correct, maybe revisit later!! summed = (
summed=summed.astype('float32') segment + np.sqrt(power_clean / (snr * power_noise)) * new_noise
#yield tf.convert_to_tensor(summed), tf.convert_to_tensor(segment) ) # not sure if this is correct, maybe revisit later!!
summed = summed.astype("float32")
# yield tf.convert_to_tensor(summed), tf.convert_to_tensor(segment)
segments_noisy.append(summed.astype('float32')) segments_noisy.append(summed.astype("float32"))
segments_clean.append(segment.astype('float32')) segments_clean.append(segment.astype("float32"))
return segments_noisy, segments_clean return segments_noisy, segments_clean
def generate_val_data(path_music, path_noises,split,num_samples=-1, fs=44100, seg_len_s=5):
val_samples=[] def generate_val_data(
path_music, path_noises, split, num_samples=-1, fs=44100, seg_len_s=5
):
val_samples = []
for path in path_music: for path in path_music:
val_samples.extend(glob.glob(os.path.join(path,"*.wav"))) val_samples.extend(glob.glob(os.path.join(path, "*.wav")))
#load data clean files # load data clean files
print("Loading clean files") print("Loading clean files")
data_clean_loaded=[] data_clean_loaded = []
for ff in tqdm(range(0,len(val_samples))): #add [1:5] for testing for ff in tqdm(range(0, len(val_samples))): # add [1:5] for testing
data_clean, samplerate = sf.read(val_samples[ff]) data_clean, samplerate = sf.read(val_samples[ff])
if samplerate!=fs: if samplerate != fs:
print("!!!!WRONG SAMPLE RATe!!!") print("!!!!WRONG SAMPLE RATe!!!")
#Stereo to mono # Stereo to mono
if len(data_clean.shape)>1 : if len(data_clean.shape) > 1:
data_clean=np.mean(data_clean,axis=1) data_clean = np.mean(data_clean, axis=1)
#normalize # normalize
data_clean=data_clean/np.max(np.abs(data_clean)) data_clean = data_clean / np.max(np.abs(data_clean))
data_clean_loaded.append(data_clean) data_clean_loaded.append(data_clean)
del data_clean del data_clean
#framify data clean files # framify data clean files
print("Framifying clean files") print("Framifying clean files")
seg_len=fs*seg_len_s seg_len = fs * seg_len_s
segments_clean=[] segments_clean = []
for file in tqdm(data_clean_loaded): for file in tqdm(data_clean_loaded):
# framify arguments: seg_len, hop_size
hop_size = int(seg_len) # no overlap
#framify arguments: seg_len, hop_size num_frames = np.floor(len(file) / hop_size - seg_len / hop_size + 1)
hop_size=int(seg_len)# no overlap pointer = 0
for i in range(0, int(num_frames)):
num_frames=np.floor(len(file)/hop_size - seg_len/hop_size +1) segment = file[pointer : pointer + seg_len]
pointer=0 pointer = pointer + hop_size
for i in range(0,int(num_frames)): segment = segment.astype("float32")
segment=file[pointer:pointer+seg_len]
pointer=pointer+hop_size
segment=segment.astype('float32')
segments_clean.append(segment) segments_clean.append(segment)
del data_clean_loaded del data_clean_loaded
SNRs=np.random.uniform(2,20,len(segments_clean)) SNRs = np.random.uniform(2, 20, len(segments_clean))
scales=np.random.uniform(-6,4,len(segments_clean)) scales = np.random.uniform(-6, 4, len(segments_clean))
#noise_shapes=np.random.randint(0,len(noise_samples), len(segments_clean)) # noise_shapes=np.random.randint(0,len(noise_samples), len(segments_clean))
noises_info=os.path.join(path_noises,"info.csv") noises_info = os.path.join(path_noises, "info.csv")
noise_generator=__noise_sample_generator(noises_info,fs, seg_len, split) #this will take care of everything noise_generator = __noise_sample_generator(
noises_info, fs, seg_len, split
) # this will take care of everything
# generate noisy segments
# load noise samples using pandas dataframe. Each split (train, val, test) should have its unique csv info file
#generate noisy segments # noise_samples=glob.glob(os.path.join(path_noises,"*.wav"))
#load noise samples using pandas dataframe. Each split (train, val, test) should have its unique csv info file segments_noisy = []
#noise_samples=glob.glob(os.path.join(path_noises,"*.wav"))
segments_noisy=[]
print("Processing noisy segments") print("Processing noisy segments")
for i in tqdm(range(0,len(segments_clean))): for i in tqdm(range(0, len(segments_clean))):
#load noise signal # load noise signal
data_noise= next(noise_generator) data_noise = next(noise_generator)
#Stereo to mono # Stereo to mono
data_noise=np.mean(data_noise,axis=1) data_noise = np.mean(data_noise, axis=1)
#normalize # normalize
data_noise=data_noise/np.max(np.abs(data_noise)) data_noise = data_noise / np.max(np.abs(data_noise))
new_noise=data_noise #if more processing needed, add here new_noise = data_noise # if more processing needed, add here
#load clean data # load clean data
data_clean=segments_clean[i] data_clean = segments_clean[i]
#configure sizes # configure sizes
# estimate clean signal power
power_clean = np.var(data_clean)
# estimate noise power
power_noise = np.var(new_noise)
#estimate clean signal power snr = 10.0 ** (SNRs[i] / 10.0)
power_clean=np.var(data_clean)
#estimate noise power
power_noise=np.var(new_noise)
snr = 10.0**(SNRs[i]/10.0) # sum both signals according to snr
summed = (
data_clean + np.sqrt(power_clean / (snr * power_noise)) * new_noise
) # not sure if this is correct, maybe revisit later!!
# the rest is normal
#sum both signals according to snr summed = 10.0 ** (scales[i] / 10.0) * summed
summed=data_clean+np.sqrt(power_clean/(snr*power_noise))*new_noise #not sure if this is correct, maybe revisit later!! segments_clean[i] = 10.0 ** (scales[i] / 10.0) * segments_clean[i]
#the rest is normal
summed=10.0**(scales[i]/10.0) *summed segments_noisy.append(summed.astype("float32"))
segments_clean[i]=10.0**(scales[i]/10.0) *segments_clean[i]
segments_noisy.append(summed.astype('float32'))
return segments_noisy, segments_clean return segments_noisy, segments_clean
def generator_train(
def generator_train(path_music, path_noises,split, fs=44100, seg_len_s=5, extend=True, stereo=False): path_music, path_noises, split, fs=44100, seg_len_s=5, extend=True, stereo=False
):
train_samples=[] train_samples = []
for path in path_music: for path in path_music:
train_samples.extend(glob.glob(os.path.join(path.decode("utf-8") ,"*.wav"))) train_samples.extend(glob.glob(os.path.join(path.decode("utf-8"), "*.wav")))
seg_len=fs*seg_len_s seg_len = fs * seg_len_s
noises_info=os.path.join(path_noises.decode("utf-8"),"info.csv") noises_info = os.path.join(path_noises.decode("utf-8"), "info.csv")
noise_generator=__noise_sample_generator(noises_info,fs, seg_len, split.decode("utf-8")) #this will take care of everything noise_generator = __noise_sample_generator(
#load data clean files noises_info, fs, seg_len, split.decode("utf-8")
) # this will take care of everything
# load data clean files
while True: while True:
random.shuffle(train_samples) random.shuffle(train_samples)
for file in train_samples: for file in train_samples:
data, samplerate = sf.read(file) data, samplerate = sf.read(file)
assert(samplerate==fs, "wrong sampling rate") assert samplerate == fs, "wrong sampling rate"
data_clean=data data_clean = data
#Stereo to mono # Stereo to mono
if len(data.shape)>1 : if len(data.shape) > 1:
data_clean=np.mean(data_clean,axis=1) data_clean = np.mean(data_clean, axis=1)
#normalize # normalize
data_clean=data_clean/np.max(np.abs(data_clean)) data_clean = data_clean / np.max(np.abs(data_clean))
#framify data clean files # framify data clean files
#framify arguments: seg_len, hop_size # framify arguments: seg_len, hop_size
hop_size=int(seg_len) hop_size = int(seg_len)
num_frames=np.floor(len(data_clean)/seg_len) num_frames = np.floor(len(data_clean) / seg_len)
if num_frames==0: if num_frames == 0:
data_clean=np.concatenate((data_clean, np.zeros(shape=(int(2*seg_len-len(data_clean)),))), axis=0) data_clean = np.concatenate(
num_frames=1 (data_clean, np.zeros(shape=(int(2 * seg_len - len(data_clean)),))),
pointer=0 axis=0,
data_clean=np.roll(data_clean, np.random.randint(0,seg_len)) #if only one frame, roll it for augmentation )
elif num_frames>1: num_frames = 1
pointer=np.random.randint(0,hop_size) #initial shifting, graeat for augmentation, better than overlap as we get different frames at each "while" iteration pointer = 0
data_clean = np.roll(
data_clean, np.random.randint(0, seg_len)
) # if only one frame, roll it for augmentation
elif num_frames > 1:
pointer = np.random.randint(
0, hop_size
) # initial shifting, graeat for augmentation, better than overlap as we get different frames at each "while" iteration
else: else:
pointer=0 pointer = 0
data_not_finished=True data_not_finished = True
while(data_not_finished): while data_not_finished:
segment=data_clean[pointer:pointer+seg_len] segment = data_clean[pointer : pointer + seg_len]
pointer=pointer+hop_size pointer = pointer + hop_size
if pointer+seg_len>len(data_clean): if pointer + seg_len > len(data_clean):
data_not_finished=False data_not_finished = False
segment=segment.astype('float32') segment = segment.astype("float32")
SNRs=np.random.uniform(2,20) SNRs = np.random.uniform(2, 20)
scale=np.random.uniform(-6,4) scale = np.random.uniform(-6, 4)
# load noise signal
#load noise signal data_noise = next(noise_generator)
data_noise= next(noise_generator) data_noise = np.mean(data_noise, axis=1)
data_noise=np.mean(data_noise,axis=1) # normalize
#normalize data_noise = data_noise / np.max(np.abs(data_noise))
data_noise=data_noise/np.max(np.abs(data_noise)) new_noise = data_noise # if more processing needed, add here
new_noise=data_noise #if more processing needed, add here # load clean data
#load clean data # configure sizes
#configure sizes
if stereo: if stereo:
#estimate clean signal power # estimate clean signal power
power_clean=0.5*np.var(segment[:,0])+0.5*np.var(segment[:,1]) power_clean = 0.5 * np.var(segment[:, 0]) + 0.5 * np.var(
#estimate noise power segment[:, 1]
power_noise=0.5*np.var(new_noise[:,0])+0.5*np.var(new_noise[:,1]) )
# estimate noise power
power_noise = 0.5 * np.var(new_noise[:, 0]) + 0.5 * np.var(
new_noise[:, 1]
)
else: else:
#estimate clean signal power # estimate clean signal power
power_clean=np.var(segment) power_clean = np.var(segment)
#estimate noise power # estimate noise power
power_noise=np.var(new_noise) power_noise = np.var(new_noise)
snr = 10.0**(SNRs/10.0) snr = 10.0 ** (SNRs / 10.0)
# sum both signals according to snr
summed = (
segment + np.sqrt(power_clean / (snr * power_noise)) * new_noise
) # not sure if this is correct, maybe revisit later!!
summed = 10.0 ** (scale / 10.0) * summed
segment = 10.0 ** (scale / 10.0) * segment
#sum both signals according to snr summed = summed.astype("float32")
summed=segment+np.sqrt(power_clean/(snr*power_noise))*new_noise #not sure if this is correct, maybe revisit later!!
summed=10.0**(scale/10.0) *summed
segment=10.0**(scale/10.0) *segment
summed=summed.astype('float32')
yield tf.convert_to_tensor(summed), tf.convert_to_tensor(segment) yield tf.convert_to_tensor(summed), tf.convert_to_tensor(segment)
def load_data(buffer_size, path_music_train, path_music_val, path_noises, fs=44100, seg_len_s=5, extend=True, stereo=False) :
def load_data(
buffer_size,
path_music_train,
path_music_val,
path_noises,
fs=44100,
seg_len_s=5,
extend=True,
stereo=False,
):
print("Generating train dataset") print("Generating train dataset")
trainshape=int(fs*seg_len_s) trainshape = int(fs * seg_len_s)
dataset_train = tf.data.Dataset.from_generator(generator_train,args=(path_music_train, path_noises,"train", fs, seg_len_s, extend, stereo), output_shapes=(tf.TensorShape((trainshape,)),tf.TensorShape((trainshape,))), output_types=(tf.float32, tf.float32) )
dataset_train = tf.data.Dataset.from_generator(
generator_train,
args=(path_music_train, path_noises, "train", fs, seg_len_s, extend, stereo),
output_shapes=(tf.TensorShape((trainshape,)), tf.TensorShape((trainshape,))),
output_types=(tf.float32, tf.float32),
)
print("Generating validation dataset") print("Generating validation dataset")
segments_noisy, segments_clean=generate_val_data(path_music_val, path_noises,"validation",fs=fs, seg_len_s=seg_len_s) segments_noisy, segments_clean = generate_val_data(
path_music_val, path_noises, "validation", fs=fs, seg_len_s=seg_len_s
)
dataset_val=tf.data.Dataset.from_tensor_slices((segments_noisy, segments_clean)) dataset_val = tf.data.Dataset.from_tensor_slices((segments_noisy, segments_clean))
return dataset_train.shuffle(buffer_size), dataset_val return dataset_train.shuffle(buffer_size), dataset_val
def load_data_test(buffer_size, path_pianos_test, path_noises, **kwargs):
def load_data_test(buffer_size, path_pianos_test, path_noises, **kwargs):
print("Generating test dataset") print("Generating test dataset")
segments_noisy, segments_clean=generate_test_data(path_pianos_test, path_noises, extend=True, **kwargs) segments_noisy, segments_clean = generate_test_data(
dataset_test=tf.data.Dataset.from_tensor_slices((segments_noisy, segments_clean)) path_pianos_test, path_noises, extend=True, **kwargs
#dataset_test=tf.data.Dataset.from_tensor_slices((segments_noisy[1:3], segments_clean[1:3])) )
#train_dataset = train.cache().shuffle(buffer_size).take(info.splits["train"].num_examples) dataset_test = tf.data.Dataset.from_tensor_slices((segments_noisy, segments_clean))
# dataset_test=tf.data.Dataset.from_tensor_slices((segments_noisy[1:3], segments_clean[1:3]))
# train_dataset = train.cache().shuffle(buffer_size).take(info.splits["train"].num_examples)
return dataset_test return dataset_test
def load_data_formal( path_pianos_test, path_noises, **kwargs) :
def load_data_formal(path_pianos_test, path_noises, **kwargs):
print("Generating test dataset") print("Generating test dataset")
segments_noisy, segments_clean=generate_paired_data_test_formal(path_pianos_test, path_noises, extend=True, **kwargs) segments_noisy, segments_clean = generate_paired_data_test_formal(
path_pianos_test, path_noises, extend=True, **kwargs
)
print("segments::") print("segments::")
print(len(segments_noisy)) print(len(segments_noisy))
dataset_test=tf.data.Dataset.from_tensor_slices((segments_noisy, segments_clean)) dataset_test = tf.data.Dataset.from_tensor_slices((segments_noisy, segments_clean))
#dataset_test=tf.data.Dataset.from_tensor_slices((segments_noisy[1:3], segments_clean[1:3])) # dataset_test=tf.data.Dataset.from_tensor_slices((segments_noisy[1:3], segments_clean[1:3]))
#train_dataset = train.cache().shuffle(buffer_size).take(info.splits["train"].num_examples) # train_dataset = train.cache().shuffle(buffer_size).take(info.splits["train"].num_examples)
return dataset_test return dataset_test
def load_real_test_recordings(buffer_size, path_recordings, **kwargs):
def load_real_test_recordings(buffer_size, path_recordings, **kwargs):
print("Generating real test dataset") print("Generating real test dataset")
segments_noisy=generate_real_recordings_data(path_recordings, **kwargs) segments_noisy = generate_real_recordings_data(path_recordings, **kwargs)
dataset_test=tf.data.Dataset.from_tensor_slices(segments_noisy) dataset_test = tf.data.Dataset.from_tensor_slices(segments_noisy)
#train_dataset = train.cache().shuffle(buffer_size).take(info.splits["train"].num_examples) # train_dataset = train.cache().shuffle(buffer_size).take(info.splits["train"].num_examples)
return dataset_test return dataset_test

View File

@ -4,6 +4,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def run(args): def run(args):
import unet import unet
import tensorflow as tf import tensorflow as tf
@ -12,127 +13,203 @@ def run(args):
from tqdm import tqdm from tqdm import tqdm
import scipy.signal import scipy.signal
path_experiment=str(args.path_experiment) path_experiment = str(args.path_experiment)
unet_model = unet.build_model_denoise(unet_args=args.unet) unet_model = unet.build_model_denoise(unet_args=args.unet)
ckpt=os.path.join(os.path.dirname(os.path.abspath(__file__)),path_experiment, 'checkpoint') ckpt = os.path.join(
os.path.dirname(os.path.abspath(__file__)), path_experiment, "checkpoint"
)
unet_model.load_weights(ckpt) unet_model.load_weights(ckpt)
def do_stft(noisy): def do_stft(noisy):
window_fn = tf.signal.hamming_window window_fn = tf.signal.hamming_window
win_size=args.stft.win_size win_size = args.stft.win_size
hop_size=args.stft.hop_size hop_size = args.stft.hop_size
stft_signal_noisy = tf.signal.stft(
stft_signal_noisy=tf.signal.stft(noisy,frame_length=win_size, window_fn=window_fn, frame_step=hop_size, pad_end=True) noisy,
stft_noisy_stacked=tf.stack( values=[tf.math.real(stft_signal_noisy), tf.math.imag(stft_signal_noisy)], axis=-1) frame_length=win_size,
window_fn=window_fn,
frame_step=hop_size,
pad_end=True,
)
stft_noisy_stacked = tf.stack(
values=[tf.math.real(stft_signal_noisy), tf.math.imag(stft_signal_noisy)],
axis=-1,
)
return stft_noisy_stacked return stft_noisy_stacked
def do_istft(data): def do_istft(data):
window_fn = tf.signal.hamming_window window_fn = tf.signal.hamming_window
win_size=args.stft.win_size win_size = args.stft.win_size
hop_size=args.stft.hop_size hop_size = args.stft.hop_size
inv_window_fn=tf.signal.inverse_stft_window_fn(hop_size, forward_window_fn=window_fn) inv_window_fn = tf.signal.inverse_stft_window_fn(
hop_size, forward_window_fn=window_fn
)
pred_cpx=data[...,0] + 1j * data[...,1] pred_cpx = data[..., 0] + 1j * data[..., 1]
pred_time=tf.signal.inverse_stft(pred_cpx, win_size, hop_size, window_fn=inv_window_fn) pred_time = tf.signal.inverse_stft(
pred_cpx, win_size, hop_size, window_fn=inv_window_fn
)
return pred_time return pred_time
audio=str(args.inference.audio) audio = str(args.inference.audio)
data, samplerate = sf.read(audio) data, samplerate = sf.read(audio)
print(data.dtype) print(data.dtype)
#Stereo to mono # Stereo to mono
if len(data.shape)>1: if len(data.shape) > 1:
data=np.mean(data,axis=1) data = np.mean(data, axis=1)
if samplerate!=44100: if samplerate != 44100:
print("Resampling") print("Resampling")
data=scipy.signal.resample(data, int((44100 / samplerate )*len(data))+1) data = scipy.signal.resample(data, int((44100 / samplerate) * len(data)) + 1)
segment_size = 44100 * 5 # 20s segments
length_data = len(data)
segment_size=44100*5 #20s segments overlapsize = 2048 # samples (46 ms)
window = np.hanning(2 * overlapsize)
length_data=len(data) window_right = window[overlapsize::]
overlapsize=2048 #samples (46 ms) window_left = window[0:overlapsize]
window=np.hanning(2*overlapsize) pointer = 0
window_right=window[overlapsize::] denoised_data = np.zeros(shape=(len(data),))
window_left=window[0:overlapsize] residual_noise = np.zeros(shape=(len(data),))
audio_finished=False numchunks = int(np.ceil(length_data / segment_size))
pointer=0
denoised_data=np.zeros(shape=(len(data),))
residual_noise=np.zeros(shape=(len(data),))
numchunks=int(np.ceil(length_data/segment_size))
for i in tqdm(range(numchunks)): for i in tqdm(range(numchunks)):
if pointer+segment_size<length_data: if pointer + segment_size < length_data:
segment=data[pointer:pointer+segment_size] segment = data[pointer : pointer + segment_size]
#dostft # dostft
segment_TF=do_stft(segment) segment_TF = do_stft(segment)
segment_TF_ds=tf.data.Dataset.from_tensors(segment_TF) segment_TF_ds = tf.data.Dataset.from_tensors(segment_TF)
pred = unet_model.predict(segment_TF_ds.batch(1)) pred = unet_model.predict(segment_TF_ds.batch(1))
pred=pred[0] pred = pred[0]
residual=segment_TF-pred[0] residual = segment_TF - pred[0]
residual=np.array(residual) residual = np.array(residual)
pred_time=do_istft(pred[0]) pred_time = do_istft(pred[0])
residual_time=do_istft(residual) residual_time = do_istft(residual)
residual_time=np.array(residual_time) residual_time = np.array(residual_time)
if pointer==0: if pointer == 0:
pred_time=np.concatenate((pred_time[0:int(segment_size-overlapsize)], np.multiply(pred_time[int(segment_size-overlapsize):segment_size],window_right)), axis=0) pred_time = np.concatenate(
residual_time=np.concatenate((residual_time[0:int(segment_size-overlapsize)], np.multiply(residual_time[int(segment_size-overlapsize):segment_size],window_right)), axis=0) (
pred_time[0 : int(segment_size - overlapsize)],
np.multiply(
pred_time[int(segment_size - overlapsize) : segment_size],
window_right,
),
),
axis=0,
)
residual_time = np.concatenate(
(
residual_time[0 : int(segment_size - overlapsize)],
np.multiply(
residual_time[
int(segment_size - overlapsize) : segment_size
],
window_right,
),
),
axis=0,
)
else: else:
pred_time=np.concatenate((np.multiply(pred_time[0:int(overlapsize)], window_left), pred_time[int(overlapsize):int(segment_size-overlapsize)], np.multiply(pred_time[int(segment_size-overlapsize):int(segment_size)],window_right)), axis=0) pred_time = np.concatenate(
residual_time=np.concatenate((np.multiply(residual_time[0:int(overlapsize)], window_left), residual_time[int(overlapsize):int(segment_size-overlapsize)], np.multiply(residual_time[int(segment_size-overlapsize):int(segment_size)],window_right)), axis=0) (
np.multiply(pred_time[0 : int(overlapsize)], window_left),
pred_time[int(overlapsize) : int(segment_size - overlapsize)],
np.multiply(
pred_time[
int(segment_size - overlapsize) : int(segment_size)
],
window_right,
),
),
axis=0,
)
residual_time = np.concatenate(
(
np.multiply(residual_time[0 : int(overlapsize)], window_left),
residual_time[
int(overlapsize) : int(segment_size - overlapsize)
],
np.multiply(
residual_time[
int(segment_size - overlapsize) : int(segment_size)
],
window_right,
),
),
axis=0,
)
denoised_data[pointer:pointer+segment_size]=denoised_data[pointer:pointer+segment_size]+pred_time denoised_data[pointer : pointer + segment_size] = (
residual_noise[pointer:pointer+segment_size]=residual_noise[pointer:pointer+segment_size]+residual_time denoised_data[pointer : pointer + segment_size] + pred_time
)
residual_noise[pointer : pointer + segment_size] = (
residual_noise[pointer : pointer + segment_size] + residual_time
)
pointer=pointer+segment_size-overlapsize pointer = pointer + segment_size - overlapsize
else: else:
segment=data[pointer::] segment = data[pointer::]
lensegment=len(segment) lensegment = len(segment)
segment=np.concatenate((segment, np.zeros(shape=(int(segment_size-len(segment)),))), axis=0) segment = np.concatenate(
audio_finished=True (segment, np.zeros(shape=(int(segment_size - len(segment)),))), axis=0
#dostft )
segment_TF=do_stft(segment) # dostft
segment_TF = do_stft(segment)
segment_TF_ds=tf.data.Dataset.from_tensors(segment_TF) segment_TF_ds = tf.data.Dataset.from_tensors(segment_TF)
pred = unet_model.predict(segment_TF_ds.batch(1)) pred = unet_model.predict(segment_TF_ds.batch(1))
pred=pred[0] pred = pred[0]
residual=segment_TF-pred[0] residual = segment_TF - pred[0]
residual=np.array(residual) residual = np.array(residual)
pred_time=do_istft(pred[0]) pred_time = do_istft(pred[0])
pred_time=np.array(pred_time) pred_time = np.array(pred_time)
pred_time=pred_time[0:segment_size] pred_time = pred_time[0:segment_size]
residual_time=do_istft(residual) residual_time = do_istft(residual)
residual_time=np.array(residual_time) residual_time = np.array(residual_time)
residual_time=residual_time[0:segment_size] residual_time = residual_time[0:segment_size]
if pointer==0: if pointer == 0:
pred_time=pred_time pred_time = pred_time
residual_time=residual_time residual_time = residual_time
else: else:
pred_time=np.concatenate((np.multiply(pred_time[0:int(overlapsize)], window_left), pred_time[int(overlapsize):int(segment_size)]),axis=0) pred_time = np.concatenate(
residual_time=np.concatenate((np.multiply(residual_time[0:int(overlapsize)], window_left), residual_time[int(overlapsize):int(segment_size)]),axis=0) (
np.multiply(pred_time[0 : int(overlapsize)], window_left),
pred_time[int(overlapsize) : int(segment_size)],
),
axis=0,
)
residual_time = np.concatenate(
(
np.multiply(residual_time[0 : int(overlapsize)], window_left),
residual_time[int(overlapsize) : int(segment_size)],
),
axis=0,
)
denoised_data[pointer::]=denoised_data[pointer::]+pred_time[0:lensegment] denoised_data[pointer::] = (
residual_noise[pointer::]=residual_noise[pointer::]+residual_time[0:lensegment] denoised_data[pointer::] + pred_time[0:lensegment]
)
residual_noise[pointer::] = (
residual_noise[pointer::] + residual_time[0:lensegment]
)
basename=os.path.splitext(audio)[0] basename = os.path.splitext(audio)[0]
wav_noisy_name=basename+"_noisy_input"+".wav" wav_noisy_name = basename + "_noisy_input" + ".wav"
sf.write(wav_noisy_name, data, 44100) sf.write(wav_noisy_name, data, 44100)
wav_output_name=basename+"_denoised"+".wav" wav_output_name = basename + "_denoised" + ".wav"
sf.write(wav_output_name, denoised_data, 44100) sf.write(wav_output_name, denoised_data, 44100)
wav_output_name=basename+"_residual"+".wav" wav_output_name = basename + "_residual" + ".wav"
sf.write(wav_output_name, residual_noise, 44100) sf.write(wav_output_name, residual_noise, 44100)
@ -156,10 +233,3 @@ def main(args):
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -1,5 +1,5 @@
#!/bin/bash #!/bin/bash
python inference.py inference.audio=$1 python inference.py inference.audio="$1"

183
train.py
View File

@ -4,143 +4,179 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def run(args): def run(args):
import unet import unet
import tensorflow as tf import tensorflow as tf
import dataset_loader import dataset_loader
from tensorflow.keras.optimizers import Adam from tensorflow.keras.optimizers import Adam
import soundfile as sf
import datetime import datetime
from tqdm import tqdm from tqdm import tqdm
import numpy as np
path_experiment=str(args.path_experiment) path_experiment = str(args.path_experiment)
if not os.path.exists(path_experiment): if not os.path.exists(path_experiment):
os.makedirs(path_experiment) os.makedirs(path_experiment)
path_music_train=args.dset.path_music_train path_music_train = args.dset.path_music_train
path_music_test=args.dset.path_music_test path_music_validation = args.dset.path_music_validation
path_music_validation=args.dset.path_music_validation
path_noise=args.dset.path_noise path_noise = args.dset.path_noise
path_recordings=args.dset.path_recordings
fs=args.fs fs = args.fs
overlap=args.overlap seg_len_s_train = args.seg_len_s_train
seg_len_s_train=args.seg_len_s_train
batch_size=args.batch_size batch_size = args.batch_size
epochs=args.epochs epochs = args.epochs
num_real_test_segments=args.num_real_test_segments buffer_size = args.buffer_size # for shuffle
buffer_size=args.buffer_size #for shuffle
tensorboard_logs=args.tensorboard_logs tensorboard_logs = args.tensorboard_logs
def do_stft(noisy, clean=None): def do_stft(noisy, clean=None):
window_fn = tf.signal.hamming_window window_fn = tf.signal.hamming_window
win_size=args.stft.win_size win_size = args.stft.win_size
hop_size=args.stft.hop_size hop_size = args.stft.hop_size
stft_signal_noisy = tf.signal.stft(
noisy, frame_length=win_size, window_fn=window_fn, frame_step=hop_size
)
stft_noisy_stacked = tf.stack(
values=[tf.math.real(stft_signal_noisy), tf.math.imag(stft_signal_noisy)],
axis=-1,
)
stft_signal_noisy=tf.signal.stft(noisy,frame_length=win_size, window_fn=window_fn, frame_step=hop_size) if clean is not None:
stft_noisy_stacked=tf.stack( values=[tf.math.real(stft_signal_noisy), tf.math.imag(stft_signal_noisy)], axis=-1) stft_signal_clean = tf.signal.stft(
clean, frame_length=win_size, window_fn=window_fn, frame_step=hop_size
if clean!=None: )
stft_clean_stacked = tf.stack(
stft_signal_clean=tf.signal.stft(clean,frame_length=win_size, window_fn=window_fn, frame_step=hop_size) values=[
stft_clean_stacked=tf.stack( values=[tf.math.real(stft_signal_clean), tf.math.imag(stft_signal_clean)], axis=-1) tf.math.real(stft_signal_clean),
tf.math.imag(stft_signal_clean),
],
axis=-1,
)
return stft_noisy_stacked, stft_clean_stacked return stft_noisy_stacked, stft_clean_stacked
else: else:
return stft_noisy_stacked return stft_noisy_stacked
#Loading data. The train dataset object is a generator. The validation dataset is loaded in memory. # Loading data. The train dataset object is a generator. The validation dataset is loaded in memory.
dataset_train, dataset_val=dataset_loader.load_data(buffer_size, path_music_train, path_music_validation, path_noise, fs=fs, seg_len_s=seg_len_s_train) dataset_train, dataset_val = dataset_loader.load_data(
buffer_size,
path_music_train,
path_music_validation,
path_noise,
fs=fs,
seg_len_s=seg_len_s_train,
)
dataset_train=dataset_train.map(do_stft, num_parallel_calls=args.num_workers, deterministic=None) dataset_train = dataset_train.map(
dataset_val=dataset_val.map(do_stft, num_parallel_calls=args.num_workers, deterministic=None) do_stft, num_parallel_calls=args.num_workers, deterministic=None
)
dataset_val = dataset_val.map(
do_stft, num_parallel_calls=args.num_workers, deterministic=None
)
strategy = tf.distribute.MirroredStrategy() strategy = tf.distribute.MirroredStrategy()
print('Number of devices: {}'.format(strategy.num_replicas_in_sync)) print("Number of devices: {}".format(strategy.num_replicas_in_sync))
with strategy.scope(): with strategy.scope():
#build the model # build the model
unet_model = unet.build_model_denoise(unet_args=args.unet) unet_model = unet.build_model_denoise(unet_args=args.unet)
current_lr=args.lr current_lr = args.lr
optimizer = Adam(learning_rate=current_lr, beta_1=args.beta1, beta_2=args.beta2) optimizer = Adam(learning_rate=current_lr, beta_1=args.beta1, beta_2=args.beta2)
loss=tf.keras.losses.MeanAbsoluteError() loss = tf.keras.losses.MeanAbsoluteError()
if args.use_tensorboard: if args.use_tensorboard:
log_dir = os.path.join(tensorboard_logs, os.path.basename(path_experiment)+"_"+datetime.datetime.now().strftime("%Y%m%d-%H%M%S")) log_dir = os.path.join(
train_summary_writer = tf.summary.create_file_writer(log_dir+"/train") tensorboard_logs,
val_summary_writer = tf.summary.create_file_writer(log_dir+"/validation") os.path.basename(path_experiment)
+ "_"
+ datetime.datetime.now().strftime("%Y%m%d-%H%M%S"),
)
train_summary_writer = tf.summary.create_file_writer(log_dir + "/train")
val_summary_writer = tf.summary.create_file_writer(log_dir + "/validation")
#path where the checkpoints will be saved # path where the checkpoints will be saved
checkpoint_filepath=os.path.join(path_experiment, 'checkpoint') checkpoint_filepath = os.path.join(path_experiment, "checkpoint")
dataset_train=dataset_train.batch(batch_size) dataset_train = dataset_train.batch(batch_size)
dataset_val=dataset_val.batch(batch_size) dataset_val = dataset_val.batch(batch_size)
#prefetching the dataset for better performance # prefetching the dataset for better performance
dataset_train=dataset_train.prefetch(batch_size*20) dataset_train = dataset_train.prefetch(batch_size * 20)
dataset_val=dataset_val.prefetch(batch_size*20) dataset_val = dataset_val.prefetch(batch_size * 20)
dataset_train=strategy.experimental_distribute_dataset(dataset_train) dataset_train = strategy.experimental_distribute_dataset(dataset_train)
dataset_val=strategy.experimental_distribute_dataset(dataset_val) dataset_val = strategy.experimental_distribute_dataset(dataset_val)
iterator = iter(dataset_train) iterator = iter(dataset_train)
from trainer import Trainer from trainer import Trainer
trainer=Trainer(unet_model,optimizer,loss,strategy, path_experiment, args) trainer = Trainer(unet_model, optimizer, loss, strategy, path_experiment, args)
for epoch in range(epochs): for epoch in range(epochs):
total_loss=0 total_loss = 0
step_loss=0 step_loss = 0
for step in tqdm(range(args.steps_per_epoch), desc="Training epoch "+str(epoch)): for step in tqdm(
step_loss=trainer.distributed_training_step(iterator.get_next()) range(args.steps_per_epoch), desc="Training epoch " + str(epoch)
total_loss+=step_loss ):
step_loss = trainer.distributed_training_step(iterator.get_next())
total_loss += step_loss
with train_summary_writer.as_default(): with train_summary_writer.as_default():
tf.summary.scalar('batch_loss', step_loss, step=step) tf.summary.scalar("batch_loss", step_loss, step=step)
tf.summary.scalar('batch_mean_absolute_error', trainer.train_mae.result(), step=step) tf.summary.scalar(
"batch_mean_absolute_error", trainer.train_mae.result(), step=step
)
train_loss=total_loss/args.steps_per_epoch train_loss = total_loss / args.steps_per_epoch
for x in tqdm(dataset_val, desc="Validating epoch "+str(epoch)): for x in tqdm(dataset_val, desc="Validating epoch " + str(epoch)):
trainer.distributed_test_step(x) trainer.distributed_test_step(x)
template = ("Epoch {}, Loss: {}, train_MAE: {}, val_Loss: {}, val_MAE: {}") template = "Epoch {}, Loss: {}, train_MAE: {}, val_Loss: {}, val_MAE: {}"
print (template.format(epoch+1, train_loss, trainer.train_mae.result(), trainer.val_loss.result(), trainer.val_mae.result())) print(
template.format(
epoch + 1,
train_loss,
trainer.train_mae.result(),
trainer.val_loss.result(),
trainer.val_mae.result(),
)
)
with train_summary_writer.as_default(): with train_summary_writer.as_default():
tf.summary.scalar('epoch_loss', train_loss, step=epoch) tf.summary.scalar("epoch_loss", train_loss, step=epoch)
tf.summary.scalar('epoch_mean_absolute_error', trainer.train_mae.result(), step=epoch) tf.summary.scalar(
"epoch_mean_absolute_error", trainer.train_mae.result(), step=epoch
)
with val_summary_writer.as_default(): with val_summary_writer.as_default():
tf.summary.scalar('epoch_loss', trainer.val_loss.result(), step=epoch) tf.summary.scalar("epoch_loss", trainer.val_loss.result(), step=epoch)
tf.summary.scalar('epoch_mean_absolute_error', trainer.val_mae.result(), step=epoch) tf.summary.scalar(
"epoch_mean_absolute_error", trainer.val_mae.result(), step=epoch
)
trainer.train_mae.reset_states() trainer.train_mae.reset_states()
trainer.val_loss.reset_states() trainer.val_loss.reset_states()
trainer.val_mae.reset_states() trainer.val_mae.reset_states()
if (epoch+1) % 50 == 0: if (epoch + 1) % 50 == 0:
if args.variable_lr: if args.variable_lr:
current_lr*=1e-1 current_lr *= 1e-1
trainer.optimizer.lr=current_lr trainer.optimizer.lr = current_lr
try: try:
unet_model.save_weights(checkpoint_filpath) unet_model.save_weights(checkpoint_filepath)
except: except Exception:
pass pass
def _main(args): def _main(args):
global __file__ global __file__
@ -161,10 +197,3 @@ def main(args):
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -1,39 +1,37 @@
import os
import numpy as np
import tensorflow as tf import tensorflow as tf
import soundfile as sf
from tqdm import tqdm
import pandas as pd
class Trainer():
def __init__(self, model, optimizer,loss, strategy, path_experiment, args): class Trainer:
self.model=model def __init__(self, model, optimizer, loss, strategy, path_experiment, args):
self.model = model
print(self.model.summary()) print(self.model.summary())
self.strategy=strategy self.strategy = strategy
self.optimizer=optimizer self.optimizer = optimizer
self.path_experiment=path_experiment self.path_experiment = path_experiment
self.args=args self.args = args
#self.metrics=[] # self.metrics=[]
with self.strategy.scope(): with self.strategy.scope():
#loss_fn=tf.keras.losses.mean_absolute_error # loss_fn=tf.keras.losses.mean_absolute_error
loss.reduction=tf.keras.losses.Reduction.NONE loss.reduction = tf.keras.losses.Reduction.NONE
self.loss_object=loss self.loss_object = loss
self.train_mae_s1=tf.keras.metrics.MeanAbsoluteError(name="train_mae_s1") self.train_mae_s1 = tf.keras.metrics.MeanAbsoluteError(name="train_mae_s1")
self.train_mae=tf.keras.metrics.MeanAbsoluteError(name="train_mae_s2") self.train_mae = tf.keras.metrics.MeanAbsoluteError(name="train_mae_s2")
self.val_mae=tf.keras.metrics.MeanAbsoluteError(name="validation_mae") self.val_mae = tf.keras.metrics.MeanAbsoluteError(name="validation_mae")
self.val_loss = tf.keras.metrics.Mean(name='test_loss') self.val_loss = tf.keras.metrics.Mean(name="test_loss")
def train_step(self, inputs):
def train_step(self,inputs): noisy, clean = inputs
noisy, clean= inputs
with tf.GradientTape() as tape: with tf.GradientTape() as tape:
logits_2, logits_1 = self.model(
noisy, training=True
) # Logits for this minibatch
logits_2,logits_1 = self.model(noisy, training=True) # Logits for this minibatch loss_value = tf.reduce_mean(
self.loss_object(clean, logits_2)
loss_value = tf.reduce_mean(self.loss_object(clean, logits_2) + tf.reduce_mean(self.loss_object(clean, logits_1))) + tf.reduce_mean(self.loss_object(clean, logits_1))
)
grads = tape.gradient(loss_value, self.model.trainable_weights) grads = tape.gradient(loss_value, self.model.trainable_weights)
self.optimizer.apply_gradients(zip(grads, self.model.trainable_weights)) self.optimizer.apply_gradients(zip(grads, self.model.trainable_weights))
@ -41,26 +39,25 @@ class Trainer():
self.train_mae_s1.update_state(clean, logits_1) self.train_mae_s1.update_state(clean, logits_1)
return loss_value return loss_value
def test_step(self,inputs): def test_step(self, inputs):
noisy, clean = inputs
noisy,clean = inputs
predictions_s2, predictions_s1 = self.model(noisy, training=False) predictions_s2, predictions_s1 = self.model(noisy, training=False)
t_loss = self.loss_object(clean, predictions_s2)+self.loss_object(clean, predictions_s1) t_loss = self.loss_object(clean, predictions_s2) + self.loss_object(
clean, predictions_s1
)
self.val_mae.update_state(clean,predictions_s2) self.val_mae.update_state(clean, predictions_s2)
self.val_loss.update_state(t_loss) self.val_loss.update_state(t_loss)
@tf.function() @tf.function()
def distributed_training_step(self,inputs): def distributed_training_step(self, inputs):
per_replica_losses=self.strategy.run(self.train_step, args=(inputs,)) per_replica_losses = self.strategy.run(self.train_step, args=(inputs,))
reduced_losses=self.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None) reduced_losses = self.strategy.reduce(
tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None
)
return reduced_losses return reduced_losses
@tf.function @tf.function
def distributed_test_step(self,inputs): def distributed_test_step(self, inputs):
return self.strategy.run(self.test_step, args=(inputs,)) return self.strategy.run(self.test_step, args=(inputs,))

646
unet.py
View File

@ -1,84 +1,93 @@
import tensorflow as tf import tensorflow as tf
from tensorflow.keras import Model, Input from tensorflow.keras import Input
from tensorflow.keras import layers from tensorflow.keras import layers
from tensorflow.keras.initializers import TruncatedNormal from tensorflow.keras.initializers import TruncatedNormal
import math as m import math as m
def build_model_denoise(unet_args=None): def build_model_denoise(unet_args=None):
inputs = Input(shape=(None, None, 2))
inputs=Input(shape=(None, None,2)) outputs_stage_2, outputs_stage_1 = MultiStage_denoise(unet_args=unet_args)(inputs)
outputs_stage_2,outputs_stage_1=MultiStage_denoise(unet_args=unet_args)(inputs) # Encapsulating MultiStage_denoise in a keras.Model object
model = tf.keras.Model(inputs=inputs, outputs=[outputs_stage_2, outputs_stage_1])
#Encapsulating MultiStage_denoise in a keras.Model object
model= tf.keras.Model(inputs=inputs,outputs=[outputs_stage_2, outputs_stage_1])
return model return model
class DenseBlock(layers.Layer): class DenseBlock(layers.Layer):
''' """
[B, T, F, N] => [B, T, F, N] [B, T, F, N] => [B, T, F, N]
DenseNet Block consisting of "num_layers" densely connected convolutional layers DenseNet Block consisting of "num_layers" densely connected convolutional layers
''' """
def __init__(self, num_layers, N, ksize,activation):
''' def __init__(self, num_layers, N, ksize, activation):
"""
num_layers: number of densely connected conv. layers num_layers: number of densely connected conv. layers
N: Number of filters (same in each layer) N: Number of filters (same in each layer)
ksize: Kernel size (same in each layer) ksize: Kernel size (same in each layer)
''' """
super(DenseBlock, self).__init__() super(DenseBlock, self).__init__()
self.activation=activation self.activation = activation
self.paddings_1=get_paddings(ksize) self.paddings_1 = get_paddings(ksize)
self.H=[] self.H = []
self.num_layers=num_layers self.num_layers = num_layers
for i in range(num_layers): for i in range(num_layers):
self.H.append(layers.Conv2D(filters=N, self.H.append(
kernel_size=ksize, layers.Conv2D(
kernel_initializer=TruncatedNormal(), filters=N,
strides=1, kernel_size=ksize,
padding='VALID', kernel_initializer=TruncatedNormal(),
activation=self.activation)) strides=1,
padding="VALID",
activation=self.activation,
)
)
def call(self, x): def call(self, x):
x_ = tf.pad(x, self.paddings_1, mode="SYMMETRIC")
x_=tf.pad(x, self.paddings_1, mode='SYMMETRIC')
x_ = self.H[0](x_) x_ = self.H[0](x_)
if self.num_layers>1: if self.num_layers > 1:
for h in self.H[1:]: for h in self.H[1:]:
x = tf.concat([x_, x], axis=-1) x = tf.concat([x_, x], axis=-1)
x_=tf.pad(x, self.paddings_1, mode='SYMMETRIC') x_ = tf.pad(x, self.paddings_1, mode="SYMMETRIC")
x_ = h(x_) x_ = h(x_)
return x_ return x_
class FinalBlock(layers.Layer): class FinalBlock(layers.Layer):
''' """
[B, T, F, N] => [B, T, F, 2] [B, T, F, N] => [B, T, F, 2]
Final block. Basically, a 3x3 conv. layer to map the output features to the output complex spectrogram. Final block. Basically, a 3x3 conv. layer to map the output features to the output complex spectrogram.
''' """
def __init__(self): def __init__(self):
super(FinalBlock, self).__init__() super(FinalBlock, self).__init__()
ksize=(3,3) ksize = (3, 3)
self.paddings_2=get_paddings(ksize) self.paddings_2 = get_paddings(ksize)
self.conv2=layers.Conv2D(filters=2, self.conv2 = layers.Conv2D(
kernel_size=ksize, filters=2,
kernel_initializer=TruncatedNormal(), kernel_size=ksize,
strides=1, kernel_initializer=TruncatedNormal(),
padding='VALID', strides=1,
activation=None) padding="VALID",
activation=None,
)
def call(self, inputs):
def call(self, inputs ): x = tf.pad(inputs, self.paddings_2, mode="SYMMETRIC")
pred = self.conv2(x)
x=tf.pad(inputs, self.paddings_2, mode='SYMMETRIC')
pred=self.conv2(x)
return pred return pred
class SAM(layers.Layer): class SAM(layers.Layer):
''' """
[B, T, F, N] => [B, T, F, N] , [B, T, F, N] [B, T, F, N] => [B, T, F, N] , [B, T, F, N]
Supervised Attention Module: Supervised Attention Module:
The purpose of SAM is to make the network only propagate the most relevant features to the second stage, discarding the less useful ones. The purpose of SAM is to make the network only propagate the most relevant features to the second stage, discarding the less useful ones.
@ -86,390 +95,424 @@ class SAM(layers.Layer):
The first stage output is then calculated adding the original input spectrogram to the residual noise. The first stage output is then calculated adding the original input spectrogram to the residual noise.
The attention-guided features are computed using the attention masks M, which are directly calculated from the first stage output with a 1x1 convolution and a sigmoid function. The attention-guided features are computed using the attention masks M, which are directly calculated from the first stage output with a 1x1 convolution and a sigmoid function.
''' """
def __init__(self, n_feat): def __init__(self, n_feat):
super(SAM, self).__init__() super(SAM, self).__init__()
ksize=(3,3) ksize = (3, 3)
self.paddings_1=get_paddings(ksize) self.paddings_1 = get_paddings(ksize)
self.conv1 = layers.Conv2D(filters=n_feat, self.conv1 = layers.Conv2D(
kernel_size=ksize, filters=n_feat,
kernel_initializer=TruncatedNormal(), kernel_size=ksize,
strides=1, kernel_initializer=TruncatedNormal(),
padding='VALID', strides=1,
activation=None) padding="VALID",
ksize=(3,3) activation=None,
self.paddings_2=get_paddings(ksize) )
self.conv2=layers.Conv2D(filters=2, ksize = (3, 3)
kernel_size=ksize, self.paddings_2 = get_paddings(ksize)
kernel_initializer=TruncatedNormal(), self.conv2 = layers.Conv2D(
strides=1, filters=2,
padding='VALID', kernel_size=ksize,
activation=None) kernel_initializer=TruncatedNormal(),
strides=1,
padding="VALID",
activation=None,
)
ksize=(3,3) ksize = (3, 3)
self.paddings_3=get_paddings(ksize) self.paddings_3 = get_paddings(ksize)
self.conv3 = layers.Conv2D(filters=n_feat, self.conv3 = layers.Conv2D(
kernel_size=ksize, filters=n_feat,
kernel_initializer=TruncatedNormal(), kernel_size=ksize,
strides=1, kernel_initializer=TruncatedNormal(),
padding='VALID', strides=1,
activation=None) padding="VALID",
self.cropadd=CropAddBlock() activation=None,
)
self.cropadd = CropAddBlock()
def call(self, inputs, input_spectrogram): def call(self, inputs, input_spectrogram):
x1=tf.pad(inputs, self.paddings_1, mode='SYMMETRIC') x1 = tf.pad(inputs, self.paddings_1, mode="SYMMETRIC")
x1 = self.conv1(x1) x1 = self.conv1(x1)
x=tf.pad(inputs, self.paddings_2, mode='SYMMETRIC') x = tf.pad(inputs, self.paddings_2, mode="SYMMETRIC")
x=self.conv2(x) x = self.conv2(x)
#residual prediction # residual prediction
pred = layers.Add()([x, input_spectrogram]) #features to next stage pred = layers.Add()([x, input_spectrogram]) # features to next stage
x3=tf.pad(pred, self.paddings_3, mode='SYMMETRIC') x3 = tf.pad(pred, self.paddings_3, mode="SYMMETRIC")
M=self.conv3(x3) M = self.conv3(x3)
M= tf.keras.activations.sigmoid(M) M = tf.keras.activations.sigmoid(M)
x1=layers.Multiply()([x1, M]) x1 = layers.Multiply()([x1, M])
x1 = layers.Add()([x1, inputs]) #features to next stage x1 = layers.Add()([x1, inputs]) # features to next stage
return x1, pred return x1, pred
class AddFreqEncoding(layers.Layer): class AddFreqEncoding(layers.Layer):
''' """
[B, T, F, 2] => [B, T, F, 12] [B, T, F, 2] => [B, T, F, 12]
Generates frequency positional embeddings and concatenates them as 10 extra channels Generates frequency positional embeddings and concatenates them as 10 extra channels
This function is optimized for F=1025 This function is optimized for F=1025
''' """
def __init__(self, f_dim): def __init__(self, f_dim):
super(AddFreqEncoding, self).__init__() super(AddFreqEncoding, self).__init__()
pi = tf.constant(m.pi) pi = tf.constant(m.pi)
pi=tf.cast(pi,'float32') pi = tf.cast(pi, "float32")
self.f_dim=f_dim #f_dim is fixed self.f_dim = f_dim # f_dim is fixed
n=tf.cast(tf.range(f_dim)/(f_dim-1),'float32') n = tf.cast(tf.range(f_dim) / (f_dim - 1), "float32")
coss=tf.math.cos(pi*n) coss = tf.math.cos(pi * n)
f_channel = tf.expand_dims(coss, -1) #(1025,1) f_channel = tf.expand_dims(coss, -1) # (1025,1)
self.fembeddings= f_channel self.fembeddings = f_channel
for k in range(1,10):
coss=tf.math.cos(2**k*pi*n)
f_channel = tf.expand_dims(coss, -1) #(1025,1)
self.fembeddings=tf.concat([self.fembeddings,f_channel],axis=-1) #(1025,10)
for k in range(1, 10):
coss = tf.math.cos(2**k * pi * n)
f_channel = tf.expand_dims(coss, -1) # (1025,1)
self.fembeddings = tf.concat(
[self.fembeddings, f_channel], axis=-1
) # (1025,10)
def call(self, input_tensor): def call(self, input_tensor):
batch_size_tensor = tf.shape(input_tensor)[0] # get batch size batch_size_tensor = tf.shape(input_tensor)[0] # get batch size
time_dim = tf.shape(input_tensor)[1] # get time dimension time_dim = tf.shape(input_tensor)[1] # get time dimension
fembeddings_2 = tf.broadcast_to(self.fembeddings, [batch_size_tensor, time_dim, self.f_dim, 10]) fembeddings_2 = tf.broadcast_to(
self.fembeddings, [batch_size_tensor, time_dim, self.f_dim, 10]
)
return tf.concat([input_tensor, fembeddings_2], axis=-1) # (batch,427,1025,12)
return tf.concat([input_tensor,fembeddings_2],axis=-1) #(batch,427,1025,12)
def get_paddings(K): def get_paddings(K):
return tf.constant([[0,0],[K[0]//2, K[0]//2 -(1- K[0]%2) ], [ K[1]//2, K[1]//2 -(1- K[1]%2) ],[0,0]]) return tf.constant(
[
[0, 0],
[K[0] // 2, K[0] // 2 - (1 - K[0] % 2)],
[K[1] // 2, K[1] // 2 - (1 - K[1] % 2)],
[0, 0],
]
)
class Decoder(layers.Layer): class Decoder(layers.Layer):
''' """
[B, T, F, N] , skip connections => [B, T, F, N] [B, T, F, N] , skip connections => [B, T, F, N]
Decoder side of the U-Net subnetwork. Decoder side of the U-Net subnetwork.
''' """
def __init__(self, Ns, Ss, unet_args): def __init__(self, Ns, Ss, unet_args):
super(Decoder, self).__init__() super(Decoder, self).__init__()
self.Ns=Ns self.Ns = Ns
self.Ss=Ss self.Ss = Ss
self.activation=unet_args.activation self.activation = unet_args.activation
self.depth=unet_args.depth self.depth = unet_args.depth
ksize = (3, 3)
self.paddings_3 = get_paddings(ksize)
self.conv2d_3 = layers.Conv2D(
filters=self.Ns[self.depth],
kernel_size=ksize,
kernel_initializer=TruncatedNormal(),
strides=1,
padding="VALID",
activation=self.activation,
)
ksize=(3,3) self.cropadd = CropAddBlock()
self.paddings_3=get_paddings(ksize)
self.conv2d_3=layers.Conv2D(filters=self.Ns[self.depth],
kernel_size=ksize,
kernel_initializer=TruncatedNormal(),
strides=1,
padding='VALID',
activation=self.activation)
self.cropadd=CropAddBlock() self.dblocks = []
self.dblocks=[]
for i in range(self.depth): for i in range(self.depth):
self.dblocks.append(D_Block(layer_idx=i,N=self.Ns[i], S=self.Ss[i], activation=self.activation,num_tfc=unet_args.num_tfc)) self.dblocks.append(
D_Block(
layer_idx=i,
N=self.Ns[i],
S=self.Ss[i],
activation=self.activation,
num_tfc=unet_args.num_tfc,
)
)
def call(self,inputs, contracting_layers): def call(self, inputs, contracting_layers):
x=inputs x = inputs
for i in range(self.depth,0,-1): for i in range(self.depth, 0, -1):
x=self.dblocks[i-1](x, contracting_layers[i-1]) x = self.dblocks[i - 1](x, contracting_layers[i - 1])
return x return x
class Encoder(tf.keras.Model):
''' class Encoder(tf.keras.Model):
"""
[B, T, F, N] => skip connections , [B, T, F, N_4] [B, T, F, N] => skip connections , [B, T, F, N_4]
Encoder side of the U-Net subnetwork. Encoder side of the U-Net subnetwork.
''' """
def __init__(self, Ns, Ss, unet_args): def __init__(self, Ns, Ss, unet_args):
super(Encoder, self).__init__() super(Encoder, self).__init__()
self.Ns=Ns self.Ns = Ns
self.Ss=Ss self.Ss = Ss
self.activation=unet_args.activation self.activation = unet_args.activation
self.depth=unet_args.depth self.depth = unet_args.depth
self.contracting_layers = {} self.contracting_layers = {}
self.eblocks=[] self.eblocks = []
for i in range(self.depth): for i in range(self.depth):
self.eblocks.append(E_Block(layer_idx=i,N0=self.Ns[i],N=self.Ns[i+1],S=self.Ss[i], activation=self.activation , num_tfc=unet_args.num_tfc)) self.eblocks.append(
E_Block(
layer_idx=i,
N0=self.Ns[i],
N=self.Ns[i + 1],
S=self.Ss[i],
activation=self.activation,
num_tfc=unet_args.num_tfc,
)
)
self.i_block=I_Block(self.Ns[self.depth],self.activation,unet_args.num_tfc) self.i_block = I_Block(self.Ns[self.depth], self.activation, unet_args.num_tfc)
def call(self, inputs): def call(self, inputs):
x=inputs x = inputs
for i in range(self.depth): for i in range(self.depth):
x, x_contract = self.eblocks[i](x)
x, x_contract=self.eblocks[i](x) self.contracting_layers[i] = x_contract # if remove 0, correct this
x = self.i_block(x)
self.contracting_layers[i] = x_contract #if remove 0, correct this
x=self.i_block(x)
return x, self.contracting_layers return x, self.contracting_layers
class MultiStage_denoise(tf.keras.Model):
def __init__(self, unet_args=None): class MultiStage_denoise(tf.keras.Model):
def __init__(self, unet_args=None):
super(MultiStage_denoise, self).__init__() super(MultiStage_denoise, self).__init__()
self.activation=unet_args.activation self.activation = unet_args.activation
self.depth=unet_args.depth self.depth = unet_args.depth
if unet_args.use_fencoding: if unet_args.use_fencoding:
self.freq_encoding=AddFreqEncoding(unet_args.f_dim) self.freq_encoding = AddFreqEncoding(unet_args.f_dim)
self.use_sam=unet_args.use_SAM self.use_sam = unet_args.use_SAM
self.use_fencoding=unet_args.use_fencoding self.use_fencoding = unet_args.use_fencoding
self.num_stages=unet_args.num_stages self.num_stages = unet_args.num_stages
#Encoder # Encoder
self.Ns= [32,64,64,128,128,256,512] self.Ns = [32, 64, 64, 128, 128, 256, 512]
self.Ss= [(2,2),(2,2),(2,2),(2,2),(2,2),(2,2)] self.Ss = [(2, 2), (2, 2), (2, 2), (2, 2), (2, 2), (2, 2)]
#initial feature extractor # initial feature extractor
ksize=(7,7) ksize = (7, 7)
self.paddings_1=get_paddings(ksize) self.paddings_1 = get_paddings(ksize)
self.conv2d_1 = layers.Conv2D(filters=self.Ns[0], self.conv2d_1 = layers.Conv2D(
kernel_size=ksize, filters=self.Ns[0],
kernel_initializer=TruncatedNormal(), kernel_size=ksize,
strides=1, kernel_initializer=TruncatedNormal(),
padding='VALID', strides=1,
activation=self.activation) padding="VALID",
activation=self.activation,
)
self.encoder_s1 = Encoder(self.Ns, self.Ss, unet_args)
self.encoder_s1=Encoder(self.Ns, self.Ss, unet_args) self.decoder_s1 = Decoder(self.Ns, self.Ss, unet_args)
self.decoder_s1=Decoder(self.Ns, self.Ss, unet_args)
self.cropconcat = CropConcatBlock() self.cropconcat = CropConcatBlock()
self.cropadd = CropAddBlock() self.cropadd = CropAddBlock()
self.finalblock=FinalBlock() self.finalblock = FinalBlock()
if self.num_stages>1: if self.num_stages > 1:
self.sam_1=SAM(self.Ns[0]) self.sam_1 = SAM(self.Ns[0])
#initial feature extractor # initial feature extractor
ksize=(7,7) ksize = (7, 7)
self.paddings_2=get_paddings(ksize) self.paddings_2 = get_paddings(ksize)
self.conv2d_2 = layers.Conv2D(filters=self.Ns[0], self.conv2d_2 = layers.Conv2D(
kernel_size=ksize, filters=self.Ns[0],
kernel_initializer=TruncatedNormal(), kernel_size=ksize,
strides=1, kernel_initializer=TruncatedNormal(),
padding='VALID', strides=1,
activation=self.activation) padding="VALID",
activation=self.activation,
)
self.encoder_s2 = Encoder(self.Ns, self.Ss, unet_args)
self.encoder_s2=Encoder(self.Ns, self.Ss, unet_args) self.decoder_s2 = Decoder(self.Ns, self.Ss, unet_args)
self.decoder_s2=Decoder(self.Ns, self.Ss, unet_args)
@tf.function() @tf.function()
def call(self, inputs): def call(self, inputs):
if self.use_fencoding: if self.use_fencoding:
x_w_freq=self.freq_encoding(inputs) #None, None, 1025, 12 x_w_freq = self.freq_encoding(inputs) # None, None, 1025, 12
else: else:
x_w_freq=inputs x_w_freq = inputs
#intitial feature extractor # intitial feature extractor
x=tf.pad(x_w_freq, self.paddings_1, mode='SYMMETRIC') x = tf.pad(x_w_freq, self.paddings_1, mode="SYMMETRIC")
x=self.conv2d_1(x) #None, None, 1025, 32 x = self.conv2d_1(x) # None, None, 1025, 32
x, contracting_layers_s1= self.encoder_s1(x) x, contracting_layers_s1 = self.encoder_s1(x)
#decoder # decoder
feats_s1 =self.decoder_s1(x, contracting_layers_s1) #None, None, 1025, 32 features feats_s1 = self.decoder_s1(
x, contracting_layers_s1
) # None, None, 1025, 32 features
if self.num_stages>1: if self.num_stages > 1:
#SAM module # SAM module
Fout, pred_stage_1=self.sam_1(feats_s1,inputs) Fout, pred_stage_1 = self.sam_1(feats_s1, inputs)
#intitial feature extractor # intitial feature extractor
x=tf.pad(x_w_freq, self.paddings_2, mode='SYMMETRIC') x = tf.pad(x_w_freq, self.paddings_2, mode="SYMMETRIC")
x=self.conv2d_2(x) x = self.conv2d_2(x)
if self.use_sam: if self.use_sam:
x = tf.concat([x, Fout], axis=-1) x = tf.concat([x, Fout], axis=-1)
else: else:
x = tf.concat([x,feats_s1], axis=-1) x = tf.concat([x, feats_s1], axis=-1)
x, contracting_layers_s2= self.encoder_s2(x) x, contracting_layers_s2 = self.encoder_s2(x)
feats_s2=self.decoder_s2(x, contracting_layers_s2) #None, None, 1025, 32 features feats_s2 = self.decoder_s2(
x, contracting_layers_s2
) # None, None, 1025, 32 features
#consider implementing a third stage? # consider implementing a third stage?
pred_stage_2=self.finalblock(feats_s2) pred_stage_2 = self.finalblock(feats_s2)
return pred_stage_2, pred_stage_1 return pred_stage_2, pred_stage_1
else: else:
pred_stage_1=self.finalblock(feats_s1) pred_stage_1 = self.finalblock(feats_s1)
return pred_stage_1 return pred_stage_1
class I_Block(layers.Layer): class I_Block(layers.Layer):
''' """
[B, T, F, N] => [B, T, F, N] [B, T, F, N] => [B, T, F, N]
Intermediate block: Intermediate block:
Basically, a densenet block with a residual connection Basically, a densenet block with a residual connection
''' """
def __init__(self,N,activation, num_tfc, **kwargs):
def __init__(self, N, activation, num_tfc, **kwargs):
super(I_Block, self).__init__(**kwargs) super(I_Block, self).__init__(**kwargs)
ksize=(3,3) ksize = (3, 3)
self.tfc=DenseBlock(num_tfc,N,ksize, activation) self.tfc = DenseBlock(num_tfc, N, ksize, activation)
self.conv2d_res= layers.Conv2D(filters=N, self.conv2d_res = layers.Conv2D(
kernel_size=(1,1), filters=N,
kernel_initializer=TruncatedNormal(), kernel_size=(1, 1),
strides=1, kernel_initializer=TruncatedNormal(),
padding='VALID') strides=1,
padding="VALID",
)
def call(self,inputs): def call(self, inputs):
x=self.tfc(inputs) x = self.tfc(inputs)
inputs_proj=self.conv2d_res(inputs) inputs_proj = self.conv2d_res(inputs)
return layers.Add()([x,inputs_proj]) return layers.Add()([x, inputs_proj])
class E_Block(layers.Layer): class E_Block(layers.Layer):
def __init__(self, layer_idx, N0, N, S, activation, num_tfc, **kwargs):
def __init__(self, layer_idx,N0, N, S,activation, num_tfc, **kwargs):
super(E_Block, self).__init__(**kwargs) super(E_Block, self).__init__(**kwargs)
self.layer_idx=layer_idx self.layer_idx = layer_idx
self.N0=N0 self.N0 = N0
self.N=N self.N = N
self.S=S self.S = S
self.activation=activation self.activation = activation
self.i_block=I_Block(N0,activation,num_tfc) self.i_block = I_Block(N0, activation, num_tfc)
ksize=(S[0]+2,S[1]+2)
self.paddings_2=get_paddings(ksize)
self.conv2d_2 = layers.Conv2D(filters=N,
kernel_size=(S[0]+2,S[1]+2),
kernel_initializer=TruncatedNormal(),
strides=S,
padding='VALID',
activation=self.activation)
ksize = (S[0] + 2, S[1] + 2)
self.paddings_2 = get_paddings(ksize)
self.conv2d_2 = layers.Conv2D(
filters=N,
kernel_size=(S[0] + 2, S[1] + 2),
kernel_initializer=TruncatedNormal(),
strides=S,
padding="VALID",
activation=self.activation,
)
def call(self, inputs, training=None, **kwargs): def call(self, inputs, training=None, **kwargs):
x=self.i_block(inputs) x = self.i_block(inputs)
x_down=tf.pad(x, self.paddings_2, mode='SYMMETRIC') x_down = tf.pad(x, self.paddings_2, mode="SYMMETRIC")
x_down = self.conv2d_2(x_down) x_down = self.conv2d_2(x_down)
return x_down, x return x_down, x
def get_config(self): def get_config(self):
return dict(layer_idx=self.layer_idx, return dict(
N=self.N, layer_idx=self.layer_idx,
S=self.S, N=self.N,
**super(E_Block, self).get_config() S=self.S,
) **super(E_Block, self).get_config(),
)
class D_Block(layers.Layer): class D_Block(layers.Layer):
def __init__(self, layer_idx, N, S, activation, num_tfc, **kwargs):
def __init__(self, layer_idx, N, S,activation, num_tfc, **kwargs):
super(D_Block, self).__init__(**kwargs) super(D_Block, self).__init__(**kwargs)
self.layer_idx=layer_idx self.layer_idx = layer_idx
self.N=N self.N = N
self.S=S self.S = S
self.activation=activation self.activation = activation
ksize=(S[0]+2, S[1]+2) ksize = (S[0] + 2, S[1] + 2)
self.paddings_1=get_paddings(ksize) self.paddings_1 = get_paddings(ksize)
self.tconv_1= layers.Conv2DTranspose(filters=N, self.tconv_1 = layers.Conv2DTranspose(
kernel_size=(S[0]+2, S[1]+2), filters=N,
kernel_initializer=TruncatedNormal(), kernel_size=(S[0] + 2, S[1] + 2),
strides=S, kernel_initializer=TruncatedNormal(),
activation=self.activation, strides=S,
padding='VALID') activation=self.activation,
padding="VALID",
)
self.upsampling = layers.UpSampling2D(size=S, interpolation='nearest') self.upsampling = layers.UpSampling2D(size=S, interpolation="nearest")
self.projection = layers.Conv2D(filters=N, self.projection = layers.Conv2D(
kernel_size=(1,1), filters=N,
kernel_initializer=TruncatedNormal(), kernel_size=(1, 1),
strides=1, kernel_initializer=TruncatedNormal(),
activation=self.activation, strides=1,
padding='VALID') activation=self.activation,
self.cropadd=CropAddBlock() padding="VALID",
self.cropconcat=CropConcatBlock() )
self.cropadd = CropAddBlock()
self.cropconcat = CropConcatBlock()
self.i_block=I_Block(N,activation,num_tfc) self.i_block = I_Block(N, activation, num_tfc)
def call(self, inputs, bridge, previous_encoder=None, previous_decoder=None,**kwargs): def call(
self, inputs, bridge, previous_encoder=None, previous_decoder=None, **kwargs
):
x = inputs x = inputs
x=tf.pad(x, self.paddings_1, mode='SYMMETRIC') x = tf.pad(x, self.paddings_1, mode="SYMMETRIC")
x = self.tconv_1(inputs) x = self.tconv_1(inputs)
x2= self.upsampling(inputs) x2 = self.upsampling(inputs)
if x2.shape[-1]!=x.shape[-1]: if x2.shape[-1] != x.shape[-1]:
x2= self.projection(x2) x2 = self.projection(x2)
x= self.cropadd(x,x2) x = self.cropadd(x, x2)
x = self.cropconcat(x, bridge)
x=self.cropconcat(x,bridge) x = self.i_block(x)
x=self.i_block(x)
return x return x
def get_config(self): def get_config(self):
return dict(layer_idx=self.layer_idx, return dict(
N=self.N, layer_idx=self.layer_idx,
S=self.S, N=self.N,
**super(D_Block, self).get_config() S=self.S,
) **super(D_Block, self).get_config(),
)
class CropAddBlock(layers.Layer): class CropAddBlock(layers.Layer):
def call(self,down_layer, x, **kwargs):
x1_shape = tf.shape(down_layer)
x2_shape = tf.shape(x)
height_diff = (x1_shape[1] - x2_shape[1]) // 2
width_diff = (x1_shape[2] - x2_shape[2]) // 2
down_layer_cropped = down_layer[:,
height_diff: (x2_shape[1] + height_diff),
width_diff: (x2_shape[2] + width_diff),
:]
x = layers.Add()([down_layer_cropped, x])
return x
class CropConcatBlock(layers.Layer):
def call(self, down_layer, x, **kwargs): def call(self, down_layer, x, **kwargs):
x1_shape = tf.shape(down_layer) x1_shape = tf.shape(down_layer)
x2_shape = tf.shape(x) x2_shape = tf.shape(x)
@ -477,10 +520,31 @@ class CropConcatBlock(layers.Layer):
height_diff = (x1_shape[1] - x2_shape[1]) // 2 height_diff = (x1_shape[1] - x2_shape[1]) // 2
width_diff = (x1_shape[2] - x2_shape[2]) // 2 width_diff = (x1_shape[2] - x2_shape[2]) // 2
down_layer_cropped = down_layer[:, down_layer_cropped = down_layer[
height_diff: (x2_shape[1] + height_diff), :,
width_diff: (x2_shape[2] + width_diff), height_diff : (x2_shape[1] + height_diff),
:] width_diff : (x2_shape[2] + width_diff),
:,
]
x = layers.Add()([down_layer_cropped, x])
return x
class CropConcatBlock(layers.Layer):
def call(self, down_layer, x, **kwargs):
x1_shape = tf.shape(down_layer)
x2_shape = tf.shape(x)
height_diff = (x1_shape[1] - x2_shape[1]) // 2
width_diff = (x1_shape[2] - x2_shape[2]) // 2
down_layer_cropped = down_layer[
:,
height_diff : (x2_shape[1] + height_diff),
width_diff : (x2_shape[2] + width_diff),
:,
]
x = tf.concat([down_layer_cropped, x], axis=-1) x = tf.concat([down_layer_cropped, x], axis=-1)
return x return x