Skip to content

Commit 53d435d

Browse files
committed
[mlir][linalg] Simplify createWriteOrMaskedWrite (NFC)
This patch removes `inputVecSizesForLeadingDims` from the parameter list of `createWriteOrMaskedWrite`. That argument is unnecessary — vector sizes can be obtained from the `vecToStore` parameter. Since this doesn't change behavior or test results, it's marked as NFC. Additional cleanups: * Renamed `vectorToStore` to `vecToStore` for consistency and brevity. * Rewrote a conditional at the end of the function to use early exit, improving readability: ```cpp // BEFORE: if (maskingRequried) { Value maskForWrite = ...; write = maskOperation(write, maskForWrite); } return write; // AFTER if (!maskingRequried) return write; Value maskFroWrite = ...; return vector::maskOperation(builder, write, maskForWrite); ``` This change addresses a TODO from #141244.
1 parent f0922e9 commit 53d435d

File tree

1 file changed

+40
-78
lines changed

1 file changed

+40
-78
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 40 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1590,61 +1590,46 @@ static bool isMaskTriviallyFoldable(SmallVector<OpFoldResult> &maskSizes,
15901590
/// Creates an optionally masked TransferWriteOp
15911591
///
15921592
/// Generates the following operation:
1593-
/// %res = vector.transfer_write %vectorToStore into %dest
1593+
/// %res = vector.transfer_write %vecToStore into %dest
15941594
///
1595-
/// If the leading N dimensions of the vector to store do not match
1596-
/// `inputVecSizesForLeadingDims` (N = rank(inputVecSizesForLeadingDims)),
1597-
/// masking is applied to ensure correctness:
1595+
/// If shape(vecToStore) != shape(dest), masking is used to ensure correctness:
15981596
///
1599-
/// %mask = vector.create_mask(%destShape) : %vectorToStoreShape
1597+
/// %mask = vector.create_mask(%destShape) : %vecToStoreShape
16001598
/// %res = vector.mask %mask {
1601-
/// vector.transfer_write %vectorToStore into %dest
1599+
/// vector.transfer_write %vecToStore into %dest
16021600
/// }
16031601
///
1604-
/// The mask shape is identical to `vectorToStore` (with the element type ==
1602+
/// The mask shape is identical to `vecToStore` (with the element type ==
16051603
/// i1), and the mask values are based on the shape of the `dest` tensor.
16061604
///
16071605
/// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
16081606
/// is used instead of masking:
16091607
///
1610-
/// %write = vector.transfer_write %vectorToStore into %dest
1608+
/// %write = vector.transfer_write %vecToStore into %dest
16111609
/// in_bounds_flags = (...)
16121610
/// %res = vector.transfer_write %input into %dest
16131611
/// {in_bounds = in_bounds_flags}
16141612
///
1615-
/// `writeIndices` specifies the offsets to use. If empty, all indices are set
1616-
/// to 0.
1617-
///
1618-
/// NOTE: When N < rank(vectorToStore), the missing vector sizes are taken from
1619-
/// `valueToStore`.
1620-
/// TODO: `inputVecSizesForLeadingDims` should not be required - these sizes are
1621-
/// already provided in `vectorToStore`.
1613+
/// Finally, `writeIndices` specifies the offsets to use. If empty, all indices
1614+
/// are set to 0.
16221615
static Operation *
1623-
createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
1624-
Value dest,
1625-
ArrayRef<int64_t> inputVecSizesForLeadingDims,
1626-
SmallVector<Value> writeIndices = {},
1616+
createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
1617+
Value dest, SmallVector<Value> writeIndices = {},
16271618
bool useInBoundsInsteadOfMasking = false) {
16281619

16291620
ShapedType destType = cast<ShapedType>(dest.getType());
16301621
int64_t destRank = destType.getRank();
16311622
auto destShape = destType.getShape();
16321623

1633-
VectorType vecToStoreType = cast<VectorType>(vectorToStore.getType());
1624+
VectorType vecToStoreType = cast<VectorType>(vecToStore.getType());
16341625
int64_t vecToStoreRank = vecToStoreType.getRank();
16351626
auto vecToStoreShape = vecToStoreType.getShape();
16361627

16371628
// Compute the in_bounds attribute
16381629
SmallVector<bool> inBoundsVal(vecToStoreRank, true);
16391630
if (useInBoundsInsteadOfMasking) {
1640-
// In this case, assume that all the required vector sizes have been
1641-
// provided.
1642-
assert(inputVecSizesForLeadingDims.size() ==
1643-
static_cast<size_t>(vecToStoreType.getRank()) &&
1644-
"Insufficient number of input vector sizes!");
1645-
// Update the inBounds attribute.
16461631
for (unsigned i = 0; i < destRank; i++)
1647-
inBoundsVal[i] = (destShape[i] == inputVecSizesForLeadingDims[i]) &&
1632+
inBoundsVal[i] = (destShape[i] == vecToStoreShape[i]) &&
16481633
!ShapedType::isDynamic(destShape[i]);
16491634
}
16501635

@@ -1660,7 +1645,7 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
16601645
// Generate the xfer_write Op
16611646
Operation *write =
16621647
builder.create<vector::TransferWriteOp>(loc,
1663-
/*vector=*/vectorToStore,
1648+
/*vector=*/vecToStore,
16641649
/*source=*/dest,
16651650
/*indices=*/writeIndices,
16661651
/*inBounds=*/inBoundsVal);
@@ -1669,46 +1654,25 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
16691654
if (useInBoundsInsteadOfMasking)
16701655
return write;
16711656

1672-
assert(llvm::none_of(
1673-
destShape.drop_front(inputVecSizesForLeadingDims.size()),
1674-
[](int64_t size) { return size == ShapedType::kDynamic; }) &&
1675-
"Only dims aligned with inputVecSizesForLeadingDims may be dynamic");
1676-
1677-
// Check if masking is needed.
1678-
bool needMaskForWrite =
1679-
!llvm::equal(inputVecSizesForLeadingDims,
1680-
destShape.take_front(destRank - vecToStoreRank +
1681-
inputVecSizesForLeadingDims.size()));
1682-
1683-
// If masking is needed, generate the mask and mask the operation.
1684-
if (needMaskForWrite) {
1685-
// Get the mask shape + type. Missing mask dimensions are taken from
1686-
// `vectorToStore`.
1687-
SmallVector<int64_t> writeMaskShape;
1688-
writeMaskShape.append(inputVecSizesForLeadingDims.begin(),
1689-
inputVecSizesForLeadingDims.end());
1690-
if (vecToStoreRank >
1691-
static_cast<int64_t>(inputVecSizesForLeadingDims.size()))
1692-
writeMaskShape.append(vecToStoreShape.begin() +
1693-
inputVecSizesForLeadingDims.size(),
1694-
vecToStoreShape.end());
1695-
auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
1696-
1697-
SmallVector<OpFoldResult> destSizes =
1698-
tensor::getMixedSizes(builder, loc, dest);
1699-
SmallVector<OpFoldResult> maskSizes(destSizes.end() - writeMaskShape.size(),
1700-
destSizes.end());
1701-
1702-
if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
1703-
writeMaskShape))
1704-
return write;
1705-
1706-
Value maskForWrite = builder.createOrFold<vector::CreateMaskOp>(
1707-
loc, writeMaskType, maskSizes);
1708-
write = mlir::vector::maskOperation(builder, write, maskForWrite);
1709-
}
1657+
// Check if masking is needed. If not, exit.
1658+
if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank)))
1659+
return write;
1660+
1661+
// Compute the mask and mask the write Op.
1662+
auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type());
1663+
1664+
SmallVector<OpFoldResult> destSizes =
1665+
tensor::getMixedSizes(builder, loc, dest);
1666+
SmallVector<OpFoldResult> maskSizes(destSizes.end() - vecToStoreRank,
1667+
destSizes.end());
1668+
1669+
if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
1670+
vecToStoreShape))
1671+
return write;
17101672

