Skip to content

Commit 218cc73

Browse files
committed
proper uniqueness enforcement for argument tibbles
1 parent 1b76e44 commit 218cc73

File tree

4 files changed

+60
-53
lines changed

4 files changed

+60
-53
lines changed

DESCRIPTION

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ Imports:
2727
magrittr,
2828
stats,
2929
tidyr,
30-
globals
30+
globals,
31+
vctrs
3132
Roxygen: list(markdown = TRUE)
3233
RoxygenNote: 6.1.1
3334
Suggests:

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,3 +210,4 @@ importFrom(utils,capture.output)
210210
importFrom(utils,getFromNamespace)
211211
importFrom(utils,globalVariables)
212212
importFrom(utils,head)
213+
importFrom(vctrs,vec_unique)

R/aaa_models.R

Lines changed: 47 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,8 @@ set_in_env <- function(...) {
7979
#' @keywords internal
8080
#' @export
8181
set_env_val <- function(name, value) {
82-
if (length(name) != 1 | length(value) != 1) {
83-
stop("`name` and `value` should both be a single value.", call. = FALSE)
84-
}
85-
if (!is.character(name)) {
86-
stop("`name` should be a character value.", call. = FALSE)
82+
if (length(name) != 1 || !is.character(name)) {
83+
stop("`name` should be a single character value.", call. = FALSE)
8784
}
8885
mod_env <- get_model_env()
8986
x <- list(value)
@@ -329,31 +326,40 @@ set_new_model <- function(model) {
329326

330327
current <- get_model_env()
331328

332-
current$models <- c(current$models, model)
333-
current[[model]] <- dplyr::tibble(engine = character(0), mode = character(0))
334-
current[[paste0(model, "_pkgs")]] <- dplyr::tibble(engine = character(0), pkg = list())
335-
current[[paste0(model, "_modes")]] <- "unknown"
336-
current[[paste0(model, "_args")]] <-
329+
set_env_val("models", c(current$models, model))
330+
set_env_val(model, dplyr::tibble(engine = character(0), mode = character(0)))
331+
set_env_val(
332+
paste0(model, "_pkgs"),
333+
dplyr::tibble(engine = character(0), pkg = list())
334+
)
335+
set_env_val(paste0(model, "_modes"), "unknown")
336+
set_env_val(
337+
paste0(model, "_args"),
337338
dplyr::tibble(
338339
engine = character(0),
339340
parsnip = character(0),
340341
original = character(0),
341342
func = list(),
342343
has_submodel = logical(0)
343344
)
344-
current[[paste0(model, "_fit")]] <-
345+
)
346+
set_env_val(
347+
paste0(model, "_fit"),
345348
dplyr::tibble(
346349
engine = character(0),
347350
mode = character(0),
348351
value = list()
349352
)
350-
current[[paste0(model, "_predict")]] <-
353+
)
354+
set_env_val(
355+
paste0(model, "_predict"),
351356
dplyr::tibble(
352357
engine = character(0),
353358
mode = character(0),
354359
type = character(0),
355360
value = list()
356361
)
362+
)
357363

358364
invisible(NULL)
359365
}
@@ -372,9 +378,11 @@ set_model_mode <- function(model, mode) {
372378
if (!any(current$modes == mode)) {
373379
current$modes <- unique(c(current$modes, mode))
374380
}
375-
current[[paste0(model, "_modes")]] <-
376-
unique(c(current[[paste0(model, "_modes")]], mode))
377381

382+
set_env_val(
383+
paste0(model, "_modes"),
384+
unique(c(get_from_env(paste0(model, "_modes")), mode))
385+
)
378386
invisible(NULL)
379387
}
380388

@@ -392,20 +400,21 @@ set_model_engine <- function(model, mode, eng) {
392400
current <- get_model_env()
393401

394402
new_eng <- dplyr::tibble(engine = eng, mode = mode)
395-
old_eng <- current[[model]]
403+
old_eng <- get_from_env(model)
404+
396405
engs <-
397406
old_eng %>%
398407
dplyr::bind_rows(new_eng) %>%
399408
dplyr::distinct()
400409

401-
current[[model]] <- engs
410+
set_env_val(model, engs)
402411

403412
invisible(NULL)
404413
}
405414

406415

407416
# ------------------------------------------------------------------------------
408-
417+
#' @importFrom vctrs vec_unique
409418
#' @rdname set_new_model
410419
#' @keywords internal
411420
#' @export
@@ -418,7 +427,7 @@ set_model_arg <- function(model, eng, parsnip, original, func, has_submodel) {
418427
check_submodels_val(has_submodel)
419428

420429
current <- get_model_env()
421-
old_args <- current[[paste0(model, "_args")]]
430+
old_args <- get_from_env(paste0(model, "_args"))
422431

423432
new_arg <-
424433
dplyr::tibble(
@@ -429,22 +438,13 @@ set_model_arg <- function(model, eng, parsnip, original, func, has_submodel) {
429438
has_submodel = has_submodel
430439
)
431440

432-
# Do not allow people to modify existing arguments
433-
combined <-
434-
dplyr::inner_join(new_arg %>% dplyr::select(engine, parsnip, original),
435-
old_args %>% dplyr::select(engine, parsnip, original),
436-
by = c("engine", "parsnip", "original"))
437-
if (nrow(combined) != 0) {
438-
stop("A model argument already exists for ", model, " using the ",
439-
eng, " engine. You cannot overwrite arguments.", call. = FALSE)
440-
}
441-
442441
updated <- try(dplyr::bind_rows(old_args, new_arg), silent = TRUE)
443442
if (inherits(updated, "try-error")) {
444443
stop("An error occured when adding the new argument.", call. = FALSE)
445444
}
446445

447-
current[[paste0(model, "_args")]] <- updated
446+
updated <- vctrs::vec_unique(updated)
447+
set_env_val(paste0(model, "_args"), updated)
448448

449449
invisible(NULL)
450450
}
@@ -461,8 +461,8 @@ set_dependency <- function(model, eng, pkg) {
461461
check_pkg_val(pkg)
462462

463463
current <- get_model_env()
464-
model_info <- current[[model]]
465-
pkg_info <- current[[paste0(model, "_pkgs")]]
464+
model_info <- get_from_env(model)
465+
pkg_info <- get_from_env(paste0(model, "_pkgs"))
466466

467467
has_engine <-
468468
model_info %>%
@@ -491,7 +491,8 @@ set_dependency <- function(model, eng, pkg) {
491491
dplyr::filter(engine != eng) %>%
492492
dplyr::bind_rows(existing_pkgs)
493493
}
494-
current[[paste0(model, "_pkgs")]] <- pkg_info
494+
495+
set_env_val(paste0(model, "_pkgs"), pkg_info)
495496

496497
invisible(NULL)
497498
}
@@ -522,8 +523,8 @@ set_fit <- function(model, mode, eng, value) {
522523
check_fit_info(value)
523524

524525
current <- get_model_env()
525-
model_info <- current[[paste0(model)]]
526-
old_fits <- current[[paste0(model, "_fit")]]
526+
model_info <- get_from_env(model)
527+
old_fits <- get_from_env(paste0(model, "_fit"))
527528

528529
has_engine <-
529530
model_info %>%
@@ -558,7 +559,10 @@ set_fit <- function(model, mode, eng, value) {
558559
stop("An error occured when adding the new fit module", call. = FALSE)
559560
}
560561

561-
current[[paste0(model, "_fit")]] <- updated
562+
set_env_val(
563+
paste0(model, "_fit"),
564+
updated
565+
)
562566

563567
invisible(NULL)
564568
}
@@ -588,8 +592,8 @@ set_pred <- function(model, mode, eng, type, value) {
588592
check_pred_info(value, type)
589593

590594
current <- get_model_env()
591-
model_info <- current[[paste0(model)]]
592-
old_fits <- current[[paste0(model, "_predict")]]
595+
model_info <- get_from_env(model)
596+
old_fits <- get_from_env(paste0(model, "_predict"))
593597

594598
has_engine <-
595599
model_info %>%
@@ -625,7 +629,7 @@ set_pred <- function(model, mode, eng, type, value) {
625629
stop("An error occured when adding the new fit module", call. = FALSE)
626630
}
627631

628-
current[[paste0(model, "_predict")]] <- updated
632+
set_env_val(paste0(model, "_predict"), updated)
629633

630634
invisible(NULL)
631635
}
@@ -660,11 +664,11 @@ show_model_info <- function(model) {
660664

661665
cat(
662666
" modes:",
663-
paste0(current[[paste0(model, "_modes")]], collapse = ", "),
667+
paste0(get_from_env(paste0(model, "_modes")), collapse = ", "),
664668
"\n\n"
665669
)
666670

667-
engines <- current[[paste0(model)]]
671+
engines <- get_from_env(model)
668672
if (nrow(engines) > 0) {
669673
cat(" engines: \n")
670674
engines %>%
@@ -686,7 +690,7 @@ show_model_info <- function(model) {
686690
cat(" no registered engines.\n\n")
687691
}
688692

689-
args <- current[[paste0(model, "_args")]]
693+
args <- get_from_env(paste0(model, "_args"))
690694
if (nrow(args) > 0) {
691695
cat(" arguments: \n")
692696
args %>%
@@ -710,7 +714,7 @@ show_model_info <- function(model) {
710714
cat(" no registered arguments.\n\n")
711715
}
712716

713-
fits <- current[[paste0(model, "_fit")]]
717+
fits <- get_from_env(paste0(model, "_fit"))
714718
if (nrow(fits) > 0) {
715719
cat(" fit modules:\n")
716720
fits %>%
@@ -723,7 +727,7 @@ show_model_info <- function(model) {
723727
cat(" no registered fit modules.\n\n")
724728
}
725729

726-
preds <- current[[paste0(model, "_predict")]]
730+
preds <- get_from_env(paste0(model, "_predict"))
727731
if (nrow(preds) > 0) {
728732
cat(" prediction modules:\n")
729733
preds %>%

tests/testthat/test_registration.R

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -136,17 +136,18 @@ test_that('adding a new argument', {
136136
has_submodel = FALSE
137137
)
138138

139-
expect_error(
140-
set_model_arg(
141-
model = "sponge",
142-
eng = "gum",
143-
parsnip = "modeling",
144-
original = "modelling",
145-
func = list(pkg = "foo", fun = "bar"),
146-
has_submodel = FALSE
147-
)
139+
set_model_arg(
140+
model = "sponge",
141+
eng = "gum",
142+
parsnip = "modeling",
143+
original = "modelling",
144+
func = list(pkg = "foo", fun = "bar"),
145+
has_submodel = FALSE
148146
)
149147

148+
args <- get_from_env("sponge_args")
149+
expect_equal(sum(args$parsnip == "modeling"), 1)
150+
150151
test_by_col(
151152
get_from_env("sponge_args"),
152153
tibble(engine = "gum", parsnip = "modeling", original = "modelling",

0 commit comments

Comments
 (0)