7272
7373root_dir = "datasets"
7474annotations_dir = os .path .join (root_dir , "annotations" )
75- images_dir = os .path .join (root_dir , "train2014" , "train2014" )
75+ images_dir = os .path .join (root_dir , "train2014" )
7676tfrecords_dir = os .path .join (root_dir , "tfrecords" )
7777annotation_file = os .path .join (
7878 annotations_dir , "annotations" , "captions_train2014.json"
9191 image_zip = wget .download ("http://images.cocodataset.org/zips/train2014.zip" )
9292 print ("Downloaded the images.\n unzipping" )
9393 with zipfile .ZipFile (image_zip , "r" ) as zip_ref :
94- zip_ref .extractall (images_dir )
94+ zip_ref .extractall (root_dir )
9595
9696print ("\n Dataset is downloaded and extracted successfully." )
9797
104104 image_path = images_dir + "/COCO_train2014_" + "%012d.jpg" % (element ["image_id" ])
105105 image_path_to_caption [image_path ].append (caption )
106106
107- images = glob .glob ("datasets/train2014/ *.jpg" )
107+ images = glob .glob (os . path . join ( images_dir , " *.jpg" ))
108108image_paths = list (image_path_to_caption .keys ())
109109if len (images ) != len (image_paths ):
110- print (
111- f"Not all images extracted correctly, expected { len (image_paths )} images , found
112- { len (images )} images"
113- )
110+ print (f"Not all images extracted correctly,\n " ,
111+ f"expected { len (image_paths )} images,\n " ,
112+ f"found: { len (images )} images" )
114113print (f"Number of images: { len (image_paths )} " )
115114
116115"""
@@ -188,7 +187,8 @@ def write_data(image_paths, num_files, files_prefix):
188187else :
189188 print (f"{ num_train_files } tfrecord files found." )
190189 print (f"{ num_train_files * images_per_file } training examples in the tfrecord files." )
191- train_example_count = 60000
190+ train_example_count = train_size * captions_per_image
191+
192192
193193found_files = glob .glob (os .path .join (root_dir , "tfrecords" , "valid-*.tfrecord" ))
194194if len (found_files ) != num_valid_files :
@@ -199,7 +199,7 @@ def write_data(image_paths, num_files, files_prefix):
199199else :
200200 print (f"{ num_valid_files } tfrecord files found." )
201201 print (f"{ num_valid_files * images_per_file } training examples in the tfrecord files." )
202- valid_example_count = 10000
202+ valid_example_count = valid_size * captions_per_image
203203
204204"""
205205### Create a
@@ -441,25 +441,26 @@ def call(self, features, training=False):
441441 return caption_embeddings , image_embeddings
442442
443443 def compute_loss (self , caption_embeddings , image_embeddings ):
444- # logits[i][j] is the dot_similarity(caption_i, image_j).
444+ # similarity between all image and caption embeddings
445445 logits = ops .divide (
446446 ops .einsum ("ae,be -> ab" , caption_embeddings , image_embeddings ),
447447 self .temperature ,
448448 )
449449
450- # images_similarity[i][j] is the dot_similarity(image_i, image_j).
450+ # similarity between all image and image embeddings
451451 images_similarity = ops .einsum (
452452 "ae,be -> ab" , image_embeddings , image_embeddings
453453 )
454- # captions_similarity[i][j] is the dot_similarity(caption_i, caption_j).
454+
455+ # similarity between all caption and caption embeddings
455456 captions_similarity = ops .einsum (
456457 "ae,be -> ab" , caption_embeddings , caption_embeddings
457458 )
458- # targets[i][j] = avarage dot_similarity(caption_i, caption_j) and
459- dot_similarity (image_i , image_j ).
459+
460460 targets = keras .activations .softmax (
461461 (captions_similarity + images_similarity ) / (2 * self .temperature )
462462 )
463+
463464 # Compute the loss for the captions using cross-entropy
464465 captions_loss = keras .losses .categorical_crossentropy (
465466 y_true = targets , y_pred = logits , from_logits = True
0 commit comments