@@ -79,11 +79,8 @@ set_in_env <- function(...) {
79
79
# ' @keywords internal
80
80
# ' @export
81
81
set_env_val <- function (name , value ) {
82
- if (length(name ) != 1 | length(value ) != 1 ) {
83
- stop(" `name` and `value` should both be a single value." , call. = FALSE )
84
- }
85
- if (! is.character(name )) {
86
- stop(" `name` should be a character value." , call. = FALSE )
82
+ if (length(name ) != 1 || ! is.character(name )) {
83
+ stop(" `name` should be a single character value." , call. = FALSE )
87
84
}
88
85
mod_env <- get_model_env()
89
86
x <- list (value )
@@ -329,31 +326,40 @@ set_new_model <- function(model) {
329
326
330
327
current <- get_model_env()
331
328
332
- current $ models <- c(current $ models , model )
333
- current [[model ]] <- dplyr :: tibble(engine = character (0 ), mode = character (0 ))
334
- current [[paste0(model , " _pkgs" )]] <- dplyr :: tibble(engine = character (0 ), pkg = list ())
335
- current [[paste0(model , " _modes" )]] <- " unknown"
336
- current [[paste0(model , " _args" )]] <-
329
+ set_env_val(" models" , c(current $ models , model ))
330
+ set_env_val(model , dplyr :: tibble(engine = character (0 ), mode = character (0 )))
331
+ set_env_val(
332
+ paste0(model , " _pkgs" ),
333
+ dplyr :: tibble(engine = character (0 ), pkg = list ())
334
+ )
335
+ set_env_val(paste0(model , " _modes" ), " unknown" )
336
+ set_env_val(
337
+ paste0(model , " _args" ),
337
338
dplyr :: tibble(
338
339
engine = character (0 ),
339
340
parsnip = character (0 ),
340
341
original = character (0 ),
341
342
func = list (),
342
343
has_submodel = logical (0 )
343
344
)
344
- current [[paste0(model , " _fit" )]] <-
345
+ )
346
+ set_env_val(
347
+ paste0(model , " _fit" ),
345
348
dplyr :: tibble(
346
349
engine = character (0 ),
347
350
mode = character (0 ),
348
351
value = list ()
349
352
)
350
- current [[paste0(model , " _predict" )]] <-
353
+ )
354
+ set_env_val(
355
+ paste0(model , " _predict" ),
351
356
dplyr :: tibble(
352
357
engine = character (0 ),
353
358
mode = character (0 ),
354
359
type = character (0 ),
355
360
value = list ()
356
361
)
362
+ )
357
363
358
364
invisible (NULL )
359
365
}
@@ -372,9 +378,11 @@ set_model_mode <- function(model, mode) {
372
378
if (! any(current $ modes == mode )) {
373
379
current $ modes <- unique(c(current $ modes , mode ))
374
380
}
375
- current [[paste0(model , " _modes" )]] <-
376
- unique(c(current [[paste0(model , " _modes" )]], mode ))
377
381
382
+ set_env_val(
383
+ paste0(model , " _modes" ),
384
+ unique(c(get_from_env(paste0(model , " _modes" )), mode ))
385
+ )
378
386
invisible (NULL )
379
387
}
380
388
@@ -392,20 +400,21 @@ set_model_engine <- function(model, mode, eng) {
392
400
current <- get_model_env()
393
401
394
402
new_eng <- dplyr :: tibble(engine = eng , mode = mode )
395
- old_eng <- current [[model ]]
403
+ old_eng <- get_from_env(model )
404
+
396
405
engs <-
397
406
old_eng %> %
398
407
dplyr :: bind_rows(new_eng ) %> %
399
408
dplyr :: distinct()
400
409
401
- current [[ model ]] <- engs
410
+ set_env_val( model , engs )
402
411
403
412
invisible (NULL )
404
413
}
405
414
406
415
407
416
# ------------------------------------------------------------------------------
408
-
417
+ # ' @importFrom vctrs vec_unique
409
418
# ' @rdname set_new_model
410
419
# ' @keywords internal
411
420
# ' @export
@@ -418,7 +427,7 @@ set_model_arg <- function(model, eng, parsnip, original, func, has_submodel) {
418
427
check_submodels_val(has_submodel )
419
428
420
429
current <- get_model_env()
421
- old_args <- current [[ paste0(model , " _args" )]]
430
+ old_args <- get_from_env( paste0(model , " _args" ))
422
431
423
432
new_arg <-
424
433
dplyr :: tibble(
@@ -429,22 +438,13 @@ set_model_arg <- function(model, eng, parsnip, original, func, has_submodel) {
429
438
has_submodel = has_submodel
430
439
)
431
440
432
- # Do not allow people to modify existing arguments
433
- combined <-
434
- dplyr :: inner_join(new_arg %> % dplyr :: select(engine , parsnip , original ),
435
- old_args %> % dplyr :: select(engine , parsnip , original ),
436
- by = c(" engine" , " parsnip" , " original" ))
437
- if (nrow(combined ) != 0 ) {
438
- stop(" A model argument already exists for " , model , " using the " ,
439
- eng , " engine. You cannot overwrite arguments." , call. = FALSE )
440
- }
441
-
442
441
updated <- try(dplyr :: bind_rows(old_args , new_arg ), silent = TRUE )
443
442
if (inherits(updated , " try-error" )) {
444
443
stop(" An error occured when adding the new argument." , call. = FALSE )
445
444
}
446
445
447
- current [[paste0(model , " _args" )]] <- updated
446
+ updated <- vctrs :: vec_unique(updated )
447
+ set_env_val(paste0(model , " _args" ), updated )
448
448
449
449
invisible (NULL )
450
450
}
@@ -461,8 +461,8 @@ set_dependency <- function(model, eng, pkg) {
461
461
check_pkg_val(pkg )
462
462
463
463
current <- get_model_env()
464
- model_info <- current [[ model ]]
465
- pkg_info <- current [[ paste0(model , " _pkgs" )]]
464
+ model_info <- get_from_env( model )
465
+ pkg_info <- get_from_env( paste0(model , " _pkgs" ))
466
466
467
467
has_engine <-
468
468
model_info %> %
@@ -491,7 +491,8 @@ set_dependency <- function(model, eng, pkg) {
491
491
dplyr :: filter(engine != eng ) %> %
492
492
dplyr :: bind_rows(existing_pkgs )
493
493
}
494
- current [[paste0(model , " _pkgs" )]] <- pkg_info
494
+
495
+ set_env_val(paste0(model , " _pkgs" ), pkg_info )
495
496
496
497
invisible (NULL )
497
498
}
@@ -522,8 +523,8 @@ set_fit <- function(model, mode, eng, value) {
522
523
check_fit_info(value )
523
524
524
525
current <- get_model_env()
525
- model_info <- current [[paste0 (model )]]
526
- old_fits <- current [[ paste0(model , " _fit" )]]
526
+ model_info <- get_from_env (model )
527
+ old_fits <- get_from_env( paste0(model , " _fit" ))
527
528
528
529
has_engine <-
529
530
model_info %> %
@@ -558,7 +559,10 @@ set_fit <- function(model, mode, eng, value) {
558
559
stop(" An error occured when adding the new fit module" , call. = FALSE )
559
560
}
560
561
561
- current [[paste0(model , " _fit" )]] <- updated
562
+ set_env_val(
563
+ paste0(model , " _fit" ),
564
+ updated
565
+ )
562
566
563
567
invisible (NULL )
564
568
}
@@ -588,8 +592,8 @@ set_pred <- function(model, mode, eng, type, value) {
588
592
check_pred_info(value , type )
589
593
590
594
current <- get_model_env()
591
- model_info <- current [[paste0 (model )]]
592
- old_fits <- current [[ paste0(model , " _predict" )]]
595
+ model_info <- get_from_env (model )
596
+ old_fits <- get_from_env( paste0(model , " _predict" ))
593
597
594
598
has_engine <-
595
599
model_info %> %
@@ -625,7 +629,7 @@ set_pred <- function(model, mode, eng, type, value) {
625
629
stop(" An error occured when adding the new fit module" , call. = FALSE )
626
630
}
627
631
628
- current [[ paste0(model , " _predict" )]] <- updated
632
+ set_env_val( paste0(model , " _predict" ), updated )
629
633
630
634
invisible (NULL )
631
635
}
@@ -660,11 +664,11 @@ show_model_info <- function(model) {
660
664
661
665
cat(
662
666
" modes:" ,
663
- paste0(current [[ paste0(model , " _modes" )]] , collapse = " , " ),
667
+ paste0(get_from_env( paste0(model , " _modes" )) , collapse = " , " ),
664
668
" \n\n "
665
669
)
666
670
667
- engines <- current [[paste0 (model )]]
671
+ engines <- get_from_env (model )
668
672
if (nrow(engines ) > 0 ) {
669
673
cat(" engines: \n " )
670
674
engines %> %
@@ -686,7 +690,7 @@ show_model_info <- function(model) {
686
690
cat(" no registered engines.\n\n " )
687
691
}
688
692
689
- args <- current [[ paste0(model , " _args" )]]
693
+ args <- get_from_env( paste0(model , " _args" ))
690
694
if (nrow(args ) > 0 ) {
691
695
cat(" arguments: \n " )
692
696
args %> %
@@ -710,7 +714,7 @@ show_model_info <- function(model) {
710
714
cat(" no registered arguments.\n\n " )
711
715
}
712
716
713
- fits <- current [[ paste0(model , " _fit" )]]
717
+ fits <- get_from_env( paste0(model , " _fit" ))
714
718
if (nrow(fits ) > 0 ) {
715
719
cat(" fit modules:\n " )
716
720
fits %> %
@@ -723,7 +727,7 @@ show_model_info <- function(model) {
723
727
cat(" no registered fit modules.\n\n " )
724
728
}
725
729
726
- preds <- current [[ paste0(model , " _predict" )]]
730
+ preds <- get_from_env( paste0(model , " _predict" ))
727
731
if (nrow(preds ) > 0 ) {
728
732
cat(" prediction modules:\n " )
729
733
preds %> %
0 commit comments