Skip to content

Commit 6ac19cc

Browse files
authored
Java Tensor and EValue serialization (#6620)
Add serialization and deserialization for EValue (except string) and Tensor. RFC: #6569
1 parent 7375cf5 commit 6ac19cc

File tree

4 files changed

+394
-67
lines changed

4 files changed

+394
-67
lines changed

extension/android/src/main/java/org/pytorch/executorch/EValue.java

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
package org.pytorch.executorch;
1010

1111
import com.facebook.jni.annotations.DoNotStrip;
12+
import java.nio.ByteBuffer;
13+
import java.util.Arrays;
1214
import java.util.Locale;
1315
import java.util.Optional;
1416
import org.pytorch.executorch.annotations.Experimental;
@@ -287,4 +289,75 @@ private void preconditionType(int typeCodeExpected, int typeCode) {
287289
private String getTypeName(int typeCode) {
288290
return typeCode >= 0 && typeCode < TYPE_NAMES.length ? TYPE_NAMES[typeCode] : "Unknown";
289291
}
292+
293+
/**
294+
* Serializes an {@code EValue} into a byte array.
295+
*
296+
* @return The serialized byte array.
297+
* @apiNote This method is experimental and subject to change without notice. This does NOT
298+
* supoprt list type.
299+
*/
300+
public byte[] toByteArray() {
301+
if (isNone()) {
302+
return ByteBuffer.allocate(1).put((byte) TYPE_CODE_NONE).array();
303+
} else if (isTensor()) {
304+
Tensor t = toTensor();
305+
byte[] tByteArray = t.toByteArray();
306+
return ByteBuffer.allocate(1 + tByteArray.length)
307+
.put((byte) TYPE_CODE_TENSOR)
308+
.put(tByteArray)
309+
.array();
310+
} else if (isBool()) {
311+
return ByteBuffer.allocate(2)
312+
.put((byte) TYPE_CODE_BOOL)
313+
.put((byte) (toBool() ? 1 : 0))
314+
.array();
315+
} else if (isInt()) {
316+
return ByteBuffer.allocate(9).put((byte) TYPE_CODE_INT).putLong(toInt()).array();
317+
} else if (isDouble()) {
318+
return ByteBuffer.allocate(9).put((byte) TYPE_CODE_DOUBLE).putDouble(toDouble()).array();
319+
} else if (isString()) {
320+
return ByteBuffer.allocate(1 + toString().length())
321+
.put((byte) TYPE_CODE_STRING)
322+
.put(toString().getBytes())
323+
.array();
324+
} else {
325+
throw new IllegalArgumentException("Unknown Tensor dtype");
326+
}
327+
}
328+
329+
/**
330+
* Deserializes an {@code EValue} from a byte[].
331+
*
332+
* @param bytes The byte array to deserialize from.
333+
* @return The deserialized {@code EValue}.
334+
* @apiNote This method is experimental and subject to change without notice. This does NOT list
335+
* type.
336+
*/
337+
public static EValue fromByteArray(byte[] bytes) {
338+
ByteBuffer buffer = ByteBuffer.wrap(bytes);
339+
if (buffer == null) {
340+
throw new IllegalArgumentException("buffer cannot be null");
341+
}
342+
if (!buffer.hasRemaining()) {
343+
throw new IllegalArgumentException("invalid buffer");
344+
}
345+
int typeCode = buffer.get();
346+
switch (typeCode) {
347+
case TYPE_CODE_NONE:
348+
return new EValue(TYPE_CODE_NONE);
349+
case TYPE_CODE_TENSOR:
350+
byte[] bufferArray = buffer.array();
351+
return from(Tensor.fromByteArray(Arrays.copyOfRange(bufferArray, 1, bufferArray.length)));
352+
case TYPE_CODE_STRING:
353+
throw new IllegalArgumentException("TYPE_CODE_STRING is not supported");
354+
case TYPE_CODE_DOUBLE:
355+
return from(buffer.getDouble());
356+
case TYPE_CODE_INT:
357+
return from(buffer.getLong());
358+
case TYPE_CODE_BOOL:
359+
return from(buffer.get() != 0);
360+
}
361+
throw new IllegalArgumentException("invalid type code: " + typeCode);
362+
}
290363
}

extension/android/src/main/java/org/pytorch/executorch/Tensor.java

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,10 @@ public abstract class Tensor {
5353

5454
@DoNotStrip final long[] shape;
5555

56+
private static final int BYTE_SIZE_BYTES = 1;
5657
private static final int INT_SIZE_BYTES = 4;
57-
private static final int FLOAT_SIZE_BYTES = 4;
5858
private static final int LONG_SIZE_BYTES = 8;
59+
private static final int FLOAT_SIZE_BYTES = 4;
5960
private static final int DOUBLE_SIZE_BYTES = 8;
6061

6162
/**
@@ -679,4 +680,105 @@ private static Tensor nativeNewTensor(
679680
tensor.mHybridData = hybridData;
680681
return tensor;
681682
}
683+
684+
/**
685+
* Serializes a {@code Tensor} into a byte array.
686+
*
687+
* @return The serialized byte array.
688+
* @apiNote This method is experimental and subject to change without notice. This does NOT
689+
* supoprt list type.
690+
*/
691+
public byte[] toByteArray() {
692+
int dtypeSize = 0;
693+
byte[] tensorAsByteArray = null;
694+
if (dtype() == DType.UINT8) {
695+
dtypeSize = BYTE_SIZE_BYTES;
696+
tensorAsByteArray = new byte[(int) numel()];
697+
Tensor_uint8 thiz = (Tensor_uint8) this;
698+
ByteBuffer.wrap(tensorAsByteArray).put(thiz.getDataAsUnsignedByteArray());
699+
} else if (dtype() == DType.INT8) {
700+
dtypeSize = BYTE_SIZE_BYTES;
701+
tensorAsByteArray = new byte[(int) numel()];
702+
Tensor_int8 thiz = (Tensor_int8) this;
703+
ByteBuffer.wrap(tensorAsByteArray).put(thiz.getDataAsByteArray());
704+
} else if (dtype() == DType.INT16) {
705+
throw new IllegalArgumentException("DType.INT16 is not supported in Java so far");
706+
} else if (dtype() == DType.INT32) {
707+
dtypeSize = INT_SIZE_BYTES;
708+
tensorAsByteArray = new byte[(int) numel() * dtypeSize];
709+
Tensor_int32 thiz = (Tensor_int32) this;
710+
ByteBuffer.wrap(tensorAsByteArray).asIntBuffer().put(thiz.getDataAsIntArray());
711+
} else if (dtype() == DType.INT64) {
712+
dtypeSize = LONG_SIZE_BYTES;
713+
tensorAsByteArray = new byte[(int) numel() * dtypeSize];
714+
Tensor_int64 thiz = (Tensor_int64) this;
715+
ByteBuffer.wrap(tensorAsByteArray).asLongBuffer().put(thiz.getDataAsLongArray());
716+
} else if (dtype() == DType.FLOAT) {
717+
dtypeSize = FLOAT_SIZE_BYTES;
718+
tensorAsByteArray = new byte[(int) numel() * dtypeSize];
719+
Tensor_float32 thiz = (Tensor_float32) this;
720+
ByteBuffer.wrap(tensorAsByteArray).asFloatBuffer().put(thiz.getDataAsFloatArray());
721+
} else if (dtype() == DType.DOUBLE) {
722+
dtypeSize = DOUBLE_SIZE_BYTES;
723+
tensorAsByteArray = new byte[(int) numel() * dtypeSize];
724+
Tensor_float64 thiz = (Tensor_float64) this;
725+
ByteBuffer.wrap(tensorAsByteArray).asDoubleBuffer().put(thiz.getDataAsDoubleArray());
726+
} else {
727+
throw new IllegalArgumentException("Unknown Tensor dtype");
728+
}
729+
ByteBuffer byteBuffer =
730+
ByteBuffer.allocate(1 + 1 + 4 * shape.length + dtypeSize * (int) numel());
731+
byteBuffer.put((byte) dtype().jniCode);
732+
byteBuffer.put((byte) shape.length);
733+
for (long s : shape) {
734+
byteBuffer.putInt((int) s);
735+
}
736+
byteBuffer.put(tensorAsByteArray);
737+
return byteBuffer.array();
738+
}
739+
740+
/**
741+
* Deserializes a {@code Tensor} from a byte[].
742+
*
743+
* @param buffer The byte array to deserialize from.
744+
* @return The deserialized {@code Tensor}.
745+
* @apiNote This method is experimental and subject to change without notice. This does NOT
746+
* supoprt list type.
747+
*/
748+
public static Tensor fromByteArray(byte[] bytes) {
749+
if (bytes == null) {
750+
throw new IllegalArgumentException("bytes cannot be null");
751+
}
752+
ByteBuffer buffer = ByteBuffer.wrap(bytes);
753+
if (!buffer.hasRemaining()) {
754+
throw new IllegalArgumentException("invalid buffer");
755+
}
756+
byte dtype = buffer.get();
757+
byte shapeLength = buffer.get();
758+
long[] shape = new long[(int) shapeLength];
759+
long numel = 1;
760+
for (int i = 0; i < shapeLength; i++) {
761+
int dim = buffer.getInt();
762+
if (dim < 0) {
763+
throw new IllegalArgumentException("invalid shape");
764+
}
765+
shape[i] = dim;
766+
numel *= dim;
767+
}
768+
if (dtype == DType.UINT8.jniCode) {
769+
return new Tensor_uint8(buffer, shape);
770+
} else if (dtype == DType.INT8.jniCode) {
771+
return new Tensor_int8(buffer, shape);
772+
} else if (dtype == DType.INT32.jniCode) {
773+
return new Tensor_int32(buffer.asIntBuffer(), shape);
774+
} else if (dtype == DType.INT64.jniCode) {
775+
return new Tensor_int64(buffer.asLongBuffer(), shape);
776+
} else if (dtype == DType.FLOAT.jniCode) {
777+
return new Tensor_float32(buffer.asFloatBuffer(), shape);
778+
} else if (dtype == DType.DOUBLE.jniCode) {
779+
return new Tensor_float64(buffer.asDoubleBuffer(), shape);
780+
} else {
781+
throw new IllegalArgumentException("Unknown Tensor dtype");
782+
}
783+
}
682784
}

0 commit comments

Comments
 (0)