Skip to content

[mlir][tensor] Extend the logic to generalise tensor.pack #109815

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,22 @@
namespace mlir {
namespace tensor {

// Return a PadOp that pads `source` to `type` size where the static
// sizes are assumed to be greater than the dynamic sizes. If `type` has dynamic
// dimensions the padding width is set to zero. The op performs "high" padding
// (i.e. it adds trailing padding values until the desired size is met).
PadOp createPadHighOp(RankedTensorType type, Value source, Value pad,
bool nofold, Location loc, OpBuilder &builder);
// Return a PadOp that pads `source` to `resType` size. The op performs "high"
// padding, i.e. it adds trailing padding values until the desired size is met.
// Output sizes are assumed to be greater than the input sizes. The padding
// width is calculated as: resDim - sourceDim.
//
// Handling static sizes is trivial. Dynamic dimensions are trickier (*):
// 1. dynamic input sizes are extracted from `source`
// 2. for dynamic output dims, there are two options:
// 2.1 all output dynamic dim sizes are specified in `dynOutDim`,
// 2.2 `dynOutDim` is empty and the corresponding padding width is set to 0.
//
// (*) Note that `resType` is just a shape and it only encodes the actual sizes
// for _static_ dimensions.
PadOp createPadHighOp(RankedTensorType resType, Value source, Value pad,
bool nofold, Location loc, OpBuilder &builder,
SmallVector<Value> dynOutDim = {});

// Creates dim ops for each dynamic dimension of the ranked tensor argument and
// returns these as values.
Expand Down
82 changes: 56 additions & 26 deletions mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1021,8 +1021,11 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
return success();
}

/// Returns a tensor.pad op if padding value is set. Otherwise, returns the
/// source directly. The method assumes that the `packOp` has static shapes.
/// If padding value is set, returns a tensor.pad Op for the source tensor,
/// with the output shape matching the output of `packOp`. Otherwise, returns
/// the source directly.
///
/// This method assumes that all outer dims for this pack Op are 1.
static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
tensor::PackOp packOp) {
Value input = packOp.getSource();
Expand All @@ -1038,26 +1041,48 @@ static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
ShapedType inputType = packOp.getSourceType();
int64_t inputRank = inputType.getRank();

SmallVector<int64_t> paddedShape;
DenseMap<int64_t, OpFoldResult> tileAndPosMapping =
packOp.getDimAndTileMapping();
for (int64_t dim = 0; dim < inputRank; ++dim) {
int64_t size = inputType.getDimSize(dim);
if (!tileAndPosMapping.count(dim)) {
paddedShape.push_back(size);

// The sizes of dynamic tiles
SmallVector<Value> dynamicTileSizes;

// Collect dims for the padded shape.
SmallVector<int64_t> paddedShape;
for (int64_t dimIdx = 0; dimIdx < inputRank; ++dimIdx) {
// 1. Non-tiled outer dims.
// These dims should be 1 and we simply preserve them.
if (!tileAndPosMapping.count(dimIdx)) {
int64_t inputDimSize = inputType.getDimSize(dimIdx);
assert(inputDimSize == 1 &&
"with all outer dims == 1, this non-tiled input dim should be 1!");
paddedShape.push_back(inputDimSize);
continue;
}

// 2. Tiled outer dims
// As all outer dims == 1, it is safe to use the tile size for the padded
// shape.
OpFoldResult tileSizeForDim = tileAndPosMapping.lookup(dimIdx);

// 2.1 Static tile sizes
std::optional<int64_t> cstTileSize = getConstantIntValue(tileSizeForDim);
if (cstTileSize.has_value()) {
paddedShape.push_back(cstTileSize.value());
continue;
}

// The size is less than or equal to tileSize because outer dims are all 1s.
std::optional<int64_t> tileSize =
getConstantIntValue(tileAndPosMapping.lookup(dim));
assert(tileSize.has_value() && "dynamic inner tile size is not supported");
paddedShape.push_back(tileSize.value());
// 2.2 Dynamic tile sizes
paddedShape.push_back(ShapedType::kDynamic);

// Get the value that holds the dynamic size.
dynamicTileSizes.push_back(llvm::dyn_cast<Value>(tileSizeForDim));
}
auto resultType =
RankedTensorType::get(paddedShape, inputType.getElementType());
return tensor::createPadHighOp(resultType, input, packOp.getPaddingValue(),
/*nofold=*/false, loc, builder);
/*nofold=*/false, loc, builder,
dynamicTileSizes);
}

// Normalizes a permutation on a higher rank space to its actual size, e.g.
Expand Down Expand Up @@ -1120,10 +1145,10 @@ getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,

LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
tensor::PackOp packOp, PatternRewriter &rewriter) const {
if (llvm::any_of(packOp.getMixedTiles(),
[](OpFoldResult tile) { return tile.is<Value>(); })) {
return rewriter.notifyMatchFailure(packOp,
"require inner tile sizes being static");
if (llvm::count_if(packOp.getMixedTiles(),
[](OpFoldResult tile) { return tile.is<Value>(); }) > 1) {
return rewriter.notifyMatchFailure(
packOp, "at most one dynamic tile size is supported");
}

// TODO: support the case that outer dimensions are not all 1s. A
Expand All @@ -1147,12 +1172,15 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
SmallVector<OpFoldResult> readOffsets(srcRank, zeroIdxAttr);
SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr);
SmallVector<OpFoldResult> readSizes;
SmallVector<int64_t> readShape;
SmallVector<OpFoldResult> transShapeForEmpty;
SmallVector<int64_t> readShapeForExtractSlice;
for (auto i : llvm::seq<unsigned>(0, srcRank)) {
if (dimAndTileMapping.count(i)) {
readShape.push_back(getConstantIntValue(dimAndTileMapping[i])
.value_or(ShapedType::kDynamic));
readShapeForExtractSlice.push_back(
getConstantIntValue(dimAndTileMapping[i])
.value_or(ShapedType::kDynamic));
readSizes.push_back(dimAndTileMapping[i]);
transShapeForEmpty.push_back(dimAndTileMapping[i]);
continue;
}
if (ShapedType::isDynamic(inputShape[i])) {
Expand All @@ -1161,12 +1189,14 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
} else {
readSizes.push_back(rewriter.getIndexAttr(inputShape[i]));
}
if (inputShape[i] != 1)
readShape.push_back(inputShape[i]);
if (inputShape[i] != 1) {
readShapeForExtractSlice.push_back(inputShape[i]);
transShapeForEmpty.push_back(rewriter.getIndexAttr(inputShape[i]));
}
}

Type elemType = packOp.getSourceType().getElementType();
auto readType = RankedTensorType::get(readShape, elemType);
auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType);

