Skip to content

Commit 3268b52

Browse files
committed
Revert "[MLIR][XeGPU] Add sg_map attribute to support Work Item level semantics (llvm#108864)"
This reverts commit 3ca5d80.
1 parent 1c01bcb commit 3268b52

File tree

4 files changed

+16
-166
lines changed

4 files changed

+16
-166
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -142,36 +142,4 @@ def XeGPU_FenceScopeAttr:
142142
let assemblyFormat = "$value";
143143
}
144144

145-
def XeGPU_SGMapAttr : XeGPUAttr<"SGMap", "sg_map"> {
146-
let summary = [{
147-
Describes the mapping between work item (WI) and the 2D tensor specified by the tensor descriptor.
148-
}];
149-
let description = [{
150-
To distribute the XeGPU operation to work items, the tensor_desc must be specified with the sg_map
151-
attribute at the tensor description creation time.
152-
Within the `sg_map`, `wi_layout` specifies the layout of work items,
153-
describing the mapping of work items to the tensor.
154-
wi_layout[0] x wi_layout[1] must be equal to the total number of work items within a subgroup.
155-
`wi_data` specifies the minimum number of data elements assigned to each work item for a single distribution.
156-
157-
E.g., #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>
158-
In this example, the subgroup has 16 work items in wi_layout=[1, 16],
159-
each accessing 1 element as specified by wi_data=[1, 1].
160-
161-
`wi_data[0] * wi_data[1]` can be greater than 1, meaning that each work item operates on multiple elements,
162-
which is eventually lowered to "SIMT-flavor" vector, like SPIR-V vector or llvm vector, or packed to a storage data type.
163-
The multiple elements indicated by `wi_data` can only be from one dimension and must be contiguous in the memory along either dimension.
164-
}];
165-
let parameters = (ins
166-
ArrayRefParameter<"uint32_t">:$wi_layout,
167-
ArrayRefParameter<"uint32_t">:$wi_data);
168-
169-
let builders = [
170-
AttrBuilder<(ins)>
171-
];
172-
173-
let hasCustomAssemblyFormat = 1;
174-
let genVerifyDecl = 1;
175-
}
176-
177145
#endif // MLIR_DIALECT_XEGPU_IR_XEGPUATTRS_TD

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
6363
element-type ::= float-type | integer-type | index-type
6464
dim-list := (static-dim-list `x`)?
6565
static-dim-list ::= decimal-literal `x` decimal-literal
66-
attr-list = (, memory_space = value)? (, arr_len = value)? (, boundary_check = value)? (, scattered = value)? (, sg_map `<` wi_layout = value, wi_data = value `>`)?
66+
attr-list = (, memory_space = value)? (, arr_len = value)? (, boundary_check = value)? (, scattered = value)?
6767
```
6868

6969
Examples:
@@ -77,31 +77,25 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
7777

7878
// A TensorDesc with 8x16 f32 elements for a memory region in shared memory space.
7979
xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr<memory_space = slm>>
80-
81-
// A TensorDesc with a sg_map
82-
xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
8380
```
8481
}];
8582

8683
let parameters = (ins ArrayRefParameter<"int64_t">: $shape,
8784
"mlir::Type": $elementType,
88-
OptionalParameter<"mlir::Attribute">: $encoding,
89-
OptionalParameter<"mlir::Attribute">: $sg_map);
85+
OptionalParameter<"mlir::Attribute">: $encoding);
9086

9187
let builders = [
9288
TypeBuilderWithInferredContext<(ins
9389
"llvm::ArrayRef<int64_t>": $shape,
9490
"mlir::Type": $elementType,
9591
CArg<"int", "1">: $array_length,
9692
CArg<"bool", "true">: $boundary_check,
97-
CArg<"xegpu::MemorySpace", "xegpu::MemorySpace::Global">:$memory_space,
98-
CArg<"mlir::Attribute", "mlir::Attribute()">:$sg_map)>,
93+
CArg<"xegpu::MemorySpace", "xegpu::MemorySpace::Global">:$memory_space)>,
9994
TypeBuilderWithInferredContext<(ins
10095
"llvm::ArrayRef<int64_t>": $shape,
10196
"mlir::Type": $elementType,
10297
CArg<"int", "1">: $chunk_size,
103-
CArg<"xegpu::MemorySpace", "xegpu::MemorySpace::Global">:$memory_space,
104-
CArg<"mlir::Attribute", "mlir::Attribute()">:$sg_map)>
98+
CArg<"xegpu::MemorySpace", "xegpu::MemorySpace::Global">:$memory_space)>
10599
];
106100

