Skip to content

Commit c61a10b

Browse files
committed
Java Tensor and EValue serialization
Add serialization and deserialization for EValue (except string) and Tensor. RFC: #6569
1 parent f40daea commit c61a10b

File tree

4 files changed

+336
-0
lines changed

4 files changed

+336
-0
lines changed

extension/android/src/main/java/org/pytorch/executorch/EValue.java

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
package org.pytorch.executorch;
1010

1111
import com.facebook.jni.annotations.DoNotStrip;
12+
import java.nio.ByteBuffer;
13+
import java.util.Arrays;
1214
import java.util.Locale;
1315
import java.util.Optional;
1416
import org.pytorch.executorch.annotations.Experimental;
@@ -287,4 +289,66 @@ private void preconditionType(int typeCodeExpected, int typeCode) {
287289
private String getTypeName(int typeCode) {
288290
return typeCode >= 0 && typeCode < TYPE_NAMES.length ? TYPE_NAMES[typeCode] : "Unknown";
289291
}
292+
293+
/**
294+
* Serializes an {@code EValue} into a byte array.
295+
* @return The serialized byte array.
296+
*
297+
* @apiNote This method is experimental and subject to change without notice.
298+
* This does NOT supoprt list type.
299+
*/
300+
public byte[] toByteArray() {
301+
if (isNone()) {
302+
return ByteBuffer.allocate(1).put((byte) TYPE_CODE_NONE).array();
303+
} else if (isTensor()) {
304+
Tensor t = toTensor();
305+
byte[] tByteArray = t.toByteArray();
306+
return ByteBuffer.allocate(1 + tByteArray.length).put((byte) TYPE_CODE_TENSOR).put(tByteArray).array();
307+
} else if (isBool()) {
308+
return ByteBuffer.allocate(2).put((byte) TYPE_CODE_BOOL).put((byte) (toBool() ? 1 : 0)).array();
309+
} else if (isInt()) {
310+
return ByteBuffer.allocate(9).put((byte) TYPE_CODE_INT).putLong(toInt()).array();
311+
} else if (isDouble()) {
312+
return ByteBuffer.allocate(9).put((byte) TYPE_CODE_DOUBLE).putDouble(toDouble()).array();
313+
} else if (isString()) {
314+
return ByteBuffer.allocate(1 + toString().length()).put((byte) TYPE_CODE_STRING).put(toString().getBytes()).array();
315+
} else {
316+
throw new IllegalArgumentException("Unknown Tensor dtype");
317+
}
318+
}
319+
320+
/**
321+
* Deserializes an {@code EValue} from a byte[].
322+
* @param bytes The byte array to deserialize from.
323+
* @return The deserialized {@code EValue}.
324+
*
325+
* @apiNote This method is experimental and subject to change without notice.
326+
* This does NOT list type.
327+
*/
328+
public static EValue fromByteArray(byte[] bytes) {
329+
ByteBuffer buffer = ByteBuffer.wrap(bytes);
330+
if (buffer == null) {
331+
throw new IllegalArgumentException("buffer cannot be null");
332+
}
333+
if (!buffer.hasRemaining()) {
334+
throw new IllegalArgumentException("invalid buffer");
335+
}
336+
int typeCode = buffer.get();
337+
switch (typeCode) {
338+
case TYPE_CODE_NONE:
339+
return new EValue(TYPE_CODE_NONE);
340+
case TYPE_CODE_TENSOR:
341+
byte[] bufferArray = buffer.array();
342+
return from(Tensor.fromByteArray(Arrays.copyOfRange(bufferArray, 1, bufferArray.length)));
343+
case TYPE_CODE_STRING:
344+
throw new IllegalArgumentException("TYPE_CODE_STRING is not supported");
345+
case TYPE_CODE_DOUBLE:
346+
return from(buffer.getDouble());
347+
case TYPE_CODE_INT:
348+
return from(buffer.getLong());
349+
case TYPE_CODE_BOOL:
350+
return from(buffer.get() != 0);
351+
}
352+
throw new IllegalArgumentException("invalid type code: " + typeCode);
353+
}
290354
}

