@@ -53,9 +53,10 @@ public abstract class Tensor {
53
53
54
54
@ DoNotStrip final long [] shape ;
55
55
56
+ private static final int BYTE_SIZE_BYTES = 1 ;
56
57
private static final int INT_SIZE_BYTES = 4 ;
57
- private static final int FLOAT_SIZE_BYTES = 4 ;
58
58
private static final int LONG_SIZE_BYTES = 8 ;
59
+ private static final int FLOAT_SIZE_BYTES = 4 ;
59
60
private static final int DOUBLE_SIZE_BYTES = 8 ;
60
61
61
62
/**
@@ -690,38 +691,38 @@ private static Tensor nativeNewTensor(
690
691
public byte [] toByteArray () {
691
692
int dtypeSize = 0 ;
692
693
byte [] tensorAsByteArray = null ;
693
- if (dtype () == DType .FLOAT ) {
694
- dtypeSize = 4 ;
695
- tensorAsByteArray = new byte [(int ) numel () * dtypeSize ];
696
- Tensor_float32 thiz = (Tensor_float32 ) this ;
697
- ByteBuffer .wrap (tensorAsByteArray ).asFloatBuffer ().put (thiz .getDataAsFloatArray ());
698
- } else if (dtype () == DType .DOUBLE ) {
699
- dtypeSize = 8 ;
700
- tensorAsByteArray = new byte [(int ) numel () * dtypeSize ];
701
- Tensor_float64 thiz = (Tensor_float64 ) this ;
702
- ByteBuffer .wrap (tensorAsByteArray ).asDoubleBuffer ().put (thiz .getDataAsDoubleArray ());
703
- } else if (dtype () == DType .UINT8 ) {
704
- dtypeSize = 1 ;
694
+ if (dtype () == DType .UINT8 ) {
695
+ dtypeSize = BYTE_SIZE_BYTES ;
705
696
tensorAsByteArray = new byte [(int ) numel ()];
706
697
Tensor_uint8 thiz = (Tensor_uint8 ) this ;
707
698
ByteBuffer .wrap (tensorAsByteArray ).put (thiz .getDataAsUnsignedByteArray ());
708
699
} else if (dtype () == DType .INT8 ) {
709
- dtypeSize = 1 ;
700
+ dtypeSize = BYTE_SIZE_BYTES ;
710
701
tensorAsByteArray = new byte [(int ) numel ()];
711
702
Tensor_int8 thiz = (Tensor_int8 ) this ;
712
703
ByteBuffer .wrap (tensorAsByteArray ).put (thiz .getDataAsByteArray ());
713
704
} else if (dtype () == DType .INT16 ) {
714
705
throw new IllegalArgumentException ("DType.INT16 is not supported in Java so far" );
715
706
} else if (dtype () == DType .INT32 ) {
716
- dtypeSize = 4 ;
707
+ dtypeSize = INT_SIZE_BYTES ;
717
708
tensorAsByteArray = new byte [(int ) numel () * dtypeSize ];
718
709
Tensor_int32 thiz = (Tensor_int32 ) this ;
719
710
ByteBuffer .wrap (tensorAsByteArray ).asIntBuffer ().put (thiz .getDataAsIntArray ());
720
711
} else if (dtype () == DType .INT64 ) {
721
- dtypeSize = 8 ;
712
+ dtypeSize = LONG_SIZE_BYTES ;
722
713
tensorAsByteArray = new byte [(int ) numel () * dtypeSize ];
723
714
Tensor_int64 thiz = (Tensor_int64 ) this ;
724
715
ByteBuffer .wrap (tensorAsByteArray ).asLongBuffer ().put (thiz .getDataAsLongArray ());
716
+ } else if (dtype () == DType .FLOAT ) {
717
+ dtypeSize = FLOAT_SIZE_BYTES ;
718
+ tensorAsByteArray = new byte [(int ) numel () * dtypeSize ];
719
+ Tensor_float32 thiz = (Tensor_float32 ) this ;
720
+ ByteBuffer .wrap (tensorAsByteArray ).asFloatBuffer ().put (thiz .getDataAsFloatArray ());
721
+ } else if (dtype () == DType .DOUBLE ) {
722
+ dtypeSize = DOUBLE_SIZE_BYTES ;
723
+ tensorAsByteArray = new byte [(int ) numel () * dtypeSize ];
724
+ Tensor_float64 thiz = (Tensor_float64 ) this ;
725
+ ByteBuffer .wrap (tensorAsByteArray ).asDoubleBuffer ().put (thiz .getDataAsDoubleArray ());
725
726
} else {
726
727
throw new IllegalArgumentException ("Unknown Tensor dtype" );
727
728
}
@@ -752,30 +753,30 @@ public static Tensor fromByteArray(byte[] bytes) {
752
753
if (!buffer .hasRemaining ()) {
753
754
throw new IllegalArgumentException ("invalid buffer" );
754
755
}
755
- byte scalarType = buffer .get ();
756
- byte numberOfDimensions = buffer .get ();
757
- long [] shape = new long [(int ) numberOfDimensions ];
756
+ byte dtype = buffer .get ();
757
+ byte shapeLength = buffer .get ();
758
+ long [] shape = new long [(int ) shapeLength ];
758
759
long numel = 1 ;
759
- for (int i = 0 ; i < numberOfDimensions ; i ++) {
760
+ for (int i = 0 ; i < shapeLength ; i ++) {
760
761
int dim = buffer .getInt ();
761
762
if (dim < 0 ) {
762
763
throw new IllegalArgumentException ("invalid shape" );
763
764
}
764
765
shape [i ] = dim ;
765
766
numel *= dim ;
766
767
}
767
- if (scalarType == DType .FLOAT .jniCode ) {
768
- return new Tensor_float32 (buffer .asFloatBuffer (), shape );
769
- } else if (scalarType == DType .DOUBLE .jniCode ) {
770
- return new Tensor_float64 (buffer .asDoubleBuffer (), shape );
771
- } else if (scalarType == DType .UINT8 .jniCode ) {
768
+ if (dtype == DType .UINT8 .jniCode ) {
772
769
return new Tensor_uint8 (buffer , shape );
773
- } else if (scalarType == DType .INT8 .jniCode ) {
770
+ } else if (dtype == DType .INT8 .jniCode ) {
774
771
return new Tensor_int8 (buffer , shape );
775
- } else if (scalarType == DType .INT16 .jniCode ) {
772
+ } else if (dtype == DType .INT32 .jniCode ) {
776
773
return new Tensor_int32 (buffer .asIntBuffer (), shape );
777
- } else if (scalarType == DType .INT64 .jniCode ) {
774
+ } else if (dtype == DType .INT64 .jniCode ) {
778
775
return new Tensor_int64 (buffer .asLongBuffer (), shape );
776
+ } else if (dtype == DType .FLOAT .jniCode ) {
777
+ return new Tensor_float32 (buffer .asFloatBuffer (), shape );
778
+ } else if (dtype == DType .DOUBLE .jniCode ) {
779
+ return new Tensor_float64 (buffer .asDoubleBuffer (), shape );
779
780
} else {
780
781
throw new IllegalArgumentException ("Unknown Tensor dtype" );
781
782
}
0 commit comments