107101
let extraClassDeclaration = [{
@@ -127,10 +121,6 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
127121
return llvm::dyn_cast_if_present<ScatterTensorDescAttr>(getEncoding());
128122
}
129123

130-
SGMapAttr getSGMapAttr() const {
131-
return llvm::dyn_cast_if_present<SGMapAttr>(getSgMap());
132-
}
133-
134124
xegpu::MemorySpace getMemorySpace() const {
135125
auto block_attr = getEncodingAsBlockTensorDescAttr();
136126
if (block_attr && block_attr.getMemorySpace())

mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp

Lines changed: 12 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -55,77 +55,6 @@ ScatterTensorDescAttr::get(mlir::MLIRContext *context,
5555
return Base::get(context, scopeAttr, chunkSizeAttr);
5656
}
5757

58-
//===----------------------------------------------------------------------===//
59-
// XeGPU_SGMapAttr
60-
//===----------------------------------------------------------------------===//
61-
namespace {
62-
template <typename T, unsigned N>
63-
LogicalResult parseIntArrayField(::mlir::AsmParser &parser,
64-
llvm::SmallVector<T, N> &result,
65-
llvm::StringRef fieldName) {
66-
if (failed(parser.parseKeyword(fieldName))) {
67-
parser.emitError(parser.getCurrentLocation(),
68-
"unexpected field name. Expected " + fieldName + ".");
69-
return failure();
70-
}
71-
72-
if (failed(parser.parseEqual())) {
73-
parser.emitError(parser.getCurrentLocation(), "expected '=' sign.");
74-
return failure();
75-
}
76-
77-
auto elemParser = [&]() -> llvm::ParseResult {
78-
uint32_t elem = 0;
79-
auto res = parser.parseInteger(elem);
80-
result.push_back(elem);
81-
return res;
82-
};
83-
84-
return parser.parseCommaSeparatedList(AsmParser::Delimiter::Square,
85-
elemParser, fieldName);
86-
}
87-
} // namespace
88-
89-
mlir::Attribute SGMapAttr::parse(::mlir::AsmParser &parser,
90-
::mlir::Type attrType) {
91-
if (failed(parser.parseLess()))
92-
return {};
93-
94-
llvm::SmallVector<uint32_t, 2> wi_layout, wi_data;
95-
if (failed(parseIntArrayField(parser, wi_layout, "wi_layout")))
96-
return {};
97-
98-
if (failed(parser.parseComma()))
99-
return {};
100-
101-
if (failed(parseIntArrayField(parser, wi_data, "wi_data")))
102-
return {};
103-
104-
return SGMapAttr::getChecked(
105-
[&]() { return parser.emitError(parser.getNameLoc()); },
106-
parser.getContext(), wi_layout, wi_data);
107-
}
108-
109-
void SGMapAttr::print(::mlir::AsmPrinter &printer) const {
110-
printer << "<";
111-
printer.printKeywordOrString("wi_layout");
112-
printer << " = [" << getWiLayout() << "], ";
113-
printer.printKeywordOrString("wi_data");
114-
printer << " = [" << getWiData() << "]";
115-
printer << ">";
116-
}
117-
118-
LogicalResult
119-
SGMapAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
120-
llvm::ArrayRef<uint32_t> wi_layout,
121-
llvm::ArrayRef<uint32_t> wi_data) {
122-
if (wi_layout.size() != 2)
123-
return emitError() << "expected wi_layout of size 2";
124-
if (wi_data.size() != 2)
125-
return emitError() << "expected wi_data of size 2";
126-
return success();
127-
}
128-
12958
//===----------------------------------------------------------------------===//
13059
// XeGPU_TensorDescType
13160
//===----------------------------------------------------------------------===//
@@ -134,7 +63,6 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
13463
llvm::SmallVector<int64_t> shape;
13564
mlir::Type elementType;
13665
mlir::FailureOr<mlir::Attribute> encoding;
137-
mlir::FailureOr<mlir::Attribute> sg_map;
13866

13967
// Parse literal '<'
14068
if (parser.parseLess())
@@ -153,31 +81,22 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
15381
}
15482

15583
// parse optional attributes
156-
while (mlir::succeeded(parser.parseOptionalComma())) {
157-
mlir::Attribute attr;
158-
ParseResult res = parser.parseAttribute(attr);
159-
if (mlir::succeeded(res)) {
160-
if (mlir::isa<SGMapAttr>(attr)) {
161-
sg_map = attr;
162-
continue;
163-
}
164-
if (mlir::isa<BlockTensorDescAttr, ScatterTensorDescAttr>(attr)) {
165-
encoding = attr;
166-
continue;
167-
}
84+
if (mlir::succeeded(parser.parseOptionalComma())) {
85+
encoding = mlir::FieldParser<mlir::Attribute>::parse(parser);
86+
if (mlir::failed(encoding)) {
87+
parser.emitError(
88+
parser.getCurrentLocation(),
89+
"Failed to parse the attribute field for TensorDescType.\n");
90+
return {};
16891
}
169-
parser.emitError(parser.getCurrentLocation(),
170-
"Failed to parse the attribute.\n");
171-
return {};
17292
}
17393

17494
// Parse literal '>'
17595
if (parser.parseGreater())
17696
return {};
17797

17898
return TensorDescType::get(parser.getContext(), shape, elementType,
179-
encoding.value_or(mlir::Attribute()),
180-
sg_map.value_or(mlir::Attribute()));
99+
encoding.value_or(mlir::Attribute()));
181100
}
182101

