Skip to content

Commit 4a3c865

Browse files
committed
[mlir] Fix arith verifier for tensor with encoding
The verifier for some arith ops were not considering that ranked tensor types can have encodings. Differential Revision: https://reviews.llvm.org/D156557
1 parent d020fa2 commit 4a3c865

File tree

3 files changed

+69
-7
lines changed

3 files changed

+69
-7
lines changed

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -130,13 +130,10 @@ namespace {
130130
/// Return the type of the same shape (scalar, vector or tensor) containing i1.
131131
static Type getI1SameShape(Type type) {
132132
auto i1Type = IntegerType::get(type.getContext(), 1);
133-
if (auto tensorType = llvm::dyn_cast<RankedTensorType>(type))
134-
return RankedTensorType::get(tensorType.getShape(), i1Type);
133+
if (auto shapedType = llvm::dyn_cast<ShapedType>(type))
134+
return shapedType.cloneWith(std::nullopt, i1Type);
135135
if (llvm::isa<UnrankedTensorType>(type))
136136
return UnrankedTensorType::get(i1Type);
137-
if (auto vectorType = llvm::dyn_cast<VectorType>(type))
138-
return VectorType::get(vectorType.getShape(), i1Type,
139-
vectorType.getScalableDims());
140137
return i1Type;
141138
}
142139

@@ -1150,9 +1147,21 @@ static Type getTypeIfLikeOrMemRef(Type type) {
11501147
type_list<ElementTypes...>());
11511148
}
11521149

1150+
/// Return false if both types are ranked tensor with mismatching encoding.
1151+
static bool hasSameEncoding(Type typeA, Type typeB) {
1152+
auto rankedTensorA = dyn_cast<RankedTensorType>(typeA);
1153+
auto rankedTensorB = dyn_cast<RankedTensorType>(typeB);
1154+
if (!rankedTensorA || !rankedTensorB)
1155+
return true;
1156+
return rankedTensorA.getEncoding() == rankedTensorB.getEncoding();
1157+
}
1158+
11531159
static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) {
1154-
return inputs.size() == 1 && outputs.size() == 1 &&
1155-
succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
1160+
if (inputs.size() != 1 || outputs.size() != 1)
1161+
return false;
1162+
if (!hasSameEncoding(inputs.front(), outputs.front()))
1163+
return false;
1164+
return succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
11561165
}
11571166

