1+ import os
2+ import tarfile
13import zipfile
24from pathlib import Path
35
46import numpy as np
57from tqdm import tqdm
68
7- from mnist_data_downloader import download_data
9+ from mnist_data_downloader import download_mnist_data
10+ from multi30k_data_downloader import download_multi30k_data
811
912
1013def prepare_mnist_data (data ):
@@ -24,39 +27,39 @@ def load_mnist(path="datasets/mnist/"):
2427 train_url = "https://pjreddie.com/media/files/mnist_train.csv"
2528 test_url = "https://pjreddie.com/media/files/mnist_test.csv"
2629
27- download_data (train_url , path + "mnist_train.csv" )
28- download_data (test_url , path + "mnist_test.csv" )
30+ download_mnist_data (train_url , path + "mnist_train.csv" )
31+ download_mnist_data (test_url , path + "mnist_test.csv" )
2932
30- training_data = Path (path ).joinpath ("mnist_train.csv" ).open ("r" ).readlines ()
33+ train_data = Path (path ).joinpath ("mnist_train.csv" ).open ("r" ).readlines ()
3134 test_data = Path (path ).joinpath ("mnist_test.csv" ).open ("r" ).readlines ()
3235
3336
3437 if not (Path (path ) / "mnist_train.npy" ).exists () or not (Path (path ) / "mnist_test.npy" ).exists ():
35- training_inputs , training_targets = prepare_mnist_data (training_data )
36- training_inputs = np .asfarray (training_inputs )
38+ train_inputs , train_targets = prepare_mnist_data (train_data )
39+ train_inputs = np .asfarray (train_inputs )
3740
3841 test_inputs , test_targets = prepare_mnist_data (test_data )
3942 test_inputs = np .asfarray (test_inputs )
4043
41- np .save (path + "mnist_train.npy" , training_inputs )
44+ np .save (path + "mnist_train.npy" , train_inputs )
4245 np .save (path + "mnist_test.npy" , test_inputs )
4346
44- np .save (path + "mnist_train_targets.npy" , training_targets )
47+ np .save (path + "mnist_train_targets.npy" , train_targets )
4548 np .save (path + "mnist_test_targets.npy" , test_targets )
4649 else :
47- training_inputs = np .load (path + "mnist_train.npy" )
50+ train_inputs = np .load (path + "mnist_train.npy" )
4851 test_inputs = np .load (path + "mnist_test.npy" )
4952
50- training_targets = np .load (path + "mnist_train_targets.npy" )
53+ train_targets = np .load (path + "mnist_train_targets.npy" )
5154 test_targets = np .load (path + "mnist_test_targets.npy" )
5255
53- training_dataset = training_inputs
56+ train_dataset = train_inputs
5457 test_dataset = test_inputs
5558
56- return training_dataset , test_dataset , training_targets , test_targets
59+ return train_dataset , test_dataset , train_targets , test_targets
60+
5761
5862
59- import os
6063
6164
6265def prepare_utkface_data (path , image_size = (3 , 32 , 32 )):
@@ -69,16 +72,16 @@ def prepare_utkface_data(path, image_size = (3, 32, 32)):
6972 images = os .listdir (path )
7073 random .shuffle (images )
7174
72- training_inputs = []
75+ train_inputs = []
7376 for image in tqdm (images , desc = 'preparing data' ):
7477 image = Image .open (path + "/" + image )
7578 image = image .resize ((image_size [1 ], image_size [2 ]))
7679 image = np .asarray (image )
7780 image = image .transpose (2 , 0 , 1 )
7881 image = image / 127.5 - 1
79- training_inputs .append (image )
82+ train_inputs .append (image )
8083
81- return np .array (training_inputs )
84+ return np .array (train_inputs )
8285
8386
8487def load_utkface (path = "datasets/utkface/" , image_size = (3 , 32 , 32 )):
@@ -92,9 +95,70 @@ def load_utkface(path="datasets/utkface/", image_size=(3, 32, 32)):
9295
9396 save_path = path / 'UTKFace.npy'
9497 if not save_path .exists ():
95- training_inputs = prepare_utkface_data (path / 'UTKFace' , image_size )
96- np .save (save_path , training_inputs )
98+ train_inputs = prepare_utkface_data (path / 'UTKFace' , image_size )
99+ np .save (save_path , train_inputs )
97100 else :
98- training_inputs = np .load (save_path )
101+ train_inputs = np .load (save_path )
102+
103+ return train_inputs
104+
105+
106+
107+ def load_multi30k (path = "datasets/multi30k/" ):
108+ #References: https://pytorch.org/text/stable/_modules/torchtext/datasets/multi30k.html
109+ urls = {
110+ "train" : r"https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/training.tar.gz" ,
111+ "valid" : r"https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/validation.tar.gz" ,
112+ "test" : r"https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/mmt16_task1_test.tar.gz" ,
113+ }
114+
115+ filenames = ["mmt16_task1_test.tar.gz" , "training.tar.gz" , "validation.tar.gz" ]
116+
117+ path = Path (path )
118+ if not path .exists ():
119+ path .mkdir (parents = True )
120+
121+ download_multi30k_data (urls .values (), path , filenames )
122+
123+ for filename in filenames :
124+ tar = tarfile .open (Path (path ) / filename )
125+ tar .extractall (path )
126+ tar .close ()
127+
128+ print (f'Extracted { filename } ' )
129+
130+
131+ ret = []
132+ filenames = ["train" , "val" , "test" ]
133+
134+ for filename in filenames :
135+
136+ examples = []
137+
138+ en_path = os .path .join (path , filename + '.en' )
139+ de_path = os .path .join (path , filename + '.de' )
140+
141+ en_file = [l .strip () for l in open (en_path , 'r' , encoding = 'utf-8' )]
142+ de_file = [l .strip () for l in open (de_path , 'r' , encoding = 'utf-8' )]
143+
144+ assert len (en_file ) == len (de_file )
145+
146+ for i in range (len (en_file )):
147+ if en_file [i ] == '' or de_file [i ] == '' :
148+ continue
149+ en_seq , de_seq = en_file [i ], de_file [i ]
150+
151+ examples .append ({'en' : en_seq , 'de' : de_seq })
152+
153+ ret .append (examples )
154+
155+ train_dataset , valid_dataset , test_dataset = ret
156+ return train_dataset , valid_dataset , test_dataset
157+
158+
159+
160+
161+
162+
99163
100- return training_inputs
164+
0 commit comments