Skip to content

Commit ca77a9b

Browse files
authored
Add tensorflow_decision_forests package (#1183)
Add also smoke test to prevent regression.
1 parent f82a0ad commit ca77a9b

File tree

2 files changed

+19
-0
lines changed

2 files changed

+19
-0
lines changed

Dockerfile.tmpl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ RUN pip install pysal && \
169169
pip install -f https://h2o-release.s3.amazonaws.com/h2o/latest_stable_Py.html h2o && \
170170
pip install tensorflow-gcs-config==2.6.0 && \
171171
pip install tensorflow-addons==0.14.0 && \
172+
pip install tensorflow_decision_forests==0.2.0 && \
172173
/tmp/clean-layer.sh
173174

174175
RUN apt-get install -y libfreetype6-dev && \
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import unittest
2+
3+
import numpy as np
4+
import pandas as pd
5+
import tensorflow_decision_forests as tfdf
6+
7+
class TestTensorflowDecisionForest(unittest.TestCase):
8+
def test_fit(self):
9+
train_df = pd.read_csv("/input/tests/data/train.csv")
10+
11+
# Convert the dataset into a TensorFlow dataset.
12+
train_ds = tfdf.keras.pd_dataframe_to_tf_dataset(train_df, label="label")
13+
14+
# Train the model
15+
model = tfdf.keras.RandomForestModel(num_trees=1)
16+
model.fit(train_ds)
17+
18+
self.assertEqual(1, model.count_params())

0 commit comments

Comments
 (0)