Skip to content

Commit 0598d89

Browse files
committed
Java Tensor and EValue serialization
WIP
1 parent f943856 commit 0598d89

File tree

4 files changed

+186
-0
lines changed

4 files changed

+186
-0
lines changed

extension/android/build.gradle

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,7 @@ task makeJar(type: Jar) {
2222
implementation 'com.facebook.soloader:nativeloader:0.10.5'
2323
}
2424
}
25+
26+
dependencies {
27+
testImplementation "junit:junit:4.12"
28+
}

extension/android/src/main/java/org/pytorch/executorch/EValue.java

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
package org.pytorch.executorch;
1010

1111
import com.facebook.jni.annotations.DoNotStrip;
12+
import java.nio.ByteBuffer;
13+
import java.nio.charset.StandardCharsets;
1214
import java.util.Locale;
1315
import java.util.Optional;
1416
import org.pytorch.executorch.annotations.Experimental;
@@ -287,4 +289,64 @@ private void preconditionType(int typeCodeExpected, int typeCode) {
287289
private String getTypeName(int typeCode) {
288290
return typeCode >= 0 && typeCode < TYPE_NAMES.length ? TYPE_NAMES[typeCode] : "Unknown";
289291
}
292+
293+
/**
294+
* Serializes an {@code EValue} into a {@link ByteBuffer}.
295+
* @return The serialized {@code ByteBuffer}.
296+
*
297+
* @apiNote This method is experimental and subject to change without notice.
298+
* This does NOT supoprt list type.
299+
*/
300+
public ByteBuffer toByteBuffer() {
301+
if (isNone()) {
302+
return ByteBuffer.allocateDirect(1).put((byte) TYPE_CODE_NONE);
303+
} else if (isTensor()) {
304+
Tensor t = toTensor();
305+
ByteBuffer tByteBuffer = t.toByteBuffer();
306+
return ByteBuffer.allocateDirect(1 + tByteBuffer.array().length).put((byte) TYPE_CODE_TENSOR).put(tByteBuffer);
307+
} else if (isBool()) {
308+
return ByteBuffer.allocateDirect(2).put((byte) TYPE_CODE_BOOL).put((byte) (toBool() ? 1 : 0));
309+
} else if (isInt()) {
310+
return ByteBuffer.allocateDirect(9).put((byte) TYPE_CODE_INT).putLong(toInt());
311+
} else if (isDouble()) {
312+
return ByteBuffer.allocateDirect(9).put((byte) TYPE_CODE_DOUBLE).putDouble(toDouble());
313+
} else if (isString()) {
314+
return ByteBuffer.allocateDirect(1 + toString().length()).put((byte) TYPE_CODE_STRING).put(toString().getBytes());
315+
} else {
316+
throw new IllegalArgumentException("Unknown Tensor dtype");
317+
}
318+
}
319+
320+
/**
321+
* Deserializes an {@code EValue} from a {@link ByteBuffer}.
322+
* @param buffer The {@link ByteBuffer} to deserialize from.
323+
* @return The deserialized {@code EValue}.
324+
*
325+
* @apiNote This method is experimental and subject to change without notice.
326+
* This does NOT list type.
327+
*/
328+
public static EValue fromByteBuffer(ByteBuffer buffer) {
329+
if (buffer == null) {
330+
throw new IllegalArgumentException("buffer cannot be null");
331+
}
332+
if (!buffer.hasRemaining()) {
333+
throw new IllegalArgumentException("invalid buffer");
334+
}
335+
int typeCode = buffer.get();
336+
switch (typeCode) {
337+
case TYPE_CODE_NONE:
338+
return new EValue(TYPE_CODE_NONE);
339+
case TYPE_CODE_TENSOR:
340+
return from(Tensor.fromByteBuffer(buffer));
341+
case TYPE_CODE_STRING:
342+
return from(StandardCharsets.UTF_8.decode(buffer).toString());
343+
case TYPE_CODE_DOUBLE:
344+
return from(buffer.getDouble());
345+
case TYPE_CODE_INT:
346+
return from(buffer.getLong());
347+
case TYPE_CODE_BOOL:
348+
return from(buffer.get() != 0);
349+
}
350+
throw new IllegalArgumentException("invalid type code: " + typeCode);
351+
}
290352
}

