Skip to content

Commit 360163c

Browse files
committed
Better way to handle test cleanup
1 parent fce7103 commit 360163c

File tree

1 file changed

+17
-6
lines changed

1 file changed

+17
-6
lines changed

tests/test_kaggle_module_resolver.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,9 @@ def do_POST(self):
6464

6565
# Load the files
6666
mount_slug = f"{model_ref['ModelSlug']}/{model_ref['Framework']}/{model_ref['InstanceSlug']}/{model_ref['VersionNumber']}"
67-
os.makedirs(os.path.dirname(os.path.join(MOUNT_PATH, mount_slug)))
68-
os.symlink('/input/tests/data/saved_model/', os.path.join(MOUNT_PATH, mount_slug), target_is_directory=True)
67+
model_path = os.path.join(MOUNT_PATH, mount_slug)
68+
os.makedirs(os.path.dirname(model_path))
69+
os.symlink('/input/tests/data/saved_model/', model_path, target_is_directory=True)
6970

7071
# Return the response
7172
self.wfile.write(bytes(json.dumps({
@@ -79,17 +80,27 @@ def do_POST(self):
7980
self.wfile.write(bytes(f"Unhandled path: {self.path}", "utf-8"))
8081

8182
class TestKaggleModuleResolver(unittest.TestCase):
82-
def test_kaggle_resolver_long_url_succeeds(self):
83+
def test_kaggle_resolver_long_url_succeeds(self):
84+
model_url = "https://kaggle.com/models/foo/foomodule/frameworks/TensorFlow2/variations/barvar/versions/2"
8385
with create_test_server(KaggleJwtHandler) as addr:
8486
test_inputs = tf.ones([1,4])
85-
layer = hub.KerasLayer("https://kaggle.com/models/foo/foomodule/frameworks/TensorFlow2/variations/barvar/versions/2")
87+
layer = hub.KerasLayer(model_url)
8688
self.assertEqual([1, 1], layer(test_inputs).shape)
89+
# Delete the files that were created in KaggleJwtHandler's do_POST method
90+
model_path = os.path.join(MOUNT_PATH, "foomodule/tensorflow2/barvar/2")
91+
os.unlink(model_path)
92+
os.rmdir(os.path.dirname(model_path))
8793

88-
def test_kaggle_resolver_short_url_succeeds(self):
94+
def test_kaggle_resolver_short_url_succeeds(self):
95+
model_url = "https://kaggle.com/models/foo/foomodule/TensorFlow2/barvar/2"
8996
with create_test_server(KaggleJwtHandler) as addr:
9097
test_inputs = tf.ones([1,4])
91-
layer = hub.KerasLayer("https://kaggle.com/models/bar/barmodule/pyTorch/barvar/1")
98+
layer = hub.KerasLayer(model_url)
9299
self.assertEqual([1, 1], layer(test_inputs).shape)
100+
# Delete the files that were created in KaggleJwtHandler's do_POST method
101+
model_path = os.path.join(MOUNT_PATH, "foomodule/tensorflow2/barvar/2")
102+
os.unlink(model_path)
103+
os.rmdir(os.path.dirname(model_path))
93104

94105
def test_kaggle_resolver_not_attached_throws(self):
95106
with create_test_server(KaggleJwtHandler) as addr:

0 commit comments

Comments
 (0)