@@ -175,9 +175,10 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
175
175
if (parser.parseGreater ())
176
176
return {};
177
177
178
- return TensorDescType::get (parser.getContext (), shape, elementType,
179
- encoding.value_or (mlir::Attribute ()),
180
- sg_map.value_or (mlir::Attribute ()));
178
+ return TensorDescType::getChecked (
179
+ [&]() { return parser.emitError (parser.getNameLoc ()); },
180
+ parser.getContext (), shape, elementType,
181
+ encoding.value_or (mlir::Attribute ()), sg_map.value_or (mlir::Attribute ()));
181
182
}
182
183
183
184
void TensorDescType::print (::mlir::AsmPrinter &printer) const {
@@ -223,6 +224,81 @@ TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
223
224
return Base::get (context, shape, elementType, attr, sg_map);
224
225
}
225
226
227
+ LogicalResult TensorDescType::verify (
228
+ llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
229
+ llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
230
+ mlir::Attribute encoding, mlir::Attribute sg_map) {
231
+ size_t rank = shape.size ();
232
+ if (rank != 1 && rank != 2 )
233
+ return emitError () << " expected 1D or 2D tensor" ;
234
+
235
+ auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(encoding);
236
+ if (scatterAttr) {
237
+ // Expected tensor ranks for scattered data:
238
+ // - 1D tensor for fully non-contiguous elements (chunk size == 1)
239
+ // - 2D tensor for scattered blocks (chunk size > 1)
240
+ IntegerAttr chunkAttr = scatterAttr.getChunkSize ();
241
+ unsigned chunkSize = chunkAttr ? chunkAttr.getInt () : 1 ;
242
+ if (rank == 1 && chunkSize != 1 )
243
+ return emitError () << " expected non-contiguous elements for 1D tensor" ;
244
+ if (rank == 2 && chunkSize < 2 )
245
+ return emitError () << " expected chunk blocks for 2D tensor" ;
246
+ }
247
+
248
+ if (auto blockAttr =
249
+ mlir::dyn_cast_if_present<BlockTensorDescAttr>(encoding)) {
250
+ MemorySpaceAttr memorySpaceAttr = blockAttr.getMemorySpace ();
251
+ if (rank == 2 && memorySpaceAttr &&
252
+ memorySpaceAttr.getValue () == MemorySpace::SLM)
253
+ return emitError () << " SLM is not supported for 2D block tensor" ;
254
+ }
255
+
256
+ if (auto sgMapAttr = llvm::dyn_cast_if_present<SGMapAttr>(sg_map)) {
257
+ ArrayRef<uint32_t > wiLayout = sgMapAttr.getWiLayout ();
258
+ ArrayRef<uint32_t > wiData = sgMapAttr.getWiData ();
259
+
260
+ if (rank == 1 ) {
261
+ if (wiLayout[0 ] != 1 || wiData[0 ] != 1 )
262
+ return emitError ()
263
+ << " outer layout distribution and data mapping must be 1 "
264
+ " for 1D tensor" ;
265
+ }
266
+
267
+ if (scatterAttr) {
268
+ // Validate subgroup mapping rules for scattered tensors.
269
+ // A work-item's slice of the tensor with shape [sg_size] or
270
+ // [sg_size, chunk_size] will be [1] or [1, chunks_size] respectively,
271
+ // the mapping should reflect that.
272
+ if (wiData[0 ] != 1 )
273
+ return emitError ()
274
+ << " cannot map over non-contiguous scattered row elements" ;
275
+
276
+ IntegerAttr chunkAttr = scatterAttr.getChunkSize ();
277
+ unsigned chunkSize = chunkAttr ? chunkAttr.getInt () : 1 ;
278
+ if (wiData[1 ] != chunkSize)
279
+ return emitError () << " work item data mapping must match the number of "
280
+ " contiguous elements" ;
281
+ }
282
+
283
+ // For 1D tensor, pad the shape with an outer unit dimension to allow common
284
+ // validation logic.
285
+ SmallVector<int64_t > tensorShape (shape.begin (), shape.end ());
286
+ if (rank == 1 )
287
+ tensorShape = {1 , tensorShape.back ()};
288
+
289
+ size_t dims = tensorShape.size ();
290
+ for (size_t i = 0 ; i < dims; ++i) {
291
+ uint32_t numElemPerWi = wiLayout[i] * wiData[i];
292
+ if (tensorShape[i] < numElemPerWi || tensorShape[i] % numElemPerWi != 0 )
293
+ return emitError () << " cannot distribute " << tensorShape[i] << " over "
294
+ << wiLayout[i] << " work items with " << wiData[i]
295
+ << " elements each" ;
296
+ }
297
+ }
298
+
299
+ return success ();
300
+ }
301
+
226
302
} // namespace xegpu
227
303
} // namespace mlir
228
304
0 commit comments