Value tile = rewriter.create<tensor::ExtractSliceOp>(
loc, readType, input, readOffsets, readSizes, readStrides);
Expand All @@ -1178,10 +1208,10 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n";
llvm::interleaveComma(perm, DBGS() << "perm: "); DBGSNL(););

SmallVector<int64_t> transpShape = readShape;
applyPermutationToVector<int64_t>(transpShape, perm);
applyPermutationToVector<OpFoldResult>(transShapeForEmpty, perm);

Value empty = rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType);
Value empty =
rewriter.create<tensor::EmptyOp>(loc, transShapeForEmpty, elemType);
auto transposedOp =
rewriter.create<linalg::TransposeOp>(loc, tile, empty, perm);

Expand Down
48 changes: 34 additions & 14 deletions mlir/lib/Dialect/Tensor/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,48 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR//VectorOps.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"

using namespace mlir;
using namespace mlir::tensor;

PadOp mlir::tensor::createPadHighOp(RankedTensorType type, Value source,
PadOp mlir::tensor::createPadHighOp(RankedTensorType resType, Value source,
Value pad, bool nofold, Location loc,
OpBuilder &b) {
SmallVector<OpFoldResult> low(type.getRank(), b.getIndexAttr(0));
SmallVector<OpFoldResult> high(type.getRank(), b.getIndexAttr(0));
for (const auto &en : enumerate(type.getShape())) {
// Pad only the static dimensions of the result tensor type.
if (ShapedType::isDynamic(en.value()))
OpBuilder &b,
SmallVector<Value> dynOutDims) {

assert((resType.getNumDynamicDims() == dynOutDims.size()) ||
dynOutDims.empty() &&
"Either none or all output dynamic dims must be specified!");

// Init "low" and "high" padding values ("low" is kept as is, "high" is
// computed below).
SmallVector<OpFoldResult> low(resType.getRank(), b.getIndexAttr(0));
SmallVector<OpFoldResult> high(resType.getRank(), b.getIndexAttr(0));

size_t outDimIdx = 0;

for (const auto [idx, val] : enumerate(resType.getShape())) {
bool isDimDynamic = ShapedType::isDynamic(val);
bool updatePadHigh = !isDimDynamic || !dynOutDims.empty();

// Keep the default padding width (i.e. "0") when the output dim is dynamic
// and no actual output sizes have been provided.
if (!updatePadHigh)
continue;
// Compute the padding width.
AffineExpr d0;
bindDims(b.getContext(), d0);
OpFoldResult sz = tensor::getMixedSize(b, loc, source, en.index());
high[en.index()] =
affine::makeComposedFoldedAffineApply(b, loc, en.value() - d0, {sz});

// Compute the padding width: resDim - sourceDim.
AffineExpr d0, d1;
bindDims(b.getContext(), d0, d1);
OpFoldResult sourceDim = tensor::getMixedSize(b, loc, source, idx);
OpFoldResult outDim = isDimDynamic ? OpFoldResult(dynOutDims[outDimIdx++])
: OpFoldResult(b.getIndexAttr(val));

high[idx] = affine::makeComposedFoldedAffineApply(b, loc, d0 - d1,
{outDim, sourceDim});
}
return b.create<PadOp>(loc, type, source, low, high, pad, nofold);
return b.create<PadOp>(loc, resType, source, low, high, pad, nofold);
}

SmallVector<Value> mlir::tensor::createDynamicDimValues(OpBuilder &b,
Expand Down
55 changes: 55 additions & 0 deletions mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ func.func @simple_pad_and_pack(%input: tensor<5x1xf32>, %output: tensor<1x1x8x2x
%0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : tensor<5x1xf32> -> tensor<1x1x8x2xf32>
return %0 : tensor<1x1x8x2xf32>
}
// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0] -> (s0 - 5)>

// CHECK-LABEL: func.func @simple_pad_and_pack
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
Expand All @@ -34,6 +36,59 @@ func.func @simple_pad_and_pack(%input: tensor<5x1xf32>, %output: tensor<1x1x8x2x
// CHECK-SAME: [0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1]
// CHECK: return %[[INSERT]]

/// Same as example above, but with dynamic tile size.

func.func @simple_pad_and_pack_dynamic(%input: tensor<5x1xf32>, %output: tensor<1x1x?x2xf32>, %pad: f32, %high: index) -> tensor<1x1x?x2xf32> {
%0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [%high, 2] into %output : tensor<5x1xf32> -> tensor<1x1x?x2xf32>
return %0 : tensor<1x1x?x2xf32>
}

// CHECK-LABEL: func.func @simple_pad_and_pack_dynamic(
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[PAD_VAL:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[HIGH_VAL:.*]]: index) -> tensor<1x1x?x2xf32> {
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[PAD_HIGH:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[HIGH_VAL]]]
// CHECK: %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[PAD_HIGH]], 1] {
// CHECK: tensor.yield %[[PAD_VAL]] : f32
// CHECK-NOT: linalg.transpose
// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[VAL_10:.*]][0, 0] {{\[}}%[[HIGH_VAL]], 2] [1, 1] : tensor<?x2xf32> to tensor<?x2xf32>
// CHECK: %[[DIM:.*]] = tensor.dim %[[DEST]], %[[C2]] : tensor<1x1x?x2xf32>
// CHECK: %[[RES:.*]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[DIM]], 2] [1, 1, 1, 1] : tensor<?x2xf32> into tensor<1x1x?x2xf32>
// CHECK: return %[[RES]] : tensor<1x1x?x2xf32>

/// Same as example above, but with scalable tile size.

/// NOTE: For this example to make sense in practice, the "?" in the output shape
/// should effectively be 8 * vector.vscale (and that's what tensor.dim
/// below should return).

func.func @simple_pad_and_pack_scalable(%input: tensor<5x1xf32>, %output: tensor<1x1x?x2xf32>, %pad: f32) -> tensor<1x1x?x2xf32> {
%c8 = arith.constant 8 : index
%vscale = vector.vscale
%c8_vscale = arith.muli %vscale, %c8 : index
%0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [%c8_vscale, 2] into %output : tensor<5x1xf32> -> tensor<1x1x?x2xf32>
return %0 : tensor<1x1x?x2xf32>
}

// CHECK-LABEL: func.func @simple_pad_and_pack_scalable(
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]: tensor<5x1xf32>,
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]: tensor<1x1x?x2xf32>,
// CHECK-SAME: %[[PAD_VAL:[a-zA-Z0-9]+]]: f32) -> tensor<1x1x?x2xf32> {
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index
// CHECK-DAG: %[[VS:.+]] = vector.vscale
// CHECK: %[[C8_VS:.+]] = arith.muli %[[VS]], %[[C8]] : index
// CHECK: %[[PAD_HIGH:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[C8_VS]]]
// CHECK: %[[PAD:.+]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[PAD_HIGH]], 1] {
// CHECK: tensor.yield %[[PAD_VAL]] : f32
// CHECK-NOT: linalg.transpose
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[PAD:.+]][0, 0] {{\[}}%[[C8_VS]], 2] [1, 1] : tensor<?x2xf32> to tensor<?x2xf32>
// CHECK: %[[DIM:.+]] = tensor.dim %[[DEST]], %[[C2]] : tensor<1x1x?x2xf32>
// CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[DIM]], 2] [1, 1, 1, 1] : tensor<?x2xf32> into tensor<1x1x?x2xf32>
// CHECK: return %[[RES]] : tensor<1x1x?x2xf32>

// -----

func.func @simple_NC_to_CNnc(%arg0: tensor<32x8xf32>, %arg1: tensor<1x1x32x8xf32>) -> tensor<1x1x32x8xf32>{
Expand Down
Loading