Skip to content

Commit 3ca5d80

Browse files
[MLIR][XeGPU] Add sg_map attribute to support Work Item level semantics (llvm#108864)
The PR adds an attribute (sg_map) describing the distribution of computation among work items for xegpu operations to be used in lowering passes. The map is attached to the tensor descriptor, so the constructor and the type are updated. Tests check the custom parser & printer. The attribute is optional now, so no other changes required. The complete description of the attribute can be found [here](https://github.com/intel/mlir-extensions/blob/main/docs/rfcs/XeGPU.md#xegpu-attributes-to-support-work-item-level-semantics).
1 parent 9b53a6e commit 3ca5d80

File tree

4 files changed

+166
-16
lines changed

4 files changed

+166
-16
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,4 +142,36 @@ def XeGPU_FenceScopeAttr:
142142
let assemblyFormat = "$value";
143143
}
144144

145+
def XeGPU_SGMapAttr : XeGPUAttr<"SGMap", "sg_map"> {
146+
let summary = [{
147+
Describes the mapping between work item (WI) and the 2D tensor specified by the tensor descriptor.
148+
}];
149+
let description = [{
150+
To distribute the XeGPU operation to work items, the tensor_desc must be specified with the sg_map
151+
attribute at the tensor description creation time.
152+
Within the `sg_map`, `wi_layout` specifies the layout of work items,
153+
describing the mapping of work items to the tensor.
154+
wi_layout[0] x wi_layout[1] must be equal to the total number of work items within a subgroup.
155+
`wi_data` specifies the minimum number of data elements assigned to each work item for a single distribution.
156+
157+
E.g., #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>
158+
In this example, the subgroup has 16 work items in wi_layout=[1, 16],
159+
each accessing 1 element as specified by wi_data=[1, 1].
160+
161+
`wi_data[0] * wi_data[1]` can be greater than 1, meaning that each work item operates on multiple elements,
162+
which is eventually lowered to "SIMT-flavor" vector, like SPIR-V vector or llvm vector, or packed to a storage data type.
163+
The multiple elements indicated by `wi_data` can only be from one dimension and must be contiguous in the memory along either dimension.
164+
}];
165+
let parameters = (ins
166+
ArrayRefParameter<"uint32_t">:$wi_layout,
167+
ArrayRefParameter<"uint32_t">:$wi_data);
168+
169+
let builders = [
170+
AttrBuilder<(ins)>
171+
];
172+
173+
let hasCustomAssemblyFormat = 1;
174+
let genVerifyDecl = 1;
175+
}
176+
145177
#endif // MLIR_DIALECT_XEGPU_IR_XEGPUATTRS_TD

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
6363
element-type ::= float-type | integer-type | index-type
6464
dim-list := (static-dim-list `x`)?
6565
static-dim-list ::= decimal-literal `x` decimal-literal
66-
attr-list = (, memory_space = value)? (, arr_len = value)? (, boundary_check = value)? (, scattered = value)?
66+
attr-list = (, memory_space = value)? (, arr_len = value)? (, boundary_check = value)? (, scattered = value)? (, sg_map `<` wi_layout = value, wi_data = value `>`)?
6767
```
6868

6969
Examples:
@@ -77,25 +77,31 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
7777

7878
// A TensorDesc with 8x16 f32 elements for a memory region in shared memory space.
7979
xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr<memory_space = slm>>
80+
81+
// A TensorDesc with a sg_map
82+
xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
8083
```
8184
}];
8285

8386
let parameters = (ins ArrayRefParameter<"int64_t">: $shape,
8487
"mlir::Type": $elementType,
85-
OptionalParameter<"mlir::Attribute">: $encoding);
88+
OptionalParameter<"mlir::Attribute">: $encoding,
89+
OptionalParameter<"mlir::Attribute">: $sg_map);
8690

8791
let builders = [
8892
TypeBuilderWithInferredContext<(ins
8993
"llvm::ArrayRef<int64_t>": $shape,
9094
"mlir::Type": $elementType,
9195
CArg<"int", "1">: $array_length,
9296
CArg<"bool", "true">: $boundary_check,
93-
CArg<"xegpu::MemorySpace", "xegpu::MemorySpace::Global">:$memory_space)>,
97+
CArg<"xegpu::MemorySpace", "xegpu::MemorySpace::Global">:$memory_space,
98+
CArg<"mlir::Attribute", "mlir::Attribute()">:$sg_map)>,
9499
TypeBuilderWithInferredContext<(ins
95100
"llvm::ArrayRef<int64_t>": $shape,
96101
"mlir::Type": $elementType,
97102
CArg<"int", "1">: $chunk_size,
98-
CArg<"xegpu::MemorySpace", "xegpu::MemorySpace::Global">:$memory_space)>
103+
CArg<"xegpu::MemorySpace", "xegpu::MemorySpace::Global">:$memory_space,
104+
CArg<"mlir::Attribute", "mlir::Attribute()">:$sg_map)>
99105
];
100106

