Skip to content

Commit 3f36d2d

Browse files
committed
[mlir][arith] Simplify muli emulation with mului_extended
Using `arith.mului_extended` makes it much simpler to emulate wide integer multiplication. Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D139776
1 parent 4446f71 commit 3f36d2d

File tree

3 files changed

+35
-102
lines changed

3 files changed

+35
-102
lines changed

mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp

Lines changed: 12 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -419,78 +419,26 @@ struct ConvertMulI final : OpConversionPattern<arith::MulIOp> {
419419
return rewriter.notifyMatchFailure(
420420
loc, llvm::formatv("unsupported type: {0}", op.getType()));
421421

422-
Type newElemTy = reduceInnermostDim(newTy);
423-
unsigned newBitWidth = newTy.getElementTypeBitWidth();
424-
unsigned digitBitWidth = newBitWidth / 2;
425-
426422
auto [lhsElem0, lhsElem1] =
427423
extractLastDimHalves(rewriter, loc, adaptor.getLhs());
428424
auto [rhsElem0, rhsElem1] =
429425
extractLastDimHalves(rewriter, loc, adaptor.getRhs());
430426

431-
// Emulate multiplication by splitting each input element of type i2N into 4
432-
// digits of type iN and bit width i(N/2). This is so that the intermediate
433-
// multiplications and additions do not overflow. We extract these i(N/2)
434-
// digits from iN vector elements by masking (low digit) and shifting right
435-
// (high digit).
436-
//
437427
// The multiplication algorithm used is the standard (long) multiplication.
438-
// Multiplying two i2N integers produces (at most) a i4N result, but because
439-
// the calculation of top i2N is not necessary, we omit it.
440-
// In total, this implementations performs 10 intermediate multiplications
441-
// and 16 additions. The number of multiplications could be decreased by
442-
// switching to a more efficient algorithm like Karatsuba. This would,
443-
// however, require being able to perform (intermediate) wide additions and
444-
// subtractions, so it is not clear that such implementation would be more
445-
// efficient.
446-
447-
APInt lowMaskVal(newBitWidth, 1);
448-
lowMaskVal = lowMaskVal.shl(digitBitWidth) - 1;
449-
Value lowMask =
450-
createScalarOrSplatConstant(rewriter, loc, newElemTy, lowMaskVal);
451-
auto getLowDigit = [lowMask, newElemTy, loc, &rewriter](Value v) {
452-
return rewriter.create<arith::AndIOp>(loc, newElemTy, v, lowMask);
453-
};
428+
// Multiplying two i2N integers produces (at most) an i4N result, but
429+
// because the calculation of top i2N is not necessary, we omit it.
430+
auto mulLowLow =
431+
rewriter.create<arith::MulUIExtendedOp>(loc, lhsElem0, rhsElem0);
432+
Value mulLowHi = rewriter.create<arith::MulIOp>(loc, lhsElem0, rhsElem1);
433+
Value mulHiLow = rewriter.create<arith::MulIOp>(loc, lhsElem1, rhsElem0);
434+
435+
Value resLow = mulLowLow.getLow();
436+
Value resHi =
437+
rewriter.create<arith::AddIOp>(loc, mulLowLow.getHigh(), mulLowHi);
438+
resHi = rewriter.create<arith::AddIOp>(loc, resHi, mulHiLow);
454439

455-
Value shiftVal =
456-
createScalarOrSplatConstant(rewriter, loc, newElemTy, digitBitWidth);
457-
auto getHighDigit = [shiftVal, loc, &rewriter](Value v) {
458-
return rewriter.create<arith::ShRUIOp>(loc, v, shiftVal);
459-
};
460-
461-
Value zeroDigit = createScalarOrSplatConstant(rewriter, loc, newElemTy, 0);
462-
std::array<Value, 4> resultDigits = {zeroDigit, zeroDigit, zeroDigit,
463-
zeroDigit};
464-
std::array<Value, 4> lhsDigits = {
465-
getLowDigit(lhsElem0), getHighDigit(lhsElem0), getLowDigit(lhsElem1),
466-
getHighDigit(lhsElem1)};
467-
std::array<Value, 4> rhsDigits = {
468-
getLowDigit(rhsElem0), getHighDigit(rhsElem0), getLowDigit(rhsElem1),
469-
getHighDigit(rhsElem1)};
470-
471-
for (unsigned i = 0, e = lhsDigits.size(); i != e; ++i) {
472-
for (unsigned j = 0; i + j != e; ++j) {
473-
Value mul =
474-
rewriter.create<arith::MulIOp>(loc, lhsDigits[i], rhsDigits[j]);
475-
Value current =
476-
rewriter.createOrFold<arith::AddIOp>(loc, resultDigits[i + j], mul);
477-
resultDigits[i + j] = getLowDigit(current);
478-
if (i + j + 1 != e) {
479-
Value carry = rewriter.createOrFold<arith::AddIOp>(
480-
loc, resultDigits[i + j + 1], getHighDigit(current));
481-
resultDigits[i + j + 1] = carry;
482-
}
483-
}
484-
}
485-
486-
auto combineDigits = [shiftVal, loc, &rewriter](Value low, Value high) {
487-
Value highBits = rewriter.create<arith::ShLIOp>(loc, high, shiftVal);
488-
return rewriter.create<arith::OrIOp>(loc, low, highBits);
489-
};
490-
Value resultElem0 = combineDigits(resultDigits[0], resultDigits[1]);
491-
Value resultElem1 = combineDigits(resultDigits[2], resultDigits[3]);
492440
Value resultVec =
493-
constructResultVector(rewriter, loc, newTy, {resultElem0, resultElem1});
441+
constructResultVector(rewriter, loc, newTy, {resLow, resHi});
494442
rewriter.replaceOp(op, resultVec);
495443
return success();
496444
}