183102
void TensorDescType::print(::mlir::AsmPrinter &printer) const {
@@ -197,30 +116,25 @@ void TensorDescType::print(::mlir::AsmPrinter &printer) const {
197116
if (auto encoding = getEncoding())
198117
printer << ", " << encoding;
199118

200-
if (auto sg_map = getSgMap())
201-
printer << ", " << sg_map;
202-
203119
printer << ">";
204120
}
205121

206122
TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
207123
mlir::Type elementType, int array_length,
208124
bool boundary_check,
209-
MemorySpace memory_space,
210-
mlir::Attribute sg_map) {
125+
MemorySpace memory_space) {
211126
auto context = elementType.getContext();
212127
auto attr = BlockTensorDescAttr::get(context, memory_space, array_length,
213128
boundary_check);
214-
return Base::get(context, shape, elementType, attr, sg_map);
129+
return Base::get(context, shape, elementType, attr);
215130
}
216131

217132
TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
218133
mlir::Type elementType, int chunk_size,
219-
MemorySpace memory_space,
220-
mlir::Attribute sg_map) {
134+
MemorySpace memory_space) {
221135
auto context = elementType.getContext();
222136
auto attr = ScatterTensorDescAttr::get(context, memory_space, chunk_size);
223-
return Base::get(context, shape, elementType, attr, sg_map);
137+
return Base::get(context, shape, elementType, attr);
224138
}
225139

226140
} // namespace xegpu

mlir/test/Dialect/XeGPU/XeGPUOps.mlir

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,6 @@ gpu.func @test_create_nd_tdesc_vc_1(%src: memref<24x32xf32>) {
1313
gpu.return
1414
}
1515

16-
// CHECK: gpu.func @test_create_nd_tdesc_with_sg_map(%[[arg0:.*]]: memref<24x32xf32>) {
17-
gpu.func @test_create_nd_tdesc_with_sg_map(%src: memref<24x32xf32>) {
18-
// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
19-
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
20-
!xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
21-
gpu.return
22-
}
23-
2416
// CHECK: gpu.func @test_create_nd_tdesc_vc_2(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index) {
2517
gpu.func @test_create_nd_tdesc_vc_2(%src: ui64, %w : index, %h : index, %x : index, %y : index) {
2618
//CHECK: %[[C:.*]] = arith.constant 1 : index
@@ -51,13 +43,6 @@ gpu.func @test_create_nd_tdesc_vc_5(%src: memref<2x24x32xf32, 3>) {
5143
gpu.return
5244
}
5345

54-
// CHECK: gpu.func @test_create_nd_tdesc_vc_6(%[[arg0:.*]]: memref<24x32xf32>) {
55-
gpu.func @test_create_nd_tdesc_vc_6(%src: memref<24x32xf32>) {
56-
// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x16xf32, #xegpu.block_tdesc_attr<array_length = 2 : i64>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>
57-
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x16xf32, #xegpu.block_tdesc_attr<array_length = 2>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
58-
gpu.return
59-
}
60-
6146
// CHECK: gpu.func @test_prefetch_nd_vc(%[[arg0:.*]]: memref<24x32xf16>) {
6247
gpu.func @test_prefetch_nd_vc(%src: memref<24x32xf16>) {
6348
// CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -135,13 +120,6 @@ gpu.func @test_create_tdesc_vc_1(%src: memref<?xf32, 3>) {
135120
gpu.return
136121
}
137122

138-
// CHECK: gpu.func @test_create_tdesc_vc_with_sg_map(%[[arg0:.*]]: ui64) {
139-
gpu.func @test_create_tdesc_vc_with_sg_map(%src: ui64) {
140-
//CHECK: %[[R0:.*]] = xegpu.create_tdesc %arg0 [0, 8, 16, 24] {chunk_size = 2 : i64} : ui64 -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
141-
%1 = xegpu.create_tdesc %src[0, 8, 16, 24] {chunk_size = 2} : ui64 -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
142-
gpu.return
143-
}
144-
145123
// CHECK: gpu.func @test_prefetch_vc(%[[arg0:.*]]: ui64) {
146124
gpu.func @test_prefetch_vc(%src: ui64) {
147125
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>

0 commit comments

Comments
 (0)