Skip to content

Commit 4b0808f

Browse files
committed
start of unit tests
1 parent 1113ed5 commit 4b0808f

File tree

5 files changed

+145
-1
lines changed

5 files changed

+145
-1
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ export(fit_xy)
9696
export(fit_xy.model_spec)
9797
export(get_dependency)
9898
export(get_fit)
99+
export(get_from_env)
99100
export(get_model_env)
100101
export(get_pred_type)
101102
export(keras_mlp)

R/aaa_models.R

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -711,4 +711,12 @@ show_model_info <- function(model) {
711711
invisible(NULL)
712712
}
713713

714+
#' @rdname get_model_env
715+
#' @keywords internal
716+
#' @export
717+
#' @param items A character string of objects in the model environment.
718+
get_from_env <- function(items) {
719+
mod_env <- get_model_env()
720+
env_get(mod_env, items)
721+
}
714722

man/check_mod_val.Rd

Lines changed: 0 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/get_model_env.Rd

Lines changed: 6 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test_registration.R

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
library(parsnip)
2+
library(dplyr)
3+
library(rlang)
4+
library(testthat)
5+
6+
# ------------------------------------------------------------------------------
7+
8+
context("model registration")
9+
#source("helpers.R")
10+
11+
test_by_col <- function(a, b) {
12+
for(i in union(names(a), names(b))) {
13+
expect_equal(a[[i]], b[[i]])
14+
}
15+
}
16+
17+
# ------------------------------------------------------------------------------
18+
19+
test_that('adding a new model', {
20+
set_new_model("sponge")
21+
22+
mod_items <- get_model_env() %>% env_names()
23+
sponges <- grep("sponge", mod_items, value = TRUE)
24+
exp_obj <- c('sponge_modes', 'sponge_fit', 'sponge_args',
25+
'sponge_predict', 'sponge_pkgs', 'sponge')
26+
expect_equal(sort(sponges), sort(exp_obj))
27+
28+
expect_equal(
29+
get_from_env("sponge"),
30+
tibble(engine = character(0), mode = character(0))
31+
)
32+
33+
test_by_col(
34+
get_from_env("sponge_pkgs"),
35+
tibble(engine = character(0), pkg = character(0))
36+
)
37+
38+
expect_equal(
39+
get_from_env("sponge_modes"), "unknown"
40+
)
41+
42+
test_by_col(
43+
get_from_env("sponge_args"),
44+
tibble(engine = character(0), parsnip = character(0),
45+
original = character(0), func = vector("list"))
46+
)
47+
48+
test_by_col(
49+
get_from_env("sponge_fit"),
50+
tibble(engine = character(0), mode = character(0), value = vector("list"))
51+
)
52+
53+
test_by_col(
54+
get_from_env("sponge_predict"),
55+
tibble(engine = character(0), mode = character(0),
56+
type = character(0), value = vector("list"))
57+
)
58+
59+
expect_error(set_new_model())
60+
# TODO expect_error(set_new_model(2))
61+
# TODO expect_error(set_new_model(letters[1:2]))
62+
})
63+
64+
65+
# ------------------------------------------------------------------------------
66+
67+
test_that('adding a new mode', {
68+
set_model_mode("sponge", "classification")
69+
70+
expect_equal(get_from_env("sponge_modes"), c("unknown", "classification"))
71+
72+
# TODO expect_error(set_model_mode("sponge", "banana"))
73+
# TODO expect_error(set_model_mode("sponge", "classification"))
74+
75+
})
76+
77+
78+
# ------------------------------------------------------------------------------
79+
80+
test_that('adding a new engine', {
81+
set_model_engine("sponge", "classification", "gum")
82+
83+
test_by_col(
84+
get_from_env("sponge"),
85+
tibble(engine = "gum", mode = "classification")
86+
)
87+
88+
89+
expect_equal(get_from_env("sponge_modes"), c("unknown", "classification"))
90+
91+
# TODO check for bad mode, check for duplicate
92+
93+
})
94+
95+
96+
# ------------------------------------------------------------------------------
97+
98+
test_that('adding a new package', {
99+
set_dependency("sponge", "gum", "trident")
100+
101+
expect_error(set_dependency("sponge", "gum", letters[1:2]))
102+
103+
test_by_col(
104+
get_from_env("sponge_pkgs"),
105+
tibble(engine = "gum", pkg = list("trident"))
106+
)
107+
})
108+
109+
110+
# ------------------------------------------------------------------------------
111+
112+
test_that('adding a new argument', {
113+
114+
})
115+
116+
117+
118+
# ------------------------------------------------------------------------------
119+
120+
test_that('adding a new fit', {
121+
122+
})
123+
124+
125+
# ------------------------------------------------------------------------------
126+
127+
test_that('adding a new predict method', {
128+
129+
})
130+

0 commit comments

Comments
 (0)