Skip to content

Commit 8750e17

Browse files
committed
Address comments
1 parent 44283a2 commit 8750e17

File tree

2 files changed

+29
-43
lines changed

2 files changed

+29
-43
lines changed

extension/android/src/main/java/org/pytorch/executorch/Tensor.java

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,10 @@ public abstract class Tensor {
5353

5454
@DoNotStrip final long[] shape;
5555

56+
private static final int BYTE_SIZE_BYTES = 1;
5657
private static final int INT_SIZE_BYTES = 4;
57-
private static final int FLOAT_SIZE_BYTES = 4;
5858
private static final int LONG_SIZE_BYTES = 8;
59+
private static final int FLOAT_SIZE_BYTES = 4;
5960
private static final int DOUBLE_SIZE_BYTES = 8;
6061

6162
/**
@@ -690,38 +691,38 @@ private static Tensor nativeNewTensor(
690691
public byte[] toByteArray() {
691692
int dtypeSize = 0;
692693
byte[] tensorAsByteArray = null;
693-
if (dtype() == DType.FLOAT) {
694-
dtypeSize = 4;
695-
tensorAsByteArray = new byte[(int) numel() * dtypeSize];
696-
Tensor_float32 thiz = (Tensor_float32) this;
697-
ByteBuffer.wrap(tensorAsByteArray).asFloatBuffer().put(thiz.getDataAsFloatArray());
698-
} else if (dtype() == DType.DOUBLE) {
699-
dtypeSize = 8;
700-
tensorAsByteArray = new byte[(int) numel() * dtypeSize];
701-
Tensor_float64 thiz = (Tensor_float64) this;
702-
ByteBuffer.wrap(tensorAsByteArray).asDoubleBuffer().put(thiz.getDataAsDoubleArray());
703-
} else if (dtype() == DType.UINT8) {
704-
dtypeSize = 1;
694+
if (dtype() == DType.UINT8) {
695+
dtypeSize = BYTE_SIZE_BYTES;
705696
tensorAsByteArray = new byte[(int) numel()];
706697
Tensor_uint8 thiz = (Tensor_uint8) this;
707698
ByteBuffer.wrap(tensorAsByteArray).put(thiz.getDataAsUnsignedByteArray());
708699
} else if (dtype() == DType.INT8) {
709-
dtypeSize = 1;
700+
dtypeSize = BYTE_SIZE_BYTES;
710701
tensorAsByteArray = new byte[(int) numel()];
711702
Tensor_int8 thiz = (Tensor_int8) this;
712703
ByteBuffer.wrap(tensorAsByteArray).put(thiz.getDataAsByteArray());
713704
} else if (dtype() == DType.INT16) {
714705
throw new IllegalArgumentException("DType.INT16 is not supported in Java so far");
715706
} else if (dtype() == DType.INT32) {
716-
dtypeSize = 4;
707+
dtypeSize = INT_SIZE_BYTES;
717708
tensorAsByteArray = new byte[(int) numel() * dtypeSize];
718709
Tensor_int32 thiz = (Tensor_int32) this;
719710
ByteBuffer.wrap(tensorAsByteArray).asIntBuffer().put(thiz.getDataAsIntArray());
720711
} else if (dtype() == DType.INT64) {
721-
dtypeSize = 8;
712+
dtypeSize = LONG_SIZE_BYTES;
722713
tensorAsByteArray = new byte[(int) numel() * dtypeSize];
723714
Tensor_int64 thiz = (Tensor_int64) this;
724715
ByteBuffer.wrap(tensorAsByteArray).asLongBuffer().put(thiz.getDataAsLongArray());
716+
} else if (dtype() == DType.FLOAT) {
717+
dtypeSize = FLOAT_SIZE_BYTES;
718+
tensorAsByteArray = new byte[(int) numel() * dtypeSize];
719+
Tensor_float32 thiz = (Tensor_float32) this;
720+
ByteBuffer.wrap(tensorAsByteArray).asFloatBuffer().put(thiz.getDataAsFloatArray());
721+
} else if (dtype() == DType.DOUBLE) {
722+
dtypeSize = DOUBLE_SIZE_BYTES;
723+
tensorAsByteArray = new byte[(int) numel() * dtypeSize];
724+
Tensor_float64 thiz = (Tensor_float64) this;
725+
ByteBuffer.wrap(tensorAsByteArray).asDoubleBuffer().put(thiz.getDataAsDoubleArray());
725726
} else {
726727
throw new IllegalArgumentException("Unknown Tensor dtype");
727728
}
@@ -752,30 +753,30 @@ public static Tensor fromByteArray(byte[] bytes) {
752753
if (!buffer.hasRemaining()) {
753754
throw new IllegalArgumentException("invalid buffer");
754755
}
755-
byte scalarType = buffer.get();
756-
byte numberOfDimensions = buffer.get();
757-
long[] shape = new long[(int) numberOfDimensions];
756+
byte dtype = buffer.get();
757+
byte shapeLength = buffer.get();
758+
long[] shape = new long[(int) shapeLength];
758759
long numel = 1;
759-
for (int i = 0; i < numberOfDimensions; i++) {
760+
for (int i = 0; i < shapeLength; i++) {
760761
int dim = buffer.getInt();
761762
if (dim < 0) {
762763
throw new IllegalArgumentException("invalid shape");
763764
}
764765
shape[i] = dim;
765766
numel *= dim;
766767
}
767-
if (scalarType == DType.FLOAT.jniCode) {
768-
return new Tensor_float32(buffer.asFloatBuffer(), shape);
769-
} else if (scalarType == DType.DOUBLE.jniCode) {
770-
return new Tensor_float64(buffer.asDoubleBuffer(), shape);
771-
} else if (scalarType == DType.UINT8.jniCode) {
768+
if (dtype == DType.UINT8.jniCode) {
772769
return new Tensor_uint8(buffer, shape);
773-
} else if (scalarType == DType.INT8.jniCode) {
770+
} else if (dtype == DType.INT8.jniCode) {
774771
return new Tensor_int8(buffer, shape);
775-
} else if (scalarType == DType.INT16.jniCode) {
772+
} else if (dtype == DType.INT32.jniCode) {
776773
return new Tensor_int32(buffer.asIntBuffer(), shape);
777-
} else if (scalarType == DType.INT64.jniCode) {
774+
} else if (dtype == DType.INT64.jniCode) {
778775
return new Tensor_int64(buffer.asLongBuffer(), shape);
776+
} else if (dtype == DType.FLOAT.jniCode) {
777+
return new Tensor_float32(buffer.asFloatBuffer(), shape);
778+
} else if (dtype == DType.DOUBLE.jniCode) {
779+
return new Tensor_float64(buffer.asDoubleBuffer(), shape);
779780
} else {
780781
throw new IllegalArgumentException("Unknown Tensor dtype");
781782
}

extension/android_test/src/test/java/org/pytorch/executorch/EValueTest.java

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,6 @@
2222
/** Unit tests for {@link EValue}. */
2323
@RunWith(JUnit4.class)
2424
public class EValueTest {
25-
private static final int TYPE_CODE_NONE = 0;
26-
private static final int TYPE_CODE_TENSOR = 1;
27-
private static final int TYPE_CODE_STRING = 2;
28-
private static final int TYPE_CODE_DOUBLE = 3;
29-
private static final int TYPE_CODE_INT = 4;
30-
private static final int TYPE_CODE_BOOL = 5;
31-
3225
@Test
3326
public void testNone() {
3427
EValue evalue = EValue.optionalNone();
@@ -227,7 +220,6 @@ public void testAllIllegalCast() {
227220
public void testNoneSerde() {
228221
EValue evalue = EValue.optionalNone();
229222
byte[] bytes = evalue.toByteArray();
230-
assertEquals(TYPE_CODE_NONE, bytes[0]);
231223

232224
EValue deser = EValue.fromByteArray(bytes);
233225
assertEquals(deser.isNone(), true);
@@ -237,7 +229,6 @@ public void testNoneSerde() {
237229
public void testBoolSerde() {
238230
EValue evalue = EValue.from(true);
239231
byte[] bytes = evalue.toByteArray();
240-
assertEquals(TYPE_CODE_BOOL, bytes[0]);
241232
assertEquals(1, bytes[1]);
242233

243234
EValue deser = EValue.fromByteArray(bytes);
@@ -249,7 +240,6 @@ public void testBoolSerde() {
249240
public void testBoolSerde2() {
250241
EValue evalue = EValue.from(false);
251242
byte[] bytes = evalue.toByteArray();
252-
assertEquals(TYPE_CODE_BOOL, bytes[0]);
253243
assertEquals(0, bytes[1]);
254244

255245
EValue deser = EValue.fromByteArray(bytes);
@@ -261,7 +251,6 @@ public void testBoolSerde2() {
261251
public void testIntSerde() {
262252
EValue evalue = EValue.from(1);
263253
byte[] bytes = evalue.toByteArray();
264-
assertEquals(TYPE_CODE_INT, bytes[0]);
265254
assertEquals(0, bytes[1]);
266255
assertEquals(0, bytes[2]);
267256
assertEquals(0, bytes[3]);
@@ -280,7 +269,6 @@ public void testIntSerde() {
280269
public void testLargeIntSerde() {
281270
EValue evalue = EValue.from(256000);
282271
byte[] bytes = evalue.toByteArray();
283-
assertEquals(TYPE_CODE_INT, bytes[0]);
284272

285273
EValue deser = EValue.fromByteArray(bytes);
286274
assertEquals(deser.isInt(), true);
@@ -291,7 +279,6 @@ public void testLargeIntSerde() {
291279
public void testDoubleSerde() {
292280
EValue evalue = EValue.from(1.345e-2d);
293281
byte[] bytes = evalue.toByteArray();
294-
assertEquals(TYPE_CODE_DOUBLE, bytes[0]);
295282

296283
EValue deser = EValue.fromByteArray(bytes);
297284
assertEquals(deser.isDouble(), true);
@@ -306,7 +293,6 @@ public void testLongTensorSerde() {
306293

307294
EValue evalue = EValue.from(tensor);
308295
byte[] bytes = evalue.toByteArray();
309-
assertEquals(TYPE_CODE_TENSOR, bytes[0]);
310296

311297
EValue deser = EValue.fromByteArray(bytes);
312298
assertEquals(deser.isTensor(), true);
@@ -331,7 +317,6 @@ public void testFloatTensorSerde() {
331317

332318
EValue evalue = EValue.from(tensor);
333319
byte[] bytes = evalue.toByteArray();
334-
assertEquals(TYPE_CODE_TENSOR, bytes[0]);
335320

336321
EValue deser = EValue.fromByteArray(bytes);
337322
assertEquals(deser.isTensor(), true);

0 commit comments

Comments
 (0)