-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][XeGPU] Add sg_map attribute to support Work Item level semantics #108864
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-mlir Author: Petr Kurapov (kurapov-peter) ChangesThe PR adds an attribute (sg_map) describing the distribution of computation among work items for xegpu operations to be used in lowering passes. The map is attached to the tensor descriptor, so the constructor and the type are updated. Tests check the custom parser & printer. The attribute is optional now, so no other changes required. Full diff: https://github.com/llvm/llvm-project/pull/108864.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index f3ca09a6a68ea8..576c79f66aaed9 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -116,4 +116,36 @@ def XeGPU_FenceScopeAttr:
let assemblyFormat = "$value";
}
+def XeGPU_SGMapAttr : XeGPUAttr<"SGMap", "sg_map"> {
+ let summary = [{
+ Describes the mapping between work item (WI) and the 2D tensor specified by the tensor descriptor.
+ }];
+ let description = [{
+ To distribute the XeGPU operation to work items, the tensor_desc must be specified with the sg_map
+ attribute at the tensor description creation time.
+ Within the `sg_map`, `wi_layout` specifies the layout of work items,
+ describing the mapping of work items to the tensor.
+ wi_layout[0] x wi_layout[1] must be equal to the total number of work items within a subgroup.
+ `wi_data` specifies the minimum number of data elements assigned to each work item for a single distribution.
+
+ E.g., #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>
+ In this example, the subgroup has 16 work items in wi_layout=[1, 16],
+ each accessing 1 element as specified by wi_data=[1, 1].
+
+ `wi_data[0] * wi_data[1]` can be greater than 1, meaning that each work item operates on multiple elements,
+ which is eventually lowered to "SIMT-flavor" vector, like SPIR-V vector or llvm vector.
+ The multiple elements indicated by `wi_data` can only be from one dimension and must be contiguous in the memory along either dimension.
+ }];
+ let parameters = (ins
+ ArrayRefParameter<"uint32_t">:$wi_layout,
+ ArrayRefParameter<"uint32_t">:$wi_data);
+
+ let builders = [
+ AttrBuilder<(ins)>
+ ];
+
+ let hasCustomAssemblyFormat = 1;
+ let genVerifyDecl = 1;
+}
+
#endif // MLIR_DIALECT_XEGPU_IR_XEGPUATTRS_TD
\ No newline at end of file
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 9f101a71697b56..c0d7a08a6cb3a2 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -63,7 +63,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
element-type ::= float-type | integer-type | index-type
dim-list := (static-dim-list `x`)?
static-dim-list ::= decimal-literal `x` decimal-literal
- attr-list = (, memory_scope = value)? (, arr_len = value)? (, boundary_check = value)? (, scattered = value)?
+ attr-list = (, memory_scope = value)? (, arr_len = value)? (, boundary_check = value)? (, scattered = value)? (, sg_map `<` wi_layout = value, wi_data = value `>`)?
```
Examples:
@@ -77,12 +77,16 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
// A TensorDesc with 8x16 f32 elements for a memory region in shared memory space.
xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr<memory_scope = slm>>
+
+ // A TensorDesc with a sg_map
+ xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
```
}];
let parameters = (ins ArrayRefParameter<"int64_t">: $shape,
"mlir::Type": $elementType,
- OptionalParameter<"mlir::Attribute">: $encoding);
+ OptionalParameter<"mlir::Attribute">: $encoding,
+ OptionalParameter<"mlir::Attribute">: $sg_map);
let builders = [
TypeBuilderWithInferredContext<(ins
@@ -91,7 +95,9 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
CArg<"bool", "false">: $scattered,
CArg<"int", "1">: $array_length,
CArg<"xegpu::MemoryScope", "xegpu::MemoryScope::Global">:$memory_scope,
- CArg<"bool", "true">: $boundary_check
+ CArg<"bool", "true">: $boundary_check,
+ CArg<"llvm::ArrayRef<uint32_t>", "{}">: $wi_layout,
+ CArg<"llvm::ArrayRef<uint32_t>", "{}">: $wi_data
)>
];
@@ -114,6 +120,10 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
return llvm::dyn_cast_if_present<TensorDescAttr>(getEncoding());
}
+ SGMapAttr getSGMapAttr() const {
+ return llvm::dyn_cast_if_present<SGMapAttr>(getSgMap());
+ }
+
xegpu::MemoryScope getMemoryScope() const {
auto attr = getEncodingAsTensorDescAttr();
if (attr && attr.getMemoryScope())
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 24719fe748fe4f..cbc140c465e1ec 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -44,6 +44,77 @@ TensorDescAttr TensorDescAttr::get(mlir::MLIRContext *context,
return Base::get(context, scopeAttr, lengthAttr, boundaryAttr, scatteredAttr);
}
+//===----------------------------------------------------------------------===//
+// XeGPU_SGMapAttr
+//===----------------------------------------------------------------------===//
+namespace {
+template <typename T, unsigned N>
+LogicalResult parseIntArrayField(::mlir::AsmParser &parser,
+ llvm::SmallVector<T, N> &result,
+ llvm::StringRef fieldName) {
+ if (failed(parser.parseKeyword(fieldName))) {
+ parser.emitError(parser.getCurrentLocation(),
+ "unexpected field name. Expected " + fieldName + ".");
+ return failure();
+ }
+
+ if (failed(parser.parseEqual())) {
+ parser.emitError(parser.getCurrentLocation(), "expected '=' sign.");
+ return failure();
+ }
+
+ auto elemParser = [&]() -> llvm::ParseResult {
+ uint32_t elem = 0;
+ auto res = parser.parseInteger(elem);
+ result.push_back(elem);
+ return res;
+ };
+
+ return parser.parseCommaSeparatedList(AsmParser::Delimiter::Square,
+ elemParser, fieldName);
+}
+} // namespace
+
+mlir::Attribute SGMapAttr::parse(::mlir::AsmParser &parser,
+ ::mlir::Type attrType) {
+ if (failed(parser.parseLess()))
+ return {};
+
+ llvm::SmallVector<uint32_t, 2> wi_layout, wi_data;
+ if (failed(parseIntArrayField(parser, wi_layout, "wi_layout")))
+ return {};
+
+ if (failed(parser.parseComma()))
+ return {};
+
+ if (failed(parseIntArrayField(parser, wi_data, "wi_data")))
+ return {};
+
+ return SGMapAttr::getChecked(
+ [&]() { return parser.emitError(parser.getNameLoc()); },
+ parser.getContext(), wi_layout, wi_data);
+}
+
+void SGMapAttr::print(::mlir::AsmPrinter &printer) const {
+ printer << "<";
+ printer.printKeywordOrString("wi_layout");
+ printer << " = [" << getWiLayout() << "], ";
+ printer.printKeywordOrString("wi_data");
+ printer << " = [" << getWiData() << "]";
+ printer << ">";
+}
+
+LogicalResult
+SGMapAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
+ llvm::ArrayRef<uint32_t> wi_layout,
+ llvm::ArrayRef<uint32_t> wi_data) {
+ if (wi_layout.size() != 2)
+ return emitError() << "expected wi_layout of size 2";
+ if (wi_data.size() != 2)
+ return emitError() << "expected wi_data of size 2";
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// XeGPU_TensorDescType
//===----------------------------------------------------------------------===//
@@ -51,6 +122,7 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
llvm::SmallVector<int64_t> shape;
mlir::Type elementType;
mlir::FailureOr<mlir::Attribute> encoding;
+ mlir::FailureOr<mlir::Attribute> sg_map;
// Parse literal '<'
if (parser.parseLess())
@@ -69,14 +141,22 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
}
// parse optional attributes
- if (mlir::succeeded(parser.parseOptionalComma())) {
- encoding = mlir::FieldParser<mlir::Attribute>::parse(parser);
- if (mlir::failed(encoding)) {
- parser.emitError(
- parser.getCurrentLocation(),
- "Failed to parse the attribute field for TensorDescType.\n");
- return {};
+ while (mlir::succeeded(parser.parseOptionalComma())) {
+ mlir::Attribute attr;
+ ParseResult res = parser.parseAttribute(attr);
+ if (mlir::succeeded(res)) {
+ if (mlir::isa<SGMapAttr>(attr)) {
+ sg_map = attr;
+ continue;
+ }
+ if (mlir::isa<TensorDescAttr>(attr)) {
+ encoding = attr;
+ continue;
+ }
}
+ parser.emitError(parser.getCurrentLocation(),
+ "Failed to parse the attribute.\n");
+ return {};
}
// Parse literal '>'
@@ -84,7 +164,8 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
return {};
return TensorDescType::get(parser.getContext(), shape, elementType,
- encoding.value_or(mlir::Attribute()));
+ encoding.value_or(mlir::Attribute()),
+ sg_map.value_or(mlir::Attribute()));
}
void TensorDescType::print(::mlir::AsmPrinter &printer) const {
@@ -104,17 +185,23 @@ void TensorDescType::print(::mlir::AsmPrinter &printer) const {
if (auto encoding = getEncoding())
printer << ", " << encoding;
+ if (auto sg_map = getSgMap())
+ printer << ", " << sg_map;
+
printer << ">";
}
TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
mlir::Type elementType, bool scattered,
int array_length, MemoryScope memory_scope,
- bool boundary_check) {
+ bool boundary_check,
+ llvm::ArrayRef<uint32_t> wi_layout,
+ llvm::ArrayRef<uint32_t> wi_data) {
auto context = elementType.getContext();
- auto attr = TensorDescAttr::get(context, memory_scope, array_length,
- boundary_check, scattered);
- return Base::get(context, shape, elementType, attr);
+ auto tensorDescAttr = TensorDescAttr::get(context, memory_scope, array_length,
+ boundary_check, scattered);
+ auto sgMapAttr = SGMapAttr::get(context, wi_layout, wi_data);
+ return Base::get(context, shape, elementType, tensorDescAttr, sgMapAttr);
}
} // namespace xegpu
diff --git a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
index 35d44cf56a239b..eb242f8aa88e64 100644
--- a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
+++ b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
@@ -13,6 +13,14 @@ gpu.func @test_create_nd_tdesc_vc_1(%src: memref<24x32xf32>) {
gpu.return
}
+// CHECK: gpu.func @test_create_nd_tdesc_with_sg_map(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @test_create_nd_tdesc_with_sg_map(%src: memref<24x32xf32>) {
+ // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+ !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ gpu.return
+}
+
// CHECK: gpu.func @test_create_nd_tdesc_vc_2(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index) {
gpu.func @test_create_nd_tdesc_vc_2(%src: ui64, %w : index, %h : index, %x : index, %y : index) {
//CHECK: %[[C:.*]] = arith.constant 1 : index
@@ -36,6 +44,13 @@ gpu.func @test_create_nd_tdesc_vc_4(%src: memref<2x24x32xf32>) {
gpu.return
}
+// CHECK: gpu.func @test_create_nd_tdesc_vc_5(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @test_create_nd_tdesc_vc_5(%src: memref<24x32xf32>) {
+ // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x16xf32, #xegpu.tdesc_attr<array_length = 2 : i64>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>
+ %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x16xf32, #xegpu.tdesc_attr<array_length = 2>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ gpu.return
+}
+
// CHECK: gpu.func @test_prefetch_nd_vc(%[[arg0:.*]]: memref<24x32xf16>) {
gpu.func @test_prefetch_nd_vc(%src: memref<24x32xf16>) {
// CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16>
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
``` | ||
}]; | ||
|
||
let parameters = (ins ArrayRefParameter<"int64_t">: $shape, | ||
"mlir::Type": $elementType, | ||
OptionalParameter<"mlir::Attribute">: $encoding); | ||
OptionalParameter<"mlir::Attribute">: $encoding, | ||
OptionalParameter<"mlir::Attribute">: $sg_map); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it also available to scattered TensorDesc (created by create_tdesc), which has slightly different semantics to Blocked TensorDesc, the one created by create_nd_tdesc? If so, is there any restrictions on it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it also available to scattered TensorDesc (created by create_tdesc)
Yes, it produces the same type, so the map can (and should) be attached to the scattered variation. It works now but I'll add a test to cover it.
If so, is there any restrictions on it?
There are no additional restrictions on the layout; it only has to match the sub-group size. The data WI field represents the minimum requirement for the distribution (e.g., to cover packing). We may want to bind it to the chunk size somehow but I don't think it is necessary at the moment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
89e4cd7
to
4ebec7a
Compare
CArg<"xegpu::MemorySpace", "xegpu::MemorySpace::Global">:$memory_space)>, | ||
CArg<"xegpu::MemorySpace", "xegpu::MemorySpace::Global">:$memory_space, | ||
CArg<"mlir::Attribute", "mlir::Attribute()">:$sg_map)>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had to slightly modify the signature since the new two conflict with default parameters (boundary check gets implicitly casted to memory space). I think it's better anyway.
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/129/builds/6929 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/177/builds/6021 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/116/builds/4508 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/117/builds/2432 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/140/builds/7970 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/138/builds/4519 Here is the relevant piece of the build log for the reference
|
… semantics (llvm#108864)" This reverts commit 3ca5d80.
…cs (llvm#108864) The PR adds an attribute (sg_map) describing the distribution of computation among work items for xegpu operations to be used in lowering passes. The map is attached to the tensor descriptor, so the constructor and the type are updated. Tests check the custom parser & printer. The attribute is optional now, so no other changes required. The complete description of the attribute can be found [here](https://github.com/intel/mlir-extensions/blob/main/docs/rfcs/XeGPU.md#xegpu-attributes-to-support-work-item-level-semantics).
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/153/builds/10744 Here is the relevant piece of the build log for the reference
|
…cs (llvm#108864) The PR adds an attribute (sg_map) describing the distribution of computation among work items for xegpu operations to be used in lowering passes. The map is attached to the tensor descriptor, so the constructor and the type are updated. Tests check the custom parser & printer. The attribute is optional now, so no other changes required. The complete description of the attribute can be found [here](https://github.com/intel/mlir-extensions/blob/main/docs/rfcs/XeGPU.md#xegpu-attributes-to-support-work-item-level-semantics).
… semantics" (llvm#110871) Reverts llvm#108864 since it breaks compilation
The PR adds an attribute (sg_map) describing the distribution of computation among work items for xegpu operations to be used in lowering passes. The map is attached to the tensor descriptor, so the constructor and the type are updated. Tests check the custom parser & printer. The attribute is optional now, so no other changes required.
The complete description of the attribute can be found here.