Skip to content

Commit 22863ae

Browse files
committed
Fix
1 parent 1e4f327 commit 22863ae

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed

extension/android/executorch_android/src/main/java/org/pytorch/executorch/DType.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,4 +73,13 @@ public enum DType {
7373
DType(int jniCode) {
7474
this.jniCode = jniCode;
7575
}
76+
77+
public static DType fromJniCode(int jniCode) {
78+
for (DType dtype : values()) {
79+
if (dtype.jniCode == jniCode) {
80+
return dtype;
81+
}
82+
}
83+
throw new IllegalArgumentException("No DType found for jniCode " + jniCode);
84+
}
7685
}

extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
package org.pytorch.executorch;
1010

11+
import android.util.Log;
1112
import com.facebook.jni.HybridData;
1213
import com.facebook.jni.annotations.DoNotStrip;
1314
import java.nio.Buffer;
@@ -630,14 +631,17 @@ public String toString() {
630631
}
631632
}
632633

633-
static class Tensor_unknown extends Tensor {
634+
static class Tensor_unsupported extends Tensor {
634635
private final ByteBuffer data;
635636
private final DType myDtype;
636637

637-
private Tensor_unknown(ByteBuffer data, long[] shape, DType dtype) {
638+
private Tensor_unsupported(ByteBuffer data, long[] shape, DType dtype) {
638639
super(shape);
639640
this.data = data;
640641
this.myDtype = dtype;
642+
Log.e(
643+
"ExecuTorch",
644+
toString() + " in Java. Please consider re-export the model with proper return type");
641645
}
642646

643647
@Override
@@ -647,7 +651,8 @@ public DType dtype() {
647651

648652
@Override
649653
public String toString() {
650-
return String.format("Tensor(%s, dtype=%d)", Arrays.toString(shape), this.myDtype);
654+
return String.format(
655+
"Unsupported tensor(%s, dtype=%d)", Arrays.toString(shape), this.myDtype);
651656
}
652657
}
653658

@@ -696,7 +701,7 @@ private static Tensor nativeNewTensor(
696701
} else if (DType.INT8.jniCode == dtype) {
697702
tensor = new Tensor_int8(data, shape);
698703
} else {
699-
tensor = new Tensor_unknown(data, shape, dtype);
704+
tensor = new Tensor_unsupported(data, shape, DType.fromJniCode(dtype));
700705
}
701706
tensor.mHybridData = hybridData;
702707
return tensor;

0 commit comments

Comments
 (0)