Skip to content

Commit e03f5bf

Browse files
authored
make linear response more robust and fix bug with predictions (#5080)
1 parent ce447cc commit e03f5bf

File tree

4 files changed

+42
-30
lines changed

4 files changed

+42
-30
lines changed

pymc/bart/bart.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import numpy as np
1616

1717
from aesara.tensor.random.op import RandomVariable, default_shape_from_params
18+
from pandas import DataFrame, Series
1819

1920
from pymc.distributions.distribution import NoDistribution
2021

@@ -93,8 +94,8 @@ class BART(NoDistribution):
9394
Scale parameter for the values of the leaf nodes. Defaults to 2. Recomended to be between 1
9495
and 3.
9596
response : str
96-
How the leaf_node values are computed. Available options are ``constant``, ``linear`` or
97-
``mix`` (default).
97+
How the leaf_node values are computed. Available options are ``constant`` (default),
98+
``linear`` or ``mix``.
9899
split_prior : array-like
99100
Each element of split_prior should be in the [0, 1] interval and the elements should sum to
100101
1. Otherwise they will be normalized.
@@ -109,12 +110,13 @@ def __new__(
109110
m=50,
110111
alpha=0.25,
111112
k=2,
112-
response="mix",
113+
response="constant",
113114
split_prior=None,
114115
**kwargs,
115116
):
116117

117118
cls.all_trees = []
119+
X, Y = preprocess_XY(X, Y)
118120

119121
bart_op = type(
120122
f"BART_{name}",
@@ -143,3 +145,14 @@ def __new__(
143145
@classmethod
144146
def dist(cls, *params, **kwargs):
145147
return super().dist(params, **kwargs)
148+
149+
150+
def preprocess_XY(X, Y):
151+
if isinstance(Y, (Series, DataFrame)):
152+
Y = Y.to_numpy()
153+
if isinstance(X, (Series, DataFrame)):
154+
X = X.to_numpy()
155+
# X = np.random.normal(X, X.std(0)/100)
156+
Y = Y.astype(float)
157+
X = X.astype(float)
158+
return X, Y

pymc/bart/pgbart.py

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import numpy as np
2121

2222
from aesara import function as aesara_function
23-
from pandas import DataFrame, Series
2423

2524
from pymc.aesaraf import inputvars, join_nonshared_inputs, make_shared_replacements
2625
from pymc.bart.bart import BARTRV
@@ -127,11 +126,13 @@ class PGBART(ArrayStepShared):
127126
def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", model=None):
128127
_log.warning("BART is experimental. Use with caution.")
129128
model = modelcontext(model)
130-
initial_values = model.initial_point
129+
initial_values = model.recompute_initial_point()
131130
value_bart = inputvars(vars)[0]
132131
self.bart = model.values_to_rvs[value_bart].owner.op
133132

134-
self.X, self.Y, self.missing_data = preprocess_XY(self.bart.X, self.bart.Y)
133+
self.X = self.bart.X
134+
self.Y = self.bart.Y
135+
self.missing_data = np.any(np.isnan(self.X))
135136
self.m = self.bart.m
136137
self.alpha = self.bart.alpha
137138
self.k = self.bart.k
@@ -342,16 +343,6 @@ def update_weight(self, particle: List[ParticleTree]) -> None:
342343
particle.old_likelihood_logp = new_likelihood
343344

344345

345-
def preprocess_XY(X, Y):
346-
if isinstance(Y, (Series, DataFrame)):
347-
Y = Y.to_numpy()
348-
if isinstance(X, (Series, DataFrame)):
349-
X = X.to_numpy()
350-
missing_data = np.any(np.isnan(X))
351-
Y = Y.astype(float)
352-
return X, Y, missing_data
353-
354-
355346
class SampleSplittingVariable:
356347
def __init__(self, alpha_prior):
357348
"""
@@ -493,16 +484,19 @@ def draw_leaf_value(Y_mu_pred, X_mu, mean, linear_fit, m, normal, mu_std, respon
493484
linear_params = None
494485
if Y_mu_pred.size == 0:
495486
return 0, linear_params
496-
elif Y_mu_pred.size == 1:
497-
mu_mean = Y_mu_pred.item() / m
498487
else:
499-
if response == "constant":
488+
norm = normal.random() * mu_std
489+
if Y_mu_pred.size == 1:
490+
mu_mean = Y_mu_pred.item() / m
491+
elif response == "constant":
500492
mu_mean = mean(Y_mu_pred) / m
501493
elif response == "linear":
502494
Y_fit, linear_params = linear_fit(X_mu, Y_mu_pred)
503495
mu_mean = Y_fit / m
504-
draw = normal.random() * mu_std + mu_mean
505-
return draw, linear_params
496+
linear_params[2] = norm
497+
498+
draw = norm + mu_mean
499+
return draw, linear_params
506500

507501

508502
def fast_mean():
@@ -532,11 +526,14 @@ def linear_fit(X, Y):
532526
xbar = np.sum(X) / n
533527
ybar = np.sum(Y) / n
534528

535-
b = (X @ Y - n * xbar * ybar) / (X @ X - n * xbar ** 2)
536-
a = ybar - b * xbar
529+
if np.all(X == xbar):
530+
b = 0
531+
else:
532+
b = (X @ Y - n * xbar * ybar) / (X @ X - n * xbar ** 2)
537533

534+
a = ybar - b * xbar
538535
Y_fit = a + b * X
539-
return Y_fit, (a, b)
536+
return Y_fit, [a, b, 0]
540537

541538
try:
542539
from numba import jit

pymc/bart/tree.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -111,12 +111,13 @@ def predict_out_of_sample(self, X, m):
111111
Value of the leaf value where the unobserved point lies.
112112
"""
113113
leaf_node, split_variable = self._traverse_tree(X, node_index=0)
114-
if leaf_node.linear_params is None:
114+
linear_params = leaf_node.linear_params
115+
if linear_params is None:
115116
return leaf_node.value
116117
else:
117118
x = X[split_variable].item()
118-
y_x = leaf_node.linear_params[0] + leaf_node.linear_params[1] * x
119-
return y_x / m
119+
y_x = (linear_params[0] + linear_params[1] * x) / m
120+
return y_x + linear_params[2]
120121

121122
def _traverse_tree(self, x, node_index=0, split_variable=None):
122123
"""
@@ -136,10 +137,10 @@ def _traverse_tree(self, x, node_index=0, split_variable=None):
136137
split_variable = current_node.idx_split_variable
137138
if x[split_variable] <= current_node.split_value:
138139
left_child = current_node.get_idx_left_child()
139-
current_node, _ = self._traverse_tree(x, left_child, split_variable)
140+
current_node, split_variable = self._traverse_tree(x, left_child, split_variable)
140141
else:
141142
right_child = current_node.get_idx_right_child()
142-
current_node, _ = self._traverse_tree(x, right_child, split_variable)
143+
current_node, split_variable = self._traverse_tree(x, right_child, split_variable)
143144
return current_node, split_variable
144145

145146
def grow_tree(self, index_leaf_node, new_split_node, new_left_node, new_right_node):

pymc/tests/test_bart.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def test_bart_vi():
4444
)
4545
var_imp /= var_imp.sum()
4646
assert var_imp[0] > var_imp[1:].sum()
47-
np.testing.assert_almost_equal(var_imp.sum(), 1)
47+
assert_almost_equal(var_imp.sum(), 1)
4848

4949

5050
def test_bart_random():
@@ -62,6 +62,7 @@ def test_bart_random():
6262
rng = RandomState(12345)
6363
pred_first = mu.owner.op.rng_fn(rng, X_new=X[:10])
6464

65+
assert_almost_equal(pred_first, pred_all[0, :10], decimal=4)
6566
assert pred_all.shape == (2, 50)
6667
assert pred_first.shape == (10,)
6768

0 commit comments

Comments
 (0)