Skip to content

Commit 495bea9

Browse files
authored
Merge pull request #119 from awslabs/xgboost-updates
enclosing get_waiter for training within try-catch
2 parents 1232b17 + 5a8f477 commit 495bea9

File tree

1 file changed

+41
-54
lines changed

1 file changed

+41
-54
lines changed

advanced_functionality/handling_kms_encrypted_data/handling_kms_encrypted_data.ipynb

Lines changed: 41 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@
5050
"cell_type": "code",
5151
"execution_count": null,
5252
"metadata": {
53-
"collapsed": true,
5453
"isConfigCell": true
5554
},
5655
"outputs": [],
@@ -69,10 +68,10 @@
6968
"assumed_role = boto3.client('sts').get_caller_identity()['Arn']\n",
7069
"role = re.sub(r'^(.+)sts::(\\d+):assumed-role/(.+?)/.*$', r'\\1iam::\\2:role/\\3', assumed_role)\n",
7170
"\n",
72-
"kms_key_id = '<your_kms_key_arn_here>'\n",
71+
"kms_key_id = '<your-kms-key-id>'\n",
7372
"\n",
74-
"bucket='<your_s3_bucket_name_here>' # put your s3 bucket name here, and create s3 bucket\n",
75-
"prefix = 'sagemarker/kms-new'\n",
73+
"bucket='<s3-bucket>' # put your s3 bucket name here, and create s3 bucket\n",
74+
"prefix = 'sagemaker/kms'\n",
7675
"# customize to your bucket where you have stored the data\n",
7776
"bucket_path = 'https://s3-{}.amazonaws.com/{}'.format(region,bucket)"
7877
]
@@ -90,17 +89,6 @@
9089
"We, first, read the dataset from an existing repository into memory. This processing could be done *in situ* by Amazon Athena, Apache Spark in Amazon EMR, Amazon Redshift, etc., assuming the dataset is present in the appropriate location. Then, the next step would be to transfer the data to S3 for use in training. For small datasets, such as the one used below, reading into memory isn't onerous, though it would be for larger datasets."
9190
]
9291
},
93-
{
94-
"cell_type": "code",
95-
"execution_count": null,
96-
"metadata": {
97-
"collapsed": true
98-
},
99-
"outputs": [],
100-
"source": [
101-
"!conda install -y -c conda-forge scikit-learn"
102-
]
103-
},
10492
{
10593
"cell_type": "code",
10694
"execution_count": null,
@@ -189,9 +177,7 @@
189177
{
190178
"cell_type": "code",
191179
"execution_count": null,
192-
"metadata": {
193-
"collapsed": true
194-
},
180+
"metadata": {},
195181
"outputs": [],
196182
"source": [
197183
"s3 = boto3.client('s3')\n",
@@ -248,9 +234,7 @@
248234
{
249235
"cell_type": "code",
250236
"execution_count": null,
251-
"metadata": {
252-
"collapsed": true
253-
},
237+
"metadata": {},
254238
"outputs": [],
255239
"source": [
256240
"%%time\n",
@@ -299,7 +283,7 @@
299283
" \"S3DataDistributionType\": \"FullyReplicated\"\n",
300284
" }\n",
301285
" },\n",
302-
" \"ContentType\": \"libsvm\",\n",
286+
" \"ContentType\": \"csv\",\n",
303287
" \"CompressionType\": \"None\"\n",
304288
" },\n",
305289
" {\n",
@@ -320,12 +304,17 @@
320304
"client = boto3.client('sagemaker')\n",
321305
"client.create_training_job(**create_training_params)\n",
322306
"\n",
323-
"status = client.describe_training_job(TrainingJobName=job_name)['TrainingJobStatus']\n",
324-
"print(status)\n",
325-
"while status !='Completed' and status!='Failed':\n",
326-
" time.sleep(60)\n",
327-
" status = client.describe_training_job(TrainingJobName=job_name)['TrainingJobStatus']\n",
328-
" print(status)"
307+
"try:\n",
308+
" # wait for the job to finish and report the ending status\n",
309+
" client.get_waiter('TrainingJob_Created').wait(TrainingJobName=job_name)\n",
310+
" training_info = client.describe_training_job(TrainingJobName=job_name)\n",
311+
" status = training_info['TrainingJobStatus']\n",
312+
" print(\"Training job ended with status: \" + status)\n",
313+
"except:\n",
314+
" print('Training failed to start')\n",
315+
" # if exception is raised, that means it has failed\n",
316+
" message = client.describe_training_job(TrainingJobName=job_name)['FailureReason']\n",
317+
" print('Training failed with the following error: {}'.format(message))"
329318
]
330319
},
331320
{
@@ -343,9 +332,7 @@
343332
{
344333
"cell_type": "code",
345334
"execution_count": null,
346-
"metadata": {
347-
"collapsed": true
348-
},
335+
"metadata": {},
349336
"outputs": [],
350337
"source": [
351338
"%%time\n",
@@ -384,9 +371,7 @@
384371
{
385372
"cell_type": "code",
386373
"execution_count": null,
387-
"metadata": {
388-
"collapsed": true
389-
},
374+
"metadata": {},
390375
"outputs": [],
391376
"source": [
392377
"from time import gmtime, strftime\n",
@@ -416,9 +401,7 @@
416401
{
417402
"cell_type": "code",
418403
"execution_count": null,
419-
"metadata": {
420-
"collapsed": true
421-
},
404+
"metadata": {},
422405
"outputs": [],
423406
"source": [
424407
"%%time\n",
@@ -431,18 +414,26 @@
431414
" EndpointConfigName=endpoint_config_name)\n",
432415
"print(create_endpoint_response['EndpointArn'])\n",
433416
"\n",
434-
"resp = client.describe_endpoint(EndpointName=endpoint_name)\n",
435-
"status = resp['EndpointStatus']\n",
436-
"print(\"Status: \" + status)\n",
437417
"\n",
438-
"while status=='Creating':\n",
439-
" time.sleep(60)\n",
440-
" resp = client.describe_endpoint(EndpointName=endpoint_name)\n",
441-
" status = resp['EndpointStatus']\n",
442-
" print(\"Status: \" + status)\n",
418+
"print('EndpointArn = {}'.format(create_endpoint_response['EndpointArn']))\n",
419+
"\n",
420+
"# get the status of the endpoint\n",
421+
"response = client.describe_endpoint(EndpointName=endpoint_name)\n",
422+
"status = response['EndpointStatus']\n",
423+
"print('EndpointStatus = {}'.format(status))\n",
424+
"\n",
443425
"\n",
444-
"print(\"Arn: \" + resp['EndpointArn'])\n",
445-
"print(\"Status: \" + status)"
426+
"# wait until the status has changed\n",
427+
"client.get_waiter('Endpoint_Created').wait(EndpointName=endpoint_name)\n",
428+
"\n",
429+
"\n",
430+
"# print the status of the endpoint\n",
431+
"endpoint_response = client.describe_endpoint(EndpointName=endpoint_name)\n",
432+
"status = endpoint_response['EndpointStatus']\n",
433+
"print('Endpoint creation ended with EndpointStatus = {}'.format(status))\n",
434+
"\n",
435+
"if status != 'InService':\n",
436+
" raise Exception('Endpoint creation failed.')"
446437
]
447438
},
448439
{
@@ -508,9 +499,7 @@
508499
{
509500
"cell_type": "code",
510501
"execution_count": null,
511-
"metadata": {
512-
"collapsed": true
513-
},
502+
"metadata": {},
514503
"outputs": [],
515504
"source": [
516505
"%%time\n",
@@ -542,12 +531,10 @@
542531
{
543532
"cell_type": "code",
544533
"execution_count": null,
545-
"metadata": {
546-
"collapsed": true
547-
},
534+
"metadata": {},
548535
"outputs": [],
549536
"source": [
550-
"#client.delete_endpoint(EndpointName=endpoint_name)"
537+
"# client.delete_endpoint(EndpointName=endpoint_name)"
551538
]
552539
}
553540
],

0 commit comments

Comments
 (0)