Skip to content

[mlir][ArmSME] Remove ConvertIllegalShapeCastOpsToTransposes #139706

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 114 additions & 57 deletions mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -724,59 +724,6 @@ struct LiftIllegalVectorTransposeToMemory
}
};

/// A rewrite to turn unit dim transpose-like vector.shape_casts into
/// vector.transposes. The shape_cast has to be from an illegal vector type to a
/// legal one (as defined by isLegalVectorType).
///
/// The reasoning for this is if we've got to this pass and we still have
/// shape_casts of illegal types, then they likely will not cancel out. Turning
/// them into transposes gives LiftIllegalVectorTransposeToMemory a chance to
/// eliminate them.
///
/// Example:
///
/// BEFORE:
/// ```mlir
/// %0 = vector.shape_cast %a : vector<[4]x1xf32> to vector<1x[4]xf32>
/// ```
///
/// AFTER:
/// ```mlir
/// %0 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
/// ```
struct ConvertIllegalShapeCastOpsToTransposes
: public OpRewritePattern<vector::ShapeCastOp> {
using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;

LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
PatternRewriter &rewriter) const override {
auto sourceType = shapeCastOp.getSourceVectorType();
auto resultType = shapeCastOp.getResultVectorType();
if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
return rewriter.notifyMatchFailure(shapeCastOp,
kMatchFailureNotIllegalToLegal);

// Note: If we know that `sourceType` is an illegal vector type (and 2D)
// then dim 0 is scalable and dim 1 is fixed.
if (sourceType.getRank() != 2 || sourceType.getDimSize(1) != 1)
return rewriter.notifyMatchFailure(
shapeCastOp, "expected source to be a 2D scalable vector with a "
"trailing unit dim");

auto loc = shapeCastOp.getLoc();
auto transpose = rewriter.create<vector::TransposeOp>(
loc, shapeCastOp.getSource(), ArrayRef<int64_t>{1, 0});

if (resultType.getRank() == 1)
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(shapeCastOp, resultType,
transpose);
else
rewriter.replaceOp(shapeCastOp, transpose);

return success();
}
};

/// Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use
/// the ZA state. This workaround rewrite to support these transposes when ZA is
/// available.
Expand Down Expand Up @@ -920,6 +867,116 @@ struct LowerIllegalTransposeStoreViaZA
}
};

