Skip to content

Commit 823af9a

Browse files
authored
Fix UserSecretsClient#set_tensorflow_credentials (#1333)
http://b/313994895
1 parent 8c70958 commit 823af9a

File tree

2 files changed

+18
-10
lines changed

2 files changed

+18
-10
lines changed

patches/kaggle_secrets.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -106,19 +106,11 @@ def set_gcloud_credentials(self, project=None, account=None):
106106
subprocess.run(['gcloud', 'config', 'set', 'account', account])
107107

108108
def set_tensorflow_credential(self, credential):
109-
"""Sets the credential for use by Tensorflow both in the local notebook
110-
and to pass to the TPU.
111-
"""
112-
# b/159906185: Import tensorflow_gcs_config only when this method is called to prevent preloading TensorFlow.
113-
import tensorflow_gcs_config
109+
"""Sets the credential for use by Tensorflow"""
114110

115-
# Write to a local JSON credentials file and set
116-
# GOOGLE_APPLICATION_CREDENTIALS for tensorflow running in the notebook.
111+
# Write to a local JSON credentials file
117112
self._write_credentials_file(credential)
118113

119-
# set the credential for the TPU
120-
tensorflow_gcs_config.configure_gcs(credentials=credential)
121-
122114
def get_bigquery_access_token(self) -> Tuple[str, Optional[datetime]]:
123115
"""Retrieves BigQuery access token information from the UserSecrets service.
124116

tests/test_user_secrets.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,22 @@ def test_fn():
166166

167167
self._test_client(test_fn, '/requests/GetUserSecretByLabelRequest', {'Label': "__gcloud_sdk_auth__"}, secret=secret)
168168

169+
def test_set_tensorflow_credential(self):
170+
secret = '{"client_id":"gcloud","type":"authorized_user","refresh_token":"refresh_token"}'
171+
172+
def test_fn():
173+
client = UserSecretsClient()
174+
creds = client.get_gcloud_credential()
175+
client.set_tensorflow_credential(creds)
176+
177+
expected_creds_file = '/tmp/gcloud_credential.json'
178+
self.assertEqual(expected_creds_file, os.environ['GOOGLE_APPLICATION_CREDENTIALS'])
179+
180+
with open(expected_creds_file, 'r') as f:
181+
self.assertEqual(secret, '\n'.join(f.readlines()))
182+
183+
self._test_client(test_fn, '/requests/GetUserSecretByLabelRequest', {'Label': "__gcloud_sdk_auth__"}, secret=secret)
184+
169185
@mock.patch('kaggle_secrets.datetime')
170186
def test_get_access_token_succeeds(self, mock_dt):
171187
secret = '12345'

0 commit comments

Comments
 (0)