|
10 | 10 |
|
11 | 11 | import com.facebook.jni.HybridData;
|
12 | 12 | import com.facebook.jni.annotations.DoNotStrip;
|
| 13 | +import java.io.ByteArrayOutputStream; |
| 14 | +import java.io.ObjectOutputStream; |
13 | 15 | import java.nio.Buffer;
|
14 | 16 | import java.nio.ByteBuffer;
|
15 | 17 | import java.nio.ByteOrder;
|
@@ -679,4 +681,99 @@ private static Tensor nativeNewTensor(
|
679 | 681 | tensor.mHybridData = hybridData;
|
680 | 682 | return tensor;
|
681 | 683 | }
|
| 684 | + |
| 685 | + private static final byte[] intToByteArray(int value) { |
| 686 | + return new byte[] { |
| 687 | + (byte)(value >>> 24), |
| 688 | + (byte)(value >>> 16), |
| 689 | + (byte)(value >>> 8), |
| 690 | + (byte)value}; |
| 691 | + } |
| 692 | + |
| 693 | + /** |
| 694 | + * Serializes a {@code Tensor} into a {@link ByteBuffer}. |
| 695 | + * @return The serialized {@code ByteBuffer}. |
| 696 | + * |
| 697 | + * @apiNote This method is experimental and subject to change without notice. |
| 698 | + * This does NOT supoprt list type. |
| 699 | + */ |
| 700 | + public ByteBuffer toByteBuffer() { |
| 701 | + int dtypeSize = 0; |
| 702 | + if (dtype() == DType.FLOAT) { |
| 703 | + dtypeSize = 4; |
| 704 | + } else if (dtype() == DType.DOUBLE) { |
| 705 | + dtypeSize = 8; |
| 706 | + } else if (dtype() == DType.UINT8) { |
| 707 | + dtypeSize = 1; |
| 708 | + } else if (dtype() == DType.INT8) { |
| 709 | + dtypeSize = 1; |
| 710 | + } else if (dtype() == DType.INT16) { |
| 711 | + dtypeSize = 2; |
| 712 | + } else if (dtype() == DType.INT32) { |
| 713 | + dtypeSize = 4; |
| 714 | + } else if (dtype() == DType.INT64) { |
| 715 | + dtypeSize = 8; |
| 716 | + } else { |
| 717 | + throw new IllegalArgumentException("Unknown Tensor dtype"); |
| 718 | + } |
| 719 | + ByteBuffer byteBuffer = ByteBuffer.allocateDirect(1 + 1 + 4 * shape.length + dtypeSize * (int) numel()); |
| 720 | + byteBuffer.put((byte) dtype().jniCode); |
| 721 | + byteBuffer.put((byte) shape.length); |
| 722 | + for (long s : shape) { |
| 723 | + byteBuffer.put(intToByteArray((int) s)); |
| 724 | + } |
| 725 | + ByteArrayOutputStream bos = new ByteArrayOutputStream(); |
| 726 | + try (ObjectOutputStream out = new ObjectOutputStream(bos)) { |
| 727 | + out.writeObject(getRawDataBuffer()); |
| 728 | + out.flush(); |
| 729 | + byteBuffer.put(bos.toByteArray()); |
| 730 | + } catch (Exception ex) { |
| 731 | + throw new RuntimeException(ex); |
| 732 | + } |
| 733 | + return byteBuffer; |
| 734 | + } |
| 735 | + |
| 736 | + /** |
| 737 | + * Deserializes a {@code Tensor} from a {@link ByteBuffer}. |
| 738 | + * @param buffer The {@link ByteBuffer} to deserialize from. |
| 739 | + * @return The deserialized {@code Tensor}. |
| 740 | + * |
| 741 | + * @apiNote This method is experimental and subject to change without notice. |
| 742 | + * This does NOT supoprt list type. |
| 743 | + */ |
| 744 | + public static Tensor fromByteBuffer(ByteBuffer buffer) { |
| 745 | + if (buffer == null) { |
| 746 | + throw new IllegalArgumentException("buffer cannot be null"); |
| 747 | + } |
| 748 | + if (!buffer.hasRemaining()) { |
| 749 | + throw new IllegalArgumentException("invalid buffer"); |
| 750 | + } |
| 751 | + byte scalarType = buffer.get(); |
| 752 | + byte numberOfDimensions = buffer.get(); |
| 753 | + long[] shape = new long[(int) numberOfDimensions]; |
| 754 | + long numel = 1; |
| 755 | + for (int i = 0; i < numberOfDimensions; i++) { |
| 756 | + int dim = buffer.getInt(); |
| 757 | + if (dim < 0) { |
| 758 | + throw new IllegalArgumentException("invalid shape"); |
| 759 | + } |
| 760 | + shape[i] = dim; |
| 761 | + numel *= dim; |
| 762 | + } |
| 763 | + if (scalarType == DType.FLOAT.jniCode) { |
| 764 | + return new Tensor_float32(buffer.asFloatBuffer(), shape); |
| 765 | + } else if (scalarType == DType.DOUBLE.jniCode) { |
| 766 | + return new Tensor_float64(buffer.asDoubleBuffer(), shape); |
| 767 | + } else if (scalarType == DType.UINT8.jniCode) { |
| 768 | + return new Tensor_uint8(buffer, shape); |
| 769 | + } else if (scalarType == DType.INT8.jniCode) { |
| 770 | + return new Tensor_int8(buffer, shape); |
| 771 | + } else if (scalarType == DType.INT16.jniCode) { |
| 772 | + return new Tensor_int32(buffer.asIntBuffer(), shape); |
| 773 | + } else if (scalarType == DType.INT64.jniCode) { |
| 774 | + return new Tensor_int64(buffer.asLongBuffer(), shape); |
| 775 | + } else { |
| 776 | + throw new IllegalArgumentException("Unknown Tensor dtype"); |
| 777 | + } |
| 778 | + } |
682 | 779 | }
|
0 commit comments