Skip to content

Commit c6d2aa0

Browse files
committed
Added kernel debugging script and updated kernel unit tests
1 parent e2413f0 commit c6d2aa0

File tree

2 files changed

+35
-11
lines changed

2 files changed

+35
-11
lines changed

demo/debug/kernel.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import numpy as np
2+
from stochtree import Dataset, ForestContainer, compute_forest_leaf_indices
3+
4+
# Create dataset
5+
X = np.array(
6+
[[1.5, 8.7, 1.2],
7+
[2.7, 3.4, 5.4],
8+
[3.6, 1.2, 9.3],
9+
[4.4, 5.4, 10.4],
10+
[5.3, 9.3, 3.6],
11+
[6.1, 10.4, 4.4]]
12+
)
13+
n, p = X.shape
14+
num_trees = 2
15+
output_dim = 1
16+
forest_dataset = Dataset()
17+
forest_dataset.add_covariates(X)
18+
forest_samples = ForestContainer(num_trees, output_dim, True, False)
19+
20+
# Initialize a forest with constant root predictions
21+
forest_samples.add_sample(0.)
22+
23+
# Split the root of the first tree in the ensemble at X[,1] > 4.0
24+
forest_samples.add_numeric_split(0, 0, 0, 0, 4.0, -5., 5.)
25+
26+
# Check that regular and "raw" predictions are the same (since the leaf is constant)
27+
computed_indices = compute_forest_leaf_indices(forest_samples, X)
28+
29+
# Split the left leaf of the first tree in the ensemble at X[,2] > 4.0
30+
forest_samples.add_numeric_split(0, 0, 1, 1, 4.0, -7.5, -2.5)
31+
32+
# Check that regular and "raw" predictions are the same (since the leaf is constant)
33+
computed_indices = compute_forest_leaf_indices(forest_samples, X)

test/python/test_kernel.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
)
1010

1111

12-
class TestJson:
13-
def test_value(self):
12+
class TestKernel:
13+
def test_forest(self):
1414
# Create dataset
1515
X = np.array(
1616
[[1.5, 8.7, 1.2],
@@ -30,19 +30,11 @@ def test_value(self):
3030
# Initialize a forest with constant root predictions
3131
forest_samples.add_sample(0.)
3232

33-
# Check that regular and "raw" predictions are the same (since the leaf is constant)
34-
pred = forest_samples.predict(forest_dataset)
35-
pred_raw = forest_samples.predict_raw(forest_dataset)
36-
37-
# Assertion
38-
np.testing.assert_almost_equal(pred, pred_raw)
39-
4033
# Split the root of the first tree in the ensemble at X[,1] > 4.0
4134
forest_samples.add_numeric_split(0, 0, 0, 0, 4.0, -5., 5.)
4235

4336
# Check that regular and "raw" predictions are the same (since the leaf is constant)
4437
computed = compute_forest_leaf_indices(forest_samples, X)
45-
print(computed)
4638
expected = np.array([
4739
[0],
4840
[0],
@@ -66,7 +58,6 @@ def test_value(self):
6658

6759
# Check that regular and "raw" predictions are the same (since the leaf is constant)
6860
computed = compute_forest_leaf_indices(forest_samples, X)
69-
print(computed)
7061
expected = np.array([
7162
[2],
7263
[1],

0 commit comments

Comments
 (0)