101107
let extraClassDeclaration = [{
@@ -121,6 +127,10 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
121127
return llvm::dyn_cast_if_present<ScatterTensorDescAttr>(getEncoding());
122128
}
123129

130+
SGMapAttr getSGMapAttr() const {
131+
return llvm::dyn_cast_if_present<SGMapAttr>(getSgMap());
132+
}
133+
124134
xegpu::MemorySpace getMemorySpace() const {
125135
auto block_attr = getEncodingAsBlockTensorDescAttr();
126136
if (block_attr && block_attr.getMemorySpace())

mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp

Lines changed: 98 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,77 @@ ScatterTensorDescAttr::get(mlir::MLIRContext *context,
5555
return Base::get(context, scopeAttr, chunkSizeAttr);
5656
}
5757

58+
//===----------------------------------------------------------------------===//
59+
// XeGPU_SGMapAttr
60+
//===----------------------------------------------------------------------===//
61+
namespace {
62+
template <typename T, unsigned N>
63+
LogicalResult parseIntArrayField(::mlir::AsmParser &parser,
64+
llvm::SmallVector<T, N> &result,
65+
llvm::StringRef fieldName) {
66+
if (failed(parser.parseKeyword(fieldName))) {
67+
parser.emitError(parser.getCurrentLocation(),
68+
"unexpected field name. Expected " + fieldName + ".");
69+
return failure();
70+
}
71+
72+
if (failed(parser.parseEqual())) {
73+
parser.emitError(parser.getCurrentLocation(), "expected '=' sign.");
74+
return failure();
75+
}
76+
77+
auto elemParser = [&]() -> llvm::ParseResult {
78+
uint32_t elem = 0;
79+
auto res = parser.parseInteger(elem);
80+
result.push_back(elem);
81+
return res;
82+
};
83+
84+
return parser.parseCommaSeparatedList(AsmParser::Delimiter::Square,
85+
elemParser, fieldName);
86+
}
87+
} // namespace
88+
89+
mlir::Attribute SGMapAttr::parse(::mlir::AsmParser &parser,
90+
::mlir::Type attrType) {
91+
if (failed(parser.parseLess()))
92+
return {};
93+
94+
llvm::SmallVector<uint32_t, 2> wi_layout, wi_data;
95+
if (failed(parseIntArrayField(parser, wi_layout, "wi_layout")))
96+
return {};
97+
98+
if (failed(parser.parseComma()))
99+
return {};
100+
101+
if (failed(parseIntArrayField(parser, wi_data, "wi_data")))
102+
return {};
103+
104+
return SGMapAttr::getChecked(
105+
[&]() { return parser.emitError(parser.getNameLoc()); },
106+
parser.getContext(), wi_layout, wi_data);
107+
}
108+
109+
void SGMapAttr::print(::mlir::AsmPrinter &printer) const {
110+
printer << "<";
111+
printer.printKeywordOrString("wi_layout");
112+
printer << " = [" << getWiLayout() << "], ";
113+
printer.printKeywordOrString("wi_data");
114+
printer << " = [" << getWiData() << "]";
115+
printer << ">";
116+
}
117+
118+
LogicalResult
119+
SGMapAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
120+
llvm::ArrayRef<uint32_t> wi_layout,
121+
llvm::ArrayRef<uint32_t> wi_data) {
122+
if (wi_layout.size() != 2)
123+
return emitError() << "expected wi_layout of size 2";
124+
if (wi_data.size() != 2)
125+
return emitError() << "expected wi_data of size 2";
126+
return success();
127+
}
128+
58129
//===----------------------------------------------------------------------===//
59130
// XeGPU_TensorDescType
60131
//===----------------------------------------------------------------------===//
@@ -63,6 +134,7 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
63134
llvm::SmallVector<int64_t> shape;
64135
mlir::Type elementType;
65136
mlir::FailureOr<mlir::Attribute> encoding;
137+
mlir::FailureOr<mlir::Attribute> sg_map;
66138

67139
// Parse literal '<'
68140
if (parser.parseLess())
@@ -81,22 +153,31 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
81153
}
82154

83155
// parse optional attributes
84-
if (mlir::succeeded(parser.parseOptionalComma())) {
85-
encoding = mlir::FieldParser<mlir::Attribute>::parse(parser);
86-
if (mlir::failed(encoding)) {
87-
parser.emitError(
88-
parser.getCurrentLocation(),
89-
"Failed to parse the attribute field for TensorDescType.\n");
90-
return {};
156+
while (mlir::succeeded(parser.parseOptionalComma())) {
157+
mlir::Attribute attr;
158+
ParseResult res = parser.parseAttribute(attr);
159+
if (mlir::succeeded(res)) {
160+
if (mlir::isa<SGMapAttr>(attr)) {
161+
sg_map = attr;
162+
continue;
163+
}
164+
if (mlir::isa<BlockTensorDescAttr, ScatterTensorDescAttr>(attr)) {
165+
encoding = attr;
166+
continue;
167+
}
91168
}
169+
parser.emitError(parser.getCurrentLocation(),
170+
"Failed to parse the attribute.\n");
171+
return {};
92172
}
93173

94174
// Parse literal '>'
95175
if (parser.parseGreater())
96176
return {};
97177

98178
return TensorDescType::get(parser.getContext(), shape, elementType,
99-
encoding.value_or(mlir::Attribute()));
179+
encoding.value_or(mlir::Attribute()),
180+
sg_map.value_or(mlir::Attribute()));
100181
}
101182

