Skip to content

Commit 091e11c

Browse files
authored
Update train.py
1 parent b90cb20 commit 091e11c

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

train.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ def train():
2323

2424
for epoch in range(flags.n_epoch):
2525
for step, batch_images in enumerate(images):
26+
if batch_images.shape[0] != flags.batch_size: # if the remaining data in this epoch < batch_size
27+
break
2628
step_time = time.time()
2729
with tf.GradientTape(persistent=True) as tape:
2830
# z = tf.distributions.Normal(0., 1.).sample([flags.batch_size, flags.z_dim]) #tf.placeholder(tf.float32, [None, z_dim], name='z_noise')

0 commit comments

Comments
 (0)