Skip to content

Commit a6c92a8

Browse files
authored
[SYCL-MLIR] Define SYCL index space operations (#8089)
Define high-level operations to retrieve information from the SYCL index space. --------- Signed-off-by: Victor Perez <[email protected]>
1 parent a96ec1e commit a6c92a8

File tree

6 files changed

+602
-0
lines changed

6 files changed

+602
-0
lines changed

mlir-sycl/include/mlir/Dialect/SYCL/IR/SYCLOpInterfaces.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,5 +95,7 @@ def SYCLGetID : SYCLOpTrait<"SYCLGetID">;
9595
def SYCLGetComponent : SYCLOpTrait<"SYCLGetComponent">;
9696
def SYCLGetRange : SYCLOpTrait<"SYCLGetRange">;
9797
def SYCLGetGroup : SYCLOpTrait<"SYCLGetGroup">;
98+
def SYCLIndexSpaceGetID : SYCLOpTrait<"SYCLIndexSpaceGetID">;
99+
def SYCLIndexSpaceGetRange : SYCLOpTrait<"SYCLIndexSpaceGetRange">;
98100

99101
#endif // SYCL_OP_INTERFACES

mlir-sycl/include/mlir/Dialect/SYCL/IR/SYCLOpTraits.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ LogicalResult verifySYCLGetIDTrait(Operation *Op);
1717
LogicalResult verifySYCLGetComponentTrait(Operation *Op);
1818
LogicalResult verifySYCLGetRangeTrait(Operation *Op);
1919
LogicalResult verifySYCLGetGroupTrait(Operation *Op);
20+
LogicalResult verifySYCLIndexSpaceGetIDTrait(Operation *Op);
21+
LogicalResult verifySYCLIndexSpaceGetRangeTrait(Operation *Op);
2022

2123
/// This interface describes an SYCLMethodOpInterface that returns a range if
2224
/// called with a single argument and a size_t if called with two arguments.
@@ -64,6 +66,28 @@ class SYCLGetGroup : public OpTrait::TraitBase<ConcreteType, SYCLGetGroup> {
6466
return verifySYCLGetGroupTrait(Op);
6567
}
6668
};
69+
70+
/// This interface describes an operation returning either a SYCL ID type (for
71+
/// cardinality 0) or an MLIR index type (for cardinality 1).
72+
template <typename ConcreteType>
73+
class SYCLIndexSpaceGetID
74+
: public OpTrait::TraitBase<ConcreteType, SYCLIndexSpaceGetID> {
75+
public:
76+
static LogicalResult verifyTrait(Operation *Op) {
77+
return verifySYCLIndexSpaceGetIDTrait(Op);
78+
}
79+
};
80+
81+
/// This interface describes an operation returning either a SYCL range type
82+
/// (for cardinality 0) or an MLIR index type (for cardinality 1).
83+
template <typename ConcreteType>
84+
class SYCLIndexSpaceGetRange
85+
: public OpTrait::TraitBase<ConcreteType, SYCLIndexSpaceGetRange> {
86+
public:
87+
static LogicalResult verifyTrait(Operation *Op) {
88+
return verifySYCLIndexSpaceGetRangeTrait(Op);
89+
}
90+
};
6791
} // namespace sycl
6892
} // namespace mlir
6993

