Skip to content

Commit 44283a2

Browse files
committed
Java Tensor and EValue serialization
Add serialization and deserialization for EValue (except string) and Tensor. RFC: #6569
1 parent f40daea commit 44283a2

File tree

4 files changed

+406
-65
lines changed

4 files changed

+406
-65
lines changed

extension/android/src/main/java/org/pytorch/executorch/EValue.java

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
package org.pytorch.executorch;
1010

1111
import com.facebook.jni.annotations.DoNotStrip;
12+
import java.nio.ByteBuffer;
13+
import java.util.Arrays;
1214
import java.util.Locale;
1315
import java.util.Optional;
1416
import org.pytorch.executorch.annotations.Experimental;
@@ -287,4 +289,75 @@ private void preconditionType(int typeCodeExpected, int typeCode) {
287289
private String getTypeName(int typeCode) {
288290
return typeCode >= 0 && typeCode < TYPE_NAMES.length ? TYPE_NAMES[typeCode] : "Unknown";
289291
}
292+
293+
/**
294+
* Serializes an {@code EValue} into a byte array.
295+
*
296+
* @return The serialized byte array.
297+
* @apiNote This method is experimental and subject to change without notice. This does NOT
298+
* supoprt list type.
299+
*/
300+
public byte[] toByteArray() {
301+
if (isNone()) {
302+
return ByteBuffer.allocate(1).put((byte) TYPE_CODE_NONE).array();
303+
} else if (isTensor()) {
304+
Tensor t = toTensor();
305+
byte[] tByteArray = t.toByteArray();
306+
return ByteBuffer.allocate(1 + tByteArray.length)
307+
.put((byte) TYPE_CODE_TENSOR)
308+
.put(tByteArray)
309+
.array();
310+
} else if (isBool()) {
311+
return ByteBuffer.allocate(2)
312+
.put((byte) TYPE_CODE_BOOL)
313+
.put((byte) (toBool() ? 1 : 0))
314+
.array();
315+
} else if (isInt()) {
316+
return ByteBuffer.allocate(9).put((byte) TYPE_CODE_INT).putLong(toInt()).array();
317+
} else if (isDouble()) {
318+
return ByteBuffer.allocate(9).put((byte) TYPE_CODE_DOUBLE).putDouble(toDouble()).array();
319+
} else if (isString()) {
320+
return ByteBuffer.allocate(1 + toString().length())
321+
.put((byte) TYPE_CODE_STRING)
322+
.put(toString().getBytes())
323+
.array();
324+
} else {
325+
throw new IllegalArgumentException("Unknown Tensor dtype");
326+
}
327+
}
328+
329+
/**
330+
* Deserializes an {@code EValue} from a byte[].
331+
*
332+
* @param bytes The byte array to deserialize from.
333+
* @return The deserialized {@code EValue}.
334+
* @apiNote This method is experimental and subject to change without notice. This does NOT list
335+
* type.
336+
*/
337+
public static EValue fromByteArray(byte[] bytes) {
338+
ByteBuffer buffer = ByteBuffer.wrap(bytes);
339+
if (buffer == null) {
340+
throw new IllegalArgumentException("buffer cannot be null");
341+
}
342+
if (!buffer.hasRemaining()) {
343+
throw new IllegalArgumentException("invalid buffer");
344+
}
345+
int typeCode = buffer.get();
346+
switch (typeCode) {
347+
case TYPE_CODE_NONE:
348+
return new EValue(TYPE_CODE_NONE);
349+
case TYPE_CODE_TENSOR:
350+
byte[] bufferArray = buffer.array();
351+
return from(Tensor.fromByteArray(Arrays.copyOfRange(bufferArray, 1, bufferArray.length)));
352+
case TYPE_CODE_STRING:
353+
throw new IllegalArgumentException("TYPE_CODE_STRING is not supported");
354+
case TYPE_CODE_DOUBLE:
355+
return from(buffer.getDouble());
356+
case TYPE_CODE_INT:
357+
return from(buffer.getLong());
358+
case TYPE_CODE_BOOL:
359+
return from(buffer.get() != 0);
360+
}
361+
throw new IllegalArgumentException("invalid type code: " + typeCode);
362+
}
290363
}

