We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent b90cb20 commit 091e11cCopy full SHA for 091e11c
train.py
@@ -23,6 +23,8 @@ def train():
23
24
for epoch in range(flags.n_epoch):
25
for step, batch_images in enumerate(images):
26
+ if batch_images.shape[0] != flags.batch_size: # if the remaining data in this epoch < batch_size
27
+ break
28
step_time = time.time()
29
with tf.GradientTape(persistent=True) as tape:
30
# z = tf.distributions.Normal(0., 1.).sample([flags.batch_size, flags.z_dim]) #tf.placeholder(tf.float32, [None, z_dim], name='z_noise')
0 commit comments