Skip to content

Java Tensor and EValue serialization #6620

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
package org.pytorch.executorch;

import com.facebook.jni.annotations.DoNotStrip;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.Locale;
import java.util.Optional;
import org.pytorch.executorch.annotations.Experimental;
Expand Down Expand Up @@ -287,4 +289,75 @@ private void preconditionType(int typeCodeExpected, int typeCode) {
private String getTypeName(int typeCode) {
return typeCode >= 0 && typeCode < TYPE_NAMES.length ? TYPE_NAMES[typeCode] : "Unknown";
}

/**
* Serializes an {@code EValue} into a byte array.
*
* @return The serialized byte array.
* @apiNote This method is experimental and subject to change without notice. This does NOT
* supoprt list type.
*/
public byte[] toByteArray() {
if (isNone()) {
return ByteBuffer.allocate(1).put((byte) TYPE_CODE_NONE).array();
} else if (isTensor()) {
Tensor t = toTensor();
byte[] tByteArray = t.toByteArray();
return ByteBuffer.allocate(1 + tByteArray.length)
.put((byte) TYPE_CODE_TENSOR)
.put(tByteArray)
.array();
} else if (isBool()) {
return ByteBuffer.allocate(2)
.put((byte) TYPE_CODE_BOOL)
.put((byte) (toBool() ? 1 : 0))
.array();
} else if (isInt()) {
return ByteBuffer.allocate(9).put((byte) TYPE_CODE_INT).putLong(toInt()).array();
} else if (isDouble()) {
return ByteBuffer.allocate(9).put((byte) TYPE_CODE_DOUBLE).putDouble(toDouble()).array();
} else if (isString()) {
return ByteBuffer.allocate(1 + toString().length())
.put((byte) TYPE_CODE_STRING)
.put(toString().getBytes())
.array();
} else {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will we add list support as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per Jacob, list is an internal dtype within ET runtime. Maybe we should totally get rid of list in java layer. I could double check with the team

throw new IllegalArgumentException("Unknown Tensor dtype");
}
}

/**
* Deserializes an {@code EValue} from a byte[].
*
* @param bytes The byte array to deserialize from.
* @return The deserialized {@code EValue}.
* @apiNote This method is experimental and subject to change without notice. This does NOT list
* type.
*/
public static EValue fromByteArray(byte[] bytes) {
ByteBuffer buffer = ByteBuffer.wrap(bytes);
if (buffer == null) {
throw new IllegalArgumentException("buffer cannot be null");
}
if (!buffer.hasRemaining()) {
throw new IllegalArgumentException("invalid buffer");
}
int typeCode = buffer.get();
switch (typeCode) {
case TYPE_CODE_NONE:
return new EValue(TYPE_CODE_NONE);
case TYPE_CODE_TENSOR:
byte[] bufferArray = buffer.array();
return from(Tensor.fromByteArray(Arrays.copyOfRange(bufferArray, 1, bufferArray.length)));
case TYPE_CODE_STRING:
throw new IllegalArgumentException("TYPE_CODE_STRING is not supported");
case TYPE_CODE_DOUBLE:
return from(buffer.getDouble());
case TYPE_CODE_INT:
return from(buffer.getLong());
case TYPE_CODE_BOOL:
return from(buffer.get() != 0);
}
throw new IllegalArgumentException("invalid type code: " + typeCode);
}
}
104 changes: 103 additions & 1 deletion extension/android/src/main/java/org/pytorch/executorch/Tensor.java
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,10 @@ public abstract class Tensor {

@DoNotStrip final long[] shape;

private static final int BYTE_SIZE_BYTES = 1;
private static final int INT_SIZE_BYTES = 4;
private static final int FLOAT_SIZE_BYTES = 4;
private static final int LONG_SIZE_BYTES = 8;
private static final int FLOAT_SIZE_BYTES = 4;
private static final int DOUBLE_SIZE_BYTES = 8;

/**
Expand Down Expand Up @@ -679,4 +680,105 @@ private static Tensor nativeNewTensor(
tensor.mHybridData = hybridData;
return tensor;
}

/**
* Serializes a {@code Tensor} into a byte array.
*
* @return The serialized byte array.
* @apiNote This method is experimental and subject to change without notice. This does NOT
* supoprt list type.
*/
public byte[] toByteArray() {
int dtypeSize = 0;
byte[] tensorAsByteArray = null;
if (dtype() == DType.UINT8) {
dtypeSize = BYTE_SIZE_BYTES;
tensorAsByteArray = new byte[(int) numel()];
Tensor_uint8 thiz = (Tensor_uint8) this;
ByteBuffer.wrap(tensorAsByteArray).put(thiz.getDataAsUnsignedByteArray());
} else if (dtype() == DType.INT8) {
dtypeSize = BYTE_SIZE_BYTES;
tensorAsByteArray = new byte[(int) numel()];
Tensor_int8 thiz = (Tensor_int8) this;
ByteBuffer.wrap(tensorAsByteArray).put(thiz.getDataAsByteArray());
} else if (dtype() == DType.INT16) {
throw new IllegalArgumentException("DType.INT16 is not supported in Java so far");
} else if (dtype() == DType.INT32) {
dtypeSize = INT_SIZE_BYTES;
tensorAsByteArray = new byte[(int) numel() * dtypeSize];
Tensor_int32 thiz = (Tensor_int32) this;
ByteBuffer.wrap(tensorAsByteArray).asIntBuffer().put(thiz.getDataAsIntArray());
} else if (dtype() == DType.INT64) {
dtypeSize = LONG_SIZE_BYTES;
tensorAsByteArray = new byte[(int) numel() * dtypeSize];
Tensor_int64 thiz = (Tensor_int64) this;
ByteBuffer.wrap(tensorAsByteArray).asLongBuffer().put(thiz.getDataAsLongArray());
} else if (dtype() == DType.FLOAT) {
dtypeSize = FLOAT_SIZE_BYTES;
tensorAsByteArray = new byte[(int) numel() * dtypeSize];
Tensor_float32 thiz = (Tensor_float32) this;
ByteBuffer.wrap(tensorAsByteArray).asFloatBuffer().put(thiz.getDataAsFloatArray());
} else if (dtype() == DType.DOUBLE) {
dtypeSize = DOUBLE_SIZE_BYTES;
tensorAsByteArray = new byte[(int) numel() * dtypeSize];
Tensor_float64 thiz = (Tensor_float64) this;
ByteBuffer.wrap(tensorAsByteArray).asDoubleBuffer().put(thiz.getDataAsDoubleArray());
} else {
throw new IllegalArgumentException("Unknown Tensor dtype");
}
ByteBuffer byteBuffer =
ByteBuffer.allocate(1 + 1 + 4 * shape.length + dtypeSize * (int) numel());
byteBuffer.put((byte) dtype().jniCode);
byteBuffer.put((byte) shape.length);
for (long s : shape) {
byteBuffer.putInt((int) s);
}
byteBuffer.put(tensorAsByteArray);
return byteBuffer.array();
}

/**
* Deserializes a {@code Tensor} from a byte[].
*
* @param buffer The byte array to deserialize from.
* @return The deserialized {@code Tensor}.
* @apiNote This method is experimental and subject to change without notice. This does NOT
* supoprt list type.
*/
public static Tensor fromByteArray(byte[] bytes) {
if (bytes == null) {
throw new IllegalArgumentException("bytes cannot be null");
}
ByteBuffer buffer = ByteBuffer.wrap(bytes);
if (!buffer.hasRemaining()) {
throw new IllegalArgumentException("invalid buffer");
}
byte dtype = buffer.get();
byte shapeLength = buffer.get();
long[] shape = new long[(int) shapeLength];
long numel = 1;
for (int i = 0; i < shapeLength; i++) {
int dim = buffer.getInt();
if (dim < 0) {
throw new IllegalArgumentException("invalid shape");
}
shape[i] = dim;
numel *= dim;
}
if (dtype == DType.UINT8.jniCode) {
return new Tensor_uint8(buffer, shape);
} else if (dtype == DType.INT8.jniCode) {
return new Tensor_int8(buffer, shape);
} else if (dtype == DType.INT32.jniCode) {
return new Tensor_int32(buffer.asIntBuffer(), shape);
} else if (dtype == DType.INT64.jniCode) {
return new Tensor_int64(buffer.asLongBuffer(), shape);
} else if (dtype == DType.FLOAT.jniCode) {
return new Tensor_float32(buffer.asFloatBuffer(), shape);
} else if (dtype == DType.DOUBLE.jniCode) {
return new Tensor_float64(buffer.asDoubleBuffer(), shape);
} else {
throw new IllegalArgumentException("Unknown Tensor dtype");
}
}
}
Loading
Loading