33
44layers = tf .keras .layers
55
6-
7- @as_KerasModel
86class GAN (tf .keras .Model ):
7+ """Generative Adversarial Network (GAN) model.
8+
9+ Parameters:
10+ discriminator: keras model, optional
11+ The discriminator network.
12+ generator: keras model, optional
13+ The generator network.
14+ latent_dim: int, optional
15+ Dimension of the latent space for random vectors.
16+ """
17+
918 def __init__ (self , discriminator = None , generator = None , latent_dim = 128 ):
1019 super (GAN , self ).__init__ ()
1120
21+ # Initialize discriminator and generator, or use default if not provided
1222 if discriminator is None :
1323 discriminator = self .default_discriminator ()
14-
1524 if generator is None :
1625 generator = self .default_generator ()
1726
@@ -21,9 +30,13 @@ def __init__(self, discriminator=None, generator=None, latent_dim=128):
2130
2231 def compile (self , d_optimizer , g_optimizer , loss_fn ):
2332 super (GAN , self ).compile ()
33+
34+ # Set optimizers and loss function for training
2435 self .d_optimizer = d_optimizer
2536 self .g_optimizer = g_optimizer
2637 self .loss_fn = loss_fn
38+
39+ # Define metrics to track during training
2740 self .d_loss_metric = tf .keras .metrics .Mean (name = "d_loss" )
2841 self .g_loss_metric = tf .keras .metrics .Mean (name = "g_loss" )
2942
@@ -36,17 +49,18 @@ def train_step(self, real_images):
3649 batch_size = tf .shape (real_images )[0 ]
3750 random_latent_vectors = tf .random .normal (shape = (batch_size , self .latent_dim ))
3851
39- # Decode them to fake images
52+ # Generate fake images using the generator
4053 generated_images = self .generator (random_latent_vectors )
4154
42- # Combine them with real images
55+ # Combine real and fake images
4356 combined_images = tf .concat ([generated_images , real_images ], axis = 0 )
4457
45- # Assemble labels discriminating real from fake images
58+ # Create labels for real and fake images
4659 labels = tf .concat (
4760 [tf .ones ((batch_size , 1 )), tf .zeros ((batch_size , 1 ))], axis = 0
4861 )
49- # Add random noise to the labels - important trick!
62+
63+ # Add random noise to labels to improve stability
5064 labels += 0.05 * tf .random .uniform (tf .shape (labels ))
5165
5266 # Train the discriminator
@@ -58,14 +72,13 @@ def train_step(self, real_images):
5872 zip (grads , self .discriminator .trainable_weights )
5973 )
6074
61- # Sample random points in the latent space
75+ # Generate new random latent vectors
6276 random_latent_vectors = tf .random .normal (shape = (batch_size , self .latent_dim ))
6377
64- # Assemble labels that say "all real images"
78+ # Create labels indicating "all real images" for generator training
6579 misleading_labels = tf .zeros ((batch_size , 1 ))
6680
67- # Train the generator (note that we should *not* update the weights
68- # of the discriminator)!
81+ # Train the generator while keeping discriminator weights fixed
6982 with tf .GradientTape () as tape :
7083 predictions = self .discriminator (self .generator (random_latent_vectors ))
7184 g_loss = self .loss_fn (misleading_labels , predictions )
@@ -75,12 +88,19 @@ def train_step(self, real_images):
7588 # Update metrics
7689 self .d_loss_metric .update_state (d_loss )
7790 self .g_loss_metric .update_state (g_loss )
91+
92+ # Return updated loss metrics
7893 return {
7994 "d_loss" : self .d_loss_metric .result (),
8095 "g_loss" : self .g_loss_metric .result (),
8196 }
8297
98+ def call (self , inputs ):
99+ # Run generator
100+ return self .generator (inputs )
101+
83102 def default_generator (self , latent_dim = 128 ):
103+ # Define the default generator architecture
84104 return tf .keras .Sequential (
85105 [
86106 tf .keras .Input (shape = (latent_dim ,)),
@@ -98,6 +118,7 @@ def default_generator(self, latent_dim=128):
98118 )
99119
100120 def default_discriminator (self ):
121+ # Define the default discriminator architecture
101122 return tf .keras .Sequential (
102123 [
103124 tf .keras .Input (shape = (64 , 64 , 3 )),
@@ -112,4 +133,4 @@ def default_discriminator(self):
112133 layers .Dense (1 , activation = "sigmoid" ),
113134 ],
114135 name = "discriminator" ,
115- )
136+ )
0 commit comments