Skip to content

Commit ff20477

Browse files
hfricksimonpcouch
andauthored
Improve readability of docs for predict.model_fit() (#866)
* Standardize language across different types * document * Apply suggestions from code review Co-authored-by: Simon P. Couch <[email protected]> --------- Co-authored-by: Simon P. Couch <[email protected]>
1 parent 3aa1e6c commit ff20477

File tree

4 files changed

+109
-104
lines changed

4 files changed

+109
-104
lines changed

R/predict.R

Lines changed: 46 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -4,45 +4,42 @@
44
#' `predict()` can be used for all types of models and uses the
55
#' "type" argument for more specificity.
66
#'
7-
#' @param object An object of class `model_fit`
7+
#' @param object An object of class `model_fit`.
88
#' @param new_data A rectangular data object, such as a data frame.
99
#' @param type A single character value or `NULL`. Possible values
10-
#' are "numeric", "class", "prob", "conf_int", "pred_int", "quantile", "time",
11-
#' "hazard", "survival", or "raw". When `NULL`, `predict()` will choose an
12-
#' appropriate value based on the model's mode.
10+
#' are `"numeric"`, `"class"`, `"prob"`, `"conf_int"`, `"pred_int"`,
11+
#' `"quantile"`, `"time"`, `"hazard"`, `"survival"`, or `"raw"`. When `NULL`,
12+
#' `predict()` will choose an appropriate value based on the model's mode.
1313
#' @param opts A list of optional arguments to the underlying
1414
#' predict function that will be used when `type = "raw"`. The
1515
#' list should not include options for the model object or the
1616
#' new data being predicted.
17-
#' @param ... Arguments to the underlying model's prediction
18-
#' function cannot be passed here (see `opts`). There are some
19-
#' `parsnip` related options that can be passed, depending on the
20-
#' value of `type`. Possible arguments are:
17+
#' @param ... Additional `parsnip`-related options, depending on the
18+
#' value of `type`. Arguments to the underlying model's prediction
19+
#' function cannot be passed here (use the `opts` argument instead).
20+
#' Possible arguments are:
2121
#' \itemize{
22-
#' \item `interval`: for `type`s of "survival" and "quantile", should
22+
#' \item `interval`: for `type` equal to `"survival"` or `"quantile"`, should
2323
#' interval estimates be added, if available? Options are `"none"`
2424
#' and `"confidence"`.
25-
#' \item `level`: for `type`s of "conf_int", "pred_int", and "survival"
25+
#' \item `level`: for `type` equal to `"conf_int"`, `"pred_int"`, or `"survival"`,
2626
#' this is the parameter for the tail area of the intervals
2727
#' (e.g. confidence level for confidence intervals).
28-
#' Default value is 0.95.
29-
#' \item `std_error`: add the standard error of fit or prediction (on
30-
#' the scale of the linear predictors) for `type`s of "conf_int"
31-
#' and "pred_int". Default value is `FALSE`.
32-
#' \item `quantile`: the quantile(s) for quantile regression
33-
#' (not implemented yet)
34-
#' \item `time`: the time(s) for hazard and survival probability estimates.
28+
#' Default value is `0.95`.
29+
#' \item `std_error`: for `type` equal to `"conf_int"` or `"pred_int"`, add
30+
#' the standard error of fit or prediction (on the scale of the
31+
#' linear predictors). Default value is `FALSE`.
32+
#' \item `quantile`: for `type` equal to `quantile`, the quantiles of the
33+
#' distribution. Default is `(1:9)/10`.
34+
#' \item `time`: for `type` equal to `"survival"` or `"hazard"`, the
35+
#' time points at which the survival probability or hazard is estimated.
3536
#' }
36-
#' @details If "type" is not supplied to `predict()`, then a choice
37-
#' is made:
37+
#' @details For `type = NULL`, `predict()` uses
3838
#'
3939
#' * `type = "numeric"` for regression models,
4040
#' * `type = "class"` for classification, and
4141
#' * `type = "time"` for censored regression.
4242
#'
43-
#' `predict()` is designed to provide a tidy result (see "Value"
44-
#' section below) in a tibble output format.
45-
#'
4643
#' ## Interval predictions
4744
#'
4845
#' When using `type = "conf_int"` and `type = "pred_int"`, the options
@@ -58,37 +55,42 @@
5855
#' have the opposite sign as what the underlying model's `predict()` method
5956
#' produces. Set `increasing = FALSE` to suppress this behavior.
6057
#'
61-
#' @return With the exception of `type = "raw"`, the results of
62-
#' `predict.model_fit()` will be a tibble as many rows in the output
63-
#' as there are rows in `new_data` and the column names will be
64-
#' predictable.
58+
#' @return With the exception of `type = "raw"`, the result of
59+
#' `predict.model_fit()`
60+
#'
61+
#' * is a tibble
62+
#' * has as many rows as there are rows in `new_data`
63+
#' * has standardized column names, see below:
64+
#'
65+
#' For `type = "numeric"`, the tibble has a `.pred` column for a single
66+
#' outcome and `.pred_Yname` columns for a multivariate outcome.
6567
#'
66-
#' For numeric results with a single outcome, the tibble will have
67-
#' a `.pred` column and `.pred_Yname` for multivariate results.
68+
#' For `type = "class"`, the tibble has a `.pred_class` column.
6869
#'
69-
#' For hard class predictions, the column is named `.pred_class`
70-
#' and, when `type = "prob"`, the columns are `.pred_classlevel`.
70+
#' For `type = "prob"`, the tibble has `.pred_classlevel` columns.
7171
#'
72-
#' `type = "conf_int"` and `type = "pred_int"` return tibbles with
73-
#' columns `.pred_lower` and `.pred_upper` with an attribute for
74-
#' the confidence level. In the case where intervals can be
75-
#' produces for class probabilities (or other non-scalar outputs),
76-
#' the columns will be named `.pred_lower_classlevel` and so on.
72+
#' For `type = "conf_int"` and `type = "pred_int"`, the tibble has
73+
#' `.pred_lower` and `.pred_upper` columns with an attribute for
74+
#' the confidence level. In the case where intervals can be
75+
#' produces for class probabilities (or other non-scalar outputs),
76+
#' the columns are named `.pred_lower_classlevel` and so on.
7777
#'
78-
#' Quantile predictions return a tibble with a column `.pred`, which is
78+
#' For `type = "quantile"`, the tibble has a `.pred` column, which is
7979
#' a list-column. Each list element contains a tibble with columns
8080
#' `.pred` and `.quantile` (and perhaps other columns).
8181
#'
82-
#' Using `type = "raw"` with `predict.model_fit()` will return
83-
#' the unadulterated results of the prediction function.
82+
#' For `type = "time"`, the tibble has a `.pred_time` column.
8483
#'
85-
#' For censored regression:
84+
#' For `type = "survival"`, the tibble has a `.pred` column, which is
85+
#' a list-column. Each list element contains a tibble with columns
86+
#' `.time` and `.pred_survival` (and perhaps other columns).
87+
#'
88+
#' For `type = "hazard"`, the tibble has a `.pred` column, which is
89+
#' a list-column. Each list element contains a tibble with columns
90+
#' `.time` and `.pred_hazard` (and perhaps other columns).
8691
#'
87-
#' * `type = "time"` produces a column `.pred_time`.
88-
#' * `type = "hazard"` results in a list column `.pred` containing tibbles
89-
#' with a column `.pred_hazard`.
90-
#' * `type = "survival"` results in a list column `.pred` containing tibbles
91-
#' with a `.pred_survival` column.
92+
#' Using `type = "raw"` with `predict.model_fit()` will return
93+
#' the unadulterated results of the prediction function.
9294
#'
9395
#' In the case of Spark-based models, since table columns cannot
9496
#' contain dots, the same convention is used except 1) no dots

man/bart-internal.Rd

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/other_predict.Rd

Lines changed: 15 additions & 14 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/predict.model_fit.Rd

Lines changed: 45 additions & 43 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)