487 lines
16 KiB
Python
487 lines
16 KiB
Python
import tensorflow as tf
|
|
from tensorflow.keras import Model, Input
|
|
from tensorflow.keras import layers
|
|
from tensorflow.keras.initializers import TruncatedNormal
|
|
import math as m
|
|
|
|
def build_model_denoise(unet_args=None):
|
|
|
|
inputs=Input(shape=(None, None,2))
|
|
|
|
outputs_stage_2,outputs_stage_1=MultiStage_denoise(unet_args=unet_args)(inputs)
|
|
|
|
#Encapsulating MultiStage_denoise in a keras.Model object
|
|
model= tf.keras.Model(inputs=inputs,outputs=[outputs_stage_2, outputs_stage_1])
|
|
|
|
return model
|
|
class DenseBlock(layers.Layer):
|
|
'''
|
|
[B, T, F, N] => [B, T, F, N]
|
|
DenseNet Block consisting of "num_layers" densely connected convolutional layers
|
|
'''
|
|
def __init__(self, num_layers, N, ksize,activation):
|
|
'''
|
|
num_layers: number of densely connected conv. layers
|
|
N: Number of filters (same in each layer)
|
|
ksize: Kernel size (same in each layer)
|
|
'''
|
|
super(DenseBlock, self).__init__()
|
|
self.activation=activation
|
|
|
|
self.paddings_1=get_paddings(ksize)
|
|
self.H=[]
|
|
self.num_layers=num_layers
|
|
|
|
for i in range(num_layers):
|
|
self.H.append(layers.Conv2D(filters=N,
|
|
kernel_size=ksize,
|
|
kernel_initializer=TruncatedNormal(),
|
|
strides=1,
|
|
padding='VALID',
|
|
activation=self.activation))
|
|
|
|
def call(self, x):
|
|
|
|
x_=tf.pad(x, self.paddings_1, mode='SYMMETRIC')
|
|
x_ = self.H[0](x_)
|
|
if self.num_layers>1:
|
|
for h in self.H[1:]:
|
|
x = tf.concat([x_, x], axis=-1)
|
|
x_=tf.pad(x, self.paddings_1, mode='SYMMETRIC')
|
|
x_ = h(x_)
|
|
|
|
return x_
|
|
|
|
|
|
class FinalBlock(layers.Layer):
|
|
'''
|
|
[B, T, F, N] => [B, T, F, 2]
|
|
Final block. Basically, a 3x3 conv. layer to map the output features to the output complex spectrogram.
|
|
|
|
'''
|
|
def __init__(self):
|
|
super(FinalBlock, self).__init__()
|
|
ksize=(3,3)
|
|
self.paddings_2=get_paddings(ksize)
|
|
self.conv2=layers.Conv2D(filters=2,
|
|
kernel_size=ksize,
|
|
kernel_initializer=TruncatedNormal(),
|
|
strides=1,
|
|
padding='VALID',
|
|
activation=None)
|
|
|
|
|
|
def call(self, inputs ):
|
|
|
|
x=tf.pad(inputs, self.paddings_2, mode='SYMMETRIC')
|
|
pred=self.conv2(x)
|
|
|
|
return pred
|
|
class SAM(layers.Layer):
|
|
'''
|
|
[B, T, F, N] => [B, T, F, N] , [B, T, F, N]
|
|
Supervised Attention Module:
|
|
The purpose of SAM is to make the network only propagate the most relevant features to the second stage, discarding the less useful ones.
|
|
The estimated residual noise signal is generated from the U-Net output features by means of a 3x3 convolutional layer.
|
|
The first stage output is then calculated adding the original input spectrogram to the residual noise.
|
|
The attention-guided features are computed using the attention masks M, which are directly calculated from the first stage output with a 1x1 convolution and a sigmoid function.
|
|
|
|
'''
|
|
def __init__(self, n_feat):
|
|
super(SAM, self).__init__()
|
|
|
|
ksize=(3,3)
|
|
self.paddings_1=get_paddings(ksize)
|
|
self.conv1 = layers.Conv2D(filters=n_feat,
|
|
kernel_size=ksize,
|
|
kernel_initializer=TruncatedNormal(),
|
|
strides=1,
|
|
padding='VALID',
|
|
activation=None)
|
|
ksize=(3,3)
|
|
self.paddings_2=get_paddings(ksize)
|
|
self.conv2=layers.Conv2D(filters=2,
|
|
kernel_size=ksize,
|
|
kernel_initializer=TruncatedNormal(),
|
|
strides=1,
|
|
padding='VALID',
|
|
activation=None)
|
|
|
|
ksize=(3,3)
|
|
self.paddings_3=get_paddings(ksize)
|
|
self.conv3 = layers.Conv2D(filters=n_feat,
|
|
kernel_size=ksize,
|
|
kernel_initializer=TruncatedNormal(),
|
|
strides=1,
|
|
padding='VALID',
|
|
activation=None)
|
|
self.cropadd=CropAddBlock()
|
|
|
|
def call(self, inputs, input_spectrogram):
|
|
x1=tf.pad(inputs, self.paddings_1, mode='SYMMETRIC')
|
|
x1 = self.conv1(x1)
|
|
|
|
x=tf.pad(inputs, self.paddings_2, mode='SYMMETRIC')
|
|
x=self.conv2(x)
|
|
|
|
#residual prediction
|
|
pred = layers.Add()([x, input_spectrogram]) #features to next stage
|
|
|
|
x3=tf.pad(pred, self.paddings_3, mode='SYMMETRIC')
|
|
M=self.conv3(x3)
|
|
|
|
M= tf.keras.activations.sigmoid(M)
|
|
x1=layers.Multiply()([x1, M])
|
|
x1 = layers.Add()([x1, inputs]) #features to next stage
|
|
|
|
return x1, pred
|
|
|
|
|
|
class AddFreqEncoding(layers.Layer):
|
|
'''
|
|
[B, T, F, 2] => [B, T, F, 12]
|
|
Generates frequency positional embeddings and concatenates them as 10 extra channels
|
|
This function is optimized for F=1025
|
|
'''
|
|
def __init__(self, f_dim):
|
|
super(AddFreqEncoding, self).__init__()
|
|
pi = tf.constant(m.pi)
|
|
pi=tf.cast(pi,'float32')
|
|
self.f_dim=f_dim #f_dim is fixed
|
|
n=tf.cast(tf.range(f_dim)/(f_dim-1),'float32')
|
|
coss=tf.math.cos(pi*n)
|
|
f_channel = tf.expand_dims(coss, -1) #(1025,1)
|
|
self.fembeddings= f_channel
|
|
|
|
for k in range(1,10):
|
|
coss=tf.math.cos(2**k*pi*n)
|
|
f_channel = tf.expand_dims(coss, -1) #(1025,1)
|
|
self.fembeddings=tf.concat([self.fembeddings,f_channel],axis=-1) #(1025,10)
|
|
|
|
|
|
def call(self, input_tensor):
|
|
|
|
batch_size_tensor = tf.shape(input_tensor)[0] # get batch size
|
|
time_dim = tf.shape(input_tensor)[1] # get time dimension
|
|
|
|
fembeddings_2 = tf.broadcast_to(self.fembeddings, [batch_size_tensor, time_dim, self.f_dim, 10])
|
|
|
|
|
|
return tf.concat([input_tensor,fembeddings_2],axis=-1) #(batch,427,1025,12)
|
|
|
|
|
|
def get_paddings(K):
|
|
return tf.constant([[0,0],[K[0]//2, K[0]//2 -(1- K[0]%2) ], [ K[1]//2, K[1]//2 -(1- K[1]%2) ],[0,0]])
|
|
|
|
class Decoder(layers.Layer):
|
|
'''
|
|
[B, T, F, N] , skip connections => [B, T, F, N]
|
|
Decoder side of the U-Net subnetwork.
|
|
'''
|
|
def __init__(self, Ns, Ss, unet_args):
|
|
super(Decoder, self).__init__()
|
|
|
|
self.Ns=Ns
|
|
self.Ss=Ss
|
|
self.activation=unet_args.activation
|
|
self.depth=unet_args.depth
|
|
|
|
|
|
ksize=(3,3)
|
|
self.paddings_3=get_paddings(ksize)
|
|
self.conv2d_3=layers.Conv2D(filters=self.Ns[self.depth],
|
|
kernel_size=ksize,
|
|
kernel_initializer=TruncatedNormal(),
|
|
strides=1,
|
|
padding='VALID',
|
|
activation=self.activation)
|
|
|
|
self.cropadd=CropAddBlock()
|
|
|
|
self.dblocks=[]
|
|
for i in range(self.depth):
|
|
self.dblocks.append(D_Block(layer_idx=i,N=self.Ns[i], S=self.Ss[i], activation=self.activation,num_tfc=unet_args.num_tfc))
|
|
|
|
def call(self,inputs, contracting_layers):
|
|
x=inputs
|
|
for i in range(self.depth,0,-1):
|
|
x=self.dblocks[i-1](x, contracting_layers[i-1])
|
|
return x
|
|
|
|
class Encoder(tf.keras.Model):
|
|
|
|
'''
|
|
[B, T, F, N] => skip connections , [B, T, F, N_4]
|
|
Encoder side of the U-Net subnetwork.
|
|
'''
|
|
def __init__(self, Ns, Ss, unet_args):
|
|
super(Encoder, self).__init__()
|
|
self.Ns=Ns
|
|
self.Ss=Ss
|
|
self.activation=unet_args.activation
|
|
self.depth=unet_args.depth
|
|
|
|
self.contracting_layers = {}
|
|
|
|
self.eblocks=[]
|
|
for i in range(self.depth):
|
|
self.eblocks.append(E_Block(layer_idx=i,N0=self.Ns[i],N=self.Ns[i+1],S=self.Ss[i], activation=self.activation , num_tfc=unet_args.num_tfc))
|
|
|
|
self.i_block=I_Block(self.Ns[self.depth],self.activation,unet_args.num_tfc)
|
|
|
|
def call(self, inputs):
|
|
x=inputs
|
|
for i in range(self.depth):
|
|
|
|
x, x_contract=self.eblocks[i](x)
|
|
|
|
self.contracting_layers[i] = x_contract #if remove 0, correct this
|
|
x=self.i_block(x)
|
|
|
|
return x, self.contracting_layers
|
|
|
|
class MultiStage_denoise(tf.keras.Model):
|
|
|
|
def __init__(self, unet_args=None):
|
|
super(MultiStage_denoise, self).__init__()
|
|
|
|
self.activation=unet_args.activation
|
|
self.depth=unet_args.depth
|
|
if unet_args.use_fencoding:
|
|
self.freq_encoding=AddFreqEncoding(unet_args.f_dim)
|
|
self.use_sam=unet_args.use_SAM
|
|
self.use_fencoding=unet_args.use_fencoding
|
|
self.num_stages=unet_args.num_stages
|
|
#Encoder
|
|
self.Ns= [32,64,64,128,128,256,512]
|
|
self.Ss= [(2,2),(2,2),(2,2),(2,2),(2,2),(2,2)]
|
|
|
|
#initial feature extractor
|
|
ksize=(7,7)
|
|
self.paddings_1=get_paddings(ksize)
|
|
self.conv2d_1 = layers.Conv2D(filters=self.Ns[0],
|
|
kernel_size=ksize,
|
|
kernel_initializer=TruncatedNormal(),
|
|
strides=1,
|
|
padding='VALID',
|
|
activation=self.activation)
|
|
|
|
|
|
self.encoder_s1=Encoder(self.Ns, self.Ss, unet_args)
|
|
self.decoder_s1=Decoder(self.Ns, self.Ss, unet_args)
|
|
|
|
self.cropconcat = CropConcatBlock()
|
|
self.cropadd = CropAddBlock()
|
|
|
|
self.finalblock=FinalBlock()
|
|
|
|
if self.num_stages>1:
|
|
self.sam_1=SAM(self.Ns[0])
|
|
|
|
#initial feature extractor
|
|
ksize=(7,7)
|
|
self.paddings_2=get_paddings(ksize)
|
|
self.conv2d_2 = layers.Conv2D(filters=self.Ns[0],
|
|
kernel_size=ksize,
|
|
kernel_initializer=TruncatedNormal(),
|
|
strides=1,
|
|
padding='VALID',
|
|
activation=self.activation)
|
|
|
|
|
|
self.encoder_s2=Encoder(self.Ns, self.Ss, unet_args)
|
|
self.decoder_s2=Decoder(self.Ns, self.Ss, unet_args)
|
|
|
|
@tf.function()
|
|
def call(self, inputs):
|
|
|
|
if self.use_fencoding:
|
|
x_w_freq=self.freq_encoding(inputs) #None, None, 1025, 12
|
|
else:
|
|
x_w_freq=inputs
|
|
|
|
#intitial feature extractor
|
|
x=tf.pad(x_w_freq, self.paddings_1, mode='SYMMETRIC')
|
|
x=self.conv2d_1(x) #None, None, 1025, 32
|
|
|
|
x, contracting_layers_s1= self.encoder_s1(x)
|
|
#decoder
|
|
feats_s1 =self.decoder_s1(x, contracting_layers_s1) #None, None, 1025, 32 features
|
|
|
|
if self.num_stages>1:
|
|
#SAM module
|
|
Fout, pred_stage_1=self.sam_1(feats_s1,inputs)
|
|
|
|
#intitial feature extractor
|
|
x=tf.pad(x_w_freq, self.paddings_2, mode='SYMMETRIC')
|
|
x=self.conv2d_2(x)
|
|
|
|
if self.use_sam:
|
|
x = tf.concat([x, Fout], axis=-1)
|
|
else:
|
|
x = tf.concat([x,feats_s1], axis=-1)
|
|
|
|
x, contracting_layers_s2= self.encoder_s2(x)
|
|
|
|
feats_s2=self.decoder_s2(x, contracting_layers_s2) #None, None, 1025, 32 features
|
|
|
|
#consider implementing a third stage?
|
|
|
|
pred_stage_2=self.finalblock(feats_s2)
|
|
return pred_stage_2, pred_stage_1
|
|
else:
|
|
pred_stage_1=self.finalblock(feats_s1)
|
|
return pred_stage_1
|
|
|
|
class I_Block(layers.Layer):
|
|
'''
|
|
[B, T, F, N] => [B, T, F, N]
|
|
Intermediate block:
|
|
Basically, a densenet block with a residual connection
|
|
'''
|
|
def __init__(self,N,activation, num_tfc, **kwargs):
|
|
super(I_Block, self).__init__(**kwargs)
|
|
|
|
ksize=(3,3)
|
|
self.tfc=DenseBlock(num_tfc,N,ksize, activation)
|
|
|
|
self.conv2d_res= layers.Conv2D(filters=N,
|
|
kernel_size=(1,1),
|
|
kernel_initializer=TruncatedNormal(),
|
|
strides=1,
|
|
padding='VALID')
|
|
|
|
def call(self,inputs):
|
|
x=self.tfc(inputs)
|
|
|
|
inputs_proj=self.conv2d_res(inputs)
|
|
return layers.Add()([x,inputs_proj])
|
|
|
|
|
|
class E_Block(layers.Layer):
|
|
|
|
def __init__(self, layer_idx,N0, N, S,activation, num_tfc, **kwargs):
|
|
super(E_Block, self).__init__(**kwargs)
|
|
self.layer_idx=layer_idx
|
|
self.N0=N0
|
|
self.N=N
|
|
self.S=S
|
|
self.activation=activation
|
|
self.i_block=I_Block(N0,activation,num_tfc)
|
|
|
|
ksize=(S[0]+2,S[1]+2)
|
|
self.paddings_2=get_paddings(ksize)
|
|
self.conv2d_2 = layers.Conv2D(filters=N,
|
|
kernel_size=(S[0]+2,S[1]+2),
|
|
kernel_initializer=TruncatedNormal(),
|
|
strides=S,
|
|
padding='VALID',
|
|
activation=self.activation)
|
|
|
|
|
|
def call(self, inputs, training=None, **kwargs):
|
|
x=self.i_block(inputs)
|
|
|
|
x_down=tf.pad(x, self.paddings_2, mode='SYMMETRIC')
|
|
x_down = self.conv2d_2(x_down)
|
|
|
|
return x_down, x
|
|
|
|
|
|
def get_config(self):
|
|
return dict(layer_idx=self.layer_idx,
|
|
N=self.N,
|
|
S=self.S,
|
|
**super(E_Block, self).get_config()
|
|
)
|
|
class D_Block(layers.Layer):
|
|
|
|
def __init__(self, layer_idx, N, S,activation, num_tfc, **kwargs):
|
|
super(D_Block, self).__init__(**kwargs)
|
|
self.layer_idx=layer_idx
|
|
self.N=N
|
|
self.S=S
|
|
self.activation=activation
|
|
ksize=(S[0]+2, S[1]+2)
|
|
self.paddings_1=get_paddings(ksize)
|
|
|
|
self.tconv_1= layers.Conv2DTranspose(filters=N,
|
|
kernel_size=(S[0]+2, S[1]+2),
|
|
kernel_initializer=TruncatedNormal(),
|
|
strides=S,
|
|
activation=self.activation,
|
|
padding='VALID')
|
|
|
|
self.upsampling = layers.UpSampling2D(size=S, interpolation='nearest')
|
|
|
|
self.projection = layers.Conv2D(filters=N,
|
|
kernel_size=(1,1),
|
|
kernel_initializer=TruncatedNormal(),
|
|
strides=1,
|
|
activation=self.activation,
|
|
padding='VALID')
|
|
self.cropadd=CropAddBlock()
|
|
self.cropconcat=CropConcatBlock()
|
|
|
|
self.i_block=I_Block(N,activation,num_tfc)
|
|
|
|
def call(self, inputs, bridge, previous_encoder=None, previous_decoder=None,**kwargs):
|
|
x = inputs
|
|
x=tf.pad(x, self.paddings_1, mode='SYMMETRIC')
|
|
x = self.tconv_1(inputs)
|
|
|
|
x2= self.upsampling(inputs)
|
|
|
|
if x2.shape[-1]!=x.shape[-1]:
|
|
x2= self.projection(x2)
|
|
|
|
x= self.cropadd(x,x2)
|
|
|
|
|
|
x=self.cropconcat(x,bridge)
|
|
|
|
x=self.i_block(x)
|
|
return x
|
|
|
|
def get_config(self):
|
|
return dict(layer_idx=self.layer_idx,
|
|
N=self.N,
|
|
S=self.S,
|
|
**super(D_Block, self).get_config()
|
|
)
|
|
|
|
class CropAddBlock(layers.Layer):
|
|
|
|
def call(self,down_layer, x, **kwargs):
|
|
x1_shape = tf.shape(down_layer)
|
|
x2_shape = tf.shape(x)
|
|
|
|
|
|
height_diff = (x1_shape[1] - x2_shape[1]) // 2
|
|
width_diff = (x1_shape[2] - x2_shape[2]) // 2
|
|
|
|
down_layer_cropped = down_layer[:,
|
|
height_diff: (x2_shape[1] + height_diff),
|
|
width_diff: (x2_shape[2] + width_diff),
|
|
:]
|
|
|
|
x = layers.Add()([down_layer_cropped, x])
|
|
return x
|
|
|
|
class CropConcatBlock(layers.Layer):
|
|
|
|
def call(self, down_layer, x, **kwargs):
|
|
x1_shape = tf.shape(down_layer)
|
|
x2_shape = tf.shape(x)
|
|
|
|
height_diff = (x1_shape[1] - x2_shape[1]) // 2
|
|
width_diff = (x1_shape[2] - x2_shape[2]) // 2
|
|
|
|
down_layer_cropped = down_layer[:,
|
|
height_diff: (x2_shape[1] + height_diff),
|
|
width_diff: (x2_shape[2] + width_diff),
|
|
:]
|
|
|
|
x = tf.concat([down_layer_cropped, x], axis=-1)
|
|
return x
|