Skip to content

Commit c8649e5

Browse files
authored
Merge branch 'main' into upgrade-tf2.9
2 parents 49aafb7 + ca77a9b commit c8649e5

File tree

2 files changed

+19
-0
lines changed

2 files changed

+19
-0
lines changed

Dockerfile.tmpl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ RUN pip install pysal && \
176176
pip install tensorflow-gcs-config==${TENSORFLOW_VERSION} && \
177177
# TODO(b/207851560) Upgrade to 0.17.1 once the base image with TensorFlow 2.9.1 is out.
178178
pip install tensorflow-addons==0.17.0 && \
179+
pip install tensorflow_decision_forests==0.2.0 && \
179180
/tmp/clean-layer.sh
180181

181182
RUN apt-get install -y libfreetype6-dev && \
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import unittest
2+
3+
import numpy as np
4+
import pandas as pd
5+
import tensorflow_decision_forests as tfdf
6+
7+
class TestTensorflowDecisionForest(unittest.TestCase):
8+
def test_fit(self):
9+
train_df = pd.read_csv("/input/tests/data/train.csv")
10+
11+
# Convert the dataset into a TensorFlow dataset.
12+
train_ds = tfdf.keras.pd_dataframe_to_tf_dataset(train_df, label="label")
13+
14+
# Train the model
15+
model = tfdf.keras.RandomForestModel(num_trees=1)
16+
model.fit(train_ds)
17+
18+
self.assertEqual(1, model.count_params())

0 commit comments

Comments
 (0)