mlir/test/Dialect/Arith/emulate-wide-int-very-wide.mlir

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,12 @@
99
// CHECK-NEXT: [[LOW1:%.+]] = vector.extract [[ARG1]][0] : vector<2xi512>
1010
// CHECK-NEXT: [[HIGH1:%.+]] = vector.extract [[ARG1]][1] : vector<2xi512>
1111
//
12-
// Check that the mask for the low 256-bits was generated correctly. The exact expected value is 2^256 - 1.
13-
// CHECK-NEXT: {{.+}} = arith.constant 115792089237316195423570985008687907853269984665640564039457584007913129639935 : i512
12+
// CHECK-DAG: arith.mului_extended
13+
// CHECK-DAG: arith.muli
14+
// CHECK-DAG: arith.muli
15+
// CHECK-NEXT: arith.addi
16+
// CHECK-NEXT: arith.addi
17+
//
1418
// CHECK: return {{%.+}} : vector<2xi512>
1519
func.func @muli_scalar(%a : i1024, %b : i1024) -> i1024 {
1620
%m = arith.muli %a, %b : i1024

mlir/test/Dialect/Arith/emulate-wide-int.mlir

Lines changed: 17 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -661,44 +661,20 @@ func.func @select_vector_elementwise(%a : vector<3xi64>, %b : vector<3xi64>, %c
661661

662662
// CHECK-LABEL: func.func @muli_scalar
663663
// CHECK-SAME: ([[ARG0:%.+]]: vector<2xi32>, [[ARG1:%.+]]: vector<2xi32>) -> vector<2xi32>
664-
// CHECK-NEXT: [[LOW0:%.+]] = vector.extract [[ARG0]][0] : vector<2xi32>
665-
// CHECK-NEXT: [[HIGH0:%.+]] = vector.extract [[ARG0]][1] : vector<2xi32>
666-
// CHECK-NEXT: [[LOW1:%.+]] = vector.extract [[ARG1]][0] : vector<2xi32>
667-
// CHECK-NEXT: [[HIGH1:%.+]] = vector.extract [[ARG1]][1] : vector<2xi32>
664+
// CHECK-NEXT: [[LOW0:%.+]] = vector.extract [[ARG0]][0] : vector<2xi32>
665+
// CHECK-NEXT: [[HIGH0:%.+]] = vector.extract [[ARG0]][1] : vector<2xi32>
666+
// CHECK-NEXT: [[LOW1:%.+]] = vector.extract [[ARG1]][0] : vector<2xi32>
667+
// CHECK-NEXT: [[HIGH1:%.+]] = vector.extract [[ARG1]][1] : vector<2xi32>
668668
//
669-
// CHECK-DAG: [[MASK:%.+]] = arith.constant 65535 : i32
670-
// CHECK-DAG: [[C16:%.+]] = arith.constant 16 : i32
669+
// CHECK-DAG: [[RESLOW:%.+]], [[HI0:%.+]] = arith.mului_extended [[LOW0]], [[LOW1]] : i32
670+
// CHECK-DAG: [[HI1:%.+]] = arith.muli [[LOW0]], [[HIGH1]] : i32
671+
// CHECK-DAG: [[HI2:%.+]] = arith.muli [[HIGH0]], [[LOW1]] : i32
672+
// CHECK-NEXT: [[RESHI1:%.+]] = arith.addi [[HI0]], [[HI1]] : i32
673+
// CHECK-NEXT: [[RESHI2:%.+]] = arith.addi [[RESHI1]], [[HI2]] : i32
671674
//
672-
// CHECK: [[LOWLOW0:%.+]] = arith.andi [[LOW0]], [[MASK]] : i32
673-
// CHECK-NEXT: [[HIGHLOW0:%.+]] = arith.shrui [[LOW0]], [[C16]] : i32
674-
// CHECK-NEXT: [[LOWHIGH0:%.+]] = arith.andi [[HIGH0]], [[MASK]] : i32
675-
// CHECK-NEXT: [[HIGHHIGH0:%.+]] = arith.shrui [[HIGH0]], [[C16]] : i32
676-
// CHECK-NEXT: [[LOWLOW1:%.+]] = arith.andi [[LOW1]], [[MASK]] : i32
677-
// CHECK-NEXT: [[HIGHLOW1:%.+]] = arith.shrui [[LOW1]], [[C16]] : i32
678-
// CHECK-NEXT: [[LOWHIGH1:%.+]] = arith.andi [[HIGH1]], [[MASK]] : i32
679-
// CHECK-NEXT: [[HIGHHIGH1:%.+]] = arith.shrui [[HIGH1]], [[C16]] : i32
680-
//
681-
// CHECK-DAG: {{%.+}} = arith.muli [[LOWLOW0]], [[LOWLOW1]] : i32
682-
// CHECK-DAG {{%.+}} = arith.muli [[LOWLOW0]], [[HIGHLOW1]] : i32
683-
// CHECK-DAG: {{%.+}} = arith.muli [[LOWLOW0]], [[LOWHIGH1]] : i32
684-
// CHECK-DAG: {{%.+}} = arith.muli [[LOWLOW0]], [[HIGHHIGH1]] : i32
685-
//
686-
// CHECK-DAG: {{%.+}} = arith.muli [[HIGHLOW0]], [[LOWLOW1]] : i32
687-
// CHECK-DAG: {{%.+}} = arith.muli [[HIGHLOW0]], [[HIGHLOW1]] : i32
688-
// CHECK-DAG: {{%.+}} = arith.muli [[HIGHLOW0]], [[LOWHIGH1]] : i32
689-
//
690-
// CHECK-DAG: {{%.+}} = arith.muli [[LOWHIGH0]], [[LOWLOW1]] : i32
691-
// CHECK-DAG: {{%.+}} = arith.muli [[LOWHIGH0]], [[HIGHLOW1]] : i32
692-
//
693-
// CHECK-DAG: {{%.+}} = arith.muli [[HIGHHIGH0]], [[LOWLOW1]] : i32
694-
//
695-
// CHECK: [[RESHIGH0:%.+]] = arith.shli {{%.+}}, [[C16]] : i32
696-
// CHECK-NEXT: [[RES0:%.+]] = arith.ori {{%.+}}, [[RESHIGH0]] : i32
697-
// CHECK-NEXT: [[RESHIGH1:%.+]] = arith.shli {{%.+}}, [[C16]] : i32
698-
// CHECK-NEXT: [[RES1:%.+]] = arith.ori {{%.+}}, [[RESHIGH1]] : i32
699-
// CHECK-NEXT: [[VZ:%.+]] = arith.constant dense<0> : vector<2xi32>
700-
// CHECK-NEXT: [[INS0:%.+]] = vector.insert [[RES0]], [[VZ]] [0] : i32 into vector<2xi32>
701-
// CHECK-NEXT: [[INS1:%.+]] = vector.insert [[RES1]], [[INS0]] [1] : i32 into vector<2xi32>
675+
// CHECK-NEXT: [[VZ:%.+]] = arith.constant dense<0> : vector<2xi32>
676+
// CHECK-NEXT: [[INS0:%.+]] = vector.insert [[RESLOW]], [[VZ]] [0] : i32 into vector<2xi32>
677+
// CHECK-NEXT: [[INS1:%.+]] = vector.insert [[RESHI2]], [[INS0]] [1] : i32 into vector<2xi32>
702678
// CHECK-NEXT: return [[INS1]] : vector<2xi32>
703679
func.func @muli_scalar(%a : i64, %b : i64) -> i64 {
704680
%m = arith.muli %a, %b : i64
@@ -707,6 +683,11 @@ func.func @muli_scalar(%a : i64, %b : i64) -> i64 {
707683

708684
// CHECK-LABEL: func.func @muli_vector
709685
// CHECK-SAME: ({{%.+}}: vector<3x2xi32>, {{%.+}}: vector<3x2xi32>) -> vector<3x2xi32>
686+
// CHECK-DAG: arith.mului_extended
687+
// CHECK-DAG: arith.muli
688+
// CHECK-DAG: arith.muli
689+
// CHECK-NEXT: arith.addi
690+
// CHECK-NEXT: arith.addi
710691
// CHECK: return {{%.+}} : vector<3x2xi32>
711692
func.func @muli_vector(%a : vector<3xi64>, %b : vector<3xi64>) -> vector<3xi64> {
712693
%m = arith.muli %a, %b : vector<3xi64>

0 commit comments

Comments
 (0)