@@ -55,77 +55,6 @@ ScatterTensorDescAttr::get(mlir::MLIRContext *context,
55
55
return Base::get (context, scopeAttr, chunkSizeAttr);
56
56
}
57
57
58
- // ===----------------------------------------------------------------------===//
59
- // XeGPU_SGMapAttr
60
- // ===----------------------------------------------------------------------===//
61
- namespace {
62
- template <typename T, unsigned N>
63
- LogicalResult parseIntArrayField (::mlir::AsmParser &parser,
64
- llvm::SmallVector<T, N> &result,
65
- llvm::StringRef fieldName) {
66
- if (failed (parser.parseKeyword (fieldName))) {
67
- parser.emitError (parser.getCurrentLocation (),
68
- " unexpected field name. Expected " + fieldName + " ." );
69
- return failure ();
70
- }
71
-
72
- if (failed (parser.parseEqual ())) {
73
- parser.emitError (parser.getCurrentLocation (), " expected '=' sign." );
74
- return failure ();
75
- }
76
-
77
- auto elemParser = [&]() -> llvm::ParseResult {
78
- uint32_t elem = 0 ;
79
- auto res = parser.parseInteger (elem);
80
- result.push_back (elem);
81
- return res;
82
- };
83
-
84
- return parser.parseCommaSeparatedList (AsmParser::Delimiter::Square,
85
- elemParser, fieldName);
86
- }
87
- } // namespace
88
-
89
- mlir::Attribute SGMapAttr::parse (::mlir::AsmParser &parser,
90
- ::mlir::Type attrType) {
91
- if (failed (parser.parseLess ()))
92
- return {};
93
-
94
- llvm::SmallVector<uint32_t , 2 > wi_layout, wi_data;
95
- if (failed (parseIntArrayField (parser, wi_layout, " wi_layout" )))
96
- return {};
97
-
98
- if (failed (parser.parseComma ()))
99
- return {};
100
-
101
- if (failed (parseIntArrayField (parser, wi_data, " wi_data" )))
102
- return {};
103
-
104
- return SGMapAttr::getChecked (
105
- [&]() { return parser.emitError (parser.getNameLoc ()); },
106
- parser.getContext (), wi_layout, wi_data);
107
- }
108
-
109
- void SGMapAttr::print (::mlir::AsmPrinter &printer) const {
110
- printer << " <" ;
111
- printer.printKeywordOrString (" wi_layout" );
112
- printer << " = [" << getWiLayout () << " ], " ;
113
- printer.printKeywordOrString (" wi_data" );
114
- printer << " = [" << getWiData () << " ]" ;
115
- printer << " >" ;
116
- }
117
-
118
- LogicalResult
119
- SGMapAttr::verify (llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
120
- llvm::ArrayRef<uint32_t > wi_layout,
121
- llvm::ArrayRef<uint32_t > wi_data) {
122
- if (wi_layout.size () != 2 )
123
- return emitError () << " expected wi_layout of size 2" ;
124
- if (wi_data.size () != 2 )
125
- return emitError () << " expected wi_data of size 2" ;
126
- return success ();
127
- }
128
-
129
58
// ===----------------------------------------------------------------------===//
130
59
// XeGPU_TensorDescType
131
60
// ===----------------------------------------------------------------------===//
@@ -134,7 +63,6 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
134
63
llvm::SmallVector<int64_t > shape;
135
64
mlir::Type elementType;
136
65
mlir::FailureOr<mlir::Attribute> encoding;
137
- mlir::FailureOr<mlir::Attribute> sg_map;
138
66
139
67
// Parse literal '<'
140
68
if (parser.parseLess ())
@@ -153,31 +81,22 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
153
81
}
154
82
155
83
// parse optional attributes
156
- while (mlir::succeeded (parser.parseOptionalComma ())) {
157
- mlir::Attribute attr;
158
- ParseResult res = parser.parseAttribute (attr);
159
- if (mlir::succeeded (res)) {
160
- if (mlir::isa<SGMapAttr>(attr)) {
161
- sg_map = attr;
162
- continue ;
163
- }
164
- if (mlir::isa<BlockTensorDescAttr, ScatterTensorDescAttr>(attr)) {
165
- encoding = attr;
166
- continue ;
167
- }
84
+ if (mlir::succeeded (parser.parseOptionalComma ())) {
85
+ encoding = mlir::FieldParser<mlir::Attribute>::parse (parser);
86
+ if (mlir::failed (encoding)) {
87
+ parser.emitError (
88
+ parser.getCurrentLocation (),
89
+ " Failed to parse the attribute field for TensorDescType.\n " );
90
+ return {};
168
91
}
169
- parser.emitError (parser.getCurrentLocation (),
170
- " Failed to parse the attribute.\n " );
171
- return {};
172
92
}
173
93
174
94
// Parse literal '>'
175
95
if (parser.parseGreater ())
176
96
return {};
177
97
178
98
return TensorDescType::get (parser.getContext (), shape, elementType,
179
- encoding.value_or (mlir::Attribute ()),
180
- sg_map.value_or (mlir::Attribute ()));
99
+ encoding.value_or (mlir::Attribute ()));
181
100
}
182
101
183
102
void TensorDescType::print (::mlir::AsmPrinter &printer) const {
@@ -197,30 +116,25 @@ void TensorDescType::print(::mlir::AsmPrinter &printer) const {
197
116
if (auto encoding = getEncoding ())
198
117
printer << " , " << encoding;
199
118
200
- if (auto sg_map = getSgMap ())
201
- printer << " , " << sg_map;
202
-
203
119
printer << " >" ;
204
120
}
205
121
206
122
TensorDescType TensorDescType::get (llvm::ArrayRef<int64_t > shape,
207
123
mlir::Type elementType, int array_length,
208
124
bool boundary_check,
209
- MemorySpace memory_space,
210
- mlir::Attribute sg_map) {
125
+ MemorySpace memory_space) {
211
126
auto context = elementType.getContext ();
212
127
auto attr = BlockTensorDescAttr::get (context, memory_space, array_length,
213
128
boundary_check);
214
- return Base::get (context, shape, elementType, attr, sg_map );
129
+ return Base::get (context, shape, elementType, attr);
215
130
}
216
131
217
132
TensorDescType TensorDescType::get (llvm::ArrayRef<int64_t > shape,
218
133
mlir::Type elementType, int chunk_size,
219
- MemorySpace memory_space,
220
- mlir::Attribute sg_map) {
134
+ MemorySpace memory_space) {
221
135
auto context = elementType.getContext ();
222
136
auto attr = ScatterTensorDescAttr::get (context, memory_space, chunk_size);
223
- return Base::get (context, shape, elementType, attr, sg_map );
137
+ return Base::get (context, shape, elementType, attr);
224
138
}
225
139
226
140
} // namespace xegpu
0 commit comments