Skip to content

Commit 7ac976b

Browse files
committed
wip return mean and std variable importance
1 parent acc5290 commit 7ac976b

File tree

3 files changed

+12
-8
lines changed

3 files changed

+12
-8
lines changed

pymc3/distributions/bart.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,10 +144,9 @@ def get_new_idx_data_points(self, current_split_node, idx_data_points):
144144
idx_split_variable = current_split_node.idx_split_variable
145145
split_value = current_split_node.split_value
146146

147-
left_idx = np.nonzero(self.X[idx_data_points, idx_split_variable] <= split_value)
147+
left_idx = self.X[idx_data_points, idx_split_variable] <= split_value
148148
left_node_idx_data_points = idx_data_points[left_idx]
149-
right_idx = np.nonzero(~(self.X[idx_data_points, idx_split_variable] <= split_value))
150-
right_node_idx_data_points = idx_data_points[right_idx]
149+
right_node_idx_data_points = idx_data_points[~left_idx]
151150

152151
return left_node_idx_data_points, right_node_idx_data_points
153152

pymc3/sampling.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -607,8 +607,11 @@ def sample(
607607
trace.report._t_sampling = t_sampling
608608

609609
if "variable_inclusion" in trace.stat_names:
610-
variable_inclusion = trace.get_sampler_stats("variable_inclusion")[::-draws]
611-
trace.report.variable_importance = np.mean([vi / vi.sum() for vi in variable_inclusion], 0)
610+
variable_inclusion = np.vstack(trace.get_sampler_stats("variable_inclusion"))
611+
variable_inclusion = np.split(variable_inclusion, 50)
612+
dada = np.vstack([v.sum(0) / v.sum() for v in variable_inclusion])
613+
trace.report.variable_importance_m = dada.mean(0)
614+
trace.report.variable_importance_s = dada.std(0)
612615

613616
n_chains = len(trace.chains)
614617
_log.info(

pymc3/step_methods/pgbart.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def __init__(self, vars=None, num_particles=10, max_stages=5000, chunk="auto", m
6060
self.idx = 0
6161
if chunk == "auto":
6262
self.chunk = max(1, int(self.bart.m * 0.1))
63-
self.variable_inclusion = np.zeros(self.bart.num_variates)
63+
self.variable_inclusion = np.zeros(self.bart.num_variates, dtype="int")
6464
self.num_particles = num_particles
6565
self.log_num_particles = np.log(num_particles)
6666
self.indices = list(range(1, num_particles))
@@ -77,6 +77,7 @@ def __init__(self, vars=None, num_particles=10, max_stages=5000, chunk="auto", m
7777
def astep(self, _):
7878
bart = self.bart
7979
num_observations = bart.num_observations
80+
variable_inclusion = self.variable_inclusion
8081

8182
# For the tunning phase we restrict max_stages to a low number, otherwise it is almost sure
8283
# we will reach max_stages given that our first set of m trees is not good at all.
@@ -141,10 +142,11 @@ def astep(self, _):
141142
bart.sum_trees_output = bart.Y - R_j + new_prediction
142143

143144
if not self.tune:
145+
variable_inclusion = self.variable_inclusion
144146
for index in new_tree.used_variates:
145-
self.variable_inclusion[index] += 1
147+
variable_inclusion[index] += 1
146148

147-
stats = {"variable_inclusion": self.variable_inclusion}
149+
stats = {"variable_inclusion": variable_inclusion}
148150

149151
return bart.sum_trees_output, [stats]
150152

0 commit comments

Comments
 (0)