Skip to content

Commit 40274d6

Browse files
authored
Merge pull request #428 from tidymodels/more-glmnet-cleanup
Handle single prediction case for multinomial glmnet
2 parents 1d21cb3 + 13cbb0f commit 40274d6

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,4 +283,5 @@ importFrom(utils,getFromNamespace)
283283
importFrom(utils,globalVariables)
284284
importFrom(utils,head)
285285
importFrom(utils,methods)
286+
importFrom(vctrs,vec_size)
286287
importFrom(vctrs,vec_unique)

R/multinom_reg.R

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,13 +168,24 @@ check_args.multinom_reg <- function(object) {
168168

169169
# ------------------------------------------------------------------------------
170170

171+
#' @importFrom vctrs vec_size
171172
organize_multnet_class <- function(x, object) {
172-
x[,1]
173+
if (vec_size(x) > 1) {
174+
x <- x[,1]
175+
} else {
176+
x <- as.character(x)
177+
}
178+
x
173179
}
174180

181+
#' @importFrom vctrs vec_size
175182
organize_multnet_prob <- function(x, object) {
176-
x <- x[,,1]
177-
as_tibble(x)
183+
if (vec_size(x) > 1) {
184+
x <- as_tibble(x[,,1])
185+
} else {
186+
x <- tibble::as_tibble_row(x[,,1])
187+
}
188+
x
178189
}
179190

180191
organize_nnet_prob <- function(x, object) {

0 commit comments

Comments
 (0)