Skip to content

Commit 4b478cf

Browse files
committed
closes #275
1 parent 421481c commit 4b478cf

File tree

2 files changed

+11
-11
lines changed

2 files changed

+11
-11
lines changed

R/aaa_multi_predict.R

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -117,18 +117,18 @@ multi_predict_args.default <- function(object, ...) {
117117
#' @export
118118
#' @rdname has_multi_predict
119119
multi_predict_args.model_fit <- function(object, ...) {
120-
existing_mthds <- methods("multi_predict")
121-
cls <- class(object)
122-
tst <- paste0("multi_predict.", cls)
123-
.fn <- tst[tst %in% existing_mthds]
124-
if (length(.fn) == 0) {
125-
return(NA_character_)
120+
model_type <- class(object$spec)[1]
121+
arg_info <- get_from_env(paste0(model_type, "_args"))
122+
arg_info <- arg_info[arg_info$engine == object$spec$engine,]
123+
arg_info <- arg_info[arg_info$has_submodel,]
124+
125+
if (nrow(arg_info) == 0) {
126+
res <- NA_character_
127+
} else {
128+
res <- arg_info[["parsnip"]]
126129
}
127130

128-
.fn <- getFromNamespace(.fn, ns = "parsnip")
129-
omit <- c('object', 'new_data', 'type', '...')
130-
args <- names(formals(.fn))
131-
args[!(args %in% omit)]
131+
res
132132
}
133133

134134
#' @export

R/linear_reg_data.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ set_model_arg(
293293
parsnip = "penalty",
294294
original = "reg_param",
295295
func = list(pkg = "dials", fun = "penalty"),
296-
has_submodel = TRUE
296+
has_submodel = FALSE
297297
)
298298

299299
set_model_arg(

0 commit comments

Comments
 (0)