Skip to content

Commit 980e3c9

Browse files
Upload last models version
1 parent e1426cc commit 980e3c9

File tree

1 file changed

+32
-0
lines changed

1 file changed

+32
-0
lines changed

train.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,54 +141,86 @@ def create_pre_transform(size):
141141
elif opt.model == 3:
142142
if opt.mode == 0:
143143
model = torch.nn.Sequential(torch.nn.Linear(int(1024*dim), 512),
144+
torch.nn.ReLU()
144145
torch.nn.Linear(512, 512),
146+
torch.nn.ReLU()
145147
torch.nn.Linear(512, config['model']['num-classes']))
146148
else:
147149
model = torch.nn.Sequential(torch.nn.Linear(int(2048*dim), 512),
150+
torch.nn.ReLU()
148151
torch.nn.Linear(512, 512),
152+
torch.nn.ReLU()
149153
torch.nn.Linear(512, config['model']['num-classes']))
150154
elif opt.model == 4:
151155
if opt.mode == 0:
152156
model = torch.nn.Sequential(torch.nn.Linear(int(1024*dim), 512),
157+
torch.nn.ReLU()
153158
torch.nn.Linear(512, 512),
159+
torch.nn.ReLU()
154160
torch.nn.Linear(512, 256),
161+
torch.nn.ReLU()
155162
torch.nn.Linear(256, 256),
163+
torch.nn.ReLU()
156164
torch.nn.Linear(256, config['model']['num-classes']))
157165
else:
158166
model = torch.nn.Sequential(torch.nn.Linear(int(2048*dim), 512),
167+
torch.nn.ReLU()
159168
torch.nn.Linear(512, 512),
169+
torch.nn.ReLU()
160170
torch.nn.Linear(512, 256),
171+
torch.nn.ReLU()
161172
torch.nn.Linear(256, 256),
173+
torch.nn.ReLU()
162174
torch.nn.Linear(256, config['model']['num-classes']))
163175
elif opt.model == 5:
164176
if opt.mode == 0:
165177
model = torch.nn.Sequential(torch.nn.Linear(int(1024*dim), 4096),
178+
torch.nn.ReLU()
166179
torch.nn.Linear(4096, 4096),
180+
torch.nn.ReLU()
167181
torch.nn.Linear(4096, 1024),
182+
torch.nn.ReLU()
168183
torch.nn.Linear(1024, config['model']['num-classes']))
169184
else:
170185
model = torch.nn.Sequential(torch.nn.Linear(int(2048*dim), 4096),
186+
torch.nn.ReLU()
171187
torch.nn.Linear(4096, 4096),
188+
torch.nn.ReLU()
172189
torch.nn.Linear(4096, 1024),
190+
torch.nn.ReLU()
173191
torch.nn.Linear(1024, config['model']['num-classes']))
174192
elif opt.model == 6:
175193
if opt.mode == 0:
176194
model = torch.nn.Sequential(torch.nn.Linear(int(1024*dim), 8192),
195+
torch.nn.ReLU()
177196
torch.nn.Linear(8192, 4096),
197+
torch.nn.ReLU()
178198
torch.nn.Linear(4096, 4096),
199+
torch.nn.ReLU()
179200
torch.nn.Linear(4096, 2048),
201+
torch.nn.ReLU()
180202
torch.nn.Linear(2048, 2048),
203+
torch.nn.ReLU()
181204
torch.nn.Linear(2048, 1024),
205+
torch.nn.ReLU()
182206
torch.nn.Linear(1024, 1024),
207+
torch.nn.ReLU()
183208
torch.nn.Linear(1024, config['model']['num-classes']))
184209
else:
185210
model = torch.nn.Sequential(torch.nn.Linear(int(2048*dim), 8192),
211+
torch.nn.ReLU()
186212
torch.nn.Linear(8192, 4096),
213+
torch.nn.ReLU()
187214
torch.nn.Linear(4096, 4096),
215+
torch.nn.ReLU()
188216
torch.nn.Linear(4096, 2048),
217+
torch.nn.ReLU()
189218
torch.nn.Linear(2048, 2048),
219+
torch.nn.ReLU()
190220
torch.nn.Linear(2048, 1024),
221+
torch.nn.ReLU()
191222
torch.nn.Linear(1024, 1024),
223+
torch.nn.ReLU()
192224
torch.nn.Linear(1024, config['model']['num-classes']))
193225
elif opt.model == 7:
194226
model = timm.create_model('xception', pretrained=True)

0 commit comments

Comments
 (0)