File tree Expand file tree Collapse file tree 2 files changed +19
-0
lines changed Expand file tree Collapse file tree 2 files changed +19
-0
lines changed Original file line number Diff line number Diff line change @@ -176,6 +176,7 @@ RUN pip install pysal && \
176
176
pip install tensorflow-gcs-config==${TENSORFLOW_VERSION} && \
177
177
# TODO(b/207851560) Upgrade to 0.17.1 once the base image with TensorFlow 2.9.1 is out.
178
178
pip install tensorflow-addons==0.17.0 && \
179
+ pip install tensorflow_decision_forests==0.2.0 && \
179
180
/tmp/clean-layer.sh
180
181
181
182
RUN apt-get install -y libfreetype6-dev && \
Original file line number Diff line number Diff line change
1
+ import unittest
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ import tensorflow_decision_forests as tfdf
6
+
7
+ class TestTensorflowDecisionForest (unittest .TestCase ):
8
+ def test_fit (self ):
9
+ train_df = pd .read_csv ("/input/tests/data/train.csv" )
10
+
11
+ # Convert the dataset into a TensorFlow dataset.
12
+ train_ds = tfdf .keras .pd_dataframe_to_tf_dataset (train_df , label = "label" )
13
+
14
+ # Train the model
15
+ model = tfdf .keras .RandomForestModel (num_trees = 1 )
16
+ model .fit (train_ds )
17
+
18
+ self .assertEqual (1 , model .count_params ())
You can’t perform that action at this time.
0 commit comments