Skip to content

Commit 2767010

Browse files
committed
change type constraint of bitcast
1 parent 76aab82 commit 2767010

File tree

5 files changed

+14
-7
lines changed

5 files changed

+14
-7
lines changed

mlir/include/mlir/Dialect/Arith/IR/ArithOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1394,7 +1394,7 @@ def Arith_IndexCastUIOp
13941394

13951395
// Bitcast can convert between memrefs of signless integers and floats.
13961396
def BitcastTypeConstraint : TypeConstraint<Or<[
1397-
SignlessIntegerOrFloatLike.predicate,
1397+
SignlessInteger.predicate, FloatLike.predicate,
13981398
MemRefOf<[AnySignlessInteger, AnyFloat]>.predicate]>,
13991399
"signless-integer-or-float-like or memref of signless-integer or float">;
14001400

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,10 @@ def Tensor_BitcastOp : Tensor_Op<"bitcast", [
7575
```
7676
}];
7777

78-
let arguments = (ins AnyTensor:$source);
79-
let results = (outs AnyTensor:$dest);
78+
let arguments = (ins TensorOf<[AnySignlessInteger, AnyUnsignedInteger,
79+
AnySignedInteger, AnyFloat]>:$source);
80+
let results = (outs TensorOf<[AnySignlessInteger, AnyUnsignedInteger,
81+
AnySignedInteger, AnyFloat]>:$dest);
8082
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
8183

8284
let hasCanonicalizer = 1;

mlir/include/mlir/IR/CommonTypeConstraints.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -908,6 +908,11 @@ def BoolLike : TypeOrContainer<I1, "bool-like">;
908908

909909
def BoolLikeOfAnyRank : TypeOrContainerOfAnyRank<I1, "bool-like">;
910910

911+
// Type constraint for signless-integer-like types: signless integers,
912+
// vectors of signless integers or tensors of signless integers.
913+
def SignlessInteger : TypeOrValueSemanticsContainer<
914+
AnySignlessInteger, "signless-integer">;
915+
911916
// Type constraint for signless-integer-like types: signless integers, indices,
912917
// vectors of signless integers or indices, tensors of signless integers.
913918
def SignlessIntegerLike : TypeOrValueSemanticsContainer<

mlir/test/Dialect/Arith/invalid.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -857,15 +857,15 @@ func.func @select_tensor_encoding(
857857
// -----
858858

859859
func.func @bitcast_index_0(%arg0 : i64) -> index {
860-
// expected-error @+1 {{'arith.bitcast' op operand type 'i64' and result type 'index' are cast incompatible}}
860+
// expected-error @+1 {{'arith.bitcast' op result #0 must be signless-integer-or-float-like or memref of signless-integer or float, but got 'index'}}
861861
%0 = arith.bitcast %arg0 : i64 to index
862862
return %0 : index
863863
}
864864

865865
// -----
866866

867867
func.func @bitcast_index_1(%arg0 : index) -> i64 {
868-
// expected-error @+1 {{'arith.bitcast' op operand type 'index' and result type 'i64' are cast incompatible}}
868+
// expected-error @+1 {{'arith.bitcast' op operand #0 must be signless-integer-or-float-like or memref of signless-integer or float, but got 'index'}}
869869
%0 = arith.bitcast %arg0 : index to i64
870870
return %0 : i64
871871
}

mlir/test/Dialect/Tensor/invalid.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -811,15 +811,15 @@ func.func @unpack_static_inner_tile_size_and_dynamic_output_shape(
811811
// -----
812812

813813
func.func @bitcast_index_0(%arg0 : tensor<?xi64>) -> tensor<?xindex> {
814-
// expected-error @+1 {{'tensor.bitcast' op operand type 'tensor<?xi64>' and result type 'tensor<?xindex>' are cast incompatible}}
814+
// expected-error @+1 {{'tensor.bitcast' op result #0 must be tensor of signless integer or unsigned integer or signed integer or floating-point values, but got 'tensor<?xindex>'}}
815815
%0 = tensor.bitcast %arg0 : tensor<?xi64> to tensor<?xindex>
816816
return %0 : tensor<?xindex>
817817
}
818818

819819
// -----
820820

821821
func.func @bitcast_index_1(%arg0 : tensor<?xindex>) -> tensor<?xi64> {
822-
// expected-error @+1 {{'tensor.bitcast' op operand type 'tensor<?xindex>' and result type 'tensor<?xi64>' are cast incompatible}}
822+
// expected-error @+1 {{'tensor.bitcast' op operand #0 must be tensor of signless integer or unsigned integer or signed integer or floating-point values, but got 'tensor<?xindex>'}}
823823
%0 = tensor.bitcast %arg0 : tensor<?xindex> to tensor<?xi64>
824824
return %0 : tensor<?xi64>
825825
}

0 commit comments

Comments
 (0)