Skip to content

[Android] Remove runtime internal evalue list types #7012

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 3 additions & 119 deletions extension/android/src/main/java/org/pytorch/executorch/EValue.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.Locale;
import java.util.Optional;
import org.pytorch.executorch.annotations.Experimental;

/**
Expand Down Expand Up @@ -44,26 +43,8 @@ public class EValue {
private static final int TYPE_CODE_INT = 4;
private static final int TYPE_CODE_BOOL = 5;

private static final int TYPE_CODE_LIST_BOOL = 6;
private static final int TYPE_CODE_LIST_DOUBLE = 7;
private static final int TYPE_CODE_LIST_INT = 8;
private static final int TYPE_CODE_LIST_TENSOR = 9;
private static final int TYPE_CODE_LIST_SCALAR = 10;
private static final int TYPE_CODE_LIST_OPTIONAL_TENSOR = 11;

private String[] TYPE_NAMES = {
"None",
"Tensor",
"String",
"Double",
"Int",
"Bool",
"ListBool",
"ListDouble",
"ListInt",
"ListTensor",
"ListScalar",
"ListOptionalTensor",
"None", "Tensor", "String", "Double", "Int", "Bool",
};

@DoNotStrip private final int mTypeCode;
Expand Down Expand Up @@ -104,31 +85,6 @@ public boolean isString() {
return TYPE_CODE_STRING == this.mTypeCode;
}

@DoNotStrip
public boolean isBoolList() {
return TYPE_CODE_LIST_BOOL == this.mTypeCode;
}

@DoNotStrip
public boolean isIntList() {
return TYPE_CODE_LIST_INT == this.mTypeCode;
}

@DoNotStrip
public boolean isDoubleList() {
return TYPE_CODE_LIST_DOUBLE == this.mTypeCode;
}

@DoNotStrip
public boolean isTensorList() {
return TYPE_CODE_LIST_TENSOR == this.mTypeCode;
}

@DoNotStrip
public boolean isOptionalTensorList() {
return TYPE_CODE_LIST_OPTIONAL_TENSOR == this.mTypeCode;
}

/** Creates a new {@code EValue} of type {@code Optional} that contains no value. */
@DoNotStrip
public static EValue optionalNone() {
Expand Down Expand Up @@ -175,46 +131,6 @@ public static EValue from(String value) {
return iv;
}

/** Creates a new {@code EValue} of type {@code List[bool]}. */
@DoNotStrip
public static EValue listFrom(boolean... list) {
final EValue iv = new EValue(TYPE_CODE_LIST_BOOL);
iv.mData = list;
return iv;
}

/** Creates a new {@code EValue} of type {@code List[int]}. */
@DoNotStrip
public static EValue listFrom(long... list) {
final EValue iv = new EValue(TYPE_CODE_LIST_INT);
iv.mData = list;
return iv;
}

/** Creates a new {@code EValue} of type {@code List[double]}. */
@DoNotStrip
public static EValue listFrom(double... list) {
final EValue iv = new EValue(TYPE_CODE_LIST_DOUBLE);
iv.mData = list;
return iv;
}

/** Creates a new {@code EValue} of type {@code List[Tensor]}. */
@DoNotStrip
public static EValue listFrom(Tensor... list) {
final EValue iv = new EValue(TYPE_CODE_LIST_TENSOR);
iv.mData = list;
return iv;
}

/** Creates a new {@code EValue} of type {@code List[Optional[Tensor]]}. */
@DoNotStrip
public static EValue listFrom(Optional<Tensor>... list) {
final EValue iv = new EValue(TYPE_CODE_LIST_OPTIONAL_TENSOR);
iv.mData = list;
return iv;
}

@DoNotStrip
public Tensor toTensor() {
preconditionType(TYPE_CODE_TENSOR, mTypeCode);
Expand Down Expand Up @@ -245,36 +161,6 @@ public String toStr() {
return (String) mData;
}

@DoNotStrip
public boolean[] toBoolList() {
preconditionType(TYPE_CODE_LIST_BOOL, mTypeCode);
return (boolean[]) mData;
}

@DoNotStrip
public long[] toIntList() {
preconditionType(TYPE_CODE_LIST_INT, mTypeCode);
return (long[]) mData;
}

@DoNotStrip
public double[] toDoubleList() {
preconditionType(TYPE_CODE_LIST_DOUBLE, mTypeCode);
return (double[]) mData;
}

@DoNotStrip
public Tensor[] toTensorList() {
preconditionType(TYPE_CODE_LIST_TENSOR, mTypeCode);
return (Tensor[]) mData;
}

@DoNotStrip
public Optional<Tensor>[] toOptionalTensorList() {
preconditionType(TYPE_CODE_LIST_OPTIONAL_TENSOR, mTypeCode);
return (Optional<Tensor>[]) mData;
}

private void preconditionType(int typeCodeExpected, int typeCode) {
if (typeCode != typeCodeExpected) {
throw new IllegalStateException(
Expand All @@ -294,8 +180,7 @@ private String getTypeName(int typeCode) {
* Serializes an {@code EValue} into a byte array.
*
* @return The serialized byte array.
* @apiNote This method is experimental and subject to change without notice. This does NOT
* supoprt list type.
* @apiNote This method is experimental and subject to change without notice.
*/
public byte[] toByteArray() {
if (isNone()) {
Expand Down Expand Up @@ -331,8 +216,7 @@ public byte[] toByteArray() {
*
* @param bytes The byte array to deserialize from.
* @return The deserialized {@code EValue}.
* @apiNote This method is experimental and subject to change without notice. This does NOT list
* type.
* @apiNote This method is experimental and subject to change without notice.
*/
public static EValue fromByteArray(byte[] bytes) {
ByteBuffer buffer = ByteBuffer.wrap(bytes);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import static org.junit.Assert.fail;

import java.util.Arrays;
import java.util.Optional;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
Expand Down Expand Up @@ -66,70 +65,6 @@ public void testStringValue() {
assertEquals(evalue.toStr(), "a");
}

@Test
public void testBoolListValue() {
boolean[] value = {true, false, true};
EValue evalue = EValue.listFrom(value);
assertTrue(evalue.isBoolList());
assertTrue(Arrays.equals(value, evalue.toBoolList()));
}

@Test
public void testIntListValue() {
long[] value = {Long.MIN_VALUE, 0, Long.MAX_VALUE};
EValue evalue = EValue.listFrom(value);
assertTrue(evalue.isIntList());
assertTrue(Arrays.equals(value, evalue.toIntList()));
}

@Test
public void testDoubleListValue() {
double[] value = {Double.MIN_VALUE, 0.1d, 0.01d, 0.001d, Double.MAX_VALUE};
EValue evalue = EValue.listFrom(value);
assertTrue(evalue.isDoubleList());
assertTrue(Arrays.equals(value, evalue.toDoubleList()));
}

@Test
public void testTensorListValue() {
long[][] data = {{1, 2, 3}, {1, 2, 3, 4, 5, 6}};
long[][] shape = {{1, 3}, {2, 3}};
Tensor[] tensors = {Tensor.fromBlob(data[0], shape[0]), Tensor.fromBlob(data[1], shape[1])};

EValue evalue = EValue.listFrom(tensors);
assertTrue(evalue.isTensorList());

assertTrue(Arrays.equals(evalue.toTensorList()[0].shape, shape[0]));
assertTrue(Arrays.equals(evalue.toTensorList()[0].getDataAsLongArray(), data[0]));

assertTrue(Arrays.equals(evalue.toTensorList()[1].shape, shape[1]));
assertTrue(Arrays.equals(evalue.toTensorList()[1].getDataAsLongArray(), data[1]));
}

@Test
@SuppressWarnings("unchecked")
public void testOptionalTensorListValue() {
long[][] data = {{1, 2, 3}, {1, 2, 3, 4, 5, 6}};
long[][] shape = {{1, 3}, {2, 3}};

EValue evalue =
EValue.listFrom(
Optional.<Tensor>empty(),
Optional.of(Tensor.fromBlob(data[0], shape[0])),
Optional.of(Tensor.fromBlob(data[1], shape[1])));
assertTrue(evalue.isOptionalTensorList());

assertTrue(!evalue.toOptionalTensorList()[0].isPresent());

assertTrue(evalue.toOptionalTensorList()[1].isPresent());
assertTrue(Arrays.equals(evalue.toOptionalTensorList()[1].get().shape, shape[0]));
assertTrue(Arrays.equals(evalue.toOptionalTensorList()[1].get().getDataAsLongArray(), data[0]));

assertTrue(evalue.toOptionalTensorList()[2].isPresent());
assertTrue(Arrays.equals(evalue.toOptionalTensorList()[2].get().shape, shape[1]));
assertTrue(Arrays.equals(evalue.toOptionalTensorList()[2].get().getDataAsLongArray(), data[1]));
}

@Test
public void testAllIllegalCast() {
EValue evalue = EValue.optionalNone();
Expand Down Expand Up @@ -174,46 +109,6 @@ public void testAllIllegalCast() {
fail("Should have thrown an exception");
} catch (IllegalStateException e) {
}

// try bool list
assertFalse(evalue.isBoolList());
try {
evalue.toBoolList();
fail("Should have thrown an exception");
} catch (IllegalStateException e) {
}

// try int list
assertFalse(evalue.isIntList());
try {
evalue.toIntList();
fail("Should have thrown an exception");
} catch (IllegalStateException e) {
}

// try double list
assertFalse(evalue.isDoubleList());
try {
evalue.toBool();
fail("Should have thrown an exception");
} catch (IllegalStateException e) {
}

// try Tensor list
assertFalse(evalue.isTensorList());
try {
evalue.toTensorList();
fail("Should have thrown an exception");
} catch (IllegalStateException e) {
}

// try optional Tensor list
assertFalse(evalue.isOptionalTensorList());
try {
evalue.toOptionalTensorList();
fail("Should have thrown an exception");
} catch (IllegalStateException e) {
}
}

@Test
Expand Down
Loading