Skip to content

Commit 9fa55ec

Browse files
[MLIR][XeGPU] Add sg_map attribute to support Work Item level semanti… (#110876)
Bring back #108864
1 parent 9016f27 commit 9fa55ec

File tree

4 files changed

+169
-17
lines changed

4 files changed

+169
-17
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,4 +142,36 @@ def XeGPU_FenceScopeAttr:
142142
let assemblyFormat = "$value";
143143
}
144144

145+
def XeGPU_SGMapAttr : XeGPUAttr<"SGMap", "sg_map"> {
146+
let summary = [{
147+
Describes the mapping between work item (WI) and the 2D tensor specified by the tensor descriptor.
148+
}];
149+
let description = [{
150+
To distribute the XeGPU operation to work items, the tensor_desc must be specified with the sg_map
151+
attribute at the tensor description creation time.
152+
Within the `sg_map`, `wi_layout` specifies the layout of work items,
153+
describing the mapping of work items to the tensor.
154+
wi_layout[0] x wi_layout[1] must be equal to the total number of work items within a subgroup.
155+
`wi_data` specifies the minimum number of data elements assigned to each work item for a single distribution.
156+
157+
E.g., #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>
158+
In this example, the subgroup has 16 work items in wi_layout=[1, 16],
159+
each accessing 1 element as specified by wi_data=[1, 1].
160+
161+
`wi_data[0] * wi_data[1]` can be greater than 1, meaning that each work item operates on multiple elements,
162+
which is eventually lowered to "SIMT-flavor" vector, like SPIR-V vector or llvm vector, or packed to a storage data type.
163+
The multiple elements indicated by `wi_data` can only be from one dimension and must be contiguous in the memory along either dimension.
164+
}];
165+
let parameters = (ins
166+
ArrayRefParameter<"uint32_t">:$wi_layout,
167+
ArrayRefParameter<"uint32_t">:$wi_data);
168+
169+
let builders = [
170+
AttrBuilder<(ins)>
171+
];
172+
173+
let hasCustomAssemblyFormat = 1;
174+
let genVerifyDecl = 1;
175+
}
176+
145177
#endif // MLIR_DIALECT_XEGPU_IR_XEGPUATTRS_TD

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
6363
element-type ::= float-type | integer-type | index-type
6464
dim-list := (static-dim-list `x`)?
6565
static-dim-list ::= decimal-literal `x` decimal-literal
66-
attr-list = (, memory_space = value)? (, arr_len = value)? (, boundary_check = value)? (, scattered = value)?
66+
attr-list = (, memory_space = value)? (, arr_len = value)? (, boundary_check = value)? (, scattered = value)? (, sg_map `<` wi_layout = value, wi_data = value `>`)?
6767
```
6868

6969
Examples:
@@ -77,27 +77,33 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
7777

7878
// A TensorDesc with 8x16 f32 elements for a memory region in shared memory space.
7979
xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr<memory_space = slm>>
80+
81+
// A TensorDesc with a sg_map
82+
xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
8083
```
8184
}];
8285

8386
let parameters = (ins ArrayRefParameter<"int64_t">: $shape,
8487
"mlir::Type": $elementType,
85-
OptionalParameter<"mlir::Attribute">: $encoding);
88+
OptionalParameter<"mlir::Attribute">: $encoding,
89+
OptionalParameter<"mlir::Attribute">: $sg_map);
8690