extension/android/src/main/java/org/pytorch/executorch/Tensor.java

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -679,4 +679,104 @@ private static Tensor nativeNewTensor(
679679
tensor.mHybridData = hybridData;
680680
return tensor;
681681
}
682+
683+
/**
684+
* Serializes a {@code Tensor} into a byte array.
685+
* @return The serialized byte array.
686+
*
687+
* @apiNote This method is experimental and subject to change without notice.
688+
* This does NOT supoprt list type.
689+
*/
690+
public byte[] toByteArray() {
691+
int dtypeSize = 0;
692+
byte[] tensorAsByteArray = null;
693+
if (dtype() == DType.FLOAT) {
694+
dtypeSize = 4;
695+
tensorAsByteArray = new byte[(int) numel() * dtypeSize];
696+
Tensor_float32 thiz = (Tensor_float32) this;
697+
ByteBuffer.wrap(tensorAsByteArray).asFloatBuffer().put(thiz.getDataAsFloatArray());
698+
} else if (dtype() == DType.DOUBLE) {
699+
dtypeSize = 8;
700+
tensorAsByteArray = new byte[(int) numel() * dtypeSize];
701+
Tensor_float64 thiz = (Tensor_float64) this;
702+
ByteBuffer.wrap(tensorAsByteArray).asDoubleBuffer().put(thiz.getDataAsDoubleArray());
703+
} else if (dtype() == DType.UINT8) {
704+
dtypeSize = 1;
705+
tensorAsByteArray = new byte[(int) numel()];
706+
Tensor_uint8 thiz = (Tensor_uint8) this;
707+
ByteBuffer.wrap(tensorAsByteArray).put(thiz.getDataAsUnsignedByteArray());
708+
} else if (dtype() == DType.INT8) {
709+
dtypeSize = 1;
710+
tensorAsByteArray = new byte[(int) numel()];
711+
Tensor_int8 thiz = (Tensor_int8) this;
712+
ByteBuffer.wrap(tensorAsByteArray).put(thiz.getDataAsByteArray());
713+
} else if (dtype() == DType.INT16) {
714+
throw new IllegalArgumentException("DType.INT16 is not supported in Java so far");
715+
} else if (dtype() == DType.INT32) {
716+
dtypeSize = 4;
717+
tensorAsByteArray = new byte[(int) numel() * dtypeSize];
718+
Tensor_int32 thiz = (Tensor_int32) this;
719+
ByteBuffer.wrap(tensorAsByteArray).asIntBuffer().put(thiz.getDataAsIntArray());
720+
} else if (dtype() == DType.INT64) {
721+
dtypeSize = 8;
722+
tensorAsByteArray = new byte[(int) numel() * dtypeSize];
723+
Tensor_int64 thiz = (Tensor_int64) this;
724+
ByteBuffer.wrap(tensorAsByteArray).asLongBuffer().put(thiz.getDataAsLongArray());
725+
} else {
726+
throw new IllegalArgumentException("Unknown Tensor dtype");
727+
}
728+
ByteBuffer byteBuffer = ByteBuffer.allocate(1 + 1 + 4 * shape.length + dtypeSize * (int) numel());
729+
byteBuffer.put((byte) dtype().jniCode);
730+
byteBuffer.put((byte) shape.length);
731+
for (long s : shape) {
732+
byteBuffer.putInt((int) s);
733+
}
734+
byteBuffer.put(tensorAsByteArray);
735+
return byteBuffer.array();
736+
}
737+
738+
/**
739+
* Deserializes a {@code Tensor} from a byte[].
740+
* @param buffer The byte array to deserialize from.
741+
* @return The deserialized {@code Tensor}.
742+
*
743+
* @apiNote This method is experimental and subject to change without notice.
744+
* This does NOT supoprt list type.
745+
*/
746+
public static Tensor fromByteArray(byte[] bytes) {
747+
if (bytes == null) {
748+
throw new IllegalArgumentException("bytes cannot be null");
749+
}
750+
ByteBuffer buffer = ByteBuffer.wrap(bytes);
751+
if (!buffer.hasRemaining()) {
752+
throw new IllegalArgumentException("invalid buffer");
753+
}
754+
byte scalarType = buffer.get();
755+
byte numberOfDimensions = buffer.get();
756+
long[] shape = new long[(int) numberOfDimensions];
757+
long numel = 1;
758+
for (int i = 0; i < numberOfDimensions; i++) {
759+
int dim = buffer.getInt();
760+
if (dim < 0) {
761+
throw new IllegalArgumentException("invalid shape");
762+
}
763+
shape[i] = dim;
764+
numel *= dim;
765+
}
766+
if (scalarType == DType.FLOAT.jniCode) {
767+
return new Tensor_float32(buffer.asFloatBuffer(), shape);
768+
} else if (scalarType == DType.DOUBLE.jniCode) {
769+
return new Tensor_float64(buffer.asDoubleBuffer(), shape);
770+
} else if (scalarType == DType.UINT8.jniCode) {
771+
return new Tensor_uint8(buffer, shape);
772+
} else if (scalarType == DType.INT8.jniCode) {
773+
return new Tensor_int8(buffer, shape);
774+
} else if (scalarType == DType.INT16.jniCode) {
775+
return new Tensor_int32(buffer.asIntBuffer(), shape);
776+
} else if (scalarType == DType.INT64.jniCode) {
777+
return new Tensor_int64(buffer.asLongBuffer(), shape);
778+
} else {
779+
throw new IllegalArgumentException("Unknown Tensor dtype");
780+
}
781+
}
682782
}

