@@ -53,9 +53,10 @@ public abstract class Tensor {
53
53
54
54
@ DoNotStrip final long [] shape ;
55
55
56
+ private static final int BYTE_SIZE_BYTES = 1 ;
56
57
private static final int INT_SIZE_BYTES = 4 ;
57
- private static final int FLOAT_SIZE_BYTES = 4 ;
58
58
private static final int LONG_SIZE_BYTES = 8 ;
59
+ private static final int FLOAT_SIZE_BYTES = 4 ;
59
60
private static final int DOUBLE_SIZE_BYTES = 8 ;
60
61
61
62
/**
@@ -679,4 +680,105 @@ private static Tensor nativeNewTensor(
679
680
tensor .mHybridData = hybridData ;
680
681
return tensor ;
681
682
}
683
+
684
+ /**
685
+ * Serializes a {@code Tensor} into a byte array.
686
+ *
687
+ * @return The serialized byte array.
688
+ * @apiNote This method is experimental and subject to change without notice. This does NOT
689
+ * supoprt list type.
690
+ */
691
+ public byte [] toByteArray () {
692
+ int dtypeSize = 0 ;
693
+ byte [] tensorAsByteArray = null ;
694
+ if (dtype () == DType .UINT8 ) {
695
+ dtypeSize = BYTE_SIZE_BYTES ;
696
+ tensorAsByteArray = new byte [(int ) numel ()];
697
+ Tensor_uint8 thiz = (Tensor_uint8 ) this ;
698
+ ByteBuffer .wrap (tensorAsByteArray ).put (thiz .getDataAsUnsignedByteArray ());
699
+ } else if (dtype () == DType .INT8 ) {
700
+ dtypeSize = BYTE_SIZE_BYTES ;
701
+ tensorAsByteArray = new byte [(int ) numel ()];
702
+ Tensor_int8 thiz = (Tensor_int8 ) this ;
703
+ ByteBuffer .wrap (tensorAsByteArray ).put (thiz .getDataAsByteArray ());
704
+ } else if (dtype () == DType .INT16 ) {
705
+ throw new IllegalArgumentException ("DType.INT16 is not supported in Java so far" );
706
+ } else if (dtype () == DType .INT32 ) {
707
+ dtypeSize = INT_SIZE_BYTES ;
708
+ tensorAsByteArray = new byte [(int ) numel () * dtypeSize ];
709
+ Tensor_int32 thiz = (Tensor_int32 ) this ;
710
+ ByteBuffer .wrap (tensorAsByteArray ).asIntBuffer ().put (thiz .getDataAsIntArray ());
711
+ } else if (dtype () == DType .INT64 ) {
712
+ dtypeSize = LONG_SIZE_BYTES ;
713
+ tensorAsByteArray = new byte [(int ) numel () * dtypeSize ];
714
+ Tensor_int64 thiz = (Tensor_int64 ) this ;
715
+ ByteBuffer .wrap (tensorAsByteArray ).asLongBuffer ().put (thiz .getDataAsLongArray ());
716
+ } else if (dtype () == DType .FLOAT ) {
717
+ dtypeSize = FLOAT_SIZE_BYTES ;
718
+ tensorAsByteArray = new byte [(int ) numel () * dtypeSize ];
719
+ Tensor_float32 thiz = (Tensor_float32 ) this ;
720
+ ByteBuffer .wrap (tensorAsByteArray ).asFloatBuffer ().put (thiz .getDataAsFloatArray ());
721
+ } else if (dtype () == DType .DOUBLE ) {
722
+ dtypeSize = DOUBLE_SIZE_BYTES ;
723
+ tensorAsByteArray = new byte [(int ) numel () * dtypeSize ];
724
+ Tensor_float64 thiz = (Tensor_float64 ) this ;
725
+ ByteBuffer .wrap (tensorAsByteArray ).asDoubleBuffer ().put (thiz .getDataAsDoubleArray ());
726
+ } else {
727
+ throw new IllegalArgumentException ("Unknown Tensor dtype" );
728
+ }
729
+ ByteBuffer byteBuffer =
730
+ ByteBuffer .allocate (1 + 1 + 4 * shape .length + dtypeSize * (int ) numel ());
731
+ byteBuffer .put ((byte ) dtype ().jniCode );
732
+ byteBuffer .put ((byte ) shape .length );
733
+ for (long s : shape ) {
734
+ byteBuffer .putInt ((int ) s );
735
+ }
736
+ byteBuffer .put (tensorAsByteArray );
737
+ return byteBuffer .array ();
738
+ }
739
+
740
+ /**
741
+ * Deserializes a {@code Tensor} from a byte[].
742
+ *
743
+ * @param buffer The byte array to deserialize from.
744
+ * @return The deserialized {@code Tensor}.
745
+ * @apiNote This method is experimental and subject to change without notice. This does NOT
746
+ * supoprt list type.
747
+ */
748
+ public static Tensor fromByteArray (byte [] bytes ) {
749
+ if (bytes == null ) {
750
+ throw new IllegalArgumentException ("bytes cannot be null" );
751
+ }
752
+ ByteBuffer buffer = ByteBuffer .wrap (bytes );
753
+ if (!buffer .hasRemaining ()) {
754
+ throw new IllegalArgumentException ("invalid buffer" );
755
+ }
756
+ byte dtype = buffer .get ();
757
+ byte shapeLength = buffer .get ();
758
+ long [] shape = new long [(int ) shapeLength ];
759
+ long numel = 1 ;
760
+ for (int i = 0 ; i < shapeLength ; i ++) {
761
+ int dim = buffer .getInt ();
762
+ if (dim < 0 ) {
763
+ throw new IllegalArgumentException ("invalid shape" );
764
+ }
765
+ shape [i ] = dim ;
766
+ numel *= dim ;
767
+ }
768
+ if (dtype == DType .UINT8 .jniCode ) {
769
+ return new Tensor_uint8 (buffer , shape );
770
+ } else if (dtype == DType .INT8 .jniCode ) {
771
+ return new Tensor_int8 (buffer , shape );
772
+ } else if (dtype == DType .INT32 .jniCode ) {
773
+ return new Tensor_int32 (buffer .asIntBuffer (), shape );
774
+ } else if (dtype == DType .INT64 .jniCode ) {
775
+ return new Tensor_int64 (buffer .asLongBuffer (), shape );
776
+ } else if (dtype == DType .FLOAT .jniCode ) {
777
+ return new Tensor_float32 (buffer .asFloatBuffer (), shape );
778
+ } else if (dtype == DType .DOUBLE .jniCode ) {
779
+ return new Tensor_float64 (buffer .asDoubleBuffer (), shape );
780
+ } else {
781
+ throw new IllegalArgumentException ("Unknown Tensor dtype" );
782
+ }
783
+ }
682
784
}
0 commit comments