8791
let builders = [
8892
TypeBuilderWithInferredContext<(ins
8993
"llvm::ArrayRef<int64_t>": $shape,
9094
"mlir::Type": $elementType,
9195
CArg<"int", "1">: $array_length,
9296
CArg<"bool", "true">: $boundary_check,
93-
CArg<"xegpu::MemorySpace", "xegpu::MemorySpace::Global">:$memory_space)>,
97+
CArg<"xegpu::MemorySpace", "xegpu::MemorySpace::Global">:$memory_space,
98+
CArg<"mlir::Attribute", "mlir::Attribute()">:$sg_map)>,
9499
TypeBuilderWithInferredContext<(ins
95100
"llvm::ArrayRef<int64_t>": $shape,
96101
"mlir::Type": $elementType,
97102
CArg<"int", "1">: $chunk_size,
98-
CArg<"xegpu::MemorySpace", "xegpu::MemorySpace::Global">:$memory_space)>
103+
CArg<"xegpu::MemorySpace", "xegpu::MemorySpace::Global">:$memory_space,
104+
CArg<"mlir::Attribute", "mlir::Attribute()">:$sg_map)>
99105
];
100-
106+
101107
let extraClassDeclaration = [{
102108
using TensorType::clone;
103109
using mlir::ShapedType::Trait<TensorDescType>::getElementTypeBitWidth;
@@ -121,6 +127,10 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
121127
return llvm::dyn_cast_if_present<ScatterTensorDescAttr>(getEncoding());
122128
}
123129

130+
SGMapAttr getSGMapAttr() const {
131+
return llvm::dyn_cast_if_present<SGMapAttr>(getSgMap());
132+
}
133+
124134
xegpu::MemorySpace getMemorySpace() const {
125135
auto block_attr = getEncodingAsBlockTensorDescAttr();
126136
if (block_attr && block_attr.getMemorySpace())

mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp

Lines changed: 98 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,77 @@ ScatterTensorDescAttr::get(mlir::MLIRContext *context,
5555
return Base::get(context, scopeAttr, chunkSizeAttr);
5656
}
5757

58+
//===----------------------------------------------------------------------===//
59+
// XeGPU_SGMapAttr
60+
//===----------------------------------------------------------------------===//
61+
namespace {
62+
template <typename T, unsigned N>
63+
LogicalResult parseIntArrayField(::mlir::AsmParser &parser,
64+
llvm::SmallVector<T, N> &result,
65+
llvm::StringRef fieldName) {
66+
if (failed(parser.parseKeyword(fieldName))) {
67+
parser.emitError(parser.getCurrentLocation(),
68+
"unexpected field name. Expected " + fieldName + ".");
69+
return failure();
70+
}
71+
72+
if (failed(parser.parseEqual())) {
73+
parser.emitError(parser.getCurrentLocation(), "expected '=' sign.");
74+
return failure();
75+
}
76+
77+
auto elemParser = [&]() -> llvm::ParseResult {
78+
uint32_t elem = 0;
79+
auto res = parser.parseInteger(elem);
80+
result.push_back(elem);
81+
return res;
82+
};
83+
84+
return parser.parseCommaSeparatedList(AsmParser::Delimiter::Square,
85+
elemParser, fieldName);
86+
}
87+
} // namespace
88+
89+
mlir::Attribute SGMapAttr::parse(::mlir::AsmParser &parser,
90+
::mlir::Type attrType) {
91+
if (failed(parser.parseLess()))
92+
return {};
93+
94+
llvm::SmallVector<uint32_t, 2> wi_layout, wi_data;
95+
if (failed(parseIntArrayField(parser, wi_layout, "wi_layout")))
96+
return {};
97+
98+
if (failed(parser.parseComma()))
99+
return {};
100+
101+
if (failed(parseIntArrayField(parser, wi_data, "wi_data")))
102+
return {};
103+
104+
return SGMapAttr::getChecked(
105+
[&]() { return parser.emitError(parser.getNameLoc()); },
106+
parser.getContext(), wi_layout, wi_data);
107+
}
108+
109+
void SGMapAttr::print(::mlir::AsmPrinter &printer) const {
110+
printer << "<";
111+
printer.printKeywordOrString("wi_layout");
112+
printer << " = [" << getWiLayout() << "], ";
113+
printer.printKeywordOrString("wi_data");
114+
printer << " = [" << getWiData() << "]";
115+
printer << ">";
116+
}
117+
118+
LogicalResult
119+
SGMapAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
120+
llvm::ArrayRef<uint32_t> wi_layout,
121+
llvm::ArrayRef<uint32_t> wi_data) {
122+
if (wi_layout.size() != 2)
123+
return emitError() << "expected wi_layout of size 2";
124+
if (wi_data.size() != 2)
125+
return emitError() << "expected wi_data of size 2";
126+
return success();
127+
}
128+
58129
//===----------------------------------------------------------------------===//
59130
// XeGPU_TensorDescType
60131
//===----------------------------------------------------------------------===//
@@ -63,6 +134,7 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
63134
llvm::SmallVector<int64_t> shape;
64135
mlir::Type elementType;
65136
mlir::FailureOr<mlir::Attribute> encoding;
137+
mlir::FailureOr<mlir::Attribute> sg_map;
66138

67139
// Parse literal '<'
68140
if (parser.parseLess())
@@ -81,22 +153,31 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
81153
}
82154

