1
1
#!/usr/bin/python
2
2
#
3
- # Copyright 2023 Kaggle Inc
3
+ # Copyright 2024 Kaggle Inc
4
4
#
5
5
# Licensed under the Apache License, Version 2.0 (the "License");
6
6
# you may not use this file except in compliance with the License.
@@ -271,7 +271,7 @@ def __repr__(self):
271
271
272
272
273
273
class KaggleApi (KaggleApi ):
274
- __version__ = '1.5.16 '
274
+ __version__ = '1.6.1 '
275
275
276
276
CONFIG_NAME_PROXY = 'proxy'
277
277
CONFIG_NAME_COMPETITION = 'competition'
@@ -909,7 +909,7 @@ def competition_download_file(self,
909
909
""" download a competition file to a designated location, or use
910
910
a default location
911
911
912
- Paramters
912
+ Parameters
913
913
=========
914
914
competition: the name of the competition
915
915
file_name: the configuration file name
@@ -2880,6 +2880,16 @@ def model_instance_get_cli(self, model_instance, folder=None):
2880
2880
data ['licenseName' ] = mi ['licenseName' ]
2881
2881
data ['fineTunable' ] = mi ['fineTunable' ]
2882
2882
data ['trainingData' ] = mi ['trainingData' ]
2883
+ data ['versionId' ] = mi ['versionId' ]
2884
+ data ['versionNumber' ] = mi ['versionNumber' ]
2885
+ data ['modelInstanceType' ] = mi ['modelInstanceType' ]
2886
+ if mi ['baseModelInstanceInformation' ] is not None :
2887
+ data ['baseModelInstance' ] = '{}/{}/{}/{}' .format (
2888
+ mi ['baseModelInstanceInformation' ]['owner' ]['slug' ],
2889
+ mi ['baseModelInstanceInformation' ]['modelSlug' ],
2890
+ mi ['baseModelInstanceInformation' ]['framework' ],
2891
+ mi ['baseModelInstanceInformation' ]['instanceSlug' ])
2892
+ data ['externalBaseModelUrl' ] = mi ['externalBaseModelUrl' ]
2883
2893
2884
2894
with open (meta_file , 'w' ) as f :
2885
2895
json .dump (data , f , indent = 2 )
@@ -2924,7 +2934,13 @@ def model_instance_initialize(self, folder):
2924
2934
'Apache 2.0' ,
2925
2935
'fineTunable' :
2926
2936
False ,
2927
- 'trainingData' : []
2937
+ 'trainingData' : [],
2938
+ 'modelInstanceType' :
2939
+ 'Unspecified' ,
2940
+ 'baseModelInstanceId' :
2941
+ 0 ,
2942
+ 'externalBaseModelUrl' :
2943
+ ''
2928
2944
}
2929
2945
meta_file = os .path .join (folder , self .MODEL_INSTANCE_METADATA_FILE )
2930
2946
with open (meta_file , 'w' ) as f :
@@ -2964,6 +2980,12 @@ def model_instance_create(self, folder, quiet=False, dir_mode='skip'):
2964
2980
license_name = self .get_or_fail (meta_data , 'licenseName' )
2965
2981
fine_tunable = self .get_or_default (meta_data , 'fineTunable' , False )
2966
2982
training_data = self .get_or_default (meta_data , 'trainingData' , [])
2983
+ model_instance_type = self .get_or_default (
2984
+ meta_data , 'modelInstanceType' , 'Unspecified' )
2985
+ base_model_instance = self .get_or_default (meta_data ,
2986
+ 'baseModelInstance' , '' )
2987
+ external_base_model_url = self .get_or_default (
2988
+ meta_data , 'externalBaseModelUrl' , '' )
2967
2989
2968
2990
# validations
2969
2991
if owner_slug == 'INSERT_OWNER_SLUG_HERE' :
@@ -2997,6 +3019,9 @@ def model_instance_create(self, folder, quiet=False, dir_mode='skip'):
2997
3019
license_name = license_name ,
2998
3020
fine_tunable = fine_tunable ,
2999
3021
training_data = training_data ,
3022
+ model_instance_type = model_instance_type ,
3023
+ base_model_instance = base_model_instance ,
3024
+ external_base_model_url = external_base_model_url ,
3000
3025
files = [])
3001
3026
3002
3027
with ResumableUploadContext () as upload_context :
@@ -3089,6 +3114,12 @@ def model_instance_update(self, folder):
3089
3114
license_name = self .get_or_default (meta_data , 'licenseName' , None )
3090
3115
fine_tunable = self .get_or_default (meta_data , 'fineTunable' , None )
3091
3116
training_data = self .get_or_default (meta_data , 'trainingData' , None )
3117
+ model_instance_type = self .get_or_default (meta_data ,
3118
+ 'modelInstanceType' , None )
3119
+ base_model_instance = self .get_or_default (meta_data ,
3120
+ 'baseModelInstance' , None )
3121
+ external_base_model_url = self .get_or_default (
3122
+ meta_data , 'externalBaseModelUrl' , None )
3092
3123
3093
3124
# validations
3094
3125
if owner_slug == 'INSERT_OWNER_SLUG_HERE' :
@@ -3128,13 +3159,22 @@ def model_instance_update(self, folder):
3128
3159
update_mask ['paths' ].append ('fine_tunable' )
3129
3160
if training_data != None :
3130
3161
update_mask ['paths' ].append ('training_data' )
3162
+ if model_instance_type != None :
3163
+ update_mask ['paths' ].append ('model_instance_type' )
3164
+ if base_model_instance != None :
3165
+ update_mask ['paths' ].append ('base_model_instance' )
3166
+ if external_base_model_url != None :
3167
+ update_mask ['paths' ].append ('external_base_model_url' )
3131
3168
3132
3169
request = ModelInstanceUpdateRequest (
3133
3170
overview = overview ,
3134
3171
usage = usage ,
3135
3172
license_name = license_name ,
3136
3173
fine_tunable = fine_tunable ,
3137
3174
training_data = training_data ,
3175
+ model_instance_type = model_instance_type ,
3176
+ base_model_instance = base_model_instance ,
3177
+ external_base_model_url = external_base_model_url ,
3138
3178
update_mask = update_mask )
3139
3179
result = ModelNewResponse (
3140
3180
self .process_response (
@@ -3283,6 +3323,7 @@ def model_instance_version_download(self,
3283
3323
os .remove (outfile )
3284
3324
except OSError as e :
3285
3325
print ('Could not delete tar file, got %s' % e )
3326
+ return outfile
3286
3327
3287
3328
def model_instance_version_download_cli (self ,
3288
3329
model_instance_version ,
@@ -3301,7 +3342,7 @@ def model_instance_version_download_cli(self,
3301
3342
quiet: suppress verbose output (default is False)
3302
3343
untar: if True, untar files upon download (default is False)
3303
3344
"""
3304
- self .model_instance_version_download (
3345
+ return self .model_instance_version_download (
3305
3346
model_instance_version ,
3306
3347
path = path ,
3307
3348
untar = untar ,
@@ -3528,6 +3569,9 @@ def process_response(self, result):
3528
3569
'Version, please consider updating (server ' +
3529
3570
api_version + ' / client ' + self .__version__ + ')' )
3530
3571
self .already_printed_version_warning = True
3572
+ if isinstance (data ,
3573
+ dict ) and 'code' in data and data ['code' ] != 200 :
3574
+ raise Exception (data ['message' ])
3531
3575
return data
3532
3576
return result
3533
3577
0 commit comments