|
15 | 15 | #include "mlir/Dialect/SCF/SCF.h"
|
16 | 16 | #include "mlir/Dialect/StandardOps/IR/Ops.h"
|
17 | 17 | #include "mlir/Dialect/StandardOps/Transforms/Passes.h"
|
18 |
| -#include "mlir/Dialect/Tensor/IR/Tensor.h" |
19 | 18 | #include "mlir/IR/BlockAndValueMapping.h"
|
20 | 19 | #include "mlir/Transforms/DialectConversion.h"
|
21 | 20 |
|
@@ -70,18 +69,29 @@ class BufferizeDynamicTensorFromElementsOp
|
70 | 69 | upperBounds.push_back(upperBound);
|
71 | 70 | }
|
72 | 71 |
|
73 |
| - // Generate tensor elements with a parallel loop. |
74 |
| - rewriter.create<scf::ParallelOp>( |
75 |
| - loc, lowerBounds, upperBounds, steps, |
76 |
| - [&](OpBuilder &b, Location loc, ValueRange ivs) { |
77 |
| - BlockAndValueMapping mapping; |
78 |
| - mapping.map(op.body().getArguments(), ivs); |
79 |
| - for (auto &nestedOp : op.getBody()->without_terminator()) |
80 |
| - b.clone(nestedOp, mapping); |
81 |
| - auto yieldOp = cast<YieldOp>(op.getBody()->getTerminator()); |
82 |
| - b.create<StoreOp>(loc, mapping.lookup(yieldOp.value()), result, ivs); |
83 |
| - b.create<scf::YieldOp>(loc); |
84 |
| - }); |
| 72 | + // Generate tensor elements with a parallel loop that stores into |
| 73 | + // each element of the resulting memref. |
| 74 | + // |
| 75 | + // This is a bit tricky. We cannot simply clone the ops because when an op |
| 76 | + // is cloned, it must be legalized. However, we want to allow arbitrary ops |
| 77 | + // in the body that we don't necessarily have legalization patterns for as |
| 78 | + // part of this dialect conversion invocation. |
| 79 | + // |
| 80 | + // To accomplish this, we use mergeBlockBefore to "move" this op's body |
| 81 | + // into the scf.parallel's body. |
| 82 | + auto parallel = |
| 83 | + rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps); |
| 84 | + Block *parallelBody = parallel.getBody(); |
| 85 | + rewriter.mergeBlockBefore(op.getBody(), parallelBody->getTerminator(), |
| 86 | + parallelBody->getArguments()); |
| 87 | + // Replace the inlined yield op with a store op. The scf.parallel's builder |
| 88 | + // already populated an scf.yield at the end, so we don't need to worry |
| 89 | + // about creating that. |
| 90 | + Operation *elementYield = parallelBody->getTerminator()->getPrevNode(); |
| 91 | + rewriter.setInsertionPointAfter(elementYield); |
| 92 | + rewriter.replaceOpWithNewOp<StoreOp>(elementYield, |
| 93 | + elementYield->getOperands()[0], result, |
| 94 | + parallelBody->getArguments()); |
85 | 95 |
|
86 | 96 | rewriter.replaceOp(op, {result});
|
87 | 97 | return success();
|
@@ -168,7 +178,6 @@ struct StdBufferizePass : public StdBufferizeBase<StdBufferizePass> {
|
168 | 178 |
|
169 | 179 | target.addLegalDialect<StandardOpsDialect>();
|
170 | 180 | target.addLegalDialect<scf::SCFDialect>();
|
171 |
| - target.addLegalDialect<tensor::TensorDialect>(); |
172 | 181 |
|
173 | 182 | populateStdBufferizePatterns(context, typeConverter, patterns);
|
174 | 183 | target.addIllegalOp<DynamicTensorFromElementsOp, TensorCastOp,
|
|
0 commit comments