import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import tensorflow as tf
import keras
from keras import layers
[docs]
class Sampling(layers.Layer):
"""Sampling layer for VAE.
This layer takes the mean and log variance of the latent space distribution
as input and samples from the distribution to produce a latent vector.
"""
[docs]
def call(self, inputs):
"""Samples from the latent space distribution.
Args:
inputs (tuple): Tuple containing mean and log variance of the latent space distribution.
Returns:
tensor: Sampled latent vector.
"""
z_mean, z_log_var = inputs
batch = tf.shape(z_mean)[0]
dim = tf.shape(z_mean)[1]
epsilon = tf.random.normal(shape=(batch, dim))
return z_mean + tf.exp(0.5 * z_log_var) * epsilon
[docs]
@tf.keras.utils.register_keras_serializable()
class VAE(keras.Model):
"""Variational Autoencoder class."""
def __init__(self, encoder, decoder, **kwargs):
"""Constructor for VAE class.
Args:
encoder (keras.Model): Encoder model.
decoder (keras.Model): Decoder model.
**kwargs: Additional arguments to be passed.
"""
super().__init__(**kwargs)
self.encoder = encoder
self.decoder = decoder
self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
self.reconstruction_loss_tracker = keras.metrics.Mean(
name="reconstruction_loss"
)
self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")
[docs]
def get_config(self):
"""Returns the configuration of the VAE model."""
return {"encoder": self.encoder, "decoder": self.decoder}
@property
def metrics(self):
"""Returns the list of metrics."""
return [
self.total_loss_tracker,
self.reconstruction_loss_tracker,
self.kl_loss_tracker,
]
[docs]
def train_step(self, data):
"""Performs a single training step.
Args:
data: Input data.
Returns:
dict: Dictionary containing loss values.
"""
with tf.GradientTape() as tape:
z_mean, z_log_var, z = self.encoder(data)
reconstruction = self.decoder(z)
reconstruction_loss = tf.reduce_mean(
tf.reduce_sum(
keras.losses.binary_crossentropy(data, reconstruction),
axis=(1, 2),
)
)
kl_loss = tf.reduce_sum(tf.square(tf.exp(z_log_var)) + tf.square(z_mean) - z_log_var - 0.5, axis=1)
total_loss = reconstruction_loss + kl_loss/2
grads = tape.gradient(total_loss, self.trainable_weights)
self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
self.total_loss_tracker.update_state(total_loss)
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
self.kl_loss_tracker.update_state(kl_loss)
return {
"loss": self.total_loss_tracker.result(),
"reconstruction_loss": self.reconstruction_loss_tracker.result(),
"kl_loss": self.kl_loss_tracker.result(),
}
[docs]
def call(self, x):
"""Call method for VAE model.
Args:
x: Input data.
Returns:
tensor: Decoded output.
"""
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return decoded
[docs]
def encoder_create(latent_dim=1024):
"""Creates an encoder model.
Args:
latent_dim (int): Dimensionality of the latent space.
Returns:
keras.Model: Encoder model.
"""
encoder_inputs = keras.Input(shape=(64, 64, 3))
x = layers.Conv2D(32, 4, activation="relu", strides=2, padding="same")(encoder_inputs)
x = layers.Conv2D(32, 3, activation="relu", padding="same")(x)
x = layers.Conv2D(64, 4, activation="relu", strides=2, padding="same")(x)
x = layers.Conv2D(64, 3, activation="relu", padding="same")(x)
x = layers.Conv2D(128, 4, activation="relu", strides=2, padding="same")(x)
x = layers.Conv2D(128, 3, activation="relu", padding="same")(x)
x = layers.Conv2D(256, 4, activation="relu", strides=2, padding="same")(x)
x = layers.Conv2D(256, 3, activation="relu", padding="same")(x)
x = layers.Flatten()(x)
x = layers.Dense(latent_dim, activation="relu")(x)
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z = Sampling()([z_mean, z_log_var])
encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name="encoder")
return encoder
[docs]
def decoder_create(latent_dim=1024):
"""Creates a decoder model.
Args:
latent_dim (int): Dimensionality of the latent space.
Returns:
keras.Model: Decoder model.
"""
latent_inputs = keras.Input(shape=(latent_dim,))
x = layers.Dense(4 * 4 * 256, activation="relu")(latent_inputs)
x = layers.Reshape((4, 4, 256))(x)
x = layers.Conv2DTranspose(256, 4, activation="relu", strides=2, padding="same")(x)
x = layers.Conv2DTranspose(128, 4, activation="relu",strides=2, padding="same")(x)
x = layers.Conv2DTranspose(128, 3, activation="relu", padding="same")(x)
x = layers.Conv2DTranspose(64, 4, activation="relu", strides=2, padding="same")(x)
x = layers.Conv2DTranspose(64, 3, activation="relu", padding="same")(x)
x = layers.Conv2DTranspose(32, 4, activation="relu", strides=2, padding="same")(x)
x = layers.Conv2DTranspose(32, 3, activation="relu", padding="same")(x)
decoder_outputs = layers.Conv2DTranspose(3, 3, activation="sigmoid", padding="same")(x)
decoder = keras.Model(latent_inputs, decoder_outputs, name="decoder")
return decoder
if __name__ == '__main__':
import matplotlib.pyplot as plt
import numpy as np
import pooch
POOCH = pooch.create(
# Use the default cache folder for the OS
path=os.path.join(os.path.dirname(__file__),'..','saved_model'),
# The remote data is on Github
base_url="https://zenodo.org/records/10957695/files/felon_finder_vae.weights.h5?download=1",
# The registry specifies the files that can be fetched
registry={
# The registry is a dict with file names and their SHA256 hashes
"felon_finder_vae.weights.h5": "38595de4ea78d8a1ba21f7b1a3a3b3c7c1c4c862bf3b290c5cd9d2ebaccc16fa",
},
)
vae_weights = POOCH.fetch("felon_finder_vae.weights.h5")
vae = VAE(encoder_create(), decoder_create())
vae.load_weights(vae_weights)
### upload images from celebA for model testing
data_dir=os.path.join(os.path.abspath(".."), 'img', 'faces')
batch_size=128
img_height= 64
img_width= 64
train_ds = tf.keras.utils.image_dataset_from_directory( data_dir,validation_split=0.2,subset="training",seed=123,image_size=(img_height, img_width),batch_size=batch_size, labels=None)
x_train_list = list(train_ds)
x_train=x_train_list[0]
x_train=x_train.numpy()
for element in x_train_list:
element=element.numpy()
x_train=np.concatenate([x_train, element], axis=0)
x_train = x_train.astype("float32") / 255
## reconstruct the faces
z_mean,_,x_encoded = vae.encoder(x_train[:20])
x_decoded = vae.decoder(x_encoded)
plt.figure(figsize=(20, 4))
for i in range(10):
# display original
ax = plt.subplot(2, 10, i + 1)
plt.title("original")
plt.imshow(x_train[i])
plt.axis("off")
# display reconstruction
bx = plt.subplot(2, 10, i + 10 + 1)
plt.title("reconstructed")
plt.imshow(x_decoded[i])
bx.get_xaxis().set_visible(False)
bx.get_yaxis().set_visible(False)
plt.show()