/// Lower `vector.transfer_read` of a scalable column to `scf::for`
///
/// Lowers a "read" of a scalable column from a MemRef for which there is no
/// hardware pperation that we could use to a loop over the rows to read and
/// loads one element at a time.
///
/// BEFORE:
/// ```
/// %res = vector.transfer_read %mem[%a, %b] (...)
/// : memref<?x?xf32>, vector<[4]x1xf32>
/// ```
///
/// AFTER:
/// ```
/// %cst = arith.constant (...) : vector<[4]xf32>
/// %vscale = vector.vscale
/// %c4_vscale = arith.muli %vscale, %c4 : index
/// %scf = scf.for %lb = %c0 to %c4_vscale step %c1 iter_args(%arg4 = %cst)
/// -> (vector<[4]xf32>) {
///
/// %load = memref.load %mem[%arg3 + %a, %b] : memref<?x?xf32>
/// %vec = vector.insert %load, %cst [%arg3] : f32 into vector<[4]xf32>
/// scf.yield %vec : vector<[4]xf32>
/// }
/// %res = vector.shape_cast %scf : vector<[4]xf32> to vector<[4]x1xf32>
/// ```
///
/// TODO: This transformation isn't specific to SME - move it to the SVE
/// dialect.
/// TODO: Check the in_bounds attribute and generate vector.maskedload if
/// required.
struct LowerColumnTransferReadToLoops
: public OpRewritePattern<vector::TransferReadOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
PatternRewriter &rewriter) const override {
// NOTE: This is a fairly low-level transformation, so we shouldn't be
// adding support for Tensors without good rationale.
if (readOp.hasPureTensorSemantics())
return rewriter.notifyMatchFailure(
readOp, "Tensor semantics are unsupported (either bufferize or "
"extend this pattern)");

auto resType = readOp.getVectorType();

if (resType.getRank() != 2)
return rewriter.notifyMatchFailure(readOp,
"Only 2D vectors are supported!");

if (resType.getShape()[1] != 1)
return rewriter.notifyMatchFailure(
readOp, "The trailing output dim is != 1 (not supported ATM)");

if (!resType.getScalableDims()[0] || resType.getScalableDims()[1])
return rewriter.notifyMatchFailure(
readOp, "Expected the leading dim to be scalable and the trailing "
"dim to be fixed.");

// Create new result type - similar to the original vector with the
// trailing unit dim collapsed.
int64_t numRows = resType.getShape()[0];
VectorType newResType = VectorType::get(numRows, resType.getElementType(),
/*scalableDims=*/{true});

// Create a loop over all rows and load one element at a time.
auto loc = readOp.getLoc();
auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto createVscaleMultiple =
vector::makeVscaleConstantBuilder(rewriter, loc);
auto upperBound = createVscaleMultiple(numRows);
auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
Value init = rewriter.create<arith::ConstantOp>(
loc, newResType, DenseElementsAttr::get(newResType, 0.0f));

scf::ForOp loadLoop;
{
OpBuilder::InsertionGuard g(rewriter);
loadLoop = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step,
ValueRange{init});
rewriter.setInsertionPointToStart(loadLoop.getBody());

auto tileSliceIndex = loadLoop.getInductionVar();

auto idx0 = rewriter.create<arith::AddIOp>(loc, tileSliceIndex,
readOp.getIndices()[0]);
auto idx1 = readOp.getIndices()[1];

Value scalar = rewriter.create<memref::LoadOp>(
loc, readOp.getBase(), SmallVector<Value>({idx0, idx1}));

Operation *updateInit = rewriter.create<vector::InsertOp>(
loc, scalar, loadLoop.getRegionIterArg(0), tileSliceIndex);

rewriter.create<scf::YieldOp>(loc, updateInit->getResult(0));
}

// The read operation has been "legalized", but since the original result
// type was a 2D vector, we need to cast before returning the result. This
// ShapeCast should cancel-out with some other ShapeCast (i.e. it's a
// no-op).
auto sc = rewriter.create<vector::ShapeCastOp>(
loc, readOp.getResult().getType(), loadLoop.getResult(0));

rewriter.replaceOp(readOp, sc);

return success();
}
};