extension/android_test/src/test/java/org/pytorch/executorch/EValueTest.java

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@
3232
/** Unit tests for {@link EValue}. */
3333
@RunWith(JUnit4.class)
3434
public class EValueTest {
35+
private static final int TYPE_CODE_NONE = 0;
36+
private static final int TYPE_CODE_TENSOR = 1;
37+
private static final int TYPE_CODE_STRING = 2;
38+
private static final int TYPE_CODE_DOUBLE = 3;
39+
private static final int TYPE_CODE_INT = 4;
40+
private static final int TYPE_CODE_BOOL = 5;
3541

3642
@Test
3743
public void testNone() {
@@ -215,4 +221,130 @@ public void testAllIllegalCast() {
215221
fail("Should have thrown an exception");
216222
} catch (IllegalStateException e) {}
217223
}
224+
225+
@Test
226+
public void testNoneSerde() {
227+
EValue evalue = EValue.optionalNone();
228+
byte[] bytes = evalue.toByteArray();
229+
assertEquals(TYPE_CODE_NONE, bytes[0]);
230+
231+
EValue deser = EValue.fromByteArray(bytes);
232+
assertEquals(deser.isNone(), true);
233+
}
234+
235+
@Test
236+
public void testBoolSerde() {
237+
EValue evalue = EValue.from(true);
238+
byte[] bytes = evalue.toByteArray();
239+
assertEquals(TYPE_CODE_BOOL, bytes[0]);
240+
assertEquals(1, bytes[1]);
241+
242+
EValue deser = EValue.fromByteArray(bytes);
243+
assertEquals(deser.isBool(), true);
244+
assertEquals(deser.toBool(), true);
245+
}
246+
247+
@Test
248+
public void testBoolSerde2() {
249+
EValue evalue = EValue.from(false);
250+
byte[] bytes = evalue.toByteArray();
251+
assertEquals(TYPE_CODE_BOOL, bytes[0]);
252+
assertEquals(0, bytes[1]);
253+
254+
EValue deser = EValue.fromByteArray(bytes);
255+
assertEquals(deser.isBool(), true);
256+
assertEquals(deser.toBool(), false);
257+
}
258+
259+
@Test
260+
public void testIntSerde() {
261+
EValue evalue = EValue.from(1);
262+
byte[] bytes = evalue.toByteArray();
263+
assertEquals(TYPE_CODE_INT, bytes[0]);
264+
assertEquals(0, bytes[1]);
265+
assertEquals(0, bytes[2]);
266+
assertEquals(0, bytes[3]);
267+
assertEquals(0, bytes[4]);
268+
assertEquals(0, bytes[5]);
269+
assertEquals(0, bytes[6]);
270+
assertEquals(0, bytes[7]);
271+
assertEquals(1, bytes[8]);
272+
273+
EValue deser = EValue.fromByteArray(bytes);
274+
assertEquals(deser.isInt(), true);
275+
assertEquals(deser.toInt(), 1);
276+
}
277+
278+
@Test
279+
public void testLargeIntSerde() {
280+
EValue evalue = EValue.from(256000);
281+
byte[] bytes = evalue.toByteArray();
282+
assertEquals(TYPE_CODE_INT, bytes[0]);
283+
284+
EValue deser = EValue.fromByteArray(bytes);
285+
assertEquals(deser.isInt(), true);
286+
assertEquals(deser.toInt(), 256000);
287+
}
288+
289+
@Test
290+
public void testDoubleSerde() {
291+
EValue evalue = EValue.from(1.345e-2d);
292+
byte[] bytes = evalue.toByteArray();
293+
assertEquals(TYPE_CODE_DOUBLE, bytes[0]);
294+
295+
EValue deser = EValue.fromByteArray(bytes);
296+
assertEquals(deser.isDouble(), true);
297+
assertEquals(1.345e-2d, deser.toDouble(), 1e-6);
298+
}
299+
300+
@Test
301+
public void testLongTensorSerde() {
302+
long data[] = {1, 2, 3, 4};
303+
long shape[] = {2, 2};
304+
Tensor tensor = Tensor.fromBlob(data, shape);
305+
306+
EValue evalue = EValue.from(tensor);
307+
byte[] bytes = evalue.toByteArray();
308+
assertEquals(TYPE_CODE_TENSOR, bytes[0]);
309+
310+
EValue deser = EValue.fromByteArray(bytes);
311+
assertEquals(deser.isTensor(), true);
312+
Tensor deserTensor = deser.toTensor();
313+
long[] deserShape = deserTensor.shape();
314+
long[] deserData = deserTensor.getDataAsLongArray();
315+
316+
for (int i = 0; i < data.length; i++) {
317+
assertEquals(data[i], deserData[i]);
318+
}
319+
320+
for (int i = 0; i < shape.length; i++) {
321+
assertEquals(shape[i], deserShape[i]);
322+
}
323+
}
324+
325+
@Test
326+
public void testFloatTensorSerde() {
327+
float data[] = {Float.MIN_VALUE, 0f, 0.1f, Float.MAX_VALUE};
328+
long shape[] = {2, 2};
329+
Tensor tensor = Tensor.fromBlob(data, shape);
330+
331+
EValue evalue = EValue.from(tensor);
332+
byte[] bytes = evalue.toByteArray();
333+
assertEquals(TYPE_CODE_TENSOR, bytes[0]);
334+
335+
EValue deser = EValue.fromByteArray(bytes);
336+
assertEquals(deser.isTensor(), true);
337+
Tensor deserTensor = deser.toTensor();
338+
long[] deserShape = deserTensor.shape();
339+
float[] deserData = deserTensor.getDataAsFloatArray();
340+
341+
for (int i = 0; i < data.length; i++) {
342+
assertEquals(data[i], deserData[i], 1e-5);
343+
}
344+
345+
for (int i = 0; i < shape.length; i++) {
346+
assertEquals(shape[i], deserShape[i]);
347+
}
348+
}
349+
218350
}

