Skip to content

Commit 8cc0d41

Browse files
[mlir][linalg] Fix builder API usage in RegionBuilderHelper
Operations must be created with the supplied builder. Otherwise, the dialect conversion / greedy pattern rewrite driver can break. This commit fixes a crash in the dialect conversion: ``` within split at llvm-project/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir:1 offset :8:8: error: failed to legalize operation 'tosa.add' %0 = tosa.add %1, %arg2 : (tensor<10x10xf32>, tensor<*xf32>) -> tensor<*xf32> ^ within split at llvm-project/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir:1 offset :8:8: note: see current operation: %9 = "tosa.add"(%8, %arg2) : (tensor<10x10xf32>, tensor<*xf32>) -> tensor<*xf32> mlir-opt: llvm-project/mlir/include/mlir/IR/UseDefLists.h:198: mlir::IRObjectWithUseList<mlir::OpOperand>::~IRObjectWithUseList() [OperandType = mlir::OpOperand]: Assertion `use_empty() && "Cannot destroy a value that still has uses!"' failed. ``` This commit is the proper fix for #87297 (which was reverted).
1 parent 01e2274 commit 8cc0d41

File tree

4 files changed

+32
-20
lines changed

4 files changed

+32
-20
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -373,14 +373,15 @@ namespace {
373373

374374
class RegionBuilderHelper {
375375
public:
376-
RegionBuilderHelper(MLIRContext *context, Block &block)
377-
: context(context), block(block) {}
376+
RegionBuilderHelper(OpBuilder &builder, Block &block)
377+
: builder(builder), block(block) {}
378378

379379
// Build the unary functions defined by OpDSL.
380380
Value buildUnaryFn(UnaryFn unaryFn, Value arg) {
381381
if (!isFloatingPoint(arg))
382382
llvm_unreachable("unsupported non numeric type");
383-
OpBuilder builder = getBuilder();
383+
OpBuilder::InsertionGuard g(builder);
384+
builder.setInsertionPointToEnd(&block);
384385
switch (unaryFn) {
385386
case UnaryFn::exp:
386387
return builder.create<math::ExpOp>(arg.getLoc(), arg);
@@ -407,7 +408,8 @@ class RegionBuilderHelper {
407408
arg1.getType().getIntOrFloatBitWidth() == 1;
408409
if (!allComplex && !allFloatingPoint && !allInteger)
409410
llvm_unreachable("unsupported non numeric type");
410-
OpBuilder builder = getBuilder();
411+
OpBuilder::InsertionGuard g(builder);
412+
builder.setInsertionPointToEnd(&block);
411413
switch (binaryFn) {
412414
case BinaryFn::add:
413415
if (allComplex)
@@ -481,37 +483,41 @@ class RegionBuilderHelper {
481483
}
482484

483485
void yieldOutputs(ValueRange values) {
484-
OpBuilder builder = getBuilder();
486+
OpBuilder::InsertionGuard g(builder);
487+
builder.setInsertionPointToEnd(&block);
485488
Location loc = builder.getUnknownLoc();
486489
builder.create<YieldOp>(loc, values);
487490
}
488491

489492
Value constant(const std::string &value) {
490-
OpBuilder builder = getBuilder();
493+
OpBuilder::InsertionGuard g(builder);
494+
builder.setInsertionPointToEnd(&block);
491495
Location loc = builder.getUnknownLoc();
492496
Attribute valueAttr = parseAttribute(value, builder.getContext());
493497
return builder.create<arith::ConstantOp>(loc, ::cast<TypedAttr>(valueAttr));
494498
}
495499

496500
Value index(int64_t dim) {
497-
OpBuilder builder = getBuilder();
501+
OpBuilder::InsertionGuard g(builder);
502+
builder.setInsertionPointToEnd(&block);
498503
return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
499504
}
500505

501506
Type getIntegerType(unsigned width) {
502-
return IntegerType::get(context, width);
507+
return IntegerType::get(builder.getContext(), width);
503508
}
504509

505-
Type getFloat32Type() { return Float32Type::get(context); }
506-
Type getFloat64Type() { return Float64Type::get(context); }
510+
Type getFloat32Type() { return Float32Type::get(builder.getContext()); }
511+
Type getFloat64Type() { return Float64Type::get(builder.getContext()); }
507512

508513
private:
509514
// Generates operations to cast the given operand to a specified type.
510515
// If the cast cannot be performed, a warning will be issued and the
511516
// operand returned as-is (which will presumably yield a verification
512517
// issue downstream).
513518
Value cast(Type toType, Value operand, bool isUnsignedCast) {
514-
OpBuilder builder = getBuilder();
519+
OpBuilder::InsertionGuard g(builder);
520+
builder.setInsertionPointToEnd(&block);
515521
auto loc = operand.getLoc();
516522
return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast);
517523
}
@@ -526,13 +532,7 @@ class RegionBuilderHelper {
526532
return llvm::isa<IntegerType>(value.getType());
527533
}
528534

529-
OpBuilder getBuilder() {
530-
OpBuilder builder(context);
531-
builder.setInsertionPointToEnd(&block);
532-
return builder;
533-
}
534-
535-
MLIRContext *context;
535+
OpBuilder &builder;
536536
Block &block;
537537
};
538538

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ struct SparseTensorCodegenPass
274274
});
275275
// The following operations and dialects may be introduced by the
276276
// codegen rules, and are therefore marked as legal.
277-
target.addLegalOp<linalg::FillOp>();
277+
target.addLegalOp<linalg::FillOp, linalg::YieldOp>();
278278
target.addLegalDialect<
279279
arith::ArithDialect, bufferization::BufferizationDialect,
280280
complex::ComplexDialect, memref::MemRefDialect, scf::SCFDialect>();

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,15 @@ func.func @tensor_with_unknown_rank(%arg0: tensor<*xi8>) -> tensor<*xi8> {
1515
%0 = "tosa.abs"(%arg0) : (tensor<*xi8>) -> tensor<*xi8>
1616
return %0 : tensor<*xi8>
1717
}
18+
19+
// -----
20+
21+
// CHECK-LABEL: @unranked_add
22+
func.func @unranked_add(%arg0 : tensor<10x10xf32> , %arg1 : tensor<10x10xf32>, %arg2 : tensor<*xf32>) -> (tensor<10x10xf32>) {
23+
// expected-error@+3 {{failed to legalize operation 'tosa.add'}}
24+
%reduce = tosa.reduce_max %arg0 {axis = 1 : i32} : (tensor<10x10xf32>) -> tensor<10x1xf32>
25+
%1 = tosa.add %reduce, %arg1 : (tensor<10x1xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
26+
%0 = tosa.add %1, %arg2 : (tensor<10x10xf32>, tensor<*xf32>) -> tensor<*xf32>
27+
%2 = tosa.reshape %0 {new_shape = array<i64: 10, 10>} : (tensor<*xf32>) -> tensor<10x10xf32>
28+
return %2 : tensor<10x10xf32>
29+
}

mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1008,7 +1008,7 @@ void {0}::regionBuilder(ImplicitLocOpBuilder &b,
10081008
Block &block, ArrayRef<NamedAttribute> attrs) {{
10091009
assert({1} > 0 && block.getNumArguments() == {1} &&
10101010
"{0} regionBuilder expects {1} (>=0) args");
1011-
RegionBuilderHelper helper(block.getArgument(0).getContext(), block);
1011+
RegionBuilderHelper helper(b, block);
10121012
SmallVector<Value> yields;
10131013
{2}
10141014
{3}

0 commit comments

Comments
 (0)