@@ -419,78 +419,26 @@ struct ConvertMulI final : OpConversionPattern<arith::MulIOp> {
419
419
return rewriter.notifyMatchFailure (
420
420
loc, llvm::formatv (" unsupported type: {0}" , op.getType ()));
421
421
422
- Type newElemTy = reduceInnermostDim (newTy);
423
- unsigned newBitWidth = newTy.getElementTypeBitWidth ();
424
- unsigned digitBitWidth = newBitWidth / 2 ;
425
-
426
422
auto [lhsElem0, lhsElem1] =
427
423
extractLastDimHalves (rewriter, loc, adaptor.getLhs ());
428
424
auto [rhsElem0, rhsElem1] =
429
425
extractLastDimHalves (rewriter, loc, adaptor.getRhs ());
430
426
431
- // Emulate multiplication by splitting each input element of type i2N into 4
432
- // digits of type iN and bit width i(N/2). This is so that the intermediate
433
- // multiplications and additions do not overflow. We extract these i(N/2)
434
- // digits from iN vector elements by masking (low digit) and shifting right
435
- // (high digit).
436
- //
437
427
// The multiplication algorithm used is the standard (long) multiplication.
438
- // Multiplying two i2N integers produces (at most) a i4N result, but because
439
- // the calculation of top i2N is not necessary, we omit it.
440
- // In total, this implementations performs 10 intermediate multiplications
441
- // and 16 additions. The number of multiplications could be decreased by
442
- // switching to a more efficient algorithm like Karatsuba. This would,
443
- // however, require being able to perform (intermediate) wide additions and
444
- // subtractions, so it is not clear that such implementation would be more
445
- // efficient.
446
-
447
- APInt lowMaskVal (newBitWidth, 1 );
448
- lowMaskVal = lowMaskVal.shl (digitBitWidth) - 1 ;
449
- Value lowMask =
450
- createScalarOrSplatConstant (rewriter, loc, newElemTy, lowMaskVal);
451
- auto getLowDigit = [lowMask, newElemTy, loc, &rewriter](Value v) {
452
- return rewriter.create <arith::AndIOp>(loc, newElemTy, v, lowMask);
453
- };
428
+ // Multiplying two i2N integers produces (at most) an i4N result, but
429
+ // because the calculation of top i2N is not necessary, we omit it.
430
+ auto mulLowLow =
431
+ rewriter.create <arith::MulUIExtendedOp>(loc, lhsElem0, rhsElem0);
432
+ Value mulLowHi = rewriter.create <arith::MulIOp>(loc, lhsElem0, rhsElem1);
433
+ Value mulHiLow = rewriter.create <arith::MulIOp>(loc, lhsElem1, rhsElem0);
434
+
435
+ Value resLow = mulLowLow.getLow ();
436
+ Value resHi =
437
+ rewriter.create <arith::AddIOp>(loc, mulLowLow.getHigh (), mulLowHi);
438
+ resHi = rewriter.create <arith::AddIOp>(loc, resHi, mulHiLow);
454
439
455
- Value shiftVal =
456
- createScalarOrSplatConstant (rewriter, loc, newElemTy, digitBitWidth);
457
- auto getHighDigit = [shiftVal, loc, &rewriter](Value v) {
458
- return rewriter.create <arith::ShRUIOp>(loc, v, shiftVal);
459
- };
460
-
461
- Value zeroDigit = createScalarOrSplatConstant (rewriter, loc, newElemTy, 0 );
462
- std::array<Value, 4 > resultDigits = {zeroDigit, zeroDigit, zeroDigit,
463
- zeroDigit};
464
- std::array<Value, 4 > lhsDigits = {
465
- getLowDigit (lhsElem0), getHighDigit (lhsElem0), getLowDigit (lhsElem1),
466
- getHighDigit (lhsElem1)};
467
- std::array<Value, 4 > rhsDigits = {
468
- getLowDigit (rhsElem0), getHighDigit (rhsElem0), getLowDigit (rhsElem1),
469
- getHighDigit (rhsElem1)};
470
-
471
- for (unsigned i = 0 , e = lhsDigits.size (); i != e; ++i) {
472
- for (unsigned j = 0 ; i + j != e; ++j) {
473
- Value mul =
474
- rewriter.create <arith::MulIOp>(loc, lhsDigits[i], rhsDigits[j]);
475
- Value current =
476
- rewriter.createOrFold <arith::AddIOp>(loc, resultDigits[i + j], mul);
477
- resultDigits[i + j] = getLowDigit (current);
478
- if (i + j + 1 != e) {
479
- Value carry = rewriter.createOrFold <arith::AddIOp>(
480
- loc, resultDigits[i + j + 1 ], getHighDigit (current));
481
- resultDigits[i + j + 1 ] = carry;
482
- }
483
- }
484
- }
485
-
486
- auto combineDigits = [shiftVal, loc, &rewriter](Value low, Value high) {
487
- Value highBits = rewriter.create <arith::ShLIOp>(loc, high, shiftVal);
488
- return rewriter.create <arith::OrIOp>(loc, low, highBits);
489
- };
490
- Value resultElem0 = combineDigits (resultDigits[0 ], resultDigits[1 ]);
491
- Value resultElem1 = combineDigits (resultDigits[2 ], resultDigits[3 ]);
492
440
Value resultVec =
493
- constructResultVector (rewriter, loc, newTy, {resultElem0, resultElem1 });
441
+ constructResultVector (rewriter, loc, newTy, {resLow, resHi });
494
442
rewriter.replaceOp (op, resultVec);
495
443
return success ();
496
444
}
0 commit comments