struct VectorLegalizationPass
: public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
void runOnOperation() override {
Expand All @@ -941,10 +998,10 @@ struct VectorLegalizationPass

// Apply preprocessing patterns.
RewritePatternSet rewritePatterns(context);
rewritePatterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
LiftIllegalVectorTransposeToMemory,
ConvertIllegalShapeCastOpsToTransposes,
LowerIllegalTransposeStoreViaZA>(context);
rewritePatterns
.add<FoldExtractFromVectorOfSMELikeCreateMasks,
LowerColumnTransferReadToLoops, LiftIllegalVectorTransposeToMemory,
LowerIllegalTransposeStoreViaZA>(context);
if (failed(
applyPatternsGreedily(getOperation(), std::move(rewritePatterns))))
return signalPassFailure();
Expand Down
13 changes: 1 addition & 12 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5758,18 +5758,7 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {

// shape_cast(transpose(x)) -> shape_cast(x)
if (auto transpose = getSource().getDefiningOp<TransposeOp>()) {
// This folder does
// shape_cast(transpose) -> shape_cast
// But another pattern, ConvertIllegalShapeCastOpsToTransposes, does
// shape_cast -> shape_cast(transpose)
// i.e. the complete opposite. When paired, these 2 patterns can cause
// infinite cycles in pattern rewriting.
// ConvertIllegalShapeCastOpsToTransposes only matches on scalable
// vectors, so by disabling this folder for scalable vectors the
// cycle is avoided.
// TODO: Check if ConvertIllegalShapeCastOpsToTransposes is
// still needed. If it's not, then we can fold here.
if (!transpose.getType().isScalable() && isOrderPreserving(transpose)) {
if (isOrderPreserving(transpose)) {
setOperand(transpose.getVector());
return getResult();
}
Expand Down
101 changes: 56 additions & 45 deletions mlir/test/Dialect/ArmSME/vector-legalization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -491,51 +491,6 @@ func.func @illegal_transpose_no_defining_source_op(%vec: vector<[4]x1xf32>) -> v

// -----

// CHECK-LABEL: @illegal_shape_cast_to_transpose_2d(
// CHECK-SAME: %[[VEC:.*]]: vector<[4]x1xf32>)
func.func @illegal_shape_cast_to_transpose_2d(%vec: vector<[4]x1xf32>) -> vector<1x[4]xf32> {
// CHECK: vector.transpose %[[VEC]], [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
%0 = vector.shape_cast %vec : vector<[4]x1xf32> to vector<1x[4]xf32>
return %0 : vector<1x[4]xf32>
}

// -----

// CHECK-LABEL: @illegal_shape_cast_to_transpose_1d(
// CHECK-SAME: %[[VEC:.*]]: vector<[4]x1xf32>)
func.func @illegal_shape_cast_to_transpose_1d(%vec: vector<[4]x1xf32>) -> vector<[4]xf32> {
// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[VEC]], [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
// CHECK: vector.shape_cast %[[TRANSPOSE]] : vector<1x[4]xf32> to vector<[4]xf32>
%0 = vector.shape_cast %vec : vector<[4]x1xf32> to vector<[4]xf32>
return %0 : vector<[4]xf32>
}

// -----

// CHECK-LABEL: @lift_illegal_2d_shape_cast_to_memory
func.func @lift_illegal_2d_shape_cast_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<1x[4]xf32> {
// CHECK: vector.transfer_read {{.*}} : memref<?x?xf32, {{.*}}>, vector<1x[4]xf32>
// CHECK-NOT: vector.shape_cast
%pad = arith.constant 0.0 : f32
%illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32>
%cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<1x[4]xf32>
return %cast : vector<1x[4]xf32>
Comment on lines -518 to -522
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what you've tested, but to know if this rewrite is still needed or not this test case should still be possible to lower to LLVM.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Ben!

I'm not sure what you've tested

I used our e2e tests - from what I can tell, we don't generate such code anymore.

this test case should still be possible to lower to LLVM

Indeed. @momchil-velikov , since you are working on a generic pattern for "xfer_read with non-trailing scalable dims", could you make sure that this example lowers with your patch?

  %illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32>

I will wait for Momchil to upload his patch before progressing this one.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can see 2 PRs of @momchil-velikov in llvm-project that might be related, but just checking in that this PR is still on the radar.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's on the radar and hasn't slipped through the cracks.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

%illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32>

I just tried it and it does not lower. And, AFAICT, it shouldn't, as the last dimension of the memerf (?) and the vector (1) do not match and the read cannot be inferred to be contiguous, e.g. if we're reading from a memref with dynamic dimensions 4 and 2:

[*][]
[*][]
[*][]
[*][]

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one

%illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}
  : memref<?x2xf32>, vector<[4]x2xf32>

is lowered, though (with an implication that %b is zero, along the way).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just tried it and it does not lower. And, AFAICT, it shouldn't, as the last dimension of the memerf

This lowering does not depend on the memref/transpose being contiguous. It lowers the transfer_read to a memref.transpose + transfer_read, which lowers to a loop in the case of a non-contiguous read (such as in this test case): https://godbolt.org/z/b96E7aYq4

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I've not had a chance to return to this yet. It's one of 3 things that are "next" on my list :)

}

// -----

// CHECK-LABEL: @lift_illegal_1d_shape_cast_to_memory
func.func @lift_illegal_1d_shape_cast_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<[4]xf32> {
// CHECK: vector.transfer_read {{.*}} : memref<?x?xf32, {{.*}}>, vector<1x[4]xf32>
// CHECK-NOT: vector.shape_cast {{.*}} : vector<[4]x1xf32> to vector<[4]xf32>
%pad = arith.constant 0.0 : f32
%illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32>
%cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<[4]xf32>
return %cast : vector<[4]xf32>
}

// -----

// CHECK-LABEL: @multi_tile_splat
func.func @multi_tile_splat() -> vector<[8]x[8]xi32>
{
Expand Down Expand Up @@ -656,3 +611,59 @@ func.func @vector_mask_without_maskable_op(%mask: vector<16x2xi1>, %vec: vector<
%0 = vector.mask %mask { vector.yield %vec : vector<16x16xf32> } : vector<16x2xi1> -> vector<16x16xf32>
return %0 : vector<16x16xf32>
}

// -----

//=============================================================================
// 1D examples - to be moved to the SVE dialect
//=============================================================================

/// TODO: Handle in_bounds

// CHECK-LABEL: func.func @xfer_read_scalable_column(
// CHECK-SAME: %[[IDX_0:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[IDX_1:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[PAD:.*]]: f32,
// CHECK-SAME: %[[SRC:.*]]: memref<?x?xf32>) -> vector<[4]x1xf32> {
func.func @xfer_read_scalable_column(%a: index, %b: index, %pad: f32, %src: memref<?x?xf32>) -> (vector<[4]x1xf32>) {
// CHECK: %[[INIT:.*]] = arith.constant dense<0.000000e+00> : vector<[4]xf32>
// CHECK: %[[STEP:.*]] = arith.constant 1 : index
// CHECK: %[[C4:.*]] = arith.constant 4 : index
// CHECK: %[[LB:.*]] = arith.constant 0 : index
// CHECK: %[[VSCALE:.*]] = vector.vscale
// CHECK: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index

// <scf.for>
// CHECK: %[[SCF:.*]] = scf.for %[[IND_VAR:.*]] = %[[LB]] to %[[C4_VSCALE]] step %[[STEP]] iter_args(%[[SCF_RES:.*]] = %[[INIT]]) -> (vector<[4]xf32>) {
// CHECK: %[[IDX_0_UPDATED:.*]] = arith.addi %[[IND_VAR]], %[[IDX_0]] : index
// CHECK: %[[VAL_10:.*]] = memref.load %[[SRC]][%[[IDX_0_UPDATED]], %[[IDX_1]]] : memref<?x?xf32>
// CHECK: %[[RES_UPDATED:.*]] = vector.insert %[[VAL_10]], %[[SCF_RES]] [%[[IND_VAR]]] : f32 into vector<[4]xf32>
// CHECK: scf.yield %[[RES_UPDATED]] : vector<[4]xf32>
// CHECK: }

// <shape-cast>
// CHECK: %[[SC:.*]] = vector.shape_cast %[[SCF]] : vector<[4]xf32> to vector<[4]x1xf32>
// CHECK: return %[[SC]]
%read = vector.transfer_read %src[%a, %b], %pad : memref<?x?xf32>, vector<[4]x1xf32>
return %read : vector<[4]x1xf32>
}

// -----

// CHECK-LABEL: func.func @negative_xfer_read_scalable_column_x2
func.func @negative_xfer_read_scalable_column_x2(%a: index, %b: index, %pad: f32, %src: memref<?x?xf32>) -> (vector<[4]x2xf32>) {
// CHECK-NOT: scf.for
// CHECK-NOT: memref.load
%read = vector.transfer_read %src[%a, %b], %pad : memref<?x?xf32>, vector<[4]x2xf32>
return %read : vector<[4]x2xf32>
}

// -----

// CHECK-LABEL: func.func @negative_xfer_read_scalable_column_scalable_trailing_dim
func.func @negative_xfer_read_scalable_column_scalable_trailing_dim(%a: index, %b: index, %pad: f32, %src: memref<?x?xf32>) -> (vector<4x[1]xf32>) {
// CHECK-NOT: scf.for
// CHECK-NOT: memref.load
%read = vector.transfer_read %src[%a, %b], %pad : memref<?x?xf32>, vector<4x[1]xf32>
return %read : vector<4x[1]xf32>
}
Loading
Loading