11581167
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Arith/invalid.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,15 @@ func.func @func_with_ops() {
206206

207207
// -----
208208

209+
func.func @func_with_ops() {
210+
^bb0:
211+
%c = arith.constant dense<0> : tensor<42 x i32, "foo">
212+
// expected-error@+1 {{op failed to verify that result type has i1 element type and same shape as operands}}
213+
%r = "arith.cmpi"(%c, %c) {predicate = 0} : (tensor<42 x i32, "foo">, tensor<42 x i32, "foo">) -> tensor<42 x i1, "bar">
214+
}
215+
216+
// -----
217+
209218
func.func @invalid_cmp_shape(%idx : () -> ()) {
210219
// expected-error@+1 {{'lhs' must be signless-integer-like, but got '() -> ()'}}
211220
%cmp = arith.cmpi eq, %idx, %idx : () -> ()
@@ -420,6 +429,14 @@ func.func @fpext_vec_f32_to_i32(%arg0 : vector<2xf32>) {
420429

421430
// -----
422431

432+
func.func @fpext_vec_f32_to_i32(%arg0 : tensor<2xf32, "foo">) {
433+
// expected-error@+1 {{op operand type 'tensor<2xf32, "foo">' and result type 'tensor<2xf64, "bar">' are cast incompatible}}
434+
%0 = arith.extf %arg0 : tensor<2xf32, "foo"> to tensor<2xf64, "bar">
435+
return
436+
}
437+
438+
// -----
439+
423440
func.func @fptrunc_f16_to_f32(%arg0 : f16) {
424441
// expected-error@+1 {{are cast incompatible}}
425442
%0 = arith.truncf %arg0 : f16 to f32
@@ -769,3 +786,12 @@ func.func @disallow_zero_rank_tensor_with_unranked_tensor(%arg0 : tensor<i1>, %a
769786
%0 = arith.select %arg0, %arg1, %arg2 : tensor<i1>, tensor<2x?xi64>
770787
return %0 : tensor<2x?xi64>
771788
}
789+
790+
// -----
791+
792+
func.func @select_tensor_encoding(
793+
%arg0 : tensor<8xi1, "bar">, %arg1 : tensor<8xi32, "foo">, %arg2 : tensor<8xi32, "foo">) -> tensor<8xi32, "foo"> {
794+
// expected-error @+1 {{'arith.select' op expected condition type to have the same shape as the result type}}
795+
%0 = arith.select %arg0, %arg1, %arg2 : tensor<8xi1, "bar">, tensor<8xi32, "foo">
796+
return %0 : tensor<8xi32, "foo">
797+
}

mlir/test/Dialect/Arith/ops.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -637,6 +637,12 @@ func.func @test_extf_tensor(%arg0 : tensor<8x8xf32>) -> tensor<8x8xf64> {
637637
return %0 : tensor<8x8xf64>
638638
}
639639

640+
// CHECK-LABEL: test_extf_tensor_encoding
641+
func.func @test_extf_tensor_encoding(%arg0 : tensor<8x8xf32, "foo">) -> tensor<8x8xf64, "foo"> {
642+
%0 = arith.extf %arg0 : tensor<8x8xf32, "foo"> to tensor<8x8xf64, "foo">
643+
return %0 : tensor<8x8xf64, "foo">
644+
}
645+
640646
// CHECK-LABEL: test_extf_vector
641647
func.func @test_extf_vector(%arg0 : vector<8xf32>) -> vector<8xf64> {
642648
%0 = arith.extf %arg0 : vector<8xf32> to vector<8xf64>
@@ -950,6 +956,12 @@ func.func @test_cmpi_tensor(%arg0 : tensor<8x8xi64>, %arg1 : tensor<8x8xi64>) ->
950956
return %0 : tensor<8x8xi1>
951957
}
952958

959+
// CHECK-LABEL: test_cmpi_tensor_encoding
960+
func.func @test_cmpi_tensor_encoding(%arg0 : tensor<8x8xi64, "foo">, %arg1 : tensor<8x8xi64, "foo">) -> tensor<8x8xi1, "foo"> {
961+
%0 = arith.cmpi slt, %arg0, %arg1 : tensor<8x8xi64, "foo">
962+
return %0 : tensor<8x8xi1, "foo">
963+
}
964+
953965
// CHECK-LABEL: test_cmpi_vector
954966
func.func @test_cmpi_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<8xi1> {
955967
%0 = arith.cmpi ult, %arg0, %arg1 : vector<8xi64>
@@ -1103,3 +1115,18 @@ func.func @fastmath(%arg0: f32, %arg1: f32, %arg2: i32) {
11031115

11041116
return
11051117
}
1118+
1119+
// CHECK-LABEL: @select_tensor
1120+
func.func @select_tensor(%arg0 : tensor<8xi1>, %arg1 : tensor<8xi32>, %arg2 : tensor<8xi32>) -> tensor<8xi32> {
1121+
// CHECK: = arith.select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xi1>, tensor<8xi32>
1122+
%0 = arith.select %arg0, %arg1, %arg2 : tensor<8xi1>, tensor<8xi32>
1123+
return %0 : tensor<8xi32>
1124+
}
1125+
1126+
// CHECK-LABEL: @select_tensor_encoding
1127+
func.func @select_tensor_encoding(
1128+
%arg0 : tensor<8xi1, "foo">, %arg1 : tensor<8xi32, "foo">, %arg2 : tensor<8xi32, "foo">) -> tensor<8xi32, "foo"> {
1129+
// CHECK: = arith.select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xi1, "foo">, tensor<8xi32, "foo">
1130+
%0 = arith.select %arg0, %arg1, %arg2 : tensor<8xi1, "foo">, tensor<8xi32, "foo">
1131+
return %0 : tensor<8xi32, "foo">
1132+
}

0 commit comments

Comments
 (0)