@@ -175,9 +175,10 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
175
175
if (parser.parseGreater ())
176
176
return {};
177
177
178
- return TensorDescType::get (parser.getContext (), shape, elementType,
179
- encoding.value_or (mlir::Attribute ()),
180
- sg_map.value_or (mlir::Attribute ()));
178
+ return TensorDescType::getChecked (
179
+ [&]() { return parser.emitError (parser.getNameLoc ()); },
180
+ parser.getContext (), shape, elementType,
181
+ encoding.value_or (mlir::Attribute ()), sg_map.value_or (mlir::Attribute ()));
181
182
}
182
183
183
184
void TensorDescType::print (::mlir::AsmPrinter &printer) const {
@@ -223,6 +224,55 @@ TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
223
224
return Base::get (context, shape, elementType, attr, sg_map);
224
225
}
225
226
227
+ LogicalResult TensorDescType::verify (
228
+ llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
229
+ llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
230
+ mlir::Attribute encoding, mlir::Attribute sg_map) {
231
+ size_t rank = shape.size ();
232
+ if (rank > 2 )
233
+ return emitError () << " desc shape rank exceeds 2" ;
234
+
235
+ if (auto sgMapAttr = llvm::dyn_cast_if_present<SGMapAttr>(sg_map)) {
236
+ ArrayRef<uint32_t > wiLayout = sgMapAttr.getWiLayout ();
237
+ ArrayRef<uint32_t > wiData = sgMapAttr.getWiData ();
238
+
239
+ if (rank == 1 ) {
240
+ if (wiLayout[0 ] != 1 || wiData[0 ] != 1 )
241
+ return emitError () << " outer layout and data mapping must be 1 "
242
+ " for 1D tensor" ;
243
+ }
244
+
245
+ // For 1D tensor, pad the shape with an outer unit dimension to allow common
246
+ // validation logic.
247
+ SmallVector<int64_t > tensorShape (shape.begin (), shape.end ());
248
+ if (rank == 1 )
249
+ tensorShape = {1 , tensorShape.back ()};
250
+
251
+ size_t dims = tensorShape.size ();
252
+ for (size_t i = 0 ; i < dims; ++i) {
253
+ uint32_t numElemPerWi = wiLayout[i] * wiData[i];
254
+ if (tensorShape[i] < numElemPerWi || tensorShape[i] % numElemPerWi != 0 )
255
+ return emitError () << " cannot map " << tensorShape[i]
256
+ << " elements into " << wiLayout[i] << " by "
257
+ << wiData[i] << " tiles" ;
258
+ }
259
+
260
+ if (llvm::isa_and_nonnull<ScatterTensorDescAttr>(encoding)) {
261
+ auto scatterAttr = llvm::dyn_cast<ScatterTensorDescAttr>(encoding);
262
+ if (wiData[0 ] != 1 )
263
+ return emitError ()
264
+ << " cannot map over non-contiguous scattered elements" ;
265
+
266
+ unsigned chunkSize = scatterAttr.getChunkSize ().getInt ();
267
+ if (wiData[1 ] > chunkSize)
268
+ return emitError ()
269
+ << " too few contiguous elements for work item mapping" ;
270
+ }
271
+ }
272
+
273
+ return success ();
274
+ }
275
+
226
276
} // namespace xegpu
227
277
} // namespace mlir
228
278
0 commit comments