denoising-historical-data/inference.py

236 lines
7.7 KiB
Python

import os
import hydra
import logging
logger = logging.getLogger(__name__)
def run(args):
import unet
import tensorflow as tf
import soundfile as sf
import numpy as np
from tqdm import tqdm
import scipy.signal
path_experiment = str(args.path_experiment)
unet_model = unet.build_model_denoise(unet_args=args.unet)
ckpt = os.path.join(
os.path.dirname(os.path.abspath(__file__)), path_experiment, "checkpoint"
)
unet_model.load_weights(ckpt)
def do_stft(noisy):
window_fn = tf.signal.hamming_window
win_size = args.stft.win_size
hop_size = args.stft.hop_size
stft_signal_noisy = tf.signal.stft(
noisy,
frame_length=win_size,
window_fn=window_fn,
frame_step=hop_size,
pad_end=True,
)
stft_noisy_stacked = tf.stack(
values=[tf.math.real(stft_signal_noisy), tf.math.imag(stft_signal_noisy)],
axis=-1,
)
return stft_noisy_stacked
def do_istft(data):
window_fn = tf.signal.hamming_window
win_size = args.stft.win_size
hop_size = args.stft.hop_size
inv_window_fn = tf.signal.inverse_stft_window_fn(
hop_size, forward_window_fn=window_fn
)
pred_cpx = data[..., 0] + 1j * data[..., 1]
pred_time = tf.signal.inverse_stft(
pred_cpx, win_size, hop_size, window_fn=inv_window_fn
)
return pred_time
audio = str(args.inference.audio)
data, samplerate = sf.read(audio)
print(data.dtype)
# Stereo to mono
if len(data.shape) > 1:
data = np.mean(data, axis=1)
if samplerate != 44100:
print("Resampling")
data = scipy.signal.resample(data, int((44100 / samplerate) * len(data)) + 1)
segment_size = 44100 * 5 # 20s segments
length_data = len(data)
overlapsize = 2048 # samples (46 ms)
window = np.hanning(2 * overlapsize)
window_right = window[overlapsize::]
window_left = window[0:overlapsize]
pointer = 0
denoised_data = np.zeros(shape=(len(data),))
residual_noise = np.zeros(shape=(len(data),))
numchunks = int(np.ceil(length_data / segment_size))
for i in tqdm(range(numchunks)):
if pointer + segment_size < length_data:
segment = data[pointer : pointer + segment_size]
# dostft
segment_TF = do_stft(segment)
segment_TF_ds = tf.data.Dataset.from_tensors(segment_TF)
pred = unet_model.predict(segment_TF_ds.batch(1))
pred = pred[0]
residual = segment_TF - pred[0]
residual = np.array(residual)
pred_time = do_istft(pred[0])
residual_time = do_istft(residual)
residual_time = np.array(residual_time)
if pointer == 0:
pred_time = np.concatenate(
(
pred_time[0 : int(segment_size - overlapsize)],
np.multiply(
pred_time[int(segment_size - overlapsize) : segment_size],
window_right,
),
),
axis=0,
)
residual_time = np.concatenate(
(
residual_time[0 : int(segment_size - overlapsize)],
np.multiply(
residual_time[
int(segment_size - overlapsize) : segment_size
],
window_right,
),
),
axis=0,
)
else:
pred_time = np.concatenate(
(
np.multiply(pred_time[0 : int(overlapsize)], window_left),
pred_time[int(overlapsize) : int(segment_size - overlapsize)],
np.multiply(
pred_time[
int(segment_size - overlapsize) : int(segment_size)
],
window_right,
),
),
axis=0,
)
residual_time = np.concatenate(
(
np.multiply(residual_time[0 : int(overlapsize)], window_left),
residual_time[
int(overlapsize) : int(segment_size - overlapsize)
],
np.multiply(
residual_time[
int(segment_size - overlapsize) : int(segment_size)
],
window_right,
),
),
axis=0,
)
denoised_data[pointer : pointer + segment_size] = (
denoised_data[pointer : pointer + segment_size] + pred_time
)
residual_noise[pointer : pointer + segment_size] = (
residual_noise[pointer : pointer + segment_size] + residual_time
)
pointer = pointer + segment_size - overlapsize
else:
segment = data[pointer::]
lensegment = len(segment)
segment = np.concatenate(
(segment, np.zeros(shape=(int(segment_size - len(segment)),))), axis=0
)
# dostft
segment_TF = do_stft(segment)
segment_TF_ds = tf.data.Dataset.from_tensors(segment_TF)
pred = unet_model.predict(segment_TF_ds.batch(1))
pred = pred[0]
residual = segment_TF - pred[0]
residual = np.array(residual)
pred_time = do_istft(pred[0])
pred_time = np.array(pred_time)
pred_time = pred_time[0:segment_size]
residual_time = do_istft(residual)
residual_time = np.array(residual_time)
residual_time = residual_time[0:segment_size]
if pointer == 0:
pred_time = pred_time
residual_time = residual_time
else:
pred_time = np.concatenate(
(
np.multiply(pred_time[0 : int(overlapsize)], window_left),
pred_time[int(overlapsize) : int(segment_size)],
),
axis=0,
)
residual_time = np.concatenate(
(
np.multiply(residual_time[0 : int(overlapsize)], window_left),
residual_time[int(overlapsize) : int(segment_size)],
),
axis=0,
)
denoised_data[pointer::] = (
denoised_data[pointer::] + pred_time[0:lensegment]
)
residual_noise[pointer::] = (
residual_noise[pointer::] + residual_time[0:lensegment]
)
basename = os.path.splitext(audio)[0]
wav_noisy_name = basename + "_noisy_input" + ".wav"
sf.write(wav_noisy_name, data, 44100)
wav_output_name = basename + "_denoised" + ".wav"
sf.write(wav_output_name, denoised_data, 44100)
wav_output_name = basename + "_residual" + ".wav"
sf.write(wav_output_name, residual_noise, 44100)
def _main(args):
global __file__
__file__ = hydra.utils.to_absolute_path(__file__)
run(args)
@hydra.main(config_path="conf/conf.yaml")
def main(args):
try:
_main(args)
except Exception:
logger.exception("Some error happened")
# Hydra intercepts exit code, fixed in beta but I could not get the beta to work
os._exit(1)
if __name__ == "__main__":
main()