Skip to content

Commit 380fcb8

Browse files
committed
[MLIR][NVVM] Update dot.accumulate.4way NVVM Op
This change refactors and updates the dot.accumulate.4way NVVM Op to be more descriptive and readable.
1 parent f0ab64b commit 380fcb8

File tree

4 files changed

+53
-53
lines changed

4 files changed

+53
-53
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3533,36 +3533,38 @@ def NVVM_Tcgen05StOp : NVVM_Op<"tcgen05.st", [NVVMRequiresSMa<[100, 101]>]> {
35333533
}
35343534

35353535
//===----------------------------------------------------------------------===//
3536-
// NVVM dot.accumulate.4way Op
3536+
// NVVM dot.accumulate Ops
35373537
//===----------------------------------------------------------------------===//
35383538

3539-
def DotAccumulate4WayS8 : I32EnumAttrCase<"S8", 1, "s8">;
3540-
def DotAccumulate4WayU8 : I32EnumAttrCase<"U8", 0, "u8">;
3539+
def DotAccumulateUnsigned : I32EnumAttrCase<"UNSIGNED", 0, "unsigned">;
3540+
def DotAccumulateSigned : I32EnumAttrCase<"SIGNED", 1, "signed">;
35413541

3542-
def DotAccumulate4WayType : I32EnumAttr<"DotAccumulate4WayType",
3543-
"NVVM DotAccumulate4WayType",
3544-
[DotAccumulate4WayS8, DotAccumulate4WayU8]> {
3542+
def DotAccumulateType : I32EnumAttr<"DotAccumulateType",
3543+
"NVVM DotAccumulateType",
3544+
[DotAccumulateSigned, DotAccumulateUnsigned]> {
35453545
let cppNamespace = "::mlir::NVVM";
35463546
let genSpecializedAttr = 0;
35473547
}
35483548

3549-
def DotAccumulate4WayTypeAttr : EnumAttr<NVVM_Dialect, DotAccumulate4WayType, "dot_accumulate_4way_type"> {
3549+
def DotAccumulateTypeAttr : EnumAttr<NVVM_Dialect, DotAccumulateType, "dot_accumulate_type"> {
35503550
let assemblyFormat = "`<` $value `>`";
35513551
}
35523552

35533553
def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
3554-
let summary = "Four-way byte dot product-accumulate instruction.";
3554+
let summary = "Four-way byte dot product-accumulate instruction";
35553555
let description = [{
35563556
Performs a four-way byte dot-product which is accumulated in a 32-bit
35573557
result.
35583558
Operand `a` and `b` are vectors of 4 bytes between which the dot product is
35593559
computed.
3560+
35603561
The `a_type` and `b_type` attributes specify the type of the elements in `a`
35613562
and `b` respectively.
3562-
If `a_type` or `b_type` is `s8`, then the elements in the corresponding
3563+
If `a_type` or `b_type` is `signed`, then the elements in the corresponding
35633564
vector are sign-extended to 32-bit before the dot product is computed.
3564-
If `a_type` or `b_type` is `u8`, then the elements in the corresponding
3565-
vector are zero-extended to 32-bit instead.
3565+
If `a_type` or `b_type` is `unsigned`, then the elements in the
3566+
corresponding vector are zero-extended to 32-bit instead.
3567+
35663568
Operand `c` is a 32-bit integer to which the result is accumulated. It is
35673569
treated as holding a signed integer if any of `a_type` or `b_type` is `s8`.
35683570

@@ -3571,9 +3573,9 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
35713573

35723574
let arguments = (ins
35733575
VectorOfLengthAndType<[4], [I8]>:$a,
3574-
DotAccumulate4WayTypeAttr:$a_type,
3576+
DotAccumulateTypeAttr:$a_type,
35753577
VectorOfLengthAndType<[4], [I8]>:$b,
3576-
DotAccumulate4WayTypeAttr:$b_type,
3578+
DotAccumulateTypeAttr:$b_type,
35773579
I32:$c
35783580
);
35793581

@@ -3582,17 +3584,15 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
35823584
let assemblyFormat = "$a $a_type `,` $b $b_type `,` $c attr-dict `:` type($a) `,` type($b)";
35833585

35843586
let extraClassDeclaration = [{
3585-
static llvm::Intrinsic::ID
3586-
getIntrinsicID(NVVM::DotAccumulate4WayType a_type,
3587-
NVVM::DotAccumulate4WayType b_type);
3588-
llvm::Value* getPackedArg(llvm::Value* arg, llvm::IRBuilderBase& builder);
3587+
static mlir::NVVM::IDArgPair
3588+
getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
3589+
llvm::IRBuilderBase &builder);
35893590
}];
35903591

35913592
string llvmBuilder = [{
3592-
llvm::Intrinsic::ID id = NVVM::DotAccumulate4WayOp::getIntrinsicID($a_type, $b_type);
3593-
llvm::Value* argA = op.getPackedArg($a, builder);
3594-
llvm::Value* argB = op.getPackedArg($b, builder);
3595-
$res = createIntrinsicCall(builder, id, {argA, argB, $c});
3593+
auto [id, args] = NVVM::DotAccumulate4WayOp::getIntrinsicIDAndArgs(
3594+
*op, moduleTranslation, builder);
3595+
$res = createIntrinsicCall(builder, id, args);
35963596
}];
35973597
}
35983598

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1205,13 +1205,6 @@ LogicalResult NVVM::VoteSyncOp::verify() {
12051205
return success();
12061206
}
12071207

1208-
llvm::Value *
1209-
NVVM::DotAccumulate4WayOp::getPackedArg(llvm::Value *arg,
1210-
llvm::IRBuilderBase &builder) {
1211-
return builder.CreateBitCast(arg,
1212-
llvm::Type::getInt32Ty(builder.getContext()));
1213-
}
1214-
12151208
//===----------------------------------------------------------------------===//
12161209
// getIntrinsicID/getIntrinsicIDAndArgs methods
12171210
//===----------------------------------------------------------------------===//
@@ -1627,24 +1620,31 @@ static void nvvmInferResultRanges(Operation *op, Value result,
16271620
}
16281621
}
16291622

1630-
llvm::Intrinsic::ID
1631-
DotAccumulate4WayOp::getIntrinsicID(NVVM::DotAccumulate4WayType a_type,
1632-
NVVM::DotAccumulate4WayType b_type) {
1633-
bool is_a_siext = a_type == NVVM::DotAccumulate4WayType::S8;
1634-
bool is_b_siext = b_type == NVVM::DotAccumulate4WayType::S8;
1635-
unsigned type = (is_a_siext << 1) | is_b_siext;
1636-
switch (type) {
1637-
case 0:
1638-
return llvm::Intrinsic::nvvm_idp4a_u_u;
1639-
case 1:
1640-
return llvm::Intrinsic::nvvm_idp4a_u_s;
1641-
case 2:
1642-
return llvm::Intrinsic::nvvm_idp4a_s_u;
1643-
case 3:
1644-
return llvm::Intrinsic::nvvm_idp4a_s_s;
1645-
default:
1646-
llvm_unreachable("Invalid DP4a type");
1647-
}
1623+
static llvm::Value *getAsPackedI32(llvm::Value *arg,
1624+
llvm::IRBuilderBase &builder) {
1625+
return builder.CreateBitCast(arg,
1626+
llvm::Type::getInt32Ty(builder.getContext()));
1627+
}
1628+
1629+
NVVM::IDArgPair DotAccumulate4WayOp::getIntrinsicIDAndArgs(
1630+
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1631+
auto curOp = cast<NVVM::DotAccumulate4WayOp>(op);
1632+
1633+
llvm::SmallVector<llvm::Value *> args;
1634+
args.push_back(getAsPackedI32(mt.lookupValue(curOp.getA()), builder));
1635+
args.push_back(getAsPackedI32(mt.lookupValue(curOp.getB()), builder));
1636+
args.push_back(mt.lookupValue(curOp.getC()));
1637+
1638+
bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
1639+
bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
1640+
unsigned type = (isASigned << 1) | isBSigned;
1641+
const llvm::Intrinsic::ID ids[] = {
1642+
llvm::Intrinsic::nvvm_idp4a_u_u,
1643+
llvm::Intrinsic::nvvm_idp4a_u_s,
1644+
llvm::Intrinsic::nvvm_idp4a_s_u,
1645+
llvm::Intrinsic::nvvm_idp4a_s_s,
1646+
};
1647+
return {ids[type], args};
16481648
}
16491649

16501650
//===----------------------------------------------------------------------===//

mlir/test/Dialect/LLVMIR/nvvm.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -579,11 +579,11 @@ func.func @st_bulk(%addr_gen: !llvm.ptr, %addr_shared: !llvm.ptr<3>, %size: i64)
579579
}
580580

581581
// CHECK-LABEL: @dot_accumulate_4way
582-
func.func @dot_accumulate_4way(%a: i32, %a_vec: vector<4xi8>, %b: i32, %b_vec: vector<4xi8>, %c: i32) {
582+
func.func @dot_accumulate_4way(%a_vec: vector<4xi8>, %b_vec: vector<4xi8>, %c: i32) {
583583
// CHECK: nvvm.dot.accumulate.4way %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi8>, vector<4xi8>
584-
%1 = nvvm.dot.accumulate.4way %a_vec <u8>, %b_vec <u8>, %c: vector<4xi8>, vector<4xi8>
584+
%1 = nvvm.dot.accumulate.4way %a_vec <unsigned>, %b_vec <unsigned>, %c: vector<4xi8>, vector<4xi8>
585585
// CHECK: nvvm.dot.accumulate.4way %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi8>, vector<4xi8>
586-
%3 = nvvm.dot.accumulate.4way %a_vec <s8>, %b_vec <s8>, %c: vector<4xi8>, vector<4xi8>
586+
%3 = nvvm.dot.accumulate.4way %a_vec <signed>, %b_vec <signed>, %c: vector<4xi8>, vector<4xi8>
587587
return
588588
}
589589

mlir/test/Target/LLVMIR/nvvmir.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -851,18 +851,18 @@ llvm.func @nvvm_dot_accumulate_4way(%a: vector<4xi8>, %b: vector<4xi8>, %c: i32)
851851
// CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
852852
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
853853
// CHECK: call i32 @llvm.nvvm.idp4a.u.u(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
854-
%0 = nvvm.dot.accumulate.4way %a <u8>, %b <u8>, %c: vector<4xi8>, vector<4xi8>
854+
%0 = nvvm.dot.accumulate.4way %a <unsigned>, %b <unsigned>, %c: vector<4xi8>, vector<4xi8>
855855
// CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
856856
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
857857
// CHECK: call i32 @llvm.nvvm.idp4a.s.u(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
858-
%1 = nvvm.dot.accumulate.4way %a <s8>, %b <u8>, %c: vector<4xi8>, vector<4xi8>
858+
%1 = nvvm.dot.accumulate.4way %a <signed>, %b <unsigned>, %c: vector<4xi8>, vector<4xi8>
859859
// CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
860860
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
861861
// CHECK: call i32 @llvm.nvvm.idp4a.u.s(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
862-
%2 = nvvm.dot.accumulate.4way %a <u8>, %b <s8>, %c: vector<4xi8>, vector<4xi8>
862+
%2 = nvvm.dot.accumulate.4way %a <unsigned>, %b <signed>, %c: vector<4xi8>, vector<4xi8>
863863
// CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
864864
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
865865
// CHECK: call i32 @llvm.nvvm.idp4a.s.s(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
866-
%3 = nvvm.dot.accumulate.4way %a <s8>, %b <s8>, %c: vector<4xi8>, vector<4xi8>
866+
%3 = nvvm.dot.accumulate.4way %a <signed>, %b <signed>, %c: vector<4xi8>, vector<4xi8>
867867
llvm.return
868868
}

0 commit comments

Comments
 (0)