extension/android_test/src/test/java/org/pytorch/executorch/TensorTest.java

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,4 +267,44 @@ public void testIllegalArguments() {
267267
// expected
268268
}
269269
}
270+
271+
@Test
272+
public void testLongTensorSerde() {
273+
long data[] = {1, 2, 3, 4};
274+
long shape[] = {2, 2};
275+
Tensor tensor = Tensor.fromBlob(data, shape);
276+
byte[] bytes = tensor.toByteArray();
277+
278+
Tensor deser = Tensor.fromByteArray(bytes);
279+
long[] deserShape = deser.shape();
280+
long[] deserData = deser.getDataAsLongArray();
281+
282+
for (int i = 0; i < data.length; i++) {
283+
assertEquals(data[i], deserData[i]);
284+
}
285+
286+
for (int i = 0; i < shape.length; i++) {
287+
assertEquals(shape[i], deserShape[i]);
288+
}
289+
}
290+
291+
@Test
292+
public void testFloatTensorSerde() {
293+
float data[] = {Float.MIN_VALUE, 0f, 0.1f, Float.MAX_VALUE};
294+
long shape[] = {2, 2};
295+
Tensor tensor = Tensor.fromBlob(data, shape);
296+
byte[] bytes = tensor.toByteArray();
297+
298+
Tensor deser = Tensor.fromByteArray(bytes);
299+
long[] deserShape = deser.shape();
300+
float[] deserData = deser.getDataAsFloatArray();
301+
302+
for (int i = 0; i < data.length; i++) {
303+
assertEquals(data[i], deserData[i], 1e-5);
304+
}
305+
306+
for (int i = 0; i < shape.length; i++) {
307+
assertEquals(shape[i], deserShape[i]);
308+
}
309+
}
270310
}

0 commit comments

Comments
 (0)