102183
void TensorDescType::print(::mlir::AsmPrinter &printer) const {
@@ -116,25 +197,30 @@ void TensorDescType::print(::mlir::AsmPrinter &printer) const {
116197
if (auto encoding = getEncoding())
117198
printer << ", " << encoding;
118199

200+
if (auto sg_map = getSgMap())
201+
printer << ", " << sg_map;
202+
119203
printer << ">";
120204
}
121205

122206
TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
123207
mlir::Type elementType, int array_length,
124208
bool boundary_check,
125-
MemorySpace memory_space) {
209+
MemorySpace memory_space,
210+
mlir::Attribute sg_map) {
126211
auto context = elementType.getContext();
127212
auto attr = BlockTensorDescAttr::get(context, memory_space, array_length,
128213
boundary_check);
129-
return Base::get(context, shape, elementType, attr);
214+
return Base::get(context, shape, elementType, attr, sg_map);
130215
}
131216

132217
TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
133218
mlir::Type elementType, int chunk_size,
134-
MemorySpace memory_space) {
219+
MemorySpace memory_space,
220+
mlir::Attribute sg_map) {
135221
auto context = elementType.getContext();
136222
auto attr = ScatterTensorDescAttr::get(context, memory_space, chunk_size);
137-
return Base::get(context, shape, elementType, attr);
223+
return Base::get(context, shape, elementType, attr, sg_map);
138224
}
139225

140226
} // namespace xegpu

mlir/test/Dialect/XeGPU/XeGPUOps.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,14 @@ gpu.func @test_create_nd_tdesc_vc_1(%src: memref<24x32xf32>) {
1313
gpu.return
1414
}
1515

16+
// CHECK: gpu.func @test_create_nd_tdesc_with_sg_map(%[[arg0:.*]]: memref<24x32xf32>) {
17+
gpu.func @test_create_nd_tdesc_with_sg_map(%src: memref<24x32xf32>) {
18+
// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
19+
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
20+
!xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
21+
gpu.return
22+
}
23+
1624
// CHECK: gpu.func @test_create_nd_tdesc_vc_2(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index) {
1725
gpu.func @test_create_nd_tdesc_vc_2(%src: ui64, %w : index, %h : index, %x : index, %y : index) {
1826
//CHECK: %[[C:.*]] = arith.constant 1 : index
@@ -43,6 +51,13 @@ gpu.func @test_create_nd_tdesc_vc_5(%src: memref<2x24x32xf32, 3>) {
4351
gpu.return
4452
}
4553

54+
// CHECK: gpu.func @test_create_nd_tdesc_vc_6(%[[arg0:.*]]: memref<24x32xf32>) {
55+
gpu.func @test_create_nd_tdesc_vc_6(%src: memref<24x32xf32>) {
56+
// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x16xf32, #xegpu.block_tdesc_attr<array_length = 2 : i64>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>
57+
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x16xf32, #xegpu.block_tdesc_attr<array_length = 2>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
58+
gpu.return
59+
}
60+
4661
// CHECK: gpu.func @test_prefetch_nd_vc(%[[arg0:.*]]: memref<24x32xf16>) {
4762
gpu.func @test_prefetch_nd_vc(%src: memref<24x32xf16>) {
4863
// CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -120,6 +135,13 @@ gpu.func @test_create_tdesc_vc_1(%src: memref<?xf32, 3>) {
120135
gpu.return
121136
}
122137

138+
// CHECK: gpu.func @test_create_tdesc_vc_with_sg_map(%[[arg0:.*]]: ui64) {
139+
gpu.func @test_create_tdesc_vc_with_sg_map(%src: ui64) {
140+
//CHECK: %[[R0:.*]] = xegpu.create_tdesc %arg0 [0, 8, 16, 24] {chunk_size = 2 : i64} : ui64 -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
141+
%1 = xegpu.create_tdesc %src[0, 8, 16, 24] {chunk_size = 2} : ui64 -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
142+
gpu.return
143+
}
144+
123145
// CHECK: gpu.func @test_prefetch_vc(%[[arg0:.*]]: ui64) {
124146
gpu.func @test_prefetch_vc(%src: ui64) {
125147
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>

0 commit comments

Comments
 (0)