@@ -55,6 +55,77 @@ ScatterTensorDescAttr::get(mlir::MLIRContext *context,
55
55
return Base::get (context, scopeAttr, chunkSizeAttr);
56
56
}
57
57
58
+ // ===----------------------------------------------------------------------===//
59
+ // XeGPU_SGMapAttr
60
+ // ===----------------------------------------------------------------------===//
61
+ namespace {
62
+ template <typename T, unsigned N>
63
+ LogicalResult parseIntArrayField (::mlir::AsmParser &parser,
64
+ llvm::SmallVector<T, N> &result,
65
+ llvm::StringRef fieldName) {
66
+ if (failed (parser.parseKeyword (fieldName))) {
67
+ parser.emitError (parser.getCurrentLocation (),
68
+ " unexpected field name. Expected " + fieldName + " ." );
69
+ return failure ();
70
+ }
71
+
72
+ if (failed (parser.parseEqual ())) {
73
+ parser.emitError (parser.getCurrentLocation (), " expected '=' sign." );
74
+ return failure ();
75
+ }
76
+
77
+ auto elemParser = [&]() -> llvm::ParseResult {
78
+ uint32_t elem = 0 ;
79
+ auto res = parser.parseInteger (elem);
80
+ result.push_back (elem);
81
+ return res;
82
+ };
83
+
84
+ return parser.parseCommaSeparatedList (AsmParser::Delimiter::Square,
85
+ elemParser, fieldName);
86
+ }
87
+ } // namespace
88
+
89
+ mlir::Attribute SGMapAttr::parse (::mlir::AsmParser &parser,
90
+ ::mlir::Type attrType) {
91
+ if (failed (parser.parseLess ()))
92
+ return {};
93
+
94
+ llvm::SmallVector<uint32_t , 2 > wi_layout, wi_data;
95
+ if (failed (parseIntArrayField (parser, wi_layout, " wi_layout" )))
96
+ return {};
97
+
98
+ if (failed (parser.parseComma ()))
99
+ return {};
100
+
101
+ if (failed (parseIntArrayField (parser, wi_data, " wi_data" )))
102
+ return {};
103
+
104
+ return SGMapAttr::getChecked (
105
+ [&]() { return parser.emitError (parser.getNameLoc ()); },
106
+ parser.getContext (), wi_layout, wi_data);
107
+ }
108
+
109
+ void SGMapAttr::print (::mlir::AsmPrinter &printer) const {
110
+ printer << " <" ;
111
+ printer.printKeywordOrString (" wi_layout" );
112
+ printer << " = [" << getWiLayout () << " ], " ;
113
+ printer.printKeywordOrString (" wi_data" );
114
+ printer << " = [" << getWiData () << " ]" ;
115
+ printer << " >" ;
116
+ }
117
+
118
+ LogicalResult
119
+ SGMapAttr::verify (llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
120
+ llvm::ArrayRef<uint32_t > wi_layout,
121
+ llvm::ArrayRef<uint32_t > wi_data) {
122
+ if (wi_layout.size () != 2 )
123
+ return emitError () << " expected wi_layout of size 2" ;
124
+ if (wi_data.size () != 2 )
125
+ return emitError () << " expected wi_data of size 2" ;
126
+ return success ();
127
+ }
128
+
58
129
// ===----------------------------------------------------------------------===//
59
130
// XeGPU_TensorDescType
60
131
// ===----------------------------------------------------------------------===//
@@ -63,6 +134,7 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
63
134
llvm::SmallVector<int64_t > shape;
64
135
mlir::Type elementType;
65
136
mlir::FailureOr<mlir::Attribute> encoding;
137
+ mlir::FailureOr<mlir::Attribute> sg_map;
66
138
67
139
// Parse literal '<'
68
140
if (parser.parseLess ())
@@ -81,22 +153,31 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
81
153
}
82
154
83
155
// parse optional attributes
84
- if (mlir::succeeded (parser.parseOptionalComma ())) {
85
- encoding = mlir::FieldParser<mlir::Attribute>::parse (parser);
86
- if (mlir::failed (encoding)) {
87
- parser.emitError (
88
- parser.getCurrentLocation (),
89
- " Failed to parse the attribute field for TensorDescType.\n " );
90
- return {};
156
+ while (mlir::succeeded (parser.parseOptionalComma ())) {
157
+ mlir::Attribute attr;
158
+ ParseResult res = parser.parseAttribute (attr);
159
+ if (mlir::succeeded (res)) {
160
+ if (mlir::isa<SGMapAttr>(attr)) {
161
+ sg_map = attr;
162
+ continue ;
163
+ }
164
+ if (mlir::isa<BlockTensorDescAttr, ScatterTensorDescAttr>(attr)) {
165
+ encoding = attr;
166
+ continue ;
167
+ }
91
168
}
169
+ parser.emitError (parser.getCurrentLocation (),
170
+ " Failed to parse the attribute.\n " );
171
+ return {};
92
172
}
93
173
94
174
// Parse literal '>'
95
175
if (parser.parseGreater ())
96
176
return {};
97
177
98
178
return TensorDescType::get (parser.getContext (), shape, elementType,
99
- encoding.value_or (mlir::Attribute ()));
179
+ encoding.value_or (mlir::Attribute ()),
180
+ sg_map.value_or (mlir::Attribute ()));
100
181
}
101
182
102
183
void TensorDescType::print (::mlir::AsmPrinter &printer) const {
@@ -116,25 +197,30 @@ void TensorDescType::print(::mlir::AsmPrinter &printer) const {
116
197
if (auto encoding = getEncoding ())
117
198
printer << " , " << encoding;
118
199
200
+ if (auto sg_map = getSgMap ())
201
+ printer << " , " << sg_map;
202
+
119
203
printer << " >" ;
120
204
}
121
205
122
206
TensorDescType TensorDescType::get (llvm::ArrayRef<int64_t > shape,
123
207
mlir::Type elementType, int array_length,
124
208
bool boundary_check,
125
- MemorySpace memory_space) {
209
+ MemorySpace memory_space,
210
+ mlir::Attribute sg_map) {
126
211
auto context = elementType.getContext ();
127
212
auto attr = BlockTensorDescAttr::get (context, memory_space, array_length,
128
213
boundary_check);
129
- return Base::get (context, shape, elementType, attr);
214
+ return Base::get (context, shape, elementType, attr, sg_map );
130
215
}
131
216
132
217
TensorDescType TensorDescType::get (llvm::ArrayRef<int64_t > shape,
133
218
mlir::Type elementType, int chunk_size,
134
- MemorySpace memory_space) {
219
+ MemorySpace memory_space,
220
+ mlir::Attribute sg_map) {
135
221
auto context = elementType.getContext ();
136
222
auto attr = ScatterTensorDescAttr::get (context, memory_space, chunk_size);
137
- return Base::get (context, shape, elementType, attr);
223
+ return Base::get (context, shape, elementType, attr, sg_map );
138
224
}
139
225
140
226
} // namespace xegpu
0 commit comments