Skip to content

Commit 437168e

Browse files
authored
[Android] added tests for Tensor.java
Differential Revision: D65608097 Pull Request resolved: #6683
1 parent f9698d8 commit 437168e

File tree

1 file changed

+270
-0
lines changed

1 file changed

+270
-0
lines changed
Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
package org.pytorch.executorch;
10+
11+
import static org.junit.Assert.assertEquals;
12+
import static org.junit.Assert.assertTrue;
13+
import static org.junit.Assert.assertFalse;
14+
import static org.junit.Assert.assertNotEquals;
15+
import static org.junit.Assert.fail;
16+
17+
import java.nio.ByteBuffer;
18+
import java.nio.DoubleBuffer;
19+
import java.nio.FloatBuffer;
20+
import java.nio.IntBuffer;
21+
import java.nio.LongBuffer;
22+
23+
import org.junit.Test;
24+
import org.junit.runner.RunWith;
25+
import org.junit.runners.JUnit4;
26+
import org.pytorch.executorch.Tensor;
27+
28+
/** Unit tests for {@link Tensor}. */
29+
@RunWith(JUnit4.class)
30+
public class TensorTest {
31+
32+
@Test
33+
public void testFloatTensor() {
34+
float data[] = {Float.MIN_VALUE, 0f, 0.1f, Float.MAX_VALUE};
35+
long shape[] = {2, 2};
36+
Tensor tensor = Tensor.fromBlob(data, shape);
37+
assertEquals(tensor.dtype(), DType.FLOAT);
38+
assertEquals(shape[0], tensor.shape()[0]);
39+
assertEquals(shape[1], tensor.shape()[1]);
40+
assertEquals(4, tensor.numel());
41+
assertEquals(data[0], tensor.getDataAsFloatArray()[0], 1e-5);
42+
assertEquals(data[1], tensor.getDataAsFloatArray()[1], 1e-5);
43+
assertEquals(data[2], tensor.getDataAsFloatArray()[2], 1e-5);
44+
assertEquals(data[3], tensor.getDataAsFloatArray()[3], 1e-5);
45+
46+
FloatBuffer floatBuffer = Tensor.allocateFloatBuffer(4);
47+
floatBuffer.put(data);
48+
tensor = Tensor.fromBlob(floatBuffer, shape);
49+
assertEquals(tensor.dtype(), DType.FLOAT);
50+
assertEquals(shape[0], tensor.shape()[0]);
51+
assertEquals(shape[1], tensor.shape()[1]);
52+
assertEquals(4, tensor.numel());
53+
assertEquals(data[0], tensor.getDataAsFloatArray()[0], 1e-5);
54+
assertEquals(data[1], tensor.getDataAsFloatArray()[1], 1e-5);
55+
assertEquals(data[2], tensor.getDataAsFloatArray()[2], 1e-5);
56+
assertEquals(data[3], tensor.getDataAsFloatArray()[3], 1e-5);
57+
}
58+
59+
@Test
60+
public void testIntTensor() {
61+
int data[] = {Integer.MIN_VALUE, 0, 1, Integer.MAX_VALUE};
62+
long shape[] = {1, 4, 1};
63+
Tensor tensor = Tensor.fromBlob(data, shape);
64+
assertEquals(tensor.dtype(), DType.INT32);
65+
assertEquals(shape[0], tensor.shape()[0]);
66+
assertEquals(shape[1], tensor.shape()[1]);
67+
assertEquals(shape[2], tensor.shape()[2]);
68+
assertEquals(4, tensor.numel());
69+
assertEquals(data[0], tensor.getDataAsIntArray()[0]);
70+
assertEquals(data[1], tensor.getDataAsIntArray()[1]);
71+
assertEquals(data[2], tensor.getDataAsIntArray()[2]);
72+
assertEquals(data[3], tensor.getDataAsIntArray()[3]);
73+
74+
IntBuffer intBuffer = Tensor.allocateIntBuffer(4);
75+
intBuffer.put(data);
76+
tensor = Tensor.fromBlob(intBuffer, shape);
77+
assertEquals(tensor.dtype(), DType.INT32);
78+
assertEquals(shape[0], tensor.shape()[0]);
79+
assertEquals(shape[1], tensor.shape()[1]);
80+
assertEquals(shape[2], tensor.shape()[2]);
81+
assertEquals(4, tensor.numel());
82+
assertEquals(data[0], tensor.getDataAsIntArray()[0]);
83+
assertEquals(data[1], tensor.getDataAsIntArray()[1]);
84+
assertEquals(data[2], tensor.getDataAsIntArray()[2]);
85+
assertEquals(data[3], tensor.getDataAsIntArray()[3]);
86+
}
87+
88+
@Test
89+
public void testDoubleTensor() {
90+
double data[] = {Double.MIN_VALUE, 0.0d, 0.1d, Double.MAX_VALUE};
91+
long shape[] = {1, 4};
92+
Tensor tensor = Tensor.fromBlob(data, shape);
93+
assertEquals(tensor.dtype(), DType.DOUBLE);
94+
assertEquals(shape[0], tensor.shape()[0]);
95+
assertEquals(shape[1], tensor.shape()[1]);
96+
assertEquals(4, tensor.numel());
97+
assertEquals(data[0], tensor.getDataAsDoubleArray()[0], 1e-5);
98+
assertEquals(data[1], tensor.getDataAsDoubleArray()[1], 1e-5);
99+
assertEquals(data[2], tensor.getDataAsDoubleArray()[2], 1e-5);
100+
assertEquals(data[3], tensor.getDataAsDoubleArray()[3], 1e-5);
101+
102+
DoubleBuffer doubleBuffer = Tensor.allocateDoubleBuffer(4);
103+
doubleBuffer.put(data);
104+
tensor = Tensor.fromBlob(doubleBuffer, shape);
105+
assertEquals(tensor.dtype(), DType.DOUBLE);
106+
assertEquals(shape[0], tensor.shape()[0]);
107+
assertEquals(shape[1], tensor.shape()[1]);
108+
assertEquals(4, tensor.numel());
109+
assertEquals(data[0], tensor.getDataAsDoubleArray()[0], 1e-5);
110+
assertEquals(data[1], tensor.getDataAsDoubleArray()[1], 1e-5);
111+
assertEquals(data[2], tensor.getDataAsDoubleArray()[2], 1e-5);
112+
assertEquals(data[3], tensor.getDataAsDoubleArray()[3], 1e-5);
113+
}
114+
115+
@Test
116+
public void testLongTensor() {
117+
long data[] = {Long.MIN_VALUE, 0L, 1L, Long.MAX_VALUE};
118+
long shape[] = {4, 1};
119+
Tensor tensor = Tensor.fromBlob(data, shape);
120+
assertEquals(tensor.dtype(), DType.INT64);
121+
assertEquals(shape[0], tensor.shape()[0]);
122+
assertEquals(shape[1], tensor.shape()[1]);
123+
assertEquals(4, tensor.numel());
124+
assertEquals(data[0], tensor.getDataAsLongArray()[0]);
125+
assertEquals(data[1], tensor.getDataAsLongArray()[1]);
126+
assertEquals(data[2], tensor.getDataAsLongArray()[2]);
127+
assertEquals(data[3], tensor.getDataAsLongArray()[3]);
128+
129+
LongBuffer longBuffer = Tensor.allocateLongBuffer(4);
130+
longBuffer.put(data);
131+
tensor = Tensor.fromBlob(longBuffer, shape);
132+
assertEquals(tensor.dtype(), DType.INT64);
133+
assertEquals(shape[0], tensor.shape()[0]);
134+
assertEquals(shape[1], tensor.shape()[1]);
135+
assertEquals(4, tensor.numel());
136+
assertEquals(data[0], tensor.getDataAsLongArray()[0]);
137+
assertEquals(data[1], tensor.getDataAsLongArray()[1]);
138+
assertEquals(data[2], tensor.getDataAsLongArray()[2]);
139+
assertEquals(data[3], tensor.getDataAsLongArray()[3]);
140+
}
141+
142+
@Test
143+
public void testSignedByteTensor() {
144+
byte data[] = {Byte.MIN_VALUE, (byte) 0, (byte) 1, Byte.MAX_VALUE};
145+
long shape[] = {1, 1, 4};
146+
Tensor tensor = Tensor.fromBlob(data, shape);
147+
assertEquals(tensor.dtype(), DType.INT8);
148+
assertEquals(shape[0], tensor.shape()[0]);
149+
assertEquals(shape[1], tensor.shape()[1]);
150+
assertEquals(shape[2], tensor.shape()[2]);
151+
assertEquals(4, tensor.numel());
152+
assertEquals(data[0], tensor.getDataAsByteArray()[0]);
153+
assertEquals(data[1], tensor.getDataAsByteArray()[1]);
154+
assertEquals(data[2], tensor.getDataAsByteArray()[2]);
155+
assertEquals(data[3], tensor.getDataAsByteArray()[3]);
156+
157+
ByteBuffer byteBuffer = Tensor.allocateByteBuffer(4);
158+
byteBuffer.put(data);
159+
tensor = Tensor.fromBlob(byteBuffer, shape);
160+
assertEquals(tensor.dtype(), DType.INT8);
161+
assertEquals(shape[0], tensor.shape()[0]);
162+
assertEquals(shape[1], tensor.shape()[1]);
163+
assertEquals(shape[2], tensor.shape()[2]);
164+
assertEquals(4, tensor.numel());
165+
assertEquals(data[0], tensor.getDataAsByteArray()[0]);
166+
assertEquals(data[1], tensor.getDataAsByteArray()[1]);
167+
assertEquals(data[2], tensor.getDataAsByteArray()[2]);
168+
assertEquals(data[3], tensor.getDataAsByteArray()[3]);
169+
}
170+
171+
@Test
172+
public void testUnsignedByteTensor() {
173+
byte data[] = {(byte) 0, (byte) 1, (byte) 2, (byte) 255};
174+
long shape[] = {4, 1, 1};
175+
Tensor tensor = Tensor.fromBlobUnsigned(data, shape);
176+
assertEquals(tensor.dtype(), DType.UINT8);
177+
assertEquals(shape[0], tensor.shape()[0]);
178+
assertEquals(shape[1], tensor.shape()[1]);
179+
assertEquals(shape[2], tensor.shape()[2]);
180+
assertEquals(4, tensor.numel());
181+
assertEquals(data[0], tensor.getDataAsUnsignedByteArray()[0]);
182+
assertEquals(data[1], tensor.getDataAsUnsignedByteArray()[1]);
183+
assertEquals(data[2], tensor.getDataAsUnsignedByteArray()[2]);
184+
assertEquals(data[3], tensor.getDataAsUnsignedByteArray()[3]);
185+
186+
ByteBuffer byteBuffer = Tensor.allocateByteBuffer(4);
187+
byteBuffer.put(data);
188+
tensor = Tensor.fromBlobUnsigned(byteBuffer, shape);
189+
assertEquals(tensor.dtype(), DType.UINT8);
190+
assertEquals(shape[0], tensor.shape()[0]);
191+
assertEquals(shape[1], tensor.shape()[1]);
192+
assertEquals(shape[2], tensor.shape()[2]);
193+
assertEquals(4, tensor.numel());
194+
assertEquals(data[0], tensor.getDataAsUnsignedByteArray()[0]);
195+
assertEquals(data[1], tensor.getDataAsUnsignedByteArray()[1]);
196+
assertEquals(data[2], tensor.getDataAsUnsignedByteArray()[2]);
197+
assertEquals(data[3], tensor.getDataAsUnsignedByteArray()[3]);
198+
}
199+
200+
@Test
201+
public void testIllegalDataTypeException() {
202+
float data[] = {Float.MIN_VALUE, 0f, 0.1f, Float.MAX_VALUE};
203+
long shape[] = {2, 2};
204+
Tensor tensor = Tensor.fromBlob(data, shape);
205+
assertEquals(tensor.dtype(), DType.FLOAT);
206+
207+
try {
208+
tensor.getDataAsByteArray();
209+
fail("Should have thrown an exception");
210+
} catch (IllegalStateException e) {
211+
// expected
212+
}
213+
try {
214+
tensor.getDataAsUnsignedByteArray();
215+
fail("Should have thrown an exception");
216+
} catch (IllegalStateException e) {
217+
// expected
218+
}
219+
try {
220+
tensor.getDataAsIntArray();
221+
fail("Should have thrown an exception");
222+
} catch (IllegalStateException e) {
223+
// expected
224+
}
225+
try {
226+
tensor.getDataAsDoubleArray();
227+
fail("Should have thrown an exception");
228+
} catch (IllegalStateException e) {
229+
// expected
230+
}
231+
try {
232+
tensor.getDataAsLongArray();
233+
fail("Should have thrown an exception");
234+
} catch (IllegalStateException e) {
235+
// expected
236+
}
237+
}
238+
239+
@Test
240+
public void testIllegalArguments() {
241+
float data[] = {Float.MIN_VALUE, 0f, 0.1f, Float.MAX_VALUE};
242+
long shapeWithNegativeValues[] = {-1, 2};
243+
long mismatchShape[] = {1, 2};
244+
245+
try {
246+
Tensor tensor = Tensor.fromBlob((float[]) null, mismatchShape);
247+
fail("Should have thrown an exception");
248+
} catch (IllegalArgumentException e) {
249+
// expected
250+
}
251+
try {
252+
Tensor tensor = Tensor.fromBlob(data, null);
253+
fail("Should have thrown an exception");
254+
} catch (IllegalArgumentException e) {
255+
// expected
256+
}
257+
try {
258+
Tensor tensor = Tensor.fromBlob(data, shapeWithNegativeValues);
259+
fail("Should have thrown an exception");
260+
} catch (IllegalArgumentException e) {
261+
// expected
262+
}
263+
try {
264+
Tensor tensor = Tensor.fromBlob(data, mismatchShape);
265+
fail("Should have thrown an exception");
266+
} catch (IllegalArgumentException e) {
267+
// expected
268+
}
269+
}
270+
}

0 commit comments

Comments
 (0)