83155
// parse optional attributes
84-
if (mlir::succeeded(parser.parseOptionalComma())) {
85-
encoding = mlir::FieldParser<mlir::Attribute>::parse(parser);
86-
if (mlir::failed(encoding)) {
87-
parser.emitError(
88-
parser.getCurrentLocation(),
89-
"Failed to parse the attribute field for TensorDescType.\n");
90-
return {};
156+
while (mlir::succeeded(parser.parseOptionalComma())) {
157+
mlir::Attribute attr;
158+
ParseResult res = parser.parseAttribute(attr);
159+
if (mlir::succeeded(res)) {
160+
if (mlir::isa<SGMapAttr>(attr)) {
161+
sg_map = attr;
162+
continue;
163+
}
164+
if (mlir::isa<BlockTensorDescAttr, ScatterTensorDescAttr>(attr)) {
165+
encoding = attr;
166+
continue;
167+
}
91168
}
169+
parser.emitError(parser.getCurrentLocation(),
170+
"Failed to parse the attribute.\n");
171+
return {};
92172
}
93173

94174
// Parse literal '>'
95175
if (parser.parseGreater())
96176
return {};
97177

98178
return TensorDescType::get(parser.getContext(), shape, elementType,
99-
encoding.value_or(mlir::Attribute()));
179+
encoding.value_or(mlir::Attribute()),
180+
sg_map.value_or(mlir::Attribute()));
100181
}
101182

102183
void TensorDescType::print(::mlir::AsmPrinter &printer) const {
@@ -116,25 +197,30 @@ void TensorDescType::print(::mlir::AsmPrinter &printer) const {
116197
if (auto encoding = getEncoding())
117198
printer << ", " << encoding;
118199

200+
if (auto sg_map = getSgMap())
201+
printer << ", " << sg_map;
202+
119203
printer << ">";
120204
}
121205

122206
TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
123207
mlir::Type elementType, int array_length,
124208
bool boundary_check,
125-
MemorySpace memory_space) {
209+
MemorySpace memory_space,
210+
mlir::Attribute sg_map) {
126211
auto context = elementType.getContext();
127212
auto attr = BlockTensorDescAttr::get(context, memory_space, array_length,
128213
boundary_check);
129-
return Base::get(context, shape, elementType, attr);
214+
return Base::get(context, shape, elementType, attr, sg_map);
130215
}
131216

132217
TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
133218
mlir::Type elementType, int chunk_size,
134-
MemorySpace memory_space) {
219+
MemorySpace memory_space,
220+
mlir::Attribute sg_map) {
135221
auto context = elementType.getContext();
136222
auto attr = ScatterTensorDescAttr::get(context, memory_space, chunk_size);
137-
return Base::get(context, shape, elementType, attr);
223+
return Base::get(context, shape, elementType, attr, sg_map);
138224
}
139225

140226
} // namespace xegpu

mlir/test/Dialect/XeGPU/XeGPUOps.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,14 @@ gpu.func @test_create_nd_tdesc_vc_1(%src: memref<24x32xf32>) {
1313
gpu.return
1414
}
1515

16+
// CHECK: gpu.func @test_create_nd_tdesc_with_sg_map(%[[arg0:.*]]: memref<24x32xf32>) {
17+
gpu.func @test_create_nd_tdesc_with_sg_map(%src: memref<24x32xf32>) {
18+
// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
19+
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
20+
!xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
21+
gpu.return
22+
}
23+
1624
// CHECK: gpu.func @test_create_nd_tdesc_vc_2(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index) {
1725
gpu.func @test_create_nd_tdesc_vc_2(%src: ui64, %w : index, %h : index, %x : index, %y : index) {
1826
//CHECK: %[[C:.*]] = arith.constant 1 : index
@@ -43,6 +51,13 @@ gpu.func @test_create_nd_tdesc_vc_5(%src: memref<2x24x32xf32, 3>) {
4351
gpu.return
4452
}
4553

54+
// CHECK: gpu.func @test_create_nd_tdesc_vc_6(%[[arg0:.*]]: memref<24x32xf32>) {
55+
gpu.func @test_create_nd_tdesc_vc_6(%src: memref<24x32xf32>) {
56+
// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x16xf32, #xegpu.block_tdesc_attr<array_length = 2 : i64>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>
57+
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x16xf32, #xegpu.block_tdesc_attr<array_length = 2>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
58+
gpu.return
59+
}
60+
4661
// CHECK: gpu.func @test_prefetch_nd_vc(%[[arg0:.*]]: memref<24x32xf16>) {
4762
gpu.func @test_prefetch_nd_vc(%src: memref<24x32xf16>) {
4863
// CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -120,6 +135,15 @@ gpu.func @test_create_tdesc_vc_1(%src: memref<?xf32, 3>) {
120135
gpu.return
121136
}
122137

138+
// CHECK: gpu.func @test_create_tdesc_vc_with_sg_map(%[[arg0:.*]]: ui64) {
139+
gpu.func @test_create_tdesc_vc_with_sg_map(%src: ui64) {
140+
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
141+
%0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
142+
//CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
143+
%1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
144+
gpu.return
145+
}
146+
123147
// CHECK: gpu.func @test_prefetch_vc(%[[arg0:.*]]: ui64) {
124148
gpu.func @test_prefetch_vc(%src: ui64) {
125149
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>

0 commit comments

Comments
 (0)