551 lines
16 KiB
Python
551 lines
16 KiB
Python
import tensorflow as tf
|
|
from tensorflow.keras import Input
|
|
from tensorflow.keras import layers
|
|
from tensorflow.keras.initializers import TruncatedNormal
|
|
import math as m
|
|
|
|
|
|
def build_model_denoise(unet_args=None):
|
|
inputs = Input(shape=(None, None, 2))
|
|
|
|
outputs_stage_2, outputs_stage_1 = MultiStage_denoise(unet_args=unet_args)(inputs)
|
|
|
|
# Encapsulating MultiStage_denoise in a keras.Model object
|
|
model = tf.keras.Model(inputs=inputs, outputs=[outputs_stage_2, outputs_stage_1])
|
|
|
|
return model
|
|
|
|
|
|
class DenseBlock(layers.Layer):
|
|
"""
|
|
[B, T, F, N] => [B, T, F, N]
|
|
DenseNet Block consisting of "num_layers" densely connected convolutional layers
|
|
"""
|
|
|
|
def __init__(self, num_layers, N, ksize, activation):
|
|
"""
|
|
num_layers: number of densely connected conv. layers
|
|
N: Number of filters (same in each layer)
|
|
ksize: Kernel size (same in each layer)
|
|
"""
|
|
super(DenseBlock, self).__init__()
|
|
self.activation = activation
|
|
|
|
self.paddings_1 = get_paddings(ksize)
|
|
self.H = []
|
|
self.num_layers = num_layers
|
|
|
|
for i in range(num_layers):
|
|
self.H.append(
|
|
layers.Conv2D(
|
|
filters=N,
|
|
kernel_size=ksize,
|
|
kernel_initializer=TruncatedNormal(),
|
|
strides=1,
|
|
padding="VALID",
|
|
activation=self.activation,
|
|
)
|
|
)
|
|
|
|
def call(self, x):
|
|
x_ = tf.pad(x, self.paddings_1, mode="SYMMETRIC")
|
|
x_ = self.H[0](x_)
|
|
if self.num_layers > 1:
|
|
for h in self.H[1:]:
|
|
x = tf.concat([x_, x], axis=-1)
|
|
x_ = tf.pad(x, self.paddings_1, mode="SYMMETRIC")
|
|
x_ = h(x_)
|
|
|
|
return x_
|
|
|
|
|
|
class FinalBlock(layers.Layer):
|
|
"""
|
|
[B, T, F, N] => [B, T, F, 2]
|
|
Final block. Basically, a 3x3 conv. layer to map the output features to the output complex spectrogram.
|
|
|
|
"""
|
|
|
|
def __init__(self):
|
|
super(FinalBlock, self).__init__()
|
|
ksize = (3, 3)
|
|
self.paddings_2 = get_paddings(ksize)
|
|
self.conv2 = layers.Conv2D(
|
|
filters=2,
|
|
kernel_size=ksize,
|
|
kernel_initializer=TruncatedNormal(),
|
|
strides=1,
|
|
padding="VALID",
|
|
activation=None,
|
|
)
|
|
|
|
def call(self, inputs):
|
|
x = tf.pad(inputs, self.paddings_2, mode="SYMMETRIC")
|
|
pred = self.conv2(x)
|
|
|
|
return pred
|
|
|
|
|
|
class SAM(layers.Layer):
|
|
"""
|
|
[B, T, F, N] => [B, T, F, N] , [B, T, F, N]
|
|
Supervised Attention Module:
|
|
The purpose of SAM is to make the network only propagate the most relevant features to the second stage, discarding the less useful ones.
|
|
The estimated residual noise signal is generated from the U-Net output features by means of a 3x3 convolutional layer.
|
|
The first stage output is then calculated adding the original input spectrogram to the residual noise.
|
|
The attention-guided features are computed using the attention masks M, which are directly calculated from the first stage output with a 1x1 convolution and a sigmoid function.
|
|
|
|
"""
|
|
|
|
def __init__(self, n_feat):
|
|
super(SAM, self).__init__()
|
|
|
|
ksize = (3, 3)
|
|
self.paddings_1 = get_paddings(ksize)
|
|
self.conv1 = layers.Conv2D(
|
|
filters=n_feat,
|
|
kernel_size=ksize,
|
|
kernel_initializer=TruncatedNormal(),
|
|
strides=1,
|
|
padding="VALID",
|
|
activation=None,
|
|
)
|
|
ksize = (3, 3)
|
|
self.paddings_2 = get_paddings(ksize)
|
|
self.conv2 = layers.Conv2D(
|
|
filters=2,
|
|
kernel_size=ksize,
|
|
kernel_initializer=TruncatedNormal(),
|
|
strides=1,
|
|
padding="VALID",
|
|
activation=None,
|
|
)
|
|
|
|
ksize = (3, 3)
|
|
self.paddings_3 = get_paddings(ksize)
|
|
self.conv3 = layers.Conv2D(
|
|
filters=n_feat,
|
|
kernel_size=ksize,
|
|
kernel_initializer=TruncatedNormal(),
|
|
strides=1,
|
|
padding="VALID",
|
|
activation=None,
|
|
)
|
|
self.cropadd = CropAddBlock()
|
|
|
|
def call(self, inputs, input_spectrogram):
|
|
x1 = tf.pad(inputs, self.paddings_1, mode="SYMMETRIC")
|
|
x1 = self.conv1(x1)
|
|
|
|
x = tf.pad(inputs, self.paddings_2, mode="SYMMETRIC")
|
|
x = self.conv2(x)
|
|
|
|
# residual prediction
|
|
pred = layers.Add()([x, input_spectrogram]) # features to next stage
|
|
|
|
x3 = tf.pad(pred, self.paddings_3, mode="SYMMETRIC")
|
|
M = self.conv3(x3)
|
|
|
|
M = tf.keras.activations.sigmoid(M)
|
|
x1 = layers.Multiply()([x1, M])
|
|
x1 = layers.Add()([x1, inputs]) # features to next stage
|
|
|
|
return x1, pred
|
|
|
|
|
|
class AddFreqEncoding(layers.Layer):
|
|
"""
|
|
[B, T, F, 2] => [B, T, F, 12]
|
|
Generates frequency positional embeddings and concatenates them as 10 extra channels
|
|
This function is optimized for F=1025
|
|
"""
|
|
|
|
def __init__(self, f_dim):
|
|
super(AddFreqEncoding, self).__init__()
|
|
pi = tf.constant(m.pi)
|
|
pi = tf.cast(pi, "float32")
|
|
self.f_dim = f_dim # f_dim is fixed
|
|
n = tf.cast(tf.range(f_dim) / (f_dim - 1), "float32")
|
|
coss = tf.math.cos(pi * n)
|
|
f_channel = tf.expand_dims(coss, -1) # (1025,1)
|
|
self.fembeddings = f_channel
|
|
|
|
for k in range(1, 10):
|
|
coss = tf.math.cos(2**k * pi * n)
|
|
f_channel = tf.expand_dims(coss, -1) # (1025,1)
|
|
self.fembeddings = tf.concat(
|
|
[self.fembeddings, f_channel], axis=-1
|
|
) # (1025,10)
|
|
|
|
def call(self, input_tensor):
|
|
batch_size_tensor = tf.shape(input_tensor)[0] # get batch size
|
|
time_dim = tf.shape(input_tensor)[1] # get time dimension
|
|
|
|
fembeddings_2 = tf.broadcast_to(
|
|
self.fembeddings, [batch_size_tensor, time_dim, self.f_dim, 10]
|
|
)
|
|
|
|
return tf.concat([input_tensor, fembeddings_2], axis=-1) # (batch,427,1025,12)
|
|
|
|
|
|
def get_paddings(K):
|
|
return tf.constant(
|
|
[
|
|
[0, 0],
|
|
[K[0] // 2, K[0] // 2 - (1 - K[0] % 2)],
|
|
[K[1] // 2, K[1] // 2 - (1 - K[1] % 2)],
|
|
[0, 0],
|
|
]
|
|
)
|
|
|
|
|
|
class Decoder(layers.Layer):
|
|
"""
|
|
[B, T, F, N] , skip connections => [B, T, F, N]
|
|
Decoder side of the U-Net subnetwork.
|
|
"""
|
|
|
|
def __init__(self, Ns, Ss, unet_args):
|
|
super(Decoder, self).__init__()
|
|
|
|
self.Ns = Ns
|
|
self.Ss = Ss
|
|
self.activation = unet_args.activation
|
|
self.depth = unet_args.depth
|
|
|
|
ksize = (3, 3)
|
|
self.paddings_3 = get_paddings(ksize)
|
|
self.conv2d_3 = layers.Conv2D(
|
|
filters=self.Ns[self.depth],
|
|
kernel_size=ksize,
|
|
kernel_initializer=TruncatedNormal(),
|
|
strides=1,
|
|
padding="VALID",
|
|
activation=self.activation,
|
|
)
|
|
|
|
self.cropadd = CropAddBlock()
|
|
|
|
self.dblocks = []
|
|
for i in range(self.depth):
|
|
self.dblocks.append(
|
|
D_Block(
|
|
layer_idx=i,
|
|
N=self.Ns[i],
|
|
S=self.Ss[i],
|
|
activation=self.activation,
|
|
num_tfc=unet_args.num_tfc,
|
|
)
|
|
)
|
|
|
|
def call(self, inputs, contracting_layers):
|
|
x = inputs
|
|
for i in range(self.depth, 0, -1):
|
|
x = self.dblocks[i - 1](x, contracting_layers[i - 1])
|
|
return x
|
|
|
|
|
|
class Encoder(tf.keras.Model):
|
|
"""
|
|
[B, T, F, N] => skip connections , [B, T, F, N_4]
|
|
Encoder side of the U-Net subnetwork.
|
|
"""
|
|
|
|
def __init__(self, Ns, Ss, unet_args):
|
|
super(Encoder, self).__init__()
|
|
self.Ns = Ns
|
|
self.Ss = Ss
|
|
self.activation = unet_args.activation
|
|
self.depth = unet_args.depth
|
|
|
|
self.contracting_layers = {}
|
|
|
|
self.eblocks = []
|
|
for i in range(self.depth):
|
|
self.eblocks.append(
|
|
E_Block(
|
|
layer_idx=i,
|
|
N0=self.Ns[i],
|
|
N=self.Ns[i + 1],
|
|
S=self.Ss[i],
|
|
activation=self.activation,
|
|
num_tfc=unet_args.num_tfc,
|
|
)
|
|
)
|
|
|
|
self.i_block = I_Block(self.Ns[self.depth], self.activation, unet_args.num_tfc)
|
|
|
|
def call(self, inputs):
|
|
x = inputs
|
|
for i in range(self.depth):
|
|
x, x_contract = self.eblocks[i](x)
|
|
|
|
self.contracting_layers[i] = x_contract # if remove 0, correct this
|
|
x = self.i_block(x)
|
|
|
|
return x, self.contracting_layers
|
|
|
|
|
|
class MultiStage_denoise(tf.keras.Model):
|
|
def __init__(self, unet_args=None):
|
|
super(MultiStage_denoise, self).__init__()
|
|
|
|
self.activation = unet_args.activation
|
|
self.depth = unet_args.depth
|
|
if unet_args.use_fencoding:
|
|
self.freq_encoding = AddFreqEncoding(unet_args.f_dim)
|
|
self.use_sam = unet_args.use_SAM
|
|
self.use_fencoding = unet_args.use_fencoding
|
|
self.num_stages = unet_args.num_stages
|
|
# Encoder
|
|
self.Ns = [32, 64, 64, 128, 128, 256, 512]
|
|
self.Ss = [(2, 2), (2, 2), (2, 2), (2, 2), (2, 2), (2, 2)]
|
|
|
|
# initial feature extractor
|
|
ksize = (7, 7)
|
|
self.paddings_1 = get_paddings(ksize)
|
|
self.conv2d_1 = layers.Conv2D(
|
|
filters=self.Ns[0],
|
|
kernel_size=ksize,
|
|
kernel_initializer=TruncatedNormal(),
|
|
strides=1,
|
|
padding="VALID",
|
|
activation=self.activation,
|
|
)
|
|
|
|
self.encoder_s1 = Encoder(self.Ns, self.Ss, unet_args)
|
|
self.decoder_s1 = Decoder(self.Ns, self.Ss, unet_args)
|
|
|
|
self.cropconcat = CropConcatBlock()
|
|
self.cropadd = CropAddBlock()
|
|
|
|
self.finalblock = FinalBlock()
|
|
|
|
if self.num_stages > 1:
|
|
self.sam_1 = SAM(self.Ns[0])
|
|
|
|
# initial feature extractor
|
|
ksize = (7, 7)
|
|
self.paddings_2 = get_paddings(ksize)
|
|
self.conv2d_2 = layers.Conv2D(
|
|
filters=self.Ns[0],
|
|
kernel_size=ksize,
|
|
kernel_initializer=TruncatedNormal(),
|
|
strides=1,
|
|
padding="VALID",
|
|
activation=self.activation,
|
|
)
|
|
|
|
self.encoder_s2 = Encoder(self.Ns, self.Ss, unet_args)
|
|
self.decoder_s2 = Decoder(self.Ns, self.Ss, unet_args)
|
|
|
|
@tf.function()
|
|
def call(self, inputs):
|
|
if self.use_fencoding:
|
|
x_w_freq = self.freq_encoding(inputs) # None, None, 1025, 12
|
|
else:
|
|
x_w_freq = inputs
|
|
|
|
# intitial feature extractor
|
|
x = tf.pad(x_w_freq, self.paddings_1, mode="SYMMETRIC")
|
|
x = self.conv2d_1(x) # None, None, 1025, 32
|
|
|
|
x, contracting_layers_s1 = self.encoder_s1(x)
|
|
# decoder
|
|
feats_s1 = self.decoder_s1(
|
|
x, contracting_layers_s1
|
|
) # None, None, 1025, 32 features
|
|
|
|
if self.num_stages > 1:
|
|
# SAM module
|
|
Fout, pred_stage_1 = self.sam_1(feats_s1, inputs)
|
|
|
|
# intitial feature extractor
|
|
x = tf.pad(x_w_freq, self.paddings_2, mode="SYMMETRIC")
|
|
x = self.conv2d_2(x)
|
|
|
|
if self.use_sam:
|
|
x = tf.concat([x, Fout], axis=-1)
|
|
else:
|
|
x = tf.concat([x, feats_s1], axis=-1)
|
|
|
|
x, contracting_layers_s2 = self.encoder_s2(x)
|
|
|
|
feats_s2 = self.decoder_s2(
|
|
x, contracting_layers_s2
|
|
) # None, None, 1025, 32 features
|
|
|
|
# consider implementing a third stage?
|
|
|
|
pred_stage_2 = self.finalblock(feats_s2)
|
|
return pred_stage_2, pred_stage_1
|
|
else:
|
|
pred_stage_1 = self.finalblock(feats_s1)
|
|
return pred_stage_1
|
|
|
|
|
|
class I_Block(layers.Layer):
|
|
"""
|
|
[B, T, F, N] => [B, T, F, N]
|
|
Intermediate block:
|
|
Basically, a densenet block with a residual connection
|
|
"""
|
|
|
|
def __init__(self, N, activation, num_tfc, **kwargs):
|
|
super(I_Block, self).__init__(**kwargs)
|
|
|
|
ksize = (3, 3)
|
|
self.tfc = DenseBlock(num_tfc, N, ksize, activation)
|
|
|
|
self.conv2d_res = layers.Conv2D(
|
|
filters=N,
|
|
kernel_size=(1, 1),
|
|
kernel_initializer=TruncatedNormal(),
|
|
strides=1,
|
|
padding="VALID",
|
|
)
|
|
|
|
def call(self, inputs):
|
|
x = self.tfc(inputs)
|
|
|
|
inputs_proj = self.conv2d_res(inputs)
|
|
return layers.Add()([x, inputs_proj])
|
|
|
|
|
|
class E_Block(layers.Layer):
|
|
def __init__(self, layer_idx, N0, N, S, activation, num_tfc, **kwargs):
|
|
super(E_Block, self).__init__(**kwargs)
|
|
self.layer_idx = layer_idx
|
|
self.N0 = N0
|
|
self.N = N
|
|
self.S = S
|
|
self.activation = activation
|
|
self.i_block = I_Block(N0, activation, num_tfc)
|
|
|
|
ksize = (S[0] + 2, S[1] + 2)
|
|
self.paddings_2 = get_paddings(ksize)
|
|
self.conv2d_2 = layers.Conv2D(
|
|
filters=N,
|
|
kernel_size=(S[0] + 2, S[1] + 2),
|
|
kernel_initializer=TruncatedNormal(),
|
|
strides=S,
|
|
padding="VALID",
|
|
activation=self.activation,
|
|
)
|
|
|
|
def call(self, inputs, training=None, **kwargs):
|
|
x = self.i_block(inputs)
|
|
|
|
x_down = tf.pad(x, self.paddings_2, mode="SYMMETRIC")
|
|
x_down = self.conv2d_2(x_down)
|
|
|
|
return x_down, x
|
|
|
|
def get_config(self):
|
|
return dict(
|
|
layer_idx=self.layer_idx,
|
|
N=self.N,
|
|
S=self.S,
|
|
**super(E_Block, self).get_config(),
|
|
)
|
|
|
|
|
|
class D_Block(layers.Layer):
|
|
def __init__(self, layer_idx, N, S, activation, num_tfc, **kwargs):
|
|
super(D_Block, self).__init__(**kwargs)
|
|
self.layer_idx = layer_idx
|
|
self.N = N
|
|
self.S = S
|
|
self.activation = activation
|
|
ksize = (S[0] + 2, S[1] + 2)
|
|
self.paddings_1 = get_paddings(ksize)
|
|
|
|
self.tconv_1 = layers.Conv2DTranspose(
|
|
filters=N,
|
|
kernel_size=(S[0] + 2, S[1] + 2),
|
|
kernel_initializer=TruncatedNormal(),
|
|
strides=S,
|
|
activation=self.activation,
|
|
padding="VALID",
|
|
)
|
|
|
|
self.upsampling = layers.UpSampling2D(size=S, interpolation="nearest")
|
|
|
|
self.projection = layers.Conv2D(
|
|
filters=N,
|
|
kernel_size=(1, 1),
|
|
kernel_initializer=TruncatedNormal(),
|
|
strides=1,
|
|
activation=self.activation,
|
|
padding="VALID",
|
|
)
|
|
self.cropadd = CropAddBlock()
|
|
self.cropconcat = CropConcatBlock()
|
|
|
|
self.i_block = I_Block(N, activation, num_tfc)
|
|
|
|
def call(
|
|
self, inputs, bridge, previous_encoder=None, previous_decoder=None, **kwargs
|
|
):
|
|
x = inputs
|
|
x = tf.pad(x, self.paddings_1, mode="SYMMETRIC")
|
|
x = self.tconv_1(inputs)
|
|
|
|
x2 = self.upsampling(inputs)
|
|
|
|
if x2.shape[-1] != x.shape[-1]:
|
|
x2 = self.projection(x2)
|
|
|
|
x = self.cropadd(x, x2)
|
|
|
|
x = self.cropconcat(x, bridge)
|
|
|
|
x = self.i_block(x)
|
|
return x
|
|
|
|
def get_config(self):
|
|
return dict(
|
|
layer_idx=self.layer_idx,
|
|
N=self.N,
|
|
S=self.S,
|
|
**super(D_Block, self).get_config(),
|
|
)
|
|
|
|
|
|
class CropAddBlock(layers.Layer):
|
|
def call(self, down_layer, x, **kwargs):
|
|
x1_shape = tf.shape(down_layer)
|
|
x2_shape = tf.shape(x)
|
|
|
|
height_diff = (x1_shape[1] - x2_shape[1]) // 2
|
|
width_diff = (x1_shape[2] - x2_shape[2]) // 2
|
|
|
|
down_layer_cropped = down_layer[
|
|
:,
|
|
height_diff : (x2_shape[1] + height_diff),
|
|
width_diff : (x2_shape[2] + width_diff),
|
|
:,
|
|
]
|
|
|
|
x = layers.Add()([down_layer_cropped, x])
|
|
return x
|
|
|
|
|
|
class CropConcatBlock(layers.Layer):
|
|
def call(self, down_layer, x, **kwargs):
|
|
x1_shape = tf.shape(down_layer)
|
|
x2_shape = tf.shape(x)
|
|
|
|
height_diff = (x1_shape[1] - x2_shape[1]) // 2
|
|
width_diff = (x1_shape[2] - x2_shape[2]) // 2
|
|
|
|
down_layer_cropped = down_layer[
|
|
:,
|
|
height_diff : (x2_shape[1] + height_diff),
|
|
width_diff : (x2_shape[2] + width_diff),
|
|
:,
|
|
]
|
|
|
|
x = tf.concat([down_layer_cropped, x], axis=-1)
|
|
return x
|