Skip to content

Commit 0000fe8

Browse files
committed
[MLIR] Add test cases for illegal output shape argument
Signed-Off-by: Gaurav Shukla<[email protected]>
1 parent 81a0637 commit 0000fe8

File tree

3 files changed

+24
-7
lines changed

3 files changed

+24
-7
lines changed

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -762,14 +762,14 @@ validateDynamicDimExpansion(LinalgOp linalgOp,
762762
continue;
763763
bool foundDynamic = false;
764764
for (int64_t shape : expandedShape) {
765-
if (ShapedType::isDynamic(shape)) {
766-
if (foundDynamic) {
767-
return rewriter.notifyMatchFailure(
768-
linalgOp, "cannot infer expanded shape with multiple dynamic "
769-
"dims in the same reassociation group");
770-
}
771-
foundDynamic = true;
765+
if (!ShapedType::isDynamic(shape))
766+
continue;
767+
if (foundDynamic) {
768+
return rewriter.notifyMatchFailure(
769+
linalgOp, "cannot infer expanded shape with multiple dynamic "
770+
"dims in the same reassociation group");
772771
}
772+
foundDynamic = true;
773773
}
774774
}
775775
return success();

mlir/test/Dialect/MemRef/invalid.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,14 @@ func.func @expand_shape(%arg0: memref<f32>) {
408408

409409
// -----
410410

411+
func.func @expand_shape_illegal_output_shape(%arg0: memref<2xf32>) {
412+
// expected-error @+1 {{expected number of static shape bounds to be equal to the output rank (3) but found 2 inputs instead}}
413+
%0 = memref.expand_shape %arg0 [[0, 1, 2]] output_shape [1, 2] : memref<2xf32> into memref<1x1x2xf32>
414+
return
415+
}
416+
417+
// -----
418+
411419
func.func @collapse_shape_out_of_bounds(%arg0: memref<?x?xf32>) {
412420
// expected-error @+1 {{op reassociation index 2 is out of bounds}}
413421
%0 = memref.collapse_shape %arg0 [[0, 1, 2]] : memref<?x?xf32> into memref<?xf32>

mlir/test/Dialect/Tensor/invalid.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,15 @@ func.func @illegal_expanding_reshape_mixed_tensor_2(%arg0 : tensor<?x?xf32>, %sz
311311
return %0 : tensor<?x4x5xf32>
312312
}
313313

314+
// -----
315+
316+
func.func @expand_shape_illegal_output_shape(%arg0: tensor<2xf32>) {
317+
// expected-error @+1 {{expected number of static shape dims to be equal to the output rank (3) but found 2 inputs instead}}
318+
%0 = tensor.expand_shape %arg0 [[0, 1, 2]] output_shape [1, 2] : tensor<2xf32> into tensor<1x1x2xf32>
319+
return
320+
}
321+
322+
314323
// -----
315324

316325
func.func @illegal_collapsing_reshape_mixed_tensor(%arg0 : tensor<?x4x5xf32>) -> tensor<?x?xf32> {

0 commit comments

Comments
 (0)