Skip to content

Commit 8ab3385

Browse files
authored
[Java][Android] add unit test for EValue (#6641)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent f31bcca commit 8ab3385

File tree

2 files changed

+219
-0
lines changed

2 files changed

+219
-0
lines changed

extension/android/build.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,6 @@ task makeJar(type: Jar) {
2020
dependencies {
2121
implementation 'com.facebook.fbjni:fbjni-java-only:0.2.2'
2222
implementation 'com.facebook.soloader:nativeloader:0.10.5'
23+
testImplementation 'junit:junit:4.13.2'
2324
}
2425
}
Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
package org.pytorch.executorch;
10+
11+
import static org.junit.Assert.assertEquals;
12+
import static org.junit.Assert.assertTrue;
13+
import static org.junit.Assert.assertFalse;
14+
import static org.junit.Assert.assertNotEquals;
15+
import static org.junit.Assert.fail;
16+
17+
import com.facebook.jni.annotations.DoNotStrip;
18+
19+
import java.util.List;
20+
import java.util.ArrayList;
21+
import java.util.Arrays;
22+
import java.util.Locale;
23+
import java.util.Optional;
24+
25+
import org.pytorch.executorch.Tensor.Tensor_int64;
26+
import org.pytorch.executorch.annotations.Experimental;
27+
28+
import org.junit.Test;
29+
import org.junit.runner.RunWith;
30+
import org.junit.runners.JUnit4;
31+
32+
/** Unit tests for {@link EValue}. */
33+
@RunWith(JUnit4.class)
34+
public class EValueTest {
35+
36+
@Test
37+
public void testNone() {
38+
EValue evalue = EValue.optionalNone();
39+
assertTrue(evalue.isNone());
40+
}
41+
42+
@Test
43+
public void testTensorValue() {
44+
long[] data = {1, 2, 3};
45+
long[] shape = {1, 3};
46+
EValue evalue = EValue.from(Tensor.fromBlob(data, shape));
47+
assertTrue(evalue.isTensor());
48+
assertTrue(Arrays.equals(evalue.toTensor().shape, shape));
49+
assertTrue(Arrays.equals(evalue.toTensor().getDataAsLongArray(), data));
50+
}
51+
52+
@Test
53+
public void testBoolValue() {
54+
EValue evalue = EValue.from(true);
55+
assertTrue(evalue.isBool());
56+
assertTrue(evalue.toBool());
57+
}
58+
59+
@Test
60+
public void testIntValue() {
61+
EValue evalue = EValue.from(1);
62+
assertTrue(evalue.isInt());
63+
assertEquals(evalue.toInt(), 1);
64+
}
65+
66+
@Test
67+
public void testDoubleValue() {
68+
EValue evalue = EValue.from(0.1d);
69+
assertTrue(evalue.isDouble());
70+
assertEquals(evalue.toDouble(), 0.1d, 0.0001d);
71+
}
72+
73+
@Test
74+
public void testStringValue() {
75+
EValue evalue = EValue.from("a");
76+
assertTrue(evalue.isString());
77+
assertEquals(evalue.toStr(), "a");
78+
}
79+
80+
@Test
81+
public void testBoolListValue() {
82+
boolean[] value = {true, false, true};
83+
EValue evalue = EValue.listFrom(value);
84+
assertTrue(evalue.isBoolList());
85+
assertTrue(Arrays.equals(value, evalue.toBoolList()));
86+
}
87+
88+
@Test
89+
public void testIntListValue() {
90+
long[] value = {Long.MIN_VALUE, 0, Long.MAX_VALUE};
91+
EValue evalue = EValue.listFrom(value);
92+
assertTrue(evalue.isIntList());
93+
assertTrue(Arrays.equals(value, evalue.toIntList()));
94+
}
95+
96+
@Test
97+
public void testDoubleListValue() {
98+
double[] value = {Double.MIN_VALUE,0.1d, 0.01d, 0.001d, Double.MAX_VALUE};
99+
EValue evalue = EValue.listFrom(value);
100+
assertTrue(evalue.isDoubleList());
101+
assertTrue(Arrays.equals(value, evalue.toDoubleList()));
102+
}
103+
104+
@Test
105+
public void testTensorListValue() {
106+
long[][] data = {{1, 2, 3}, {1, 2, 3, 4, 5, 6}};
107+
long[][] shape = {{1, 3}, {2, 3}};
108+
Tensor[] tensors = {Tensor.fromBlob(data[0], shape[0]), Tensor.fromBlob(data[1], shape[1])};
109+
110+
EValue evalue = EValue.listFrom(tensors);
111+
assertTrue(evalue.isTensorList());
112+
113+
assertTrue(Arrays.equals(evalue.toTensorList()[0].shape, shape[0]));
114+
assertTrue(Arrays.equals(evalue.toTensorList()[0].getDataAsLongArray(), data[0]));
115+
116+
assertTrue(Arrays.equals(evalue.toTensorList()[1].shape, shape[1]));
117+
assertTrue(Arrays.equals(evalue.toTensorList()[1].getDataAsLongArray(), data[1]));
118+
}
119+
120+
@Test
121+
@SuppressWarnings("unchecked")
122+
public void testOptionalTensorListValue() {
123+
long[][] data = {{1, 2, 3}, {1, 2, 3, 4, 5, 6}};
124+
long[][] shape = {{1, 3}, {2, 3}};
125+
126+
EValue evalue = EValue.listFrom(
127+
Optional.<Tensor>empty(),
128+
Optional.of(Tensor.fromBlob(data[0], shape[0])),
129+
Optional.of(Tensor.fromBlob(data[1], shape[1])));
130+
assertTrue(evalue.isOptionalTensorList());
131+
132+
assertTrue(evalue.toOptionalTensorList()[0].isEmpty());
133+
134+
assertTrue(evalue.toOptionalTensorList()[1].isPresent());
135+
assertTrue(Arrays.equals(evalue.toOptionalTensorList()[1].get().shape, shape[0]));
136+
assertTrue(Arrays.equals(evalue.toOptionalTensorList()[1].get().getDataAsLongArray(), data[0]));
137+
138+
assertTrue(evalue.toOptionalTensorList()[2].isPresent());
139+
assertTrue(Arrays.equals(evalue.toOptionalTensorList()[2].get().shape, shape[1]));
140+
assertTrue(Arrays.equals(evalue.toOptionalTensorList()[2].get().getDataAsLongArray(), data[1]));
141+
}
142+
143+
@Test
144+
public void testAllIllegalCast() {
145+
EValue evalue = EValue.optionalNone();
146+
assertTrue(evalue.isNone());
147+
148+
// try Tensor
149+
assertFalse(evalue.isTensor());
150+
try {
151+
evalue.toTensor();
152+
fail("Should have thrown an exception");
153+
} catch (IllegalStateException e) {}
154+
155+
// try bool
156+
assertFalse(evalue.isBool());
157+
try {
158+
evalue.toBool();
159+
fail("Should have thrown an exception");
160+
} catch (IllegalStateException e) {}
161+
162+
// try int
163+
assertFalse(evalue.isInt());
164+
try {
165+
evalue.toInt();
166+
fail("Should have thrown an exception");
167+
} catch (IllegalStateException e) {}
168+
169+
// try double
170+
assertFalse(evalue.isDouble());
171+
try {
172+
evalue.toDouble();
173+
fail("Should have thrown an exception");
174+
} catch (IllegalStateException e) {}
175+
176+
// try string
177+
assertFalse(evalue.isString());
178+
try {
179+
evalue.toStr();
180+
fail("Should have thrown an exception");
181+
} catch (IllegalStateException e) {}
182+
183+
// try bool list
184+
assertFalse(evalue.isBoolList());
185+
try {
186+
evalue.toBoolList();
187+
fail("Should have thrown an exception");
188+
} catch (IllegalStateException e) {}
189+
190+
// try int list
191+
assertFalse(evalue.isIntList());
192+
try {
193+
evalue.toIntList();
194+
fail("Should have thrown an exception");
195+
} catch (IllegalStateException e) {}
196+
197+
// try double list
198+
assertFalse(evalue.isDoubleList());
199+
try {
200+
evalue.toBool();
201+
fail("Should have thrown an exception");
202+
} catch (IllegalStateException e) {}
203+
204+
// try Tensor list
205+
assertFalse(evalue.isTensorList());
206+
try {
207+
evalue.toTensorList();
208+
fail("Should have thrown an exception");
209+
} catch (IllegalStateException e) {}
210+
211+
// try optional Tensor list
212+
assertFalse(evalue.isOptionalTensorList());
213+
try {
214+
evalue.toOptionalTensorList();
215+
fail("Should have thrown an exception");
216+
} catch (IllegalStateException e) {}
217+
}
218+
}

0 commit comments

Comments
 (0)