Skip to content

Commit 32dbb32

Browse files
committed
Handle single prediction case for multinomial glmnet
1 parent 1d21cb3 commit 32dbb32

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,4 +283,5 @@ importFrom(utils,getFromNamespace)
283283
importFrom(utils,globalVariables)
284284
importFrom(utils,head)
285285
importFrom(utils,methods)
286+
importFrom(vctrs,vec_size)
286287
importFrom(vctrs,vec_unique)

R/multinom_reg.R

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,13 +168,25 @@ check_args.multinom_reg <- function(object) {
168168

169169
# ------------------------------------------------------------------------------
170170

171+
#' @importFrom vctrs vec_size
171172
organize_multnet_class <- function(x, object) {
172-
x[,1]
173+
if (vec_size(x) > 1) {
174+
x <- x[,1]
175+
} else {
176+
x <- as.character(x)
177+
}
178+
x
173179
}
174180

181+
#' @importFrom vctrs vec_size
175182
organize_multnet_prob <- function(x, object) {
176183
x <- x[,,1]
177-
as_tibble(x)
184+
if (vec_size(x) > 1) {
185+
x <- as_tibble(x)
186+
} else {
187+
x <- tibble::as_tibble_row(x)
188+
}
189+
x
178190
}
179191

180192
organize_nnet_prob <- function(x, object) {

0 commit comments

Comments
 (0)