denoising-historical-data/unet.py
2021-08-30 18:30:51 +03:00

487 lines
16 KiB
Python

import tensorflow as tf
from tensorflow.keras import Model, Input
from tensorflow.keras import layers
from tensorflow.keras.initializers import TruncatedNormal
import math as m
def build_model_denoise(unet_args=None):
inputs=Input(shape=(None, None,2))
outputs_stage_2,outputs_stage_1=MultiStage_denoise(unet_args=unet_args)(inputs)
#Encapsulating MultiStage_denoise in a keras.Model object
model= tf.keras.Model(inputs=inputs,outputs=[outputs_stage_2, outputs_stage_1])
return model
class DenseBlock(layers.Layer):
'''
[B, T, F, N] => [B, T, F, N]
DenseNet Block consisting of "num_layers" densely connected convolutional layers
'''
def __init__(self, num_layers, N, ksize,activation):
'''
num_layers: number of densely connected conv. layers
N: Number of filters (same in each layer)
ksize: Kernel size (same in each layer)
'''
super(DenseBlock, self).__init__()
self.activation=activation
self.paddings_1=get_paddings(ksize)
self.H=[]
self.num_layers=num_layers
for i in range(num_layers):
self.H.append(layers.Conv2D(filters=N,
kernel_size=ksize,
kernel_initializer=TruncatedNormal(),
strides=1,
padding='VALID',
activation=self.activation))
def call(self, x):
x_=tf.pad(x, self.paddings_1, mode='SYMMETRIC')
x_ = self.H[0](x_)
if self.num_layers>1:
for h in self.H[1:]:
x = tf.concat([x_, x], axis=-1)
x_=tf.pad(x, self.paddings_1, mode='SYMMETRIC')
x_ = h(x_)
return x_
class FinalBlock(layers.Layer):
'''
[B, T, F, N] => [B, T, F, 2]
Final block. Basically, a 3x3 conv. layer to map the output features to the output complex spectrogram.
'''
def __init__(self):
super(FinalBlock, self).__init__()
ksize=(3,3)
self.paddings_2=get_paddings(ksize)
self.conv2=layers.Conv2D(filters=2,
kernel_size=ksize,
kernel_initializer=TruncatedNormal(),
strides=1,
padding='VALID',
activation=None)
def call(self, inputs ):
x=tf.pad(inputs, self.paddings_2, mode='SYMMETRIC')
pred=self.conv2(x)
return pred
class SAM(layers.Layer):
'''
[B, T, F, N] => [B, T, F, N] , [B, T, F, N]
Supervised Attention Module:
The purpose of SAM is to make the network only propagate the most relevant features to the second stage, discarding the less useful ones.
The estimated residual noise signal is generated from the U-Net output features by means of a 3x3 convolutional layer.
The first stage output is then calculated adding the original input spectrogram to the residual noise.
The attention-guided features are computed using the attention masks M, which are directly calculated from the first stage output with a 1x1 convolution and a sigmoid function.
'''
def __init__(self, n_feat):
super(SAM, self).__init__()
ksize=(3,3)
self.paddings_1=get_paddings(ksize)
self.conv1 = layers.Conv2D(filters=n_feat,
kernel_size=ksize,
kernel_initializer=TruncatedNormal(),
strides=1,
padding='VALID',
activation=None)
ksize=(3,3)
self.paddings_2=get_paddings(ksize)
self.conv2=layers.Conv2D(filters=2,
kernel_size=ksize,
kernel_initializer=TruncatedNormal(),
strides=1,
padding='VALID',
activation=None)
ksize=(3,3)
self.paddings_3=get_paddings(ksize)
self.conv3 = layers.Conv2D(filters=n_feat,
kernel_size=ksize,
kernel_initializer=TruncatedNormal(),
strides=1,
padding='VALID',
activation=None)
self.cropadd=CropAddBlock()
def call(self, inputs, input_spectrogram):
x1=tf.pad(inputs, self.paddings_1, mode='SYMMETRIC')
x1 = self.conv1(x1)
x=tf.pad(inputs, self.paddings_2, mode='SYMMETRIC')
x=self.conv2(x)
#residual prediction
pred = layers.Add()([x, input_spectrogram]) #features to next stage
x3=tf.pad(pred, self.paddings_3, mode='SYMMETRIC')
M=self.conv3(x3)
M= tf.keras.activations.sigmoid(M)
x1=layers.Multiply()([x1, M])
x1 = layers.Add()([x1, inputs]) #features to next stage
return x1, pred
class AddFreqEncoding(layers.Layer):
'''
[B, T, F, 2] => [B, T, F, 12]
Generates frequency positional embeddings and concatenates them as 10 extra channels
This function is optimized for F=1025
'''
def __init__(self, f_dim):
super(AddFreqEncoding, self).__init__()
pi = tf.constant(m.pi)
pi=tf.cast(pi,'float32')
self.f_dim=f_dim #f_dim is fixed
n=tf.cast(tf.range(f_dim)/(f_dim-1),'float32')
coss=tf.math.cos(pi*n)
f_channel = tf.expand_dims(coss, -1) #(1025,1)
self.fembeddings= f_channel
for k in range(1,10):
coss=tf.math.cos(2**k*pi*n)
f_channel = tf.expand_dims(coss, -1) #(1025,1)
self.fembeddings=tf.concat([self.fembeddings,f_channel],axis=-1) #(1025,10)
def call(self, input_tensor):
batch_size_tensor = tf.shape(input_tensor)[0] # get batch size
time_dim = tf.shape(input_tensor)[1] # get time dimension
fembeddings_2 = tf.broadcast_to(self.fembeddings, [batch_size_tensor, time_dim, self.f_dim, 10])
return tf.concat([input_tensor,fembeddings_2],axis=-1) #(batch,427,1025,12)
def get_paddings(K):
return tf.constant([[0,0],[K[0]//2, K[0]//2 -(1- K[0]%2) ], [ K[1]//2, K[1]//2 -(1- K[1]%2) ],[0,0]])
class Decoder(layers.Layer):
'''
[B, T, F, N] , skip connections => [B, T, F, N]
Decoder side of the U-Net subnetwork.
'''
def __init__(self, Ns, Ss, unet_args):
super(Decoder, self).__init__()
self.Ns=Ns
self.Ss=Ss
self.activation=unet_args.activation
self.depth=unet_args.depth
ksize=(3,3)
self.paddings_3=get_paddings(ksize)
self.conv2d_3=layers.Conv2D(filters=self.Ns[self.depth],
kernel_size=ksize,
kernel_initializer=TruncatedNormal(),
strides=1,
padding='VALID',
activation=self.activation)
self.cropadd=CropAddBlock()
self.dblocks=[]
for i in range(self.depth):
self.dblocks.append(D_Block(layer_idx=i,N=self.Ns[i], S=self.Ss[i], activation=self.activation,num_tfc=unet_args.num_tfc))
def call(self,inputs, contracting_layers):
x=inputs
for i in range(self.depth,0,-1):
x=self.dblocks[i-1](x, contracting_layers[i-1])
return x
class Encoder(tf.keras.Model):
'''
[B, T, F, N] => skip connections , [B, T, F, N_4]
Encoder side of the U-Net subnetwork.
'''
def __init__(self, Ns, Ss, unet_args):
super(Encoder, self).__init__()
self.Ns=Ns
self.Ss=Ss
self.activation=unet_args.activation
self.depth=unet_args.depth
self.contracting_layers = {}
self.eblocks=[]
for i in range(self.depth):
self.eblocks.append(E_Block(layer_idx=i,N0=self.Ns[i],N=self.Ns[i+1],S=self.Ss[i], activation=self.activation , num_tfc=unet_args.num_tfc))
self.i_block=I_Block(self.Ns[self.depth],self.activation,unet_args.num_tfc)
def call(self, inputs):
x=inputs
for i in range(self.depth):
x, x_contract=self.eblocks[i](x)
self.contracting_layers[i] = x_contract #if remove 0, correct this
x=self.i_block(x)
return x, self.contracting_layers
class MultiStage_denoise(tf.keras.Model):
def __init__(self, unet_args=None):
super(MultiStage_denoise, self).__init__()
self.activation=unet_args.activation
self.depth=unet_args.depth
if unet_args.use_fencoding:
self.freq_encoding=AddFreqEncoding(unet_args.f_dim)
self.use_sam=unet_args.use_SAM
self.use_fencoding=unet_args.use_fencoding
self.num_stages=unet_args.num_stages
#Encoder
self.Ns= [32,64,64,128,128,256,512]
self.Ss= [(2,2),(2,2),(2,2),(2,2),(2,2),(2,2)]
#initial feature extractor
ksize=(7,7)
self.paddings_1=get_paddings(ksize)
self.conv2d_1 = layers.Conv2D(filters=self.Ns[0],
kernel_size=ksize,
kernel_initializer=TruncatedNormal(),
strides=1,
padding='VALID',
activation=self.activation)
self.encoder_s1=Encoder(self.Ns, self.Ss, unet_args)
self.decoder_s1=Decoder(self.Ns, self.Ss, unet_args)
self.cropconcat = CropConcatBlock()
self.cropadd = CropAddBlock()
self.finalblock=FinalBlock()
if self.num_stages>1:
self.sam_1=SAM(self.Ns[0])
#initial feature extractor
ksize=(7,7)
self.paddings_2=get_paddings(ksize)
self.conv2d_2 = layers.Conv2D(filters=self.Ns[0],
kernel_size=ksize,
kernel_initializer=TruncatedNormal(),
strides=1,
padding='VALID',
activation=self.activation)
self.encoder_s2=Encoder(self.Ns, self.Ss, unet_args)
self.decoder_s2=Decoder(self.Ns, self.Ss, unet_args)
@tf.function()
def call(self, inputs):
if self.use_fencoding:
x_w_freq=self.freq_encoding(inputs) #None, None, 1025, 12
else:
x_w_freq=inputs
#intitial feature extractor
x=tf.pad(x_w_freq, self.paddings_1, mode='SYMMETRIC')
x=self.conv2d_1(x) #None, None, 1025, 32
x, contracting_layers_s1= self.encoder_s1(x)
#decoder
feats_s1 =self.decoder_s1(x, contracting_layers_s1) #None, None, 1025, 32 features
if self.num_stages>1:
#SAM module
Fout, pred_stage_1=self.sam_1(feats_s1,inputs)
#intitial feature extractor
x=tf.pad(x_w_freq, self.paddings_2, mode='SYMMETRIC')
x=self.conv2d_2(x)
if self.use_sam:
x = tf.concat([x, Fout], axis=-1)
else:
x = tf.concat([x,feats_s1], axis=-1)
x, contracting_layers_s2= self.encoder_s2(x)
feats_s2=self.decoder_s2(x, contracting_layers_s2) #None, None, 1025, 32 features
#consider implementing a third stage?
pred_stage_2=self.finalblock(feats_s2)
return pred_stage_2, pred_stage_1
else:
pred_stage_1=self.finalblock(feats_s1)
return pred_stage_1
class I_Block(layers.Layer):
'''
[B, T, F, N] => [B, T, F, N]
Intermediate block:
Basically, a densenet block with a residual connection
'''
def __init__(self,N,activation, num_tfc, **kwargs):
super(I_Block, self).__init__(**kwargs)
ksize=(3,3)
self.tfc=DenseBlock(num_tfc,N,ksize, activation)
self.conv2d_res= layers.Conv2D(filters=N,
kernel_size=(1,1),
kernel_initializer=TruncatedNormal(),
strides=1,
padding='VALID')
def call(self,inputs):
x=self.tfc(inputs)
inputs_proj=self.conv2d_res(inputs)
return layers.Add()([x,inputs_proj])
class E_Block(layers.Layer):
def __init__(self, layer_idx,N0, N, S,activation, num_tfc, **kwargs):
super(E_Block, self).__init__(**kwargs)
self.layer_idx=layer_idx
self.N0=N0
self.N=N
self.S=S
self.activation=activation
self.i_block=I_Block(N0,activation,num_tfc)
ksize=(S[0]+2,S[1]+2)
self.paddings_2=get_paddings(ksize)
self.conv2d_2 = layers.Conv2D(filters=N,
kernel_size=(S[0]+2,S[1]+2),
kernel_initializer=TruncatedNormal(),
strides=S,
padding='VALID',
activation=self.activation)
def call(self, inputs, training=None, **kwargs):
x=self.i_block(inputs)
x_down=tf.pad(x, self.paddings_2, mode='SYMMETRIC')
x_down = self.conv2d_2(x_down)
return x_down, x
def get_config(self):
return dict(layer_idx=self.layer_idx,
N=self.N,
S=self.S,
**super(E_Block, self).get_config()
)
class D_Block(layers.Layer):
def __init__(self, layer_idx, N, S,activation, num_tfc, **kwargs):
super(D_Block, self).__init__(**kwargs)
self.layer_idx=layer_idx
self.N=N
self.S=S
self.activation=activation
ksize=(S[0]+2, S[1]+2)
self.paddings_1=get_paddings(ksize)
self.tconv_1= layers.Conv2DTranspose(filters=N,
kernel_size=(S[0]+2, S[1]+2),
kernel_initializer=TruncatedNormal(),
strides=S,
activation=self.activation,
padding='VALID')
self.upsampling = layers.UpSampling2D(size=S, interpolation='nearest')
self.projection = layers.Conv2D(filters=N,
kernel_size=(1,1),
kernel_initializer=TruncatedNormal(),
strides=1,
activation=self.activation,
padding='VALID')
self.cropadd=CropAddBlock()
self.cropconcat=CropConcatBlock()
self.i_block=I_Block(N,activation,num_tfc)
def call(self, inputs, bridge, previous_encoder=None, previous_decoder=None,**kwargs):
x = inputs
x=tf.pad(x, self.paddings_1, mode='SYMMETRIC')
x = self.tconv_1(inputs)
x2= self.upsampling(inputs)
if x2.shape[-1]!=x.shape[-1]:
x2= self.projection(x2)
x= self.cropadd(x,x2)
x=self.cropconcat(x,bridge)
x=self.i_block(x)
return x
def get_config(self):
return dict(layer_idx=self.layer_idx,
N=self.N,
S=self.S,
**super(D_Block, self).get_config()
)
class CropAddBlock(layers.Layer):
def call(self,down_layer, x, **kwargs):
x1_shape = tf.shape(down_layer)
x2_shape = tf.shape(x)
height_diff = (x1_shape[1] - x2_shape[1]) // 2
width_diff = (x1_shape[2] - x2_shape[2]) // 2
down_layer_cropped = down_layer[:,
height_diff: (x2_shape[1] + height_diff),
width_diff: (x2_shape[2] + width_diff),
:]
x = layers.Add()([down_layer_cropped, x])
return x
class CropConcatBlock(layers.Layer):
def call(self, down_layer, x, **kwargs):
x1_shape = tf.shape(down_layer)
x2_shape = tf.shape(x)
height_diff = (x1_shape[1] - x2_shape[1]) // 2
width_diff = (x1_shape[2] - x2_shape[2]) // 2
down_layer_cropped = down_layer[:,
height_diff: (x2_shape[1] + height_diff),
width_diff: (x2_shape[2] + width_diff),
:]
x = tf.concat([down_layer_cropped, x], axis=-1)
return x