@@ -679,4 +679,105 @@ private static Tensor nativeNewTensor(
679
679
tensor .mHybridData = hybridData ;
680
680
return tensor ;
681
681
}
682
+
683
+ /**
684
+ * Serializes a {@code Tensor} into a byte array.
685
+ *
686
+ * @return The serialized byte array.
687
+ * @apiNote This method is experimental and subject to change without notice. This does NOT
688
+ * supoprt list type.
689
+ */
690
+ public byte [] toByteArray () {
691
+ int dtypeSize = 0 ;
692
+ byte [] tensorAsByteArray = null ;
693
+ if (dtype () == DType .FLOAT ) {
694
+ dtypeSize = 4 ;
695
+ tensorAsByteArray = new byte [(int ) numel () * dtypeSize ];
696
+ Tensor_float32 thiz = (Tensor_float32 ) this ;
697
+ ByteBuffer .wrap (tensorAsByteArray ).asFloatBuffer ().put (thiz .getDataAsFloatArray ());
698
+ } else if (dtype () == DType .DOUBLE ) {
699
+ dtypeSize = 8 ;
700
+ tensorAsByteArray = new byte [(int ) numel () * dtypeSize ];
701
+ Tensor_float64 thiz = (Tensor_float64 ) this ;
702
+ ByteBuffer .wrap (tensorAsByteArray ).asDoubleBuffer ().put (thiz .getDataAsDoubleArray ());
703
+ } else if (dtype () == DType .UINT8 ) {
704
+ dtypeSize = 1 ;
705
+ tensorAsByteArray = new byte [(int ) numel ()];
706
+ Tensor_uint8 thiz = (Tensor_uint8 ) this ;
707
+ ByteBuffer .wrap (tensorAsByteArray ).put (thiz .getDataAsUnsignedByteArray ());
708
+ } else if (dtype () == DType .INT8 ) {
709
+ dtypeSize = 1 ;
710
+ tensorAsByteArray = new byte [(int ) numel ()];
711
+ Tensor_int8 thiz = (Tensor_int8 ) this ;
712
+ ByteBuffer .wrap (tensorAsByteArray ).put (thiz .getDataAsByteArray ());
713
+ } else if (dtype () == DType .INT16 ) {
714
+ throw new IllegalArgumentException ("DType.INT16 is not supported in Java so far" );
715
+ } else if (dtype () == DType .INT32 ) {
716
+ dtypeSize = 4 ;
717
+ tensorAsByteArray = new byte [(int ) numel () * dtypeSize ];
718
+ Tensor_int32 thiz = (Tensor_int32 ) this ;
719
+ ByteBuffer .wrap (tensorAsByteArray ).asIntBuffer ().put (thiz .getDataAsIntArray ());
720
+ } else if (dtype () == DType .INT64 ) {
721
+ dtypeSize = 8 ;
722
+ tensorAsByteArray = new byte [(int ) numel () * dtypeSize ];
723
+ Tensor_int64 thiz = (Tensor_int64 ) this ;
724
+ ByteBuffer .wrap (tensorAsByteArray ).asLongBuffer ().put (thiz .getDataAsLongArray ());
725
+ } else {
726
+ throw new IllegalArgumentException ("Unknown Tensor dtype" );
727
+ }
728
+ ByteBuffer byteBuffer =
729
+ ByteBuffer .allocate (1 + 1 + 4 * shape .length + dtypeSize * (int ) numel ());
730
+ byteBuffer .put ((byte ) dtype ().jniCode );
731
+ byteBuffer .put ((byte ) shape .length );
732
+ for (long s : shape ) {
733
+ byteBuffer .putInt ((int ) s );
734
+ }
735
+ byteBuffer .put (tensorAsByteArray );
736
+ return byteBuffer .array ();
737
+ }
738
+
739
+ /**
740
+ * Deserializes a {@code Tensor} from a byte[].
741
+ *
742
+ * @param buffer The byte array to deserialize from.
743
+ * @return The deserialized {@code Tensor}.
744
+ * @apiNote This method is experimental and subject to change without notice. This does NOT
745
+ * supoprt list type.
746
+ */
747
+ public static Tensor fromByteArray (byte [] bytes ) {
748
+ if (bytes == null ) {
749
+ throw new IllegalArgumentException ("bytes cannot be null" );
750
+ }
751
+ ByteBuffer buffer = ByteBuffer .wrap (bytes );
752
+ if (!buffer .hasRemaining ()) {
753
+ throw new IllegalArgumentException ("invalid buffer" );
754
+ }
755
+ byte scalarType = buffer .get ();
756
+ byte numberOfDimensions = buffer .get ();
757
+ long [] shape = new long [(int ) numberOfDimensions ];
758
+ long numel = 1 ;
759
+ for (int i = 0 ; i < numberOfDimensions ; i ++) {
760
+ int dim = buffer .getInt ();
761
+ if (dim < 0 ) {
762
+ throw new IllegalArgumentException ("invalid shape" );
763
+ }
764
+ shape [i ] = dim ;
765
+ numel *= dim ;
766
+ }
767
+ if (scalarType == DType .FLOAT .jniCode ) {
768
+ return new Tensor_float32 (buffer .asFloatBuffer (), shape );
769
+ } else if (scalarType == DType .DOUBLE .jniCode ) {
770
+ return new Tensor_float64 (buffer .asDoubleBuffer (), shape );
771
+ } else if (scalarType == DType .UINT8 .jniCode ) {
772
+ return new Tensor_uint8 (buffer , shape );
773
+ } else if (scalarType == DType .INT8 .jniCode ) {
774
+ return new Tensor_int8 (buffer , shape );
775
+ } else if (scalarType == DType .INT16 .jniCode ) {
776
+ return new Tensor_int32 (buffer .asIntBuffer (), shape );
777
+ } else if (scalarType == DType .INT64 .jniCode ) {
778
+ return new Tensor_int64 (buffer .asLongBuffer (), shape );
779
+ } else {
780
+ throw new IllegalArgumentException ("Unknown Tensor dtype" );
781
+ }
782
+ }
682
783
}
0 commit comments