|
22 | 22 | using namespace mlir;
|
23 | 23 | using namespace mlir::tensor;
|
24 | 24 |
|
25 |
| -PadOp mlir::tensor::createPadHighOp(RankedTensorType type, Value source, |
| 25 | +PadOp mlir::tensor::createPadHighOp(RankedTensorType resType, Value source, |
26 | 26 | Value pad, bool nofold, Location loc,
|
27 | 27 | OpBuilder &b,
|
28 |
| - std::optional<Value> dynOutDim) { |
| 28 | + SmallVector<Value> dynOutDims) { |
29 | 29 |
|
30 |
| - assert(type.getNumDynamicDims() <= 1 && |
31 |
| - "At most one output dim can be dynamic!"); |
| 30 | + assert((resType.getNumDynamicDims() == dynOutDims.size()) || |
| 31 | + dynOutDims.empty() && |
| 32 | + "Either none or all output dynamic dims must be specified!"); |
32 | 33 |
|
33 | 34 | // Init "low" and "high" padding values ("low" is kept as is, "high" is
|
34 | 35 | // computed below).
|
35 |
| - SmallVector<OpFoldResult> low(type.getRank(), b.getIndexAttr(0)); |
36 |
| - SmallVector<OpFoldResult> high(type.getRank(), b.getIndexAttr(0)); |
| 36 | + SmallVector<OpFoldResult> low(resType.getRank(), b.getIndexAttr(0)); |
| 37 | + SmallVector<OpFoldResult> high(resType.getRank(), b.getIndexAttr(0)); |
37 | 38 |
|
38 |
| - for (const auto [idx, val] : enumerate(type.getShape())) { |
39 |
| - bool isOutDimDynamic = ShapedType::isDynamic(val); |
40 |
| - assert((!isOutDimDynamic || dynOutDim.has_value()) && |
41 |
| - "dynamic output dim requires dynOutDim to be set"); |
| 39 | + size_t outDimIdx = 0; |
42 | 40 |
|
43 |
| - // Compute the padding width: outDim - srcDim. |
| 41 | + for (const auto [idx, val] : enumerate(resType.getShape())) { |
| 42 | + bool isDimDynamic = ShapedType::isDynamic(val); |
| 43 | + bool updatePadHigh = !isDimDynamic || !dynOutDims.empty(); |
| 44 | + |
| 45 | + // Keep the default padding width (i.e. "0") when the output dim is dynamic |
| 46 | + // and no actual output sizes have been provided. |
| 47 | + if (!updatePadHigh) |
| 48 | + continue; |
| 49 | + |
| 50 | + // Compute the padding width: resDim - sourceDim. |
44 | 51 | AffineExpr d0, d1;
|
45 | 52 | bindDims(b.getContext(), d0, d1);
|
46 |
| - OpFoldResult srcDim = tensor::getMixedSize(b, loc, source, idx); |
47 |
| - Value outDim = isOutDimDynamic |
48 |
| - ? dynOutDim.value() |
| 53 | + OpFoldResult sourceDim = tensor::getMixedSize(b, loc, source, idx); |
| 54 | + Value outDim = isDimDynamic |
| 55 | + ? dynOutDims[outDimIdx++] |
49 | 56 | : b.create<arith::ConstantIndexOp>(loc, val).getResult();
|
50 | 57 |
|
51 | 58 | high[idx] = affine::makeComposedFoldedAffineApply(b, loc, d0 - d1,
|
52 |
| - {outDim, srcDim}); |
| 59 | + {outDim, sourceDim}); |
53 | 60 | }
|
54 |
| - return b.create<PadOp>(loc, type, source, low, high, pad, nofold); |
| 61 | + return b.create<PadOp>(loc, resType, source, low, high, pad, nofold); |
55 | 62 | }
|
56 | 63 |
|
57 | 64 | SmallVector<Value> mlir::tensor::createDynamicDimValues(OpBuilder &b,
|
|
0 commit comments