extension/android/src/main/java/org/pytorch/executorch/Tensor.java

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -679,4 +679,105 @@ private static Tensor nativeNewTensor(
679679
tensor.mHybridData = hybridData;
680680
return tensor;
681681
}
682+
683+
/**
684+
* Serializes a {@code Tensor} into a byte array.
685+
*
686+
* @return The serialized byte array.
687+
* @apiNote This method is experimental and subject to change without notice. This does NOT
688+
* supoprt list type.
689+
*/
690+
public byte[] toByteArray() {
691+
int dtypeSize = 0;
692+
byte[] tensorAsByteArray = null;
693+
if (dtype() == DType.FLOAT) {
694+
dtypeSize = 4;
695+
tensorAsByteArray = new byte[(int) numel() * dtypeSize];
696+
Tensor_float32 thiz = (Tensor_float32) this;
697+
ByteBuffer.wrap(tensorAsByteArray).asFloatBuffer().put(thiz.getDataAsFloatArray());
698+
} else if (dtype() == DType.DOUBLE) {
699+
dtypeSize = 8;
700+
tensorAsByteArray = new byte[(int) numel() * dtypeSize];
701+
Tensor_float64 thiz = (Tensor_float64) this;
702+
ByteBuffer.wrap(tensorAsByteArray).asDoubleBuffer().put(thiz.getDataAsDoubleArray());
703+
} else if (dtype() == DType.UINT8) {
704+
dtypeSize = 1;
705+
tensorAsByteArray = new byte[(int) numel()];
706+
Tensor_uint8 thiz = (Tensor_uint8) this;
707+
ByteBuffer.wrap(tensorAsByteArray).put(thiz.getDataAsUnsignedByteArray());
708+
} else if (dtype() == DType.INT8) {
709+
dtypeSize = 1;
710+
tensorAsByteArray = new byte[(int) numel()];
711+
Tensor_int8 thiz = (Tensor_int8) this;
712+
ByteBuffer.wrap(tensorAsByteArray).put(thiz.getDataAsByteArray());
713+
} else if (dtype() == DType.INT16) {
714+
throw new IllegalArgumentException("DType.INT16 is not supported in Java so far");
715+
} else if (dtype() == DType.INT32) {
716+
dtypeSize = 4;
717+
tensorAsByteArray = new byte[(int) numel() * dtypeSize];
718+
Tensor_int32 thiz = (Tensor_int32) this;
719+
ByteBuffer.wrap(tensorAsByteArray).asIntBuffer().put(thiz.getDataAsIntArray());
720+
} else if (dtype() == DType.INT64) {
721+
dtypeSize = 8;
722+
tensorAsByteArray = new byte[(int) numel() * dtypeSize];
723+
Tensor_int64 thiz = (Tensor_int64) this;
724+
ByteBuffer.wrap(tensorAsByteArray).asLongBuffer().put(thiz.getDataAsLongArray());
725+
} else {
726+
throw new IllegalArgumentException("Unknown Tensor dtype");
727+
}
728+
ByteBuffer byteBuffer =
729+
ByteBuffer.allocate(1 + 1 + 4 * shape.length + dtypeSize * (int) numel());
730+
byteBuffer.put((byte) dtype().jniCode);
731+
byteBuffer.put((byte) shape.length);
732+
for (long s : shape) {
733+
byteBuffer.putInt((int) s);
734+
}
735+
byteBuffer.put(tensorAsByteArray);
736+
return byteBuffer.array();
737+
}
738+
739+
/**
740+
* Deserializes a {@code Tensor} from a byte[].
741+
*
742+
* @param buffer The byte array to deserialize from.
743+
* @return The deserialized {@code Tensor}.
744+
* @apiNote This method is experimental and subject to change without notice. This does NOT
745+
* supoprt list type.
746+
*/
747+
public static Tensor fromByteArray(byte[] bytes) {
748+
if (bytes == null) {
749+
throw new IllegalArgumentException("bytes cannot be null");
750+
}
751+
ByteBuffer buffer = ByteBuffer.wrap(bytes);
752+
if (!buffer.hasRemaining()) {
753+
throw new IllegalArgumentException("invalid buffer");
754+
}
755+
byte scalarType = buffer.get();
756+
byte numberOfDimensions = buffer.get();
757+
long[] shape = new long[(int) numberOfDimensions];
758+
long numel = 1;
759+
for (int i = 0; i < numberOfDimensions; i++) {
760+
int dim = buffer.getInt();
761+
if (dim < 0) {
762+
throw new IllegalArgumentException("invalid shape");
763+
}
764+
shape[i] = dim;
765+
numel *= dim;
766+
}
767+
if (scalarType == DType.FLOAT.jniCode) {
768+
return new Tensor_float32(buffer.asFloatBuffer(), shape);
769+
} else if (scalarType == DType.DOUBLE.jniCode) {
770+
return new Tensor_float64(buffer.asDoubleBuffer(), shape);
771+
} else if (scalarType == DType.UINT8.jniCode) {
772+
return new Tensor_uint8(buffer, shape);
773+
} else if (scalarType == DType.INT8.jniCode) {
774+
return new Tensor_int8(buffer, shape);
775+
} else if (scalarType == DType.INT16.jniCode) {
776+
return new Tensor_int32(buffer.asIntBuffer(), shape);
777+
} else if (scalarType == DType.INT64.jniCode) {
778+
return new Tensor_int64(buffer.asLongBuffer(), shape);
779+
} else {
780+
throw new IllegalArgumentException("Unknown Tensor dtype");
781+
}
782+
}
682783
}

0 commit comments

Comments
 (0)