Skip to content

Commit a1efcff

Browse files
authored
Keras CV (#1270)
http://b/270144038
1 parent a581e74 commit a1efcff

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

Dockerfile.tmpl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,8 @@ RUN pip install flashtext \
526526
accelerate==0.12.0 \
527527
catalyst \
528528
# b/206990323 osmx 1.1.2 requires numpy >= 1.21 which we don't want.
529-
osmnx==1.1.1 && \
529+
osmnx==1.1.1 \
530+
keras-cv && \
530531
apt-get -y install libspatialindex-dev
531532
RUN pip install pytorch-ignite \
532533
qgrid \

tests/test_kerascv.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import unittest
2+
3+
import json
4+
import keras_cv
5+
import keras
6+
import numpy as np
7+
8+
class TestKerasCV(unittest.TestCase):
9+
def test_inference(self):
10+
classifier = keras_cv.models.ImageClassifier.from_preset(
11+
'efficientnetv2_b0_imagenet_classifier',
12+
load_weights=False, # load randomly initialized model from preset architecture with weights
13+
)
14+
image = keras.utils.load_img('/input/tests/data/face.jpg')
15+
image = np.array(image)
16+
keras_cv.visualization.plot_image_gallery(
17+
[image], rows=1, cols=1, value_range=(0, 255), show=True, scale=4
18+
)
19+
predictions = classifier.predict(np.expand_dims(image, axis=0))
20+
top_classes = predictions[0].argsort(axis=-1)
21+
self.assertEqual(1000, len(top_classes))

0 commit comments

Comments
 (0)