extension/android/src/main/java/org/pytorch/executorch/Tensor.java

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
import com.facebook.jni.HybridData;
1212
import com.facebook.jni.annotations.DoNotStrip;
13+
import java.io.ByteArrayOutputStream;
14+
import java.io.ObjectOutputStream;
1315
import java.nio.Buffer;
1416
import java.nio.ByteBuffer;
1517
import java.nio.ByteOrder;
@@ -679,4 +681,99 @@ private static Tensor nativeNewTensor(
679681
tensor.mHybridData = hybridData;
680682
return tensor;
681683
}
684+
685+
private static final byte[] intToByteArray(int value) {
686+
return new byte[] {
687+
(byte)(value >>> 24),
688+
(byte)(value >>> 16),
689+
(byte)(value >>> 8),
690+
(byte)value};
691+
}
692+
693+
/**
694+
* Serializes a {@code Tensor} into a {@link ByteBuffer}.
695+
* @return The serialized {@code ByteBuffer}.
696+
*
697+
* @apiNote This method is experimental and subject to change without notice.
698+
* This does NOT supoprt list type.
699+
*/
700+
public ByteBuffer toByteBuffer() {
701+
int dtypeSize = 0;
702+
if (dtype() == DType.FLOAT) {
703+
dtypeSize = 4;
704+
} else if (dtype() == DType.DOUBLE) {
705+
dtypeSize = 8;
706+
} else if (dtype() == DType.UINT8) {
707+
dtypeSize = 1;
708+
} else if (dtype() == DType.INT8) {
709+
dtypeSize = 1;
710+
} else if (dtype() == DType.INT16) {
711+
dtypeSize = 2;
712+
} else if (dtype() == DType.INT32) {
713+
dtypeSize = 4;
714+
} else if (dtype() == DType.INT64) {
715+
dtypeSize = 8;
716+
} else {
717+
throw new IllegalArgumentException("Unknown Tensor dtype");
718+
}
719+
ByteBuffer byteBuffer = ByteBuffer.allocateDirect(1 + 1 + 4 * shape.length + dtypeSize * (int) numel());
720+
byteBuffer.put((byte) dtype().jniCode);
721+
byteBuffer.put((byte) shape.length);
722+
for (long s : shape) {
723+
byteBuffer.put(intToByteArray((int) s));
724+
}
725+
ByteArrayOutputStream bos = new ByteArrayOutputStream();
726+
try (ObjectOutputStream out = new ObjectOutputStream(bos)) {
727+
out.writeObject(getRawDataBuffer());
728+
out.flush();
729+
byteBuffer.put(bos.toByteArray());
730+
} catch (Exception ex) {
731+
throw new RuntimeException(ex);
732+
}
733+
return byteBuffer;
734+
}
735+
736+
/**
737+
* Deserializes a {@code Tensor} from a {@link ByteBuffer}.
738+
* @param buffer The {@link ByteBuffer} to deserialize from.
739+
* @return The deserialized {@code Tensor}.
740+
*
741+
* @apiNote This method is experimental and subject to change without notice.
742+
* This does NOT supoprt list type.
743+
*/
744+
public static Tensor fromByteBuffer(ByteBuffer buffer) {
745+
if (buffer == null) {
746+
throw new IllegalArgumentException("buffer cannot be null");
747+
}
748+
if (!buffer.hasRemaining()) {
749+
throw new IllegalArgumentException("invalid buffer");
750+
}
751+
byte scalarType = buffer.get();
752+
byte numberOfDimensions = buffer.get();
753+
long[] shape = new long[(int) numberOfDimensions];
754+
long numel = 1;
755+
for (int i = 0; i < numberOfDimensions; i++) {
756+
int dim = buffer.getInt();
757+
if (dim < 0) {
758+
throw new IllegalArgumentException("invalid shape");
759+
}
760+
shape[i] = dim;
761+
numel *= dim;
762+
}
763+
if (scalarType == DType.FLOAT.jniCode) {
764+
return new Tensor_float32(buffer.asFloatBuffer(), shape);
765+
} else if (scalarType == DType.DOUBLE.jniCode) {
766+
return new Tensor_float64(buffer.asDoubleBuffer(), shape);
767+
} else if (scalarType == DType.UINT8.jniCode) {
768+
return new Tensor_uint8(buffer, shape);
769+
} else if (scalarType == DType.INT8.jniCode) {
770+
return new Tensor_int8(buffer, shape);
771+
} else if (scalarType == DType.INT16.jniCode) {
772+
return new Tensor_int32(buffer.asIntBuffer(), shape);
773+
} else if (scalarType == DType.INT64.jniCode) {
774+
return new Tensor_int64(buffer.asLongBuffer(), shape);
775+
} else {
776+
throw new IllegalArgumentException("Unknown Tensor dtype");
777+
}
778+
}
682779
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package org.pytorch.executorch;
2+
3+
import java.nio.ByteBuffer;
4+
import org.junit.Assert;
5+
import org.junit.Test;
6+
import org.junit.runner.RunWith;
7+
8+
import static org.junit.Assert.assertEquals;
9+
public class TensorTest {
10+
@Test
11+
public void testHello() {
12+
EValue evalue = EValue.from(1);
13+
ByteBuffer bb = evalue.toByteBuffer();
14+
assertEquals(4, bb.get());
15+
assertEquals(0, bb.get());
16+
assertEquals(0, bb.get());
17+
assertEquals(0, bb.get());
18+
assertEquals(1, bb.get());
19+
20+
// Tensor tensor = Tensor.fromBlob(new float[] {1.0f}, new long[] {1});
21+
// tensor.toByteBuffer();
22+
}
23+
}

0 commit comments

Comments
 (0)