1711-
return write;
1673+
Value maskForWrite =
1674+
builder.createOrFold<vector::CreateMaskOp>(loc, writeMaskType, maskSizes);
1675+
return mlir::vector::maskOperation(builder, write, maskForWrite);
17121676
}
17131677

17141678
/// Vectorize linalg::PackOp with (1) static inner_tiles (2) constant
@@ -1808,10 +1772,10 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
18081772
Value dest = rewriter.create<tensor::EmptyOp>(
18091773
loc, reifiedReturnShapes[0],
18101774
transposeOp.getResult().getType().getElementType());
1811-
Operation *write = createWriteOrMaskedWrite(
1812-
rewriter, loc, transposeOp.getResult(), dest,
1813-
/*inputVecSizesForLeadingDims=*/inputVectorSizes, /*writeIndices=*/{},
1814-
/*useInBoundsInsteadOfMasking=*/false);
1775+
Operation *write =
1776+
createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(), dest,
1777+
/*writeIndices=*/{},
1778+
/*useInBoundsInsteadOfMasking=*/false);
18151779
newResults.push_back(write->getResult(0));
18161780
return success();
18171781
}
@@ -1949,7 +1913,6 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19491913
shapeCastOp.getResult().getType().getElementType());
19501914
Operation *write = createWriteOrMaskedWrite(
19511915
rewriter, loc, shapeCastOp.getResult(), dest,
1952-
/*inputVecSizesForLeadingDims=*/writeVectorSizes,
19531916
/*writeIndices=*/{}, useInBoundsInsteadOfMasking);
19541917
newResults.push_back(write->getResult(0));
19551918
return success();
@@ -1982,10 +1945,9 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
19821945
// Create Xfer write Op
19831946
Value dest = rewriter.create<tensor::EmptyOp>(
19841947
loc, reifiedReturnShapes[0], padOp.getResultType().getElementType());
1985-
Operation *write = createWriteOrMaskedWrite(
1986-
rewriter, loc, maskedRead, dest,
1987-
/*inputVecSizesForLeadingDims=*/inputVectorSizes, {},
1988-
/*useInBoundsInsteadOfMasking=*/false);
1948+
Operation *write =
1949+
createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest, {},
1950+
/*useInBoundsInsteadOfMasking=*/false);
19891951
newResults.push_back(write->getResult(0));
19901952
return success();
19911953
}
@@ -3041,8 +3003,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
30413003
// Create write
30423004
auto writeIndices =
30433005
getValueOrCreateConstantIndexOp(rewriter, loc, sliceOp.getMixedOffsets());
3044-
Operation *write = createWriteOrMaskedWrite(
3045-
rewriter, loc, read, sliceOp.getDest(), vecType.getShape(), writeIndices);
3006+
Operation *write = createWriteOrMaskedWrite(rewriter, loc, read,
3007+
sliceOp.getDest(), writeIndices);
30463008

30473009
// 4. Finalize
30483010
newResults.push_back(write->getResult(0));

0 commit comments

Comments
 (0)