Skip to content

[mlir][vector] Use DenseI64ArrayAttr for shuffle masks #101163

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ def Vector_ShuffleOp :
TCresVTEtIsSameAsOpBase<0, 1>>,
InferTypeOpAdaptor]>,
Arguments<(ins AnyFixedVector:$v1, AnyFixedVector:$v2,
I64ArrayAttr:$mask)>,
DenseI64ArrayAttr:$mask)>,
Results<(outs AnyVector:$vector)> {
let summary = "shuffle operation";
let description = [{
Expand Down Expand Up @@ -459,11 +459,7 @@ def Vector_ShuffleOp :
: vector<f32>, vector<f32> ; yields vector<2xf32>
```
}];
let builders = [
OpBuilder<(ins "Value":$v1, "Value":$v2, "ArrayRef<int64_t>")>
];
let hasFolder = 1;
let hasCanonicalizer = 1;

let extraClassDeclaration = [{
VectorType getV1VectorType() {
return ::llvm::cast<VectorType>(getV1().getType());
Expand All @@ -475,7 +471,10 @@ def Vector_ShuffleOp :
return ::llvm::cast<VectorType>(getVector().getType());
}
}];

let assemblyFormat = "operands $mask attr-dict `:` type(operands)";

let hasFolder = 1;
let hasVerifier = 1;
let hasCanonicalizer = 1;
}
Expand Down
7 changes: 3 additions & 4 deletions mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -994,7 +994,7 @@ class VectorShuffleOpConversion
auto v2Type = shuffleOp.getV2VectorType();
auto vectorType = shuffleOp.getResultVectorType();
Type llvmType = typeConverter->convertType(vectorType);
auto maskArrayAttr = shuffleOp.getMask();
ArrayRef<int64_t> mask = shuffleOp.getMask();

// Bail if result type cannot be lowered.
if (!llvmType)
Expand All @@ -1015,7 +1015,7 @@ class VectorShuffleOpConversion
if (rank <= 1 && v1Type == v2Type) {
Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>(
loc, adaptor.getV1(), adaptor.getV2(),
LLVM::convertArrayToIndices<int32_t>(maskArrayAttr));
llvm::to_vector_of<int32_t>(mask));
rewriter.replaceOp(shuffleOp, llvmShuffleOp);
return success();
}
Expand All @@ -1029,8 +1029,7 @@ class VectorShuffleOpConversion
eltType = cast<VectorType>(llvmType).getElementType();
Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
int64_t insPos = 0;
for (const auto &en : llvm::enumerate(maskArrayAttr)) {
int64_t extPos = cast<IntegerAttr>(en.value()).getInt();
for (int64_t extPos : mask) {
Value value = adaptor.getV1();
if (extPos >= v1Dim) {
extPos -= v1Dim;
Expand Down
5 changes: 1 addition & 4 deletions mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -527,10 +527,7 @@ struct VectorShuffleOpConvert final
return rewriter.notifyMatchFailure(shuffleOp,
"unsupported result vector type");

SmallVector<int32_t, 4> mask = llvm::map_to_vector<4>(
shuffleOp.getMask(), [](Attribute attr) -> int32_t {
return cast<IntegerAttr>(attr).getValue().getZExtValue();
});
auto mask = llvm::to_vector_of<int32_t>(shuffleOp.getMask());

VectorType oldV1Type = shuffleOp.getV1VectorType();
VectorType oldV2Type = shuffleOp.getV2VectorType();
Expand Down
42 changes: 17 additions & 25 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2464,11 +2464,6 @@ void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
// ShuffleOp
//===----------------------------------------------------------------------===//

void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value v1,
Value v2, ArrayRef<int64_t> mask) {
build(builder, result, v1, v2, getVectorSubscriptAttr(builder, mask));
}

LogicalResult ShuffleOp::verify() {
VectorType resultType = getResultVectorType();
VectorType v1Type = getV1VectorType();
Expand All @@ -2491,19 +2486,18 @@ LogicalResult ShuffleOp::verify() {
return emitOpError("dimension mismatch");
}
// Verify mask length.
auto maskAttr = getMask().getValue();
int64_t maskLength = maskAttr.size();
ArrayRef<int64_t> mask = getMask();
int64_t maskLength = mask.size();
if (maskLength <= 0)
return emitOpError("invalid mask length");
if (maskLength != resultType.getDimSize(0))
return emitOpError("mask length mismatch");
// Verify all indices.
int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
(v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
for (const auto &en : llvm::enumerate(maskAttr)) {
auto attr = llvm::dyn_cast<IntegerAttr>(en.value());
if (!attr || attr.getInt() < 0 || attr.getInt() >= indexSize)
return emitOpError("mask index #") << (en.index() + 1) << " out of range";
for (auto [idx, maskPos] : llvm::enumerate(mask)) {
if (maskPos < 0 || maskPos >= indexSize)
return emitOpError("mask index #") << (idx + 1) << " out of range";
}
return success();
}
Expand All @@ -2527,13 +2521,12 @@ ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
return success();
}

static bool isStepIndexArray(ArrayAttr idxArr, uint64_t begin, size_t width) {
uint64_t expected = begin;
return idxArr.size() == width &&
llvm::all_of(idxArr.getAsValueRange<IntegerAttr>(),
[&expected](auto attr) {
return attr.getZExtValue() == expected++;
});
template <typename T>
static bool isStepIndexArray(ArrayRef<T> idxArr, uint64_t begin, size_t width) {
T expected = begin;
return idxArr.size() == width && llvm::all_of(idxArr, [&expected](T value) {
return value == expected++;
});
}

OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
Expand Down Expand Up @@ -2568,8 +2561,7 @@ OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
SmallVector<Attribute> results;
auto lhsElements = llvm::cast<DenseElementsAttr>(lhs).getValues<Attribute>();
auto rhsElements = llvm::cast<DenseElementsAttr>(rhs).getValues<Attribute>();
for (const auto &index : this->getMask().getAsValueRange<IntegerAttr>()) {
int64_t i = index.getZExtValue();
for (int64_t i : this->getMask()) {
if (i >= lhsSize) {
results.push_back(rhsElements[i - lhsSize]);
} else {
Expand All @@ -2590,13 +2582,13 @@ struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
LogicalResult matchAndRewrite(ShuffleOp shuffleOp,
PatternRewriter &rewriter) const override {
VectorType v1VectorType = shuffleOp.getV1VectorType();
ArrayAttr mask = shuffleOp.getMask();
ArrayRef<int64_t> mask = shuffleOp.getMask();
if (v1VectorType.getRank() > 0)
return failure();
if (mask.size() != 1)
return failure();
VectorType resType = VectorType::Builder(v1VectorType).setShape({1});
if (llvm::cast<IntegerAttr>(mask[0]).getInt() == 0)
if (mask[0] == 0)
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
shuffleOp.getV1());
else
Expand Down Expand Up @@ -2651,11 +2643,11 @@ class ShuffleInterleave : public OpRewritePattern<ShuffleOp> {
op, "ShuffleOp types don't match an interleave");
}

ArrayAttr shuffleMask = op.getMask();
ArrayRef<int64_t> shuffleMask = op.getMask();
int64_t resultVectorSize = resultType.getNumElements();
for (int i = 0, e = resultVectorSize / 2; i < e; ++i) {
int64_t maskValueA = cast<IntegerAttr>(shuffleMask[i * 2]).getInt();
int64_t maskValueB = cast<IntegerAttr>(shuffleMask[(i * 2) + 1]).getInt();
int64_t maskValueA = shuffleMask[i * 2];
int64_t maskValueB = shuffleMask[(i * 2) + 1];
if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
return rewriter.notifyMatchFailure(op,
"ShuffleOp mask not interleaving");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,7 @@ class Convert1DExtractStridedSliceIntoShuffle
off += stride)
offsets.push_back(off);
rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.getVector(),
op.getVector(),
rewriter.getI64ArrayAttr(offsets));
op.getVector(), offsets);
return success();
}
};
Expand Down
22 changes: 8 additions & 14 deletions mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,7 @@ struct LinearizeVectorExtractStridedSlice final
}
// Perform a shuffle to extract the kD vector.
rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
extractOp, dstType, srcVector, srcVector,
rewriter.getI64ArrayAttr(indices));
extractOp, dstType, srcVector, srcVector, indices);
return success();
}

Expand Down Expand Up @@ -298,20 +297,17 @@ struct LinearizeVectorShuffle final
// that needs to be shuffled to the destination vector. If shuffleSliceLen >
// 1 we need to shuffle the slices (consecutive shuffleSliceLen number of
// elements) instead of scalars.
ArrayAttr mask = shuffleOp.getMask();
ArrayRef<int64_t> mask = shuffleOp.getMask();
int64_t totalSizeOfShuffledElmnts = mask.size() * shuffleSliceLen;
llvm::SmallVector<int64_t, 2> indices(totalSizeOfShuffledElmnts);
for (auto [i, value] :
llvm::enumerate(mask.getAsValueRange<IntegerAttr>())) {

int64_t v = value.getZExtValue();
for (auto [i, value] : llvm::enumerate(mask)) {
std::iota(indices.begin() + shuffleSliceLen * i,
indices.begin() + shuffleSliceLen * (i + 1),
shuffleSliceLen * v);
shuffleSliceLen * value);
}

rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
shuffleOp, dstType, vec1, vec2, rewriter.getI64ArrayAttr(indices));
rewriter.replaceOpWithNewOp<vector::ShuffleOp>(shuffleOp, dstType, vec1,
vec2, indices);
return success();
}

Expand Down Expand Up @@ -368,8 +364,7 @@ struct LinearizeVectorExtract final
llvm::SmallVector<int64_t, 2> indices(size);
std::iota(indices.begin(), indices.end(), linearizedOffset);
rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
extractOp, dstTy, adaptor.getVector(), adaptor.getVector(),
rewriter.getI64ArrayAttr(indices));
extractOp, dstTy, adaptor.getVector(), adaptor.getVector(), indices);

return success();
}
Expand Down Expand Up @@ -452,8 +447,7 @@ struct LinearizeVectorInsert final
// [offset+srcNumElements, end)

rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
insertOp, dstTy, adaptor.getDest(), adaptor.getSource(),
rewriter.getI64ArrayAttr(indices));
insertOp, dstTy, adaptor.getDest(), adaptor.getSource(), indices);

return success();
}
Expand Down
Loading