Skip to content

Commit ae42617

Browse files
authored
Merge pull request #210 from tidymodels/multi-predict-column-names
multi_predict column names
2 parents f305412 + c2036b7 commit ae42617

File tree

10 files changed

+26
-19
lines changed

10 files changed

+26
-19
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
Package: parsnip
2-
Version: 0.0.3.9000
2+
Version: 0.0.3.9001
33
Title: A Common API to Modeling and Analysis Functions
44
Description: A common interface is provided to allow users to specify a model without having to remember the different argument names across different functions or computational engines (e.g. 'R', 'Spark', 'Stan', etc).
55
Authors@R: c(

NEWS.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
# parsnip 0.0.3.9000
1+
# parsnip 0.0.3.9001
22

33
* Some default parameter ranges were updated for SVM, KNN, and MARS models.
44

55
* [A bug](https://github.com/tidymodels/parsnip/issues/208) was fixed related to using data descriptors and `fit_xy()`.
66

7+
* A bug was fixed related to the column names generated by `multi_predict()`. The top-level tibble will always have a column named `.pred` and this list column contains tibbles across sub-models. The column names for these sub-model tibbles will have names consistent with `predict()` (which was previously incorrect). See [43c15db](https://github.com/tidymodels/parsnip/commit/43c15db377ea9ef27483ff209f6bd0e98cb830d2).
8+
79
# parsnip 0.0.3.1
810

911
Test case update due to CRAN running extra tests [(#202)](https://github.com/tidymodels/parsnip/issues/202)

R/aaa.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ utils::globalVariables(
3333
'lab', 'original', 'predicted_label', 'prediction', 'value', 'type',
3434
"neighbors", ".submodels", "has_submodel", "max_neighbor", "max_penalty",
3535
"max_terms", "max_tree", "name", "num_terms", "penalty", "trees",
36-
"sub_neighbors")
36+
"sub_neighbors", ".pred_class")
3737
)
3838

3939
# nocov end

R/aaa_multi_predict.R

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212
#' @param ... Optional arguments to pass to `predict.model_fit(type = "raw")`
1313
#' such as `type`.
1414
#' @return A tibble with the same number of rows as the data being predicted.
15-
#' Mostly likely, there is a list-column named `.pred` that is a tibble with
16-
#' multiple rows per sub-model.
15+
#' There is a list-column named `.pred` that contains tibbles with
16+
#' multiple rows per sub-model. Note that, within the tibbles, the column names
17+
#' follow the usual standard based on prediction `type` (i.e. `.pred_class` for
18+
#' `type = "class"` and so on).
1719
#' @export
1820
multi_predict <- function(object, ...) {
1921
if (inherits(object$fit, "try-error")) {

R/boost_tree.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,7 @@ xgb_by_tree <- function(tree, object, new_data, type, ...) {
404404
} else {
405405
if (type == "class") {
406406
pred <- object$spec$method$pred$class$post(pred, object)
407-
pred <- tibble(.pred = factor(pred, levels = object$lvl))
407+
pred <- tibble(.pred_class = factor(pred, levels = object$lvl))
408408
} else {
409409
pred <- object$spec$method$pred$prob$post(pred, object)
410410
pred <- as_tibble(pred)
@@ -503,7 +503,7 @@ C50_by_tree <- function(tree, object, new_data, type, ...) {
503503

504504
# switch based on prediction type
505505
if (type == "class") {
506-
pred <- tibble(.pred = factor(pred, levels = object$lvl))
506+
pred <- tibble(.pred_class = factor(pred, levels = object$lvl))
507507
} else {
508508
pred <- as_tibble(pred)
509509
names(pred) <- paste0(".pred_", names(pred))

R/logistic_reg.R

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ multi_predict._lognet <-
309309
if (is.null(type))
310310
type <- "class"
311311
if (!(type %in% c("class", "prob", "link", "raw"))) {
312-
stop ("`type` should be either 'class', 'link', 'raw', or 'prob'.", call. = FALSE)
312+
stop("`type` should be either 'class', 'link', 'raw', or 'prob'.", call. = FALSE)
313313
}
314314
if (type == "prob")
315315
dots$type <- "response"
@@ -321,12 +321,12 @@ multi_predict._lognet <-
321321
param_key <- tibble(group = colnames(pred), penalty = penalty)
322322
pred <- as_tibble(pred)
323323
pred$.row <- 1:nrow(pred)
324-
pred <- gather(pred, group, .pred, -.row)
324+
pred <- gather(pred, group, .pred_class, -.row)
325325
if (dots$type == "class") {
326-
pred[[".pred"]] <- factor(pred[[".pred"]], levels = object$lvl)
326+
pred[[".pred_class"]] <- factor(pred[[".pred_class"]], levels = object$lvl)
327327
} else {
328328
if (dots$type == "response") {
329-
pred[[".pred2"]] <- 1 - pred[[".pred"]]
329+
pred[[".pred2"]] <- 1 - pred[[".pred_class"]]
330330
names(pred) <- c(".row", "group", paste0(".pred_", rev(object$lvl)))
331331
pred <- pred[, c(".row", "group", paste0(".pred_", object$lvl))]
332332
}
@@ -371,3 +371,4 @@ predict_raw._lognet <- function(object, new_data, opts = list(), ...) {
371371
object$spec <- eval_args(object$spec)
372372
predict_raw.model_fit(object, new_data = new_data, opts = opts, ...)
373373
}
374+

R/multinom_reg.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ multi_predict._multnet <-
290290
pred <-
291291
tibble(
292292
.row = rep(1:nrow(new_data), length(penalty)),
293-
.pred = as.vector(pred),
293+
.pred_class = as.vector(pred),
294294
penalty = rep(penalty, each = nrow(new_data))
295295
)
296296
}

man/multi_predict.Rd

Lines changed: 4 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test_logistic_reg_glmnet.R

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ test_that('glmnet prediction, mulitiple lambda', {
119119
mult_pred$rows <- rep(1:7, 2)
120120
mult_pred <- mult_pred[order(mult_pred$rows, mult_pred$penalty), ]
121121
mult_pred <- mult_pred[, c("penalty", "values")]
122-
names(mult_pred) <- c("penalty", ".pred")
122+
names(mult_pred) <- c("penalty", ".pred_class")
123123
mult_pred <- tibble::as_tibble(mult_pred)
124124

125125
expect_equal(
@@ -148,7 +148,7 @@ test_that('glmnet prediction, mulitiple lambda', {
148148
form_pred$rows <- rep(1:7, 2)
149149
form_pred <- form_pred[order(form_pred$rows, form_pred$penalty), ]
150150
form_pred <- form_pred[, c("penalty", "values")]
151-
names(form_pred) <- c("penalty", ".pred")
151+
names(form_pred) <- c("penalty", ".pred_class")
152152
form_pred <- tibble::as_tibble(form_pred)
153153

154154
expect_equal(
@@ -180,7 +180,7 @@ test_that('glmnet prediction, no lambda', {
180180
mult_pred$rows <- rep(1:7, 2)
181181
mult_pred <- mult_pred[order(mult_pred$rows, mult_pred$penalty), ]
182182
mult_pred <- mult_pred[, c("penalty", "values")]
183-
names(mult_pred) <- c("penalty", ".pred")
183+
names(mult_pred) <- c("penalty", ".pred_class")
184184
mult_pred <- tibble::as_tibble(mult_pred)
185185

186186
expect_equal(mult_pred, multi_predict(xy_fit, lending_club[1:7, num_pred]) %>% unnest())
@@ -206,7 +206,7 @@ test_that('glmnet prediction, no lambda', {
206206
form_pred$rows <- rep(1:7, 2)
207207
form_pred <- form_pred[order(form_pred$rows, form_pred$penalty), ]
208208
form_pred <- form_pred[, c("penalty", "values")]
209-
names(form_pred) <- c("penalty", ".pred")
209+
names(form_pred) <- c("penalty", ".pred_class")
210210
form_pred <- tibble::as_tibble(form_pred)
211211

212212
expect_equal(

tests/testthat/test_multinom_reg_glmnet.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ test_that('glmnet probabilities, mulitiple lambda', {
123123

124124
mult_class <- names(mult_probs)[apply(mult_probs, 1, which.max)]
125125
mult_class <- tibble(
126-
.pred = mult_class,
126+
.pred_class = mult_class,
127127
penalty = rep(lams, each = 3),
128128
row = rep(1:3, 2)
129129
)

0 commit comments

Comments
 (0)