@@ -64,8 +64,9 @@ def do_POST(self):
64
64
65
65
# Load the files
66
66
mount_slug = f"{ model_ref ['ModelSlug' ]} /{ model_ref ['Framework' ]} /{ model_ref ['InstanceSlug' ]} /{ model_ref ['VersionNumber' ]} "
67
- os .makedirs (os .path .dirname (os .path .join (MOUNT_PATH , mount_slug )))
68
- os .symlink ('/input/tests/data/saved_model/' , os .path .join (MOUNT_PATH , mount_slug ), target_is_directory = True )
67
+ model_path = os .path .join (MOUNT_PATH , mount_slug )
68
+ os .makedirs (os .path .dirname (model_path ))
69
+ os .symlink ('/input/tests/data/saved_model/' , model_path , target_is_directory = True )
69
70
70
71
# Return the response
71
72
self .wfile .write (bytes (json .dumps ({
@@ -79,17 +80,27 @@ def do_POST(self):
79
80
self .wfile .write (bytes (f"Unhandled path: { self .path } " , "utf-8" ))
80
81
81
82
class TestKaggleModuleResolver (unittest .TestCase ):
82
- def test_kaggle_resolver_long_url_succeeds (self ):
83
+ def test_kaggle_resolver_long_url_succeeds (self ):
84
+ model_url = "https://kaggle.com/models/foo/foomodule/frameworks/TensorFlow2/variations/barvar/versions/2"
83
85
with create_test_server (KaggleJwtHandler ) as addr :
84
86
test_inputs = tf .ones ([1 ,4 ])
85
- layer = hub .KerasLayer ("https://kaggle.com/models/foo/foomodule/frameworks/TensorFlow2/variations/barvar/versions/2" )
87
+ layer = hub .KerasLayer (model_url )
86
88
self .assertEqual ([1 , 1 ], layer (test_inputs ).shape )
89
+ # Delete the files that were created in KaggleJwtHandler's do_POST method
90
+ model_path = os .path .join (MOUNT_PATH , "foomodule/tensorflow2/barvar/2" )
91
+ os .unlink (model_path )
92
+ os .rmdir (os .path .dirname (model_path ))
87
93
88
- def test_kaggle_resolver_short_url_succeeds (self ):
94
+ def test_kaggle_resolver_short_url_succeeds (self ):
95
+ model_url = "https://kaggle.com/models/foo/foomodule/TensorFlow2/barvar/2"
89
96
with create_test_server (KaggleJwtHandler ) as addr :
90
97
test_inputs = tf .ones ([1 ,4 ])
91
- layer = hub .KerasLayer ("https://kaggle.com/models/bar/barmodule/pyTorch/barvar/1" )
98
+ layer = hub .KerasLayer (model_url )
92
99
self .assertEqual ([1 , 1 ], layer (test_inputs ).shape )
100
+ # Delete the files that were created in KaggleJwtHandler's do_POST method
101
+ model_path = os .path .join (MOUNT_PATH , "foomodule/tensorflow2/barvar/2" )
102
+ os .unlink (model_path )
103
+ os .rmdir (os .path .dirname (model_path ))
93
104
94
105
def test_kaggle_resolver_not_attached_throws (self ):
95
106
with create_test_server (KaggleJwtHandler ) as addr :
0 commit comments