mlir-sycl/include/mlir/Dialect/SYCL/IR/SYCLOps.td

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,10 +413,188 @@ def SYCLMemref : AnyTypeOf<[
413413
def IndexType : AnyTypeOf<[I32, I64, Index]>;
414414
def SYCLGetIDResult : AnyTypeOf<[I64, SYCL_IDType]>;
415415
def SYCLGetRangeResult : AnyTypeOf<[I64, SYCL_RangeType]>;
416+
def SYCLIndexSpaceGetIDResult : AnyTypeOf<[Index, SYCL_IDType]>;
417+
def SYCLIndexSpaceGetRangeResult : AnyTypeOf<[Index, SYCL_RangeType]>;
416418
def SYCLGetResult : AnyTypeOf<[I64, MemRefOf<[I64]>]>;
417419
def VectorEltTy : AnyTypeOf<[I1, I8, I16, I32, I64, F16, F32, F64]>;
418420
def VectorSplatArg : MemRefOf<[VectorEltTy]>;
419421

422+
////////////////////////////////////////////////////////////////////////////////
423+
// SYCL GRID OPERATIONS
424+
////////////////////////////////////////////////////////////////////////////////
425+
426+
def SYCLNumWorkItemsOp : SYCL_Op<"num_work_items", [SYCLIndexSpaceGetRange]> {
427+
let summary = "Retrieve the number of work-items.";
428+
let description = [{
429+
This operation returns the number of work-items in the index space. If the
430+
optional argument is passed, the number of work-items in the given dimension
431+
is returned.
432+
}];
433+
434+
let arguments = (ins Optional<I32>:$dimension);
435+
let results = (outs SYCLIndexSpaceGetRangeResult:$result);
436+
let assemblyFormat = [{
437+
`(` operands `)` attr-dict `:` functional-type(operands, results)
438+
}];
439+
}
440+
441+
def SYCLGlobalIDOp : SYCL_Op<"global_id", [SYCLIndexSpaceGetID]> {
442+
let summary = "Retrieve the global ID of the item.";
443+
let description = [{
444+
This operation returns the global ID of the item in the index space. If the
445+
optional argument is passed, the global ID of the item in the given
446+
dimension is returned.
447+
}];
448+
449+
let arguments = (ins Optional<I32>:$dimension);
450+
let results = (outs SYCLIndexSpaceGetIDResult:$result);
451+
let assemblyFormat = [{
452+
`(` operands `)` attr-dict `:` functional-type(operands, results)
453+
}];
454+
}
455+
456+
def SYCLLocalIDOp : SYCL_Op<"local_id", [SYCLIndexSpaceGetID]> {
457+
let summary = "Retrieve the local ID of the item.";
458+
let description = [{
459+
This operation returns the local ID of the item within the work-group. If
460+
the optional argument is passed, the local ID of the item in the given
461+
dimension is returned.
462+
}];
463+
464+
let arguments = (ins Optional<I32>:$dimension);
465+
let results = (outs SYCLIndexSpaceGetIDResult:$result);
466+
let assemblyFormat = [{
467+
`(` operands `)` attr-dict `:` functional-type(operands, results)
468+
}];
469+
}
470+
471+
def SYCLGlobalOffsetOp : SYCL_Op<"global_offset", [SYCLIndexSpaceGetID]>,
472+
Deprecated<"deprecated in SYCL 2020"> {
473+
let summary = "Retrieve the global offset of the item.";
474+
let description = [{
475+
This operation returns the global offset of the item in the index space. If
476+
the optional argument is passed, the global offset of the item in the given
477+
dimension is returned.
478+
479+
Note that this is deprecated in SYCL 2020.
480+
}];
481+
482+
let arguments = (ins Optional<I32>:$dimension);
483+
let results = (outs SYCLIndexSpaceGetIDResult:$result);
484+
let assemblyFormat = [{
485+
`(` operands `)` attr-dict `:` functional-type(operands, results)
486+
}];
487+
}
488+
489+
def SYCLNumWorkGroupsOp : SYCL_Op<"num_work_groups", [SYCLIndexSpaceGetRange]> {
490+
let summary = "Retrieve the number of work-groups.";
491+
let description = [{
492+
This operation returns the number of work-groups in the index space. If the
493+
optional argument is passed, the number of work-groups in the given
494+
dimension is returned.
495+
}];
496+
497+
let arguments = (ins Optional<I32>:$dimension);
498+
let results = (outs SYCLIndexSpaceGetRangeResult:$result);
499+
let assemblyFormat = [{
500+
`(` operands `)` attr-dict `:` functional-type(operands, results)
501+
}];
502+
}
503+
504+
def SYCLWorkGroupSizeOp : SYCL_Op<"work_group_size", [SYCLIndexSpaceGetRange]> {
505+
let summary = "Retrieve the number of work-items in a work-group.";
506+
let description = [{
507+
This operation returns the number of work-items per work-group. If the
508+
optional argument is passed, the number of work-items per work-group in the
509+
given dimension is returned.
510+
}];
511+
512+
let arguments = (ins Optional<I32>:$dimension);
513+
let results = (outs SYCLIndexSpaceGetRangeResult:$result);
514+
let assemblyFormat = [{
515+
`(` operands `)` attr-dict `:` functional-type(operands, results)
516+
}];
517+
}
518+
519+
def SYCLWorkGroupIDOp : SYCL_Op<"work_group_id", [SYCLIndexSpaceGetID]> {
520+
let summary = "Retrieve the ID of the work-group.";
521+
let description = [{
522+
This operation returns the ID of the work-group. If the optional argument is
523+
passed, the ID of the work-group in the given dimension is returned.
524+
}];
525+
526+
let arguments = (ins Optional<I32>:$dimension);
527+
let results = (outs SYCLIndexSpaceGetIDResult:$result);
528+
let assemblyFormat = [{
529+
`(` operands `)` attr-dict `:` functional-type(operands, results)
530+
}];
531+
}
532+
533+
def SYCLNumSubGroupsOp : SYCL_Op<"num_sub_groups"> {
534+
let summary = "Retrieve the number of sub-groups.";
535+
let description = [{
536+
This operation returns the number of sub-groups.
537+
}];
538+
539+
let arguments = (ins);
540+
let results = (outs I32:$result);
541+
let assemblyFormat = [{
542+
attr-dict `:` functional-type(operands, results)
543+
}];
544+
}
545+
546+
def SYCLSubGroupMaxSizeOp : SYCL_Op<"sub_group_max_size"> {
547+
let summary = "Retrieve the maximum size of a sub-group.";
548+
let description = [{
549+
This operation returns the maximum size of a sub-group.
550+
}];
551+
552+
let arguments = (ins);
553+
let results = (outs I32:$result);
554+
let assemblyFormat = [{
555+
attr-dict `:` functional-type(operands, results)
556+
}];
557+
}
558+
559+
def SYCLSubGroupSizeOp : SYCL_Op<"sub_group_size"> {
560+
let summary = "Retrieve the sub-group size.";
561+
let description = [{
562+
This operation returns the sub-group size.
563+
}];
564+
565+
let arguments = (ins);
566+
let results = (outs I32:$result);
567+
let assemblyFormat = [{
568+
attr-dict `:` functional-type(operands, results)
569+
}];
570+
}
571+
572+
def SYCLSubGroupIDOp : SYCL_Op<"sub_group_id"> {
573+
let summary = "Retrieve the ID of the sub-group.";
574+
let description = [{
575+
This operation returns the ID of the sub-group.
576+
}];
577+
578+
let arguments = (ins);
579+
let results = (outs I32:$result);
580+
let assemblyFormat = [{
581+
attr-dict `:` functional-type(operands, results)
582+
}];
583+
}
584+
585+
def SYCLSubGroupLocalIDOp : SYCL_Op<"sub_group_local_id"> {
586+
let summary = "Retrieve the local ID of the sub-group.";
587+
let description = [{
588+
This operation returns the local ID of the sub-group.
589+
}];
590+
591+
let arguments = (ins);
592+
let results = (outs I32:$result);
593+
let assemblyFormat = [{
594+
attr-dict `:` functional-type(operands, results)
595+
}];
596+
}
597+
420598
////////////////////////////////////////////////////////////////////////////////
421599
// CONSTRUCTOR OPERATION
422600
////////////////////////////////////////////////////////////////////////////////

mlir-sycl/lib/Dialect/IR/SYCLTraits.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,14 @@
1010

1111
#include "mlir/Dialect/SYCL/IR/SYCLOpTraits.h"
1212

13+
#include "mlir/Dialect/Arith/IR/Arith.h"
1314
#include "mlir/Dialect/SYCL/IR/SYCLOps.h"
1415

1516
#include "llvm/ADT/TypeSwitch.h"
1617

18+
using namespace mlir;
19+
using namespace mlir::sycl;
20+
1721
static unsigned getDimensions(mlir::Type Type) {
1822
if (auto MemRefTy = Type.dyn_cast<mlir::MemRefType>()) {
1923
Type = MemRefTy.getElementType();
@@ -174,3 +178,37 @@ mlir::LogicalResult mlir::sycl::verifySYCLGetGroupTrait(Operation *Op) {
174178
return verifyGetSYCLTyOperation(cast<mlir::sycl::SYCLMethodOpInterface>(Op),
175179
"group");
176180
}
181+
182+
static LogicalResult verifyIndexSpaceTrait(Operation *Op) {
183+
const auto Ty = Op->getResultTypes();
184+
assert(Ty.size() == 1 && "Expecting a single return value");
185+
const auto IsIndex = Ty[0].isa<IndexType>();
186+
switch (Op->getNumOperands()) {
187+
case 0:
188+
return !IsIndex ? success()
189+
: Op->emitOpError("Not expecting an index return value for "
190+
"this cardinality");
191+
case 1:
192+
if (auto C = Op->getOperand(0).getDefiningOp<arith::ConstantOp>()) {
193+
const auto Value = static_cast<arith::ConstantIntOp>(C).value();
194+
if (!(0 <= Value && Value < 3)) {
195+
return Op->emitOpError(
196+
"The SYCL index space can only be 1, 2, or 3 dimensional");
197+
}
198+
}
199+
return IsIndex
200+
? success()
201+
: Op->emitOpError(
202+
"Expecting an index return value for this cardinality");
203+
default:
204+
llvm_unreachable("Invalid cardinality");
205+
}
206+
}
207+
208+
LogicalResult mlir::sycl::verifySYCLIndexSpaceGetIDTrait(Operation *Op) {
209+
return verifyIndexSpaceTrait(Op);
210+
}
211+
212+
LogicalResult mlir::sycl::verifySYCLIndexSpaceGetRangeTrait(Operation *Op) {
213+
return verifyIndexSpaceTrait(Op);
214+
}

0 commit comments

Comments
 (0)