Skip to content

Commit 2e43fcc

Browse files
authored
Add Hugging Face datasets (#1152)
- Included a smoke test - Fix apt-key issue http://b/230657835
1 parent f1a3cfc commit 2e43fcc

File tree

2 files changed

+24
-4
lines changed

2 files changed

+24
-4
lines changed

Dockerfile.tmpl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,16 @@ RUN pip uninstall -y horovod && \
5151
/tmp/clean-layer.sh
5252
{{ end }}
5353

54+
{{ if eq .Accelerator "gpu" }}
55+
# b/230864778: Temporarily swap the NVIDIA GPG key. Remove once new base image with new GPG key is released.
56+
RUN rm /etc/apt/sources.list.d/cuda.list && \
57+
apt-key del 7fa2af80 && \
58+
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/7fa2af80.pub
59+
{{ end }}
60+
5461
# Use a fixed apt-get repo to stop intermittent failures due to flaky httpredir connections,
5562
# as described by Lionel Chan at http://stackoverflow.com/a/37426929/5881346
5663
RUN sed -i "s/httpredir.debian.org/debian.uchicago.edu/" /etc/apt/sources.list && \
57-
# b/230864778: Temporarily swap the NVIDIA GPG key. Remove once new base image with new GPG key is released.
58-
rm /etc/apt/sources.list.d/cuda.list && \
59-
apt-key del 7fa2af80 && \
60-
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/7fa2af80.pub && \
6164
apt-get update && \
6265
# Needed by lightGBM (GPU build)
6366
# https://lightgbm.readthedocs.io/en/latest/GPU-Tutorial.html#build-lightgbm
@@ -491,6 +494,7 @@ RUN pip install flashtext && \
491494
pip install bqplot && \
492495
pip install earthengine-api && \
493496
pip install transformers && \
497+
pip install datasets && \
494498
pip install dlib && \
495499
pip install kaggle-environments && \
496500
pip install geopandas && \

tests/test_hf_datasets.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import unittest
2+
3+
from datasets import Dataset
4+
5+
6+
class TestHuggingFaceDatasets(unittest.TestCase):
7+
8+
def test_map(self):
9+
def some_func(batch):
10+
batch['label'] = 'foo'
11+
return batch
12+
13+
df = Dataset.from_dict({'text': ['Kaggle rocks!']})
14+
mapped_df = df.map(some_func)
15+
16+
self.assertEqual('foo', mapped_df[0]['label'])

0 commit comments

Comments
 (0)