Skip to content

[mlir][NVVM] Add ops for vote all and any sync #134309

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion flang/lib/Optimizer/Builder/IntrinsicCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6542,7 +6542,8 @@ IntrinsicLibrary::genVoteBallotSync(mlir::Type resultType,
mlir::Value arg1 =
builder.create<fir::ConvertOp>(loc, builder.getI1Type(), args[1]);
return builder
.create<mlir::NVVM::VoteBallotOp>(loc, resultType, args[0], arg1)
.create<mlir::NVVM::VoteSyncOp>(loc, resultType, args[0], arg1,
mlir::NVVM::VoteSyncKind::ballot)
.getResult();
}

Expand Down
2 changes: 1 addition & 1 deletion flang/test/Lower/CUDA/cuda-device-proc.cuf
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ end subroutine
! CHECK-LABEL: func.func @_QPtestvote()
! CHECK: fir.call @llvm.nvvm.vote.all.sync
! CHECK: fir.call @llvm.nvvm.vote.any.sync
! CHECK: %{{.*}} = nvvm.vote.ballot.sync %{{.*}}, %{{.*}} : i32
! CHECK: %{{.*}} = nvvm.vote.sync ballot %{{.*}}, %{{.*}} -> i32

! CHECK-DAG: func.func private @__ldca_i4x4_(!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<4xi32>>)
! CHECK-DAG: func.func private @__ldcg_i4x4_(!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<4xi32>>)
Expand Down
48 changes: 41 additions & 7 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -808,15 +808,49 @@ def NVVM_ShflOp :
let hasVerifier = 1;
}

def NVVM_VoteBallotOp :
NVVM_Op<"vote.ballot.sync">,
Results<(outs LLVM_Type:$res)>,
Arguments<(ins LLVM_Type:$mask, LLVM_Type:$pred)> {
def VoteSyncKindAny : I32EnumAttrCase<"any", 0>;
def VoteSyncKindAll : I32EnumAttrCase<"all", 1>;
def VoteSyncKindBallot : I32EnumAttrCase<"ballot", 2>;
def VoteSyncKindUni : I32EnumAttrCase<"uni", 3>;

def VoteSyncKind : I32EnumAttr<"VoteSyncKind", "NVVM vote sync kind",
[VoteSyncKindAny, VoteSyncKindAll,
VoteSyncKindBallot, VoteSyncKindUni]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::NVVM";
}

def VoteSyncKindAttr : EnumAttr<NVVM_Dialect, VoteSyncKind, "vote_sync_kind">;

def NVVM_VoteSyncOp
: NVVM_Op<"vote.sync">,
Results<(outs AnyTypeOf<[I32, I1]>:$res)>,
Arguments<(ins I32:$mask, I1:$pred, VoteSyncKindAttr:$kind)> {
let summary = "Vote across thread group";
let description = [{
The `vote.sync` op will cause executing thread to wait until all non-exited
threads corresponding to membermask have executed `vote.sync` with the same
qualifiers and same membermask value before resuming execution.

The vote operation kinds are:
- `any`: True if source predicate is True for some thread in membermask.
- `all`: True if source predicate is True for all non-exited threads in
membermask.
- `uni`: True if source predicate has the same value in all non-exited
threads in membermask.
- `ballot`: In the ballot form, the destination result is a 32 bit integer.
In this form, the predicate from each thread in membermask are copied into
the corresponding bit position of the result, where the bit position
corresponds to the thread’s lane id.

[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-vote-sync)
}];
string llvmBuilder = [{
$res = createIntrinsicCall(builder,
llvm::Intrinsic::nvvm_vote_ballot_sync, {$mask, $pred});
auto intId = getVoteSyncIntrinsicId($kind);
$res = createIntrinsicCall(builder, intId, {$mask, $pred});
}];
let hasCustomAssemblyFormat = 1;
let assemblyFormat = "$kind $mask `,` $pred attr-dict `->` type($res)";
let hasVerifier = 1;
}

def NVVM_SyncWarpOp :
Expand Down
41 changes: 13 additions & 28 deletions mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,34 +48,6 @@ using namespace NVVM;
#include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
#include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"

//===----------------------------------------------------------------------===//
// Printing/parsing for NVVM ops
//===----------------------------------------------------------------------===//

static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) {
p << " " << op->getOperands();
if (op->getNumResults() > 0)
p << " : " << op->getResultTypes();
}

