Skip to content

Commit 6722690

Browse files
authored
Port "R bring your own" notebook to R (#1221)
* feat(rbyo): first draft convert nb to R Convert Python notebook in this example to R * feat(rbyo): Surface console output on failure Print outputs to notebook even when shell command fails, with new utility function nbsystem2. Also improve representation of elapsed times with repr. * refactor(rbyo): simplify S3 data upload Use SM session instead of boto3 s3 resource * doc(rbyo): syntax highlight R code in NB markdown * fix(rbyo): syntax error missing close brackets Introduced with repr. * fix(rbyo): repeatably runnable df_to_csv function Wrap the textConnection in a function so predict cell can be run multiple times * doc(rbyo): update commentary in line with R nb Update comments to reflect that notebook is now in R * doc(rbyo): remove out-of-scope Docker TODOs Remove extra TODOs and update NB to match Dockerfile * style(rbyo): R code style and extra doc comments Fix the bizarre brace spacing used in the R scripts; replicate this code to the notebook sections that duplicate script code; add more guidance for SMStudio users (not ready for full support yet); and expand notes on further reading e.g. batch transform, HPO.
1 parent c51a1e0 commit 6722690

File tree

6 files changed

+375
-206
lines changed

6 files changed

+375
-206
lines changed

advanced_functionality/r_bring_your_own/Dockerfile

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@ FROM ubuntu:20.04
22

33
ARG DEBIAN_FRONTEND=noninteractive
44

5+
# Don't prompt for tzdata on new versions of Ubuntu:
6+
ARG DEBIAN_FRONTEND=noninteractive
7+
58
RUN apt-get -y update && apt-get install -y --no-install-recommends \
69
wget \
710
libcurl4-openssl-dev\
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#!/bin/bash
2+
3+
# The name of our algorithm
4+
algorithm_name=sagemaker-rmars
5+
6+
set -e # stop if anything fails
7+
8+
account=$(aws sts get-caller-identity --query Account --output text)
9+
echo "AWS Account ID $account"
10+
11+
# Get the region defined in the current configuration (default to us-west-2 if none defined)
12+
region=$(aws configure get region)
13+
region=${region:-us-west-2}
14+
echo "AWS Region $region"
15+
16+
fullname="${account}.dkr.ecr.${region}.amazonaws.com/${algorithm_name}:latest"
17+
18+
echo "Target image URI $fullname"
19+
20+
# If the repository doesn't exist in ECR, create it.
21+
22+
echo "Checking for existing repository..."
23+
set +e
24+
aws ecr describe-repositories --repository-names "${algorithm_name}"
25+
if [ $? -ne 0 ]
26+
then
27+
set -e
28+
echo "Creating repository"
29+
aws ecr create-repository --repository-name "${algorithm_name}"
30+
else
31+
set -e
32+
fi
33+
34+
# Get the login command from ECR and execute it directly
35+
$(aws ecr get-login --region ${region} --no-include-email)
36+
37+
# Build the docker image locally with the image name and then push it to ECR
38+
# with the full name.
39+
docker build -t ${algorithm_name} .
40+
docker tag ${algorithm_name} ${fullname}
41+
42+
docker push ${fullname}

advanced_functionality/r_bring_your_own/mars.R

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,16 @@ training_path <- paste(input_path, channel_name, sep='/')
3232

3333
# Setup training function
3434
train <- function() {
35-
3635
# Read in hyperparameters
3736
training_params <- read_json(param_path)
3837

3938
target <- training_params$target
4039

4140
if (!is.null(training_params$degree)) {
42-
degree <- as.numeric(training_params$degree)}
43-
else {
44-
degree <- 2}
41+
degree <- as.numeric(training_params$degree)
42+
} else {
43+
degree <- 2
44+
}
4545

4646
# Bring in data
4747
training_files = list.files(path=training_path, full.names=TRUE)
@@ -51,8 +51,10 @@ train <- function() {
5151
training_X <- model.matrix(~., training_data[, colnames(training_data) != target])
5252

5353
# Save factor levels for scoring
54-
factor_levels <- lapply(training_data[, sapply(training_data, is.factor), drop=FALSE],
55-
function(x) {levels(x)})
54+
factor_levels <- lapply(
55+
training_data[, sapply(training_data, is.factor), drop=FALSE],
56+
function(x) { levels(x) }
57+
)
5658

5759
# Run multivariate adaptive regression splines algorithm
5860
model <- mars(x=training_X, y=training_data[, target], degree=degree)
@@ -64,18 +66,22 @@ train <- function() {
6466
print(summary(mars_model))
6567

6668
write.csv(model$fitted.values, paste(output_path, 'data/fitted_values.csv', sep='/'), row.names=FALSE)
67-
write('success', file=paste(output_path, 'success', sep='/'))}
69+
write('success', file=paste(output_path, 'success', sep='/'))
70+
}
6871

6972

7073
# Setup scoring function
7174
serve <- function() {
7275
app <- plumb(paste(prefix, 'plumber.R', sep='/'))
73-
app$run(host='0.0.0.0', port=8080)}
76+
app$run(host='0.0.0.0', port=8080)
77+
}
7478

7579

7680
# Run at start-up
7781
args <- commandArgs()
7882
if (any(grepl('train', args))) {
79-
train()}
83+
train()
84+
}
8085
if (any(grepl('serve', args))) {
81-
serve()}
86+
serve()
87+
}

advanced_functionality/r_bring_your_own/plumber.R

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@
1111
#' Ping to show server is there
1212
#' @get /ping
1313
function() {
14-
return('')}
14+
return('')
15+
}
1516

1617

1718
#' Parse input and return prediction from model
1819
#' @param req The http request sent
1920
#' @post /invocations
2021
function(req) {
21-
2222
# Setup locations
2323
prefix <- '/opt/ml'
2424
model_path <- paste(prefix, 'model', sep='/')
@@ -35,4 +35,5 @@ function(req) {
3535
scoring_X <- model.matrix(~., data, xlev=factor_levels)
3636

3737
# Return prediction
38-
return(paste(predict(mars_model, scoring_X, row.names=FALSE), collapse=','))}
38+
return(paste(predict(mars_model, scoring_X, row.names=FALSE), collapse=','))
39+
}

0 commit comments

Comments
 (0)