@@ -3533,36 +3533,38 @@ def NVVM_Tcgen05StOp : NVVM_Op<"tcgen05.st", [NVVMRequiresSMa<[100, 101]>]> {
3533
3533
}
3534
3534
3535
3535
//===----------------------------------------------------------------------===//
3536
- // NVVM dot.accumulate.4way Op
3536
+ // NVVM dot.accumulate Ops
3537
3537
//===----------------------------------------------------------------------===//
3538
3538
3539
- def DotAccumulate4WayS8 : I32EnumAttrCase<"S8 ", 1 , "s8 ">;
3540
- def DotAccumulate4WayU8 : I32EnumAttrCase<"U8 ", 0 , "u8 ">;
3539
+ def DotAccumulateUnsigned : I32EnumAttrCase<"UNSIGNED ", 0 , "unsigned ">;
3540
+ def DotAccumulateSigned : I32EnumAttrCase<"SIGNED ", 1 , "signed ">;
3541
3541
3542
- def DotAccumulate4WayType : I32EnumAttr<"DotAccumulate4WayType ",
3543
- "NVVM DotAccumulate4WayType ",
3544
- [DotAccumulate4WayS8, DotAccumulate4WayU8 ]> {
3542
+ def DotAccumulateType : I32EnumAttr<"DotAccumulateType ",
3543
+ "NVVM DotAccumulateType ",
3544
+ [DotAccumulateSigned, DotAccumulateUnsigned ]> {
3545
3545
let cppNamespace = "::mlir::NVVM";
3546
3546
let genSpecializedAttr = 0;
3547
3547
}
3548
3548
3549
- def DotAccumulate4WayTypeAttr : EnumAttr<NVVM_Dialect, DotAccumulate4WayType , "dot_accumulate_4way_type "> {
3549
+ def DotAccumulateTypeAttr : EnumAttr<NVVM_Dialect, DotAccumulateType , "dot_accumulate_type "> {
3550
3550
let assemblyFormat = "`<` $value `>`";
3551
3551
}
3552
3552
3553
3553
def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
3554
- let summary = "Four-way byte dot product-accumulate instruction. ";
3554
+ let summary = "Four-way byte dot product-accumulate instruction";
3555
3555
let description = [{
3556
3556
Performs a four-way byte dot-product which is accumulated in a 32-bit
3557
3557
result.
3558
3558
Operand `a` and `b` are vectors of 4 bytes between which the dot product is
3559
3559
computed.
3560
+
3560
3561
The `a_type` and `b_type` attributes specify the type of the elements in `a`
3561
3562
and `b` respectively.
3562
- If `a_type` or `b_type` is `s8 `, then the elements in the corresponding
3563
+ If `a_type` or `b_type` is `signed `, then the elements in the corresponding
3563
3564
vector are sign-extended to 32-bit before the dot product is computed.
3564
- If `a_type` or `b_type` is `u8`, then the elements in the corresponding
3565
- vector are zero-extended to 32-bit instead.
3565
+ If `a_type` or `b_type` is `unsigned`, then the elements in the
3566
+ corresponding vector are zero-extended to 32-bit instead.
3567
+
3566
3568
Operand `c` is a 32-bit integer to which the result is accumulated. It is
3567
3569
treated as holding a signed integer if any of `a_type` or `b_type` is `s8`.
3568
3570
@@ -3571,9 +3573,9 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
3571
3573
3572
3574
let arguments = (ins
3573
3575
VectorOfLengthAndType<[4], [I8]>:$a,
3574
- DotAccumulate4WayTypeAttr :$a_type,
3576
+ DotAccumulateTypeAttr :$a_type,
3575
3577
VectorOfLengthAndType<[4], [I8]>:$b,
3576
- DotAccumulate4WayTypeAttr :$b_type,
3578
+ DotAccumulateTypeAttr :$b_type,
3577
3579
I32:$c
3578
3580
);
3579
3581
@@ -3582,17 +3584,15 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
3582
3584
let assemblyFormat = "$a $a_type `,` $b $b_type `,` $c attr-dict `:` type($a) `,` type($b)";
3583
3585
3584
3586
let extraClassDeclaration = [{
3585
- static llvm::Intrinsic::ID
3586
- getIntrinsicID(NVVM::DotAccumulate4WayType a_type,
3587
- NVVM::DotAccumulate4WayType b_type);
3588
- llvm::Value* getPackedArg(llvm::Value* arg, llvm::IRBuilderBase& builder);
3587
+ static mlir::NVVM::IDArgPair
3588
+ getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
3589
+ llvm::IRBuilderBase &builder);
3589
3590
}];
3590
3591
3591
3592
string llvmBuilder = [{
3592
- llvm::Intrinsic::ID id = NVVM::DotAccumulate4WayOp::getIntrinsicID($a_type, $b_type);
3593
- llvm::Value* argA = op.getPackedArg($a, builder);
3594
- llvm::Value* argB = op.getPackedArg($b, builder);
3595
- $res = createIntrinsicCall(builder, id, {argA, argB, $c});
3593
+ auto [id, args] = NVVM::DotAccumulate4WayOp::getIntrinsicIDAndArgs(
3594
+ *op, moduleTranslation, builder);
3595
+ $res = createIntrinsicCall(builder, id, args);
3596
3596
}];
3597
3597
}
3598
3598
0 commit comments