@@ -141,54 +141,86 @@ def create_pre_transform(size):
141141 elif opt .model == 3 :
142142 if opt .mode == 0 :
143143 model = torch .nn .Sequential (torch .nn .Linear (int (1024 * dim ), 512 ),
144+ torch .nn .ReLU ()
144145 torch .nn .Linear (512 , 512 ),
146+ torch .nn .ReLU ()
145147 torch .nn .Linear (512 , config ['model' ]['num-classes' ]))
146148 else :
147149 model = torch .nn .Sequential (torch .nn .Linear (int (2048 * dim ), 512 ),
150+ torch .nn .ReLU ()
148151 torch .nn .Linear (512 , 512 ),
152+ torch .nn .ReLU ()
149153 torch .nn .Linear (512 , config ['model' ]['num-classes' ]))
150154 elif opt .model == 4 :
151155 if opt .mode == 0 :
152156 model = torch .nn .Sequential (torch .nn .Linear (int (1024 * dim ), 512 ),
157+ torch .nn .ReLU ()
153158 torch .nn .Linear (512 , 512 ),
159+ torch .nn .ReLU ()
154160 torch .nn .Linear (512 , 256 ),
161+ torch .nn .ReLU ()
155162 torch .nn .Linear (256 , 256 ),
163+ torch .nn .ReLU ()
156164 torch .nn .Linear (256 , config ['model' ]['num-classes' ]))
157165 else :
158166 model = torch .nn .Sequential (torch .nn .Linear (int (2048 * dim ), 512 ),
167+ torch .nn .ReLU ()
159168 torch .nn .Linear (512 , 512 ),
169+ torch .nn .ReLU ()
160170 torch .nn .Linear (512 , 256 ),
171+ torch .nn .ReLU ()
161172 torch .nn .Linear (256 , 256 ),
173+ torch .nn .ReLU ()
162174 torch .nn .Linear (256 , config ['model' ]['num-classes' ]))
163175 elif opt .model == 5 :
164176 if opt .mode == 0 :
165177 model = torch .nn .Sequential (torch .nn .Linear (int (1024 * dim ), 4096 ),
178+ torch .nn .ReLU ()
166179 torch .nn .Linear (4096 , 4096 ),
180+ torch .nn .ReLU ()
167181 torch .nn .Linear (4096 , 1024 ),
182+ torch .nn .ReLU ()
168183 torch .nn .Linear (1024 , config ['model' ]['num-classes' ]))
169184 else :
170185 model = torch .nn .Sequential (torch .nn .Linear (int (2048 * dim ), 4096 ),
186+ torch .nn .ReLU ()
171187 torch .nn .Linear (4096 , 4096 ),
188+ torch .nn .ReLU ()
172189 torch .nn .Linear (4096 , 1024 ),
190+ torch .nn .ReLU ()
173191 torch .nn .Linear (1024 , config ['model' ]['num-classes' ]))
174192 elif opt .model == 6 :
175193 if opt .mode == 0 :
176194 model = torch .nn .Sequential (torch .nn .Linear (int (1024 * dim ), 8192 ),
195+ torch .nn .ReLU ()
177196 torch .nn .Linear (8192 , 4096 ),
197+ torch .nn .ReLU ()
178198 torch .nn .Linear (4096 , 4096 ),
199+ torch .nn .ReLU ()
179200 torch .nn .Linear (4096 , 2048 ),
201+ torch .nn .ReLU ()
180202 torch .nn .Linear (2048 , 2048 ),
203+ torch .nn .ReLU ()
181204 torch .nn .Linear (2048 , 1024 ),
205+ torch .nn .ReLU ()
182206 torch .nn .Linear (1024 , 1024 ),
207+ torch .nn .ReLU ()
183208 torch .nn .Linear (1024 , config ['model' ]['num-classes' ]))
184209 else :
185210 model = torch .nn .Sequential (torch .nn .Linear (int (2048 * dim ), 8192 ),
211+ torch .nn .ReLU ()
186212 torch .nn .Linear (8192 , 4096 ),
213+ torch .nn .ReLU ()
187214 torch .nn .Linear (4096 , 4096 ),
215+ torch .nn .ReLU ()
188216 torch .nn .Linear (4096 , 2048 ),
217+ torch .nn .ReLU ()
189218 torch .nn .Linear (2048 , 2048 ),
219+ torch .nn .ReLU ()
190220 torch .nn .Linear (2048 , 1024 ),
221+ torch .nn .ReLU ()
191222 torch .nn .Linear (1024 , 1024 ),
223+ torch .nn .ReLU ()
192224 torch .nn .Linear (1024 , config ['model' ]['num-classes' ]))
193225 elif opt .model == 7 :
194226 model = timm .create_model ('xception' , pretrained = True )
0 commit comments