Skip to content

[mlir][spirv] Add canon patterns for IAddCarry/[S|U]MulExtended #73340

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,8 @@ def SPIRV_IAddCarryOp : SPIRV_ArithmeticExtendedBinaryOp<"IAddCarry",
%2 = spirv.IAddCarry %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
```
}];

let hasCanonicalizer = 1;
}

// -----
Expand Down Expand Up @@ -607,6 +609,8 @@ def SPIRV_SMulExtendedOp : SPIRV_ArithmeticExtendedBinaryOp<"SMulExtended",
%2 = spirv.SMulExtended %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
```
}];

let hasCanonicalizer = 1;
}

// -----
Expand Down Expand Up @@ -742,6 +746,8 @@ def SPIRV_UMulExtendedOp : SPIRV_ArithmeticExtendedBinaryOp<"UMulExtended",
%2 = spirv.UMulExtended %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
```
}];

let hasCanonicalizer = 1;
}

// -----
Expand Down
194 changes: 194 additions & 0 deletions mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,200 @@ void spirv::AccessChainOp::getCanonicalizationPatterns(
results.add<CombineChainedAccessChain>(context);
}

//===----------------------------------------------------------------------===//
// spirv.IAddCarry
//===----------------------------------------------------------------------===//

// We are required to use CompositeConstructOp to create a constant struct as
// they are not yet implemented as constant, hence we can not do so in a fold.
struct IAddCarryFold final : OpRewritePattern<spirv::IAddCarryOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(spirv::IAddCarryOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value lhs = op.getOperand1();
Value rhs = op.getOperand2();
Type constituentType = lhs.getType();

// iaddcarry (x, 0) = <0, x>
if (matchPattern(rhs, m_Zero())) {
Value constituents[2] = {rhs, lhs};
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
constituents);
return success();
}

// According to the SPIR-V spec:
//
// Result Type must be from OpTypeStruct. The struct must have two
// members...
//
// Member 0 of the result gets the low-order bits (full component width) of
// the addition.
//
// Member 1 of the result gets the high-order (carry) bit of the result of
// the addition. That is, it gets the value 1 if the addition overflowed
// the component width, and 0 otherwise.
Attribute lhsAttr;
Attribute rhsAttr;
if (!matchPattern(lhs, m_Constant(&lhsAttr)) ||
!matchPattern(rhs, m_Constant(&rhsAttr)))
return failure();

auto adds = constFoldBinaryOp<IntegerAttr>(
{lhsAttr, rhsAttr},
[](const APInt &a, const APInt &b) { return a + b; });
if (!adds)
return failure();

auto carrys = constFoldBinaryOp<IntegerAttr>(
ArrayRef{adds, lhsAttr}, [](const APInt &a, const APInt &b) {
APInt zero = APInt::getZero(a.getBitWidth());
return a.ult(b) ? (zero + 1) : zero;
});

if (!carrys)
return failure();

Value addsVal =
rewriter.create<spirv::ConstantOp>(loc, constituentType, adds);

Value carrysVal =
rewriter.create<spirv::ConstantOp>(loc, constituentType, carrys);

// Create empty struct
Value undef = rewriter.create<spirv::UndefOp>(loc, op.getType());
// Fill in adds at id 0
Value intermediate =
rewriter.create<spirv::CompositeInsertOp>(loc, addsVal, undef, 0);
// Fill in carrys at id 1
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(op, carrysVal,
intermediate, 1);
return success();
}
};

void spirv::IAddCarryOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add<IAddCarryFold>(context);
}

//===----------------------------------------------------------------------===//
// spirv.[S|U]MulExtended
//===----------------------------------------------------------------------===//

// We are required to use CompositeConstructOp to create a constant struct as
// they are not yet implemented as constant, hence we can not do so in a fold.
template <typename MulOp, bool IsSigned>
struct MulExtendedFold final : OpRewritePattern<MulOp> {
using OpRewritePattern<MulOp>::OpRewritePattern;

LogicalResult matchAndRewrite(MulOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value lhs = op.getOperand1();
Value rhs = op.getOperand2();
Type constituentType = lhs.getType();

// [su]mulextended (x, 0) = <0, 0>
if (matchPattern(rhs, m_Zero())) {
Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
Value constituents[2] = {zero, zero};
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
constituents);
return success();
}

// According to the SPIR-V spec:
//
// Result Type must be from OpTypeStruct. The struct must have two
// members...
//
// Member 0 of the result gets the low-order bits of the multiplication.
//
// Member 1 of the result gets the high-order bits of the multiplication.
Attribute lhsAttr;
Attribute rhsAttr;
if (!matchPattern(lhs, m_Constant(&lhsAttr)) ||
!matchPattern(rhs, m_Constant(&rhsAttr)))
return failure();

auto lowBits = constFoldBinaryOp<IntegerAttr>(
{lhsAttr, rhsAttr},
[](const APInt &a, const APInt &b) { return a * b; });

if (!lowBits)
return failure();

auto highBits = constFoldBinaryOp<IntegerAttr>(
{lhsAttr, rhsAttr}, [](const APInt &a, const APInt &b) {
unsigned bitWidth = a.getBitWidth();
APInt c;
if (IsSigned) {
c = a.sext(bitWidth * 2) * b.sext(bitWidth * 2);
} else {
c = a.zext(bitWidth * 2) * b.zext(bitWidth * 2);
}
return c.extractBits(bitWidth, bitWidth); // Extract high result
});

if (!highBits)
return failure();

Value lowBitsVal =
rewriter.create<spirv::ConstantOp>(loc, constituentType, lowBits);

Value highBitsVal =
rewriter.create<spirv::ConstantOp>(loc, constituentType, highBits);

// Create empty struct
Value undef = rewriter.create<spirv::UndefOp>(loc, op.getType());
// Fill in lowBits at id 0
Value intermediate =
rewriter.create<spirv::CompositeInsertOp>(loc, lowBitsVal, undef, 0);
// Fill in highBits at id 1
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(op, highBitsVal,
intermediate, 1);
return success();
}
};

using SMulExtendedOpFold = MulExtendedFold<spirv::SMulExtendedOp, true>;
void spirv::SMulExtendedOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add<SMulExtendedOpFold>(context);
}

struct UMulExtendedOpXOne final : OpRewritePattern<spirv::UMulExtendedOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(spirv::UMulExtendedOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value lhs = op.getOperand1();
Value rhs = op.getOperand2();
Type constituentType = lhs.getType();

// umulextended (x, 1) = <x, 0>
if (matchPattern(rhs, m_One())) {
Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
Value constituents[2] = {lhs, zero};
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
constituents);
return success();
}

return failure();
}
};

using UMulExtendedOpFold = MulExtendedFold<spirv::UMulExtendedOp, false>;
void spirv::UMulExtendedOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add<UMulExtendedOpFold, UMulExtendedOpXOne>(context);
}

//===----------------------------------------------------------------------===//
// spirv.UMod
//===----------------------------------------------------------------------===//
Expand Down
Loading