Skip to content

Commit bfd12ce

Browse files
committed
Multiple changes based on gemini code-review
1 parent 7544abb commit bfd12ce

File tree

1 file changed

+15
-14
lines changed

1 file changed

+15
-14
lines changed

examples/vision/nl_image_search.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@
7272

7373
root_dir = "datasets"
7474
annotations_dir = os.path.join(root_dir, "annotations")
75-
images_dir = os.path.join(root_dir, "train2014", "train2014")
75+
images_dir = os.path.join(root_dir, "train2014")
7676
tfrecords_dir = os.path.join(root_dir, "tfrecords")
7777
annotation_file = os.path.join(
7878
annotations_dir, "annotations", "captions_train2014.json"
@@ -91,7 +91,7 @@
9191
image_zip = wget.download("http://images.cocodataset.org/zips/train2014.zip")
9292
print("Downloaded the images.\nunzipping")
9393
with zipfile.ZipFile(image_zip, "r") as zip_ref:
94-
zip_ref.extractall(images_dir)
94+
zip_ref.extractall(root_dir)
9595

9696
print("\nDataset is downloaded and extracted successfully.")
9797

@@ -104,13 +104,12 @@
104104
image_path = images_dir + "/COCO_train2014_" + "%012d.jpg" % (element["image_id"])
105105
image_path_to_caption[image_path].append(caption)
106106

107-
images = glob.glob("datasets/train2014/*.jpg")
107+
images = glob.glob(os.path.join(images_dir, "*.jpg"))
108108
image_paths = list(image_path_to_caption.keys())
109109
if len(images) != len(image_paths):
110-
print(
111-
f"Not all images extracted correctly, expected {len(image_paths)} images, found
112-
{len(images)} images"
113-
)
110+
print(f"Not all images extracted correctly,\n",
111+
f"expected {len(image_paths)} images,\n",
112+
f"found: {len(images)} images")
114113
print(f"Number of images: {len(image_paths)}")
115114

116115
"""
@@ -188,7 +187,8 @@ def write_data(image_paths, num_files, files_prefix):
188187
else:
189188
print(f"{num_train_files} tfrecord files found.")
190189
print(f"{num_train_files*images_per_file} training examples in the tfrecord files.")
191-
train_example_count = 60000
190+
train_example_count = train_size * captions_per_image
191+
192192

193193
found_files = glob.glob(os.path.join(root_dir, "tfrecords", "valid-*.tfrecord"))
194194
if len(found_files) != num_valid_files:
@@ -199,7 +199,7 @@ def write_data(image_paths, num_files, files_prefix):
199199
else:
200200
print(f"{num_valid_files} tfrecord files found.")
201201
print(f"{num_valid_files*images_per_file} training examples in the tfrecord files.")
202-
valid_example_count = 10000
202+
valid_example_count = valid_size * captions_per_image
203203

204204
"""
205205
### Create a
@@ -441,25 +441,26 @@ def call(self, features, training=False):
441441
return caption_embeddings, image_embeddings
442442

443443
def compute_loss(self, caption_embeddings, image_embeddings):
444-
# logits[i][j] is the dot_similarity(caption_i, image_j).
444+
# similarity between all image and caption embeddings
445445
logits = ops.divide(
446446
ops.einsum("ae,be -> ab", caption_embeddings, image_embeddings),
447447
self.temperature,
448448
)
449449

450-
# images_similarity[i][j] is the dot_similarity(image_i, image_j).
450+
# similarity between all image and image embeddings
451451
images_similarity = ops.einsum(
452452
"ae,be -> ab", image_embeddings, image_embeddings
453453
)
454-
# captions_similarity[i][j] is the dot_similarity(caption_i, caption_j).
454+
455+
# similarity between all caption and caption embeddings
455456
captions_similarity = ops.einsum(
456457
"ae,be -> ab", caption_embeddings, caption_embeddings
457458
)
458-
# targets[i][j] = avarage dot_similarity(caption_i, caption_j) and
459-
dot_similarity(image_i, image_j).
459+
460460
targets = keras.activations.softmax(
461461
(captions_similarity + images_similarity) / (2 * self.temperature)
462462
)
463+
463464
# Compute the loss for the captions using cross-entropy
464465
captions_loss = keras.losses.categorical_crossentropy(
465466
y_true=targets, y_pred=logits, from_logits=True

0 commit comments

Comments
 (0)