// <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type
ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) {
MLIRContext *context = parser.getContext();
auto int32Ty = IntegerType::get(context, 32);
auto int1Ty = IntegerType::get(context, 1);

SmallVector<OpAsmParser::UnresolvedOperand, 8> ops;
Type type;
return failure(parser.parseOperandList(ops) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(type) ||
parser.addTypeToList(type, result.types) ||
parser.resolveOperands(ops, {int32Ty, int1Ty},
parser.getNameLoc(), result.operands));
}

void VoteBallotOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }

//===----------------------------------------------------------------------===//
// Verifier methods
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1160,6 +1132,19 @@ LogicalResult NVVM::MatchSyncOp::verify() {
return success();
}

LogicalResult NVVM::VoteSyncOp::verify() {
if (getKind() == NVVM::VoteSyncKind::ballot) {
if (!getType().isInteger(32)) {
return emitOpError("vote.sync 'ballot' returns an i32");
}
} else {
if (!getType().isInteger(1)) {
return emitOpError("vote.sync 'any', 'all' and 'uni' returns an i1");
}
}
return success();
}

//===----------------------------------------------------------------------===//
// getIntrinsicID/getIntrinsicIDAndArgs methods
//===----------------------------------------------------------------------===//
Expand Down
15 changes: 15 additions & 0 deletions mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,21 @@ static llvm::Intrinsic::ID getMatchSyncIntrinsicId(Type valType,
}
}

static llvm::Intrinsic::ID getVoteSyncIntrinsicId(NVVM::VoteSyncKind kind) {
switch (kind) {
case NVVM::VoteSyncKind::any:
return llvm::Intrinsic::nvvm_vote_any_sync;
case NVVM::VoteSyncKind::all:
return llvm::Intrinsic::nvvm_vote_all_sync;
case NVVM::VoteSyncKind::ballot:
return llvm::Intrinsic::nvvm_vote_ballot_sync;
case NVVM::VoteSyncKind::uni:
return llvm::Intrinsic::nvvm_vote_uni_sync;
default:
llvm_unreachable("unsupported vote kind");
}
}

/// Return the intrinsic ID associated with ldmatrix for the given paramters.
static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout,
int32_t num) {
Expand Down
10 changes: 8 additions & 2 deletions mlir/test/Dialect/LLVMIR/nvvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,14 @@ func.func @nvvm_shfl_pred(

// CHECK-LABEL: @nvvm_vote(
func.func @nvvm_vote(%arg0 : i32, %arg1 : i1) -> i32 {
// CHECK: nvvm.vote.ballot.sync %{{.*}}, %{{.*}} : i32
%0 = nvvm.vote.ballot.sync %arg0, %arg1 : i32
// CHECK: nvvm.vote.sync ballot %{{.*}}, %{{.*}} -> i32
%0 = nvvm.vote.sync ballot %arg0, %arg1 -> i32
// CHECK: nvvm.vote.sync all %{{.*}}, %{{.*}} -> i1
%1 = nvvm.vote.sync all %arg0, %arg1 -> i1
// CHECK: nvvm.vote.sync any %{{.*}}, %{{.*}} -> i1
%2 = nvvm.vote.sync any %arg0, %arg1 -> i1
// CHECK: nvvm.vote.sync uni %{{.*}}, %{{.*}} -> i1
%3 = nvvm.vote.sync uni %arg0, %arg1 -> i1
llvm.return %0 : i32
}

Expand Down
8 changes: 7 additions & 1 deletion mlir/test/Target/LLVMIR/nvvmir.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,13 @@ llvm.func @nvvm_shfl_pred(
// CHECK-LABEL: @nvvm_vote
llvm.func @nvvm_vote(%0 : i32, %1 : i1) -> i32 {
// CHECK: call i32 @llvm.nvvm.vote.ballot.sync(i32 %{{.*}}, i1 %{{.*}})
%3 = nvvm.vote.ballot.sync %0, %1 : i32
%3 = nvvm.vote.sync ballot %0, %1 -> i32
// CHECK: call i1 @llvm.nvvm.vote.all.sync(i32 %{{.*}}, i1 %{{.*}})
%4 = nvvm.vote.sync all %0, %1 -> i1
// CHECK: call i1 @llvm.nvvm.vote.any.sync(i32 %{{.*}}, i1 %{{.*}})
%5 = nvvm.vote.sync any %0, %1 -> i1
// CHECK: call i1 @llvm.nvvm.vote.uni.sync(i32 %{{.*}}, i1 %{{.*}})
%6 = nvvm.vote.sync uni %0, %1 -> i1
llvm.return %3 : i32
}

Expand Down