Skip to content

Commit 76aab82

Browse files
committed
disable index type for bitcast
1 parent 4d7c7dd commit 76aab82

File tree

6 files changed

+41
-25
lines changed

6 files changed

+41
-25
lines changed

mlir/include/mlir/Dialect/Arith/IR/ArithOps.td

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1392,11 +1392,10 @@ def Arith_IndexCastUIOp
13921392
// BitcastOp
13931393
//===----------------------------------------------------------------------===//
13941394

1395-
// Bitcast can convert between memrefs of signless integers, indices, and
1396-
// floats too.
1395+
// Bitcast can convert between memrefs of signless integers and floats.
13971396
def BitcastTypeConstraint : TypeConstraint<Or<[
13981397
SignlessIntegerOrFloatLike.predicate,
1399-
MemRefOf<[AnySignlessInteger, Index, AnyFloat]>.predicate]>,
1398+
MemRefOf<[AnySignlessInteger, AnyFloat]>.predicate]>,
14001399
"signless-integer-or-float-like or memref of signless-integer or float">;
14011400

14021401
def Arith_BitcastOp : Arith_CastOp<"bitcast", BitcastTypeConstraint,

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1716,25 +1716,12 @@ bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
17161716
if (!areValidCastInputsAndOutputs(inputs, outputs))
17171717
return false;
17181718

1719-
auto srcType =
1720-
getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front());
1721-
auto dstType =
1722-
getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front());
1719+
auto srcType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(inputs.front());
1720+
auto dstType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(outputs.front());
17231721
if (!srcType || !dstType)
17241722
return false;
17251723

1726-
unsigned srcWidth, dstWidth;
1727-
if (auto indexTy = dyn_cast<IndexType>(srcType))
1728-
srcWidth = IndexType::kInternalStorageBitWidth;
1729-
else
1730-
srcWidth = srcType.getIntOrFloatBitWidth();
1731-
1732-
if (auto indexTy = dyn_cast<IndexType>(dstType))
1733-
dstWidth = IndexType::kInternalStorageBitWidth;
1734-
else
1735-
dstWidth = dstType.getIntOrFloatBitWidth();
1736-
1737-
return srcWidth == dstWidth;
1724+
return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
17381725
}
17391726

17401727
OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,10 @@ bool BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
219219
if (!aT || !bT)
220220
return false;
221221

222+
if (isa<IndexType>(aT.getElementType()) ||
223+
isa<IndexType>(bT.getElementType()))
224+
return false;
225+
222226
if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth())
223227
return false;
224228

mlir/test/Dialect/Arith/invalid.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -853,3 +853,19 @@ func.func @select_tensor_encoding(
853853
%0 = arith.select %arg0, %arg1, %arg2 : tensor<8xi1, "bar">, tensor<8xi32, "foo">
854854
return %0 : tensor<8xi32, "foo">
855855
}
856+
857+
// -----
858+
859+
func.func @bitcast_index_0(%arg0 : i64) -> index {
860+
// expected-error @+1 {{'arith.bitcast' op operand type 'i64' and result type 'index' are cast incompatible}}
861+
%0 = arith.bitcast %arg0 : i64 to index
862+
return %0 : index
863+
}
864+
865+
// -----
866+
867+
func.func @bitcast_index_1(%arg0 : index) -> i64 {
868+
// expected-error @+1 {{'arith.bitcast' op operand type 'index' and result type 'i64' are cast incompatible}}
869+
%0 = arith.bitcast %arg0 : index to i64
870+
return %0 : i64
871+
}

mlir/test/Dialect/Arith/ops.mlir

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -954,12 +954,6 @@ func.func @test_bitcast_scalable_vector1(%arg0 : vector<[8]xf32>) -> vector<[8]x
954954
return %0 : vector<[8]xi32>
955955
}
956956

957-
// CHECK-LABEL: test_bitcast_index
958-
func.func @test_bitcast_index(%arg0 : i64) -> index {
959-
%0 = arith.bitcast %arg0 : i64 to index
960-
return %0 : index
961-
}
962-
963957
// CHECK-LABEL: test_cmpi
964958
func.func @test_cmpi(%arg0 : i64, %arg1 : i64) -> i1 {
965959
%0 = arith.cmpi ne, %arg0, %arg1 : i64

mlir/test/Dialect/Tensor/invalid.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -807,3 +807,19 @@ func.func @unpack_static_inner_tile_size_and_dynamic_output_shape(
807807
%0 = tensor.unpack %input inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %output : tensor<?x?x?x4xf32> -> tensor<?x?xf32>
808808
return %0 : tensor<?x?xf32>
809809
}
810+
811+
// -----
812+
813+
func.func @bitcast_index_0(%arg0 : tensor<?xi64>) -> tensor<?xindex> {
814+
// expected-error @+1 {{'tensor.bitcast' op operand type 'tensor<?xi64>' and result type 'tensor<?xindex>' are cast incompatible}}
815+
%0 = tensor.bitcast %arg0 : tensor<?xi64> to tensor<?xindex>
816+
return %0 : tensor<?xindex>
817+
}
818+
819+
// -----
820+
821+
func.func @bitcast_index_1(%arg0 : tensor<?xindex>) -> tensor<?xi64> {
822+
// expected-error @+1 {{'tensor.bitcast' op operand type 'tensor<?xindex>' and result type 'tensor<?xi64>' are cast incompatible}}
823+
%0 = tensor.bitcast %arg0 : tensor<?xindex> to tensor<?xi64>
824+
return %0 : tensor<?xi64>
825+
}

0 commit comments

Comments
 (0)