@@ -3640,36 +3640,38 @@ def NVVM_Tcgen05StOp : NVVM_Op<"tcgen05.st", [NVVMRequiresSMa<[100, 101]>]> {
3640
3640
}
3641
3641
3642
3642
//===----------------------------------------------------------------------===//
3643
- // NVVM dot.accumulate.4way Op
3643
+ // NVVM dot.accumulate Ops
3644
3644
//===----------------------------------------------------------------------===//
3645
3645
3646
- def DotAccumulate4WayS8 : I32EnumAttrCase<"S8 ", 1 , "s8 ">;
3647
- def DotAccumulate4WayU8 : I32EnumAttrCase<"U8 ", 0 , "u8 ">;
3646
+ def DotAccumulateUnsigned : I32EnumAttrCase<"UNSIGNED ", 0 , "unsigned ">;
3647
+ def DotAccumulateSigned : I32EnumAttrCase<"SIGNED ", 1 , "signed ">;
3648
3648
3649
- def DotAccumulate4WayType : I32EnumAttr<"DotAccumulate4WayType ",
3650
- "NVVM DotAccumulate4WayType ",
3651
- [DotAccumulate4WayS8, DotAccumulate4WayU8 ]> {
3649
+ def DotAccumulateType : I32EnumAttr<"DotAccumulateType ",
3650
+ "NVVM DotAccumulateType ",
3651
+ [DotAccumulateSigned, DotAccumulateUnsigned ]> {
3652
3652
let cppNamespace = "::mlir::NVVM";
3653
3653
let genSpecializedAttr = 0;
3654
3654
}
3655
3655
3656
- def DotAccumulate4WayTypeAttr : EnumAttr<NVVM_Dialect, DotAccumulate4WayType , "dot_accumulate_4way_type "> {
3656
+ def DotAccumulateTypeAttr : EnumAttr<NVVM_Dialect, DotAccumulateType , "dot_accumulate_type "> {
3657
3657
let assemblyFormat = "`<` $value `>`";
3658
3658
}
3659
3659
3660
3660
def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
3661
- let summary = "Four-way byte dot product-accumulate instruction. ";
3661
+ let summary = "Four-way byte dot product-accumulate instruction";
3662
3662
let description = [{
3663
3663
Performs a four-way byte dot-product which is accumulated in a 32-bit
3664
3664
result.
3665
3665
Operand `a` and `b` are vectors of 4 bytes between which the dot product is
3666
3666
computed.
3667
+
3667
3668
The `a_type` and `b_type` attributes specify the type of the elements in `a`
3668
3669
and `b` respectively.
3669
- If `a_type` or `b_type` is `s8 `, then the elements in the corresponding
3670
+ If `a_type` or `b_type` is `signed `, then the elements in the corresponding
3670
3671
vector are sign-extended to 32-bit before the dot product is computed.
3671
- If `a_type` or `b_type` is `u8`, then the elements in the corresponding
3672
- vector are zero-extended to 32-bit instead.
3672
+ If `a_type` or `b_type` is `unsigned`, then the elements in the
3673
+ corresponding vector are zero-extended to 32-bit instead.
3674
+
3673
3675
Operand `c` is a 32-bit integer to which the result is accumulated. It is
3674
3676
treated as holding a signed integer if any of `a_type` or `b_type` is `s8`.
3675
3677
@@ -3678,9 +3680,9 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
3678
3680
3679
3681
let arguments = (ins
3680
3682
VectorOfLengthAndType<[4], [I8]>:$a,
3681
- DotAccumulate4WayTypeAttr :$a_type,
3683
+ DotAccumulateTypeAttr :$a_type,
3682
3684
VectorOfLengthAndType<[4], [I8]>:$b,
3683
- DotAccumulate4WayTypeAttr :$b_type,
3685
+ DotAccumulateTypeAttr :$b_type,
3684
3686
I32:$c
3685
3687
);
3686
3688
@@ -3689,17 +3691,15 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
3689
3691
let assemblyFormat = "$a $a_type `,` $b $b_type `,` $c attr-dict `:` type($a) `,` type($b)";
3690
3692
3691
3693
let extraClassDeclaration = [{
3692
- static llvm::Intrinsic::ID
3693
- getIntrinsicID(NVVM::DotAccumulate4WayType a_type,
3694
- NVVM::DotAccumulate4WayType b_type);
3695
- llvm::Value* getPackedArg(llvm::Value* arg, llvm::IRBuilderBase& builder);
3694
+ static mlir::NVVM::IDArgPair
3695
+ getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
3696
+ llvm::IRBuilderBase &builder);
3696
3697
}];
3697
3698
3698
3699
string llvmBuilder = [{
3699
- llvm::Intrinsic::ID id = NVVM::DotAccumulate4WayOp::getIntrinsicID($a_type, $b_type);
3700
- llvm::Value* argA = op.getPackedArg($a, builder);
3701
- llvm::Value* argB = op.getPackedArg($b, builder);
3702
- $res = createIntrinsicCall(builder, id, {argA, argB, $c});
3700
+ auto [id, args] = NVVM::DotAccumulate4WayOp::getIntrinsicIDAndArgs(
3701
+ *op, moduleTranslation, builder);
3702
+ $res = createIntrinsicCall(builder, id, args);
3703
3703
}];
3704
3704
}
3705
3705
0 commit comments