Skip to content

Commit 5df9cf0

Browse files
committed
Improve scattered verification
1 parent fa82d89 commit 5df9cf0

File tree

2 files changed

+58
-32
lines changed

2 files changed

+58
-32
lines changed

mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -229,17 +229,47 @@ LogicalResult TensorDescType::verify(
229229
llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
230230
mlir::Attribute encoding, mlir::Attribute sg_map) {
231231
size_t rank = shape.size();
232-
if (rank > 2)
233-
return emitError() << "desc shape rank exceeds 2";
232+
if (rank != 1 && rank != 2)
233+
return emitError() << "expected 1D or 2D tensor";
234+
235+
// Scattered attribute imposes extra restriction on tensor descriptor.
236+
// Block attribute can only be validated further against data transfer
237+
// operations.
238+
auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(encoding);
239+
if (scatterAttr) {
240+
// Expected tensor ranks for scattered data:
241+
// - 1D tensor for fully non-contiguous elements (chunk size == 1)
242+
// - 2D tensor for scattered blocks (chunk size > 1)
243+
IntegerAttr chunkAttr = scatterAttr.getChunkSize();
244+
unsigned chunkSize = chunkAttr ? chunkAttr.getInt() : 1;
245+
if (rank == 1 && chunkSize != 1)
246+
return emitError() << "expected non-contiguous elements for 1D tensor";
247+
if (rank == 2 && chunkSize < 2)
248+
return emitError() << "expected chunk blocks for 2D tensor";
249+
}
234250

235251
if (auto sgMapAttr = llvm::dyn_cast_if_present<SGMapAttr>(sg_map)) {
236252
ArrayRef<uint32_t> wiLayout = sgMapAttr.getWiLayout();
237253
ArrayRef<uint32_t> wiData = sgMapAttr.getWiData();
238254

239255
if (rank == 1) {
240256
if (wiLayout[0] != 1 || wiData[0] != 1)
241-
return emitError() << "outer layout and data mapping must be 1 "
242-
"for 1D tensor";
257+
return emitError()
258+
<< "outer layout distribution and data mapping must be 1 "
259+
"for 1D tensor";
260+
}
261+
262+
if (scatterAttr) {
263+
// Validate subgroup mapping rules for scattered tensors.
264+
if (wiData[0] != 1)
265+
return emitError()
266+
<< "cannot map over non-contiguous scattered row elements";
267+
268+
IntegerAttr chunkAttr = scatterAttr.getChunkSize();
269+
unsigned chunkSize = chunkAttr ? chunkAttr.getInt() : 1;
270+
if (wiData[1] != chunkSize)
271+
return emitError() << "work item data mapping must match the number of "
272+
"contiguous elements";
243273
}
244274

245275
// For 1D tensor, pad the shape with an outer unit dimension to allow common
@@ -252,21 +282,9 @@ LogicalResult TensorDescType::verify(
252282
for (size_t i = 0; i < dims; ++i) {
253283
uint32_t numElemPerWi = wiLayout[i] * wiData[i];
254284
if (tensorShape[i] < numElemPerWi || tensorShape[i] % numElemPerWi != 0)
255-
return emitError() << "cannot map " << tensorShape[i]
256-
<< " elements into " << wiLayout[i] << " by "
257-
<< wiData[i] << " tiles";
258-
}
259-
260-
if (mlir::isa_and_nonnull<ScatterTensorDescAttr>(encoding)) {
261-
auto scatterAttr = llvm::dyn_cast<ScatterTensorDescAttr>(encoding);
262-
if (wiData[0] != 1)
263-
return emitError()
264-
<< "cannot map over non-contiguous scattered elements";
265-
266-
unsigned chunkSize = scatterAttr.getChunkSize().getInt();
267-
if (wiData[1] > chunkSize)
268-
return emitError()
269-
<< "too few contiguous elements for work item mapping";
285+
return emitError() << "cannot distribute " << tensorShape[i] << " over "
286+
<< wiLayout[i] << " work items with " << wiData[i]
287+
<< " elements each";
270288
}
271289
}
272290

mlir/test/Dialect/XeGPU/invalid.mlir

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -183,8 +183,8 @@ func.func @test_create_tdesc_vc_1(%src: ui64) {
183183
// -----
184184
func.func @test_create_tdesc_vc_2(%src: ui64) {
185185
%0 = arith.constant dense<[0, 2, 4, 6, 8, 10, 12, 14]> : vector<8xindex>
186-
// expected-error@+1 {{Incorrect TensorDesc shape}}
187186
%1 = xegpu.create_tdesc %src, %0 : ui64, vector<8xindex>
187+
// expected-error@+1 {{expected chunk blocks for 2D tensor}}
188188
-> !xegpu.tensor_desc<8x4xf16, #xegpu.scatter_tdesc_attr<>>
189189
return
190190
}
@@ -219,23 +219,23 @@ func.func @test_prefetch_vc_2(%src: ui64) {
219219
// -----
220220
func.func @test_create_tdesc_sg_map_1(%src: ui64) {
221221
%cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
222-
// expected-error@+1 {{Detected a conflict between SG map's work-item layout and TensorDesc shape. Check the index of `subgroup_size` in WI layout map}}
222+
// expected-error@+1 {{outer layout distribution and data mapping must be 1 for 1D tensor}}
223223
%1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
224224
return
225225
}
226226

227227
// -----
228228
func.func @test_create_tdesc_sg_map_2(%src: ui64) {
229229
%cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
230-
// expected-error@+1 {{TensorDesc's SG map only supports multiple elements contiguous along rows}}
230+
// expected-error@+1 {{cannot map over non-contiguous scattered row elements}}
231231
%1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [1, 4], wi_data = [2, 1]>>
232232
return
233233
}
234234

235235
// -----
236236
func.func @test_create_tdesc_sg_map_3(%src: ui64) {
237237
%cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
238-
// expected-error@+1 {{TensorDesc's chunkSize must match WI's data mapping}}
238+
// expected-error@+1 {{work item data mapping must match the number of contiguous elements}}
239239
%1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x3xf32, #xegpu.scatter_tdesc_attr<chunk_size = 3>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
240240
return
241241
}
@@ -366,55 +366,63 @@ func.func @test_atomic_rmw(%src: ui64, %value : vector<16x4xf32>, %mask : vector
366366
// -----
367367
func.func @tensor_desc_invalid_rank(%src: memref<24x32xf32>) {
368368
%0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
369-
// expected-error@+1 {{desc shape rank exceeds 2}}
369+
// expected-error@+1 {{expected 1D or 2D tensor}}
370370
!xegpu.tensor_desc<16x2x2xf32>
371371
return
372372
}
373373

374+
// -----
375+
func.func @tensor_desc_invalid_rank_1(%src: memref<24x32xf32>) {
376+
%0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
377+
// expected-error@+1 {{expected 1D or 2D tensor}}
378+
!xegpu.tensor_desc<f32>
379+
return
380+
}
381+
374382
// -----
375383
func.func @tensor_desc_1D_invalid_map_layout(%src: memref<24x32xf32>) {
376384
%0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
377-
// expected-error@+1 {{outer layout and data mapping must be 1 for 1D tensor}}
385+
// expected-error@+1 {{outer layout distribution and data mapping must be 1 for 1D tensor}}
378386
!xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [2, 16], wi_data = [1, 1]>>
379387
return
380388
}
381389

382390
// -----
383391
func.func @tensor_desc_1D_invalid_map_data(%src: memref<24x32xf32>) {
384392
%0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
385-
// expected-error@+1 {{outer layout and data mapping must be 1 for 1D tensor}}
393+
// expected-error@+1 {{outer layout distribution and data mapping must be 1 for 1D tensor}}
386394
!xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [2, 1]>>
387395
return
388396
}
389397

390398
// -----
391399
func.func @tensor_desc_invalid_map_layout(%src: memref<24x32xf32>) {
392400
%0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
393-
// expected-error@+1 {{cannot map 8 elements into 16 by 1 tiles}}
401+
// expected-error@+1 {{cannot distribute 8 over 16 work items with 1 elements each}}
394402
!xegpu.tensor_desc<4x8xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
395403
return
396404
}
397405

398406
// -----
399407
func.func @tensor_desc_invalid_map_layout_1(%src: memref<24x32xf32>) {
400408
%0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
401-
// expected-error@+1 {{cannot map 4 elements into 8 by 1 tiles}}
409+
// expected-error@+1 {{cannot distribute 4 over 8 work items with 1 elements each}}
402410
!xegpu.tensor_desc<4x8xf32, #xegpu.sg_map<wi_layout = [8, 2], wi_data = [1, 1]>>
403411
return
404412
}
405413

406414
// -----
407415
func.func @tensor_desc_invalid_map_data(%src: memref<24x32xf32>) {
408416
%0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
409-
// expected-error@+1 {{cannot map 4 elements into 2 by 4 tiles}}
417+
// expected-error@+1 {{cannot distribute 4 over 2 work items with 4 elements each}}
410418
!xegpu.tensor_desc<4x8xf32, #xegpu.sg_map<wi_layout = [2, 8], wi_data = [4, 1]>>
411419
return
412420
}
413421

414422
// -----
415423
func.func @tensor_desc_invalid_map_data_1(%src: memref<24x32xf32>) {
416424
%0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
417-
// expected-error@+1 {{cannot map 4 elements into 8 by 1 tiles}}
425+
// expected-error@+1 {{cannot distribute 4 over 8 work items with 1 elements each}}
418426
!xegpu.tensor_desc<4x8xf32, #xegpu.sg_map<wi_layout = [8, 2], wi_data = [1, 2]>>
419427
return
420428
}
@@ -423,7 +431,7 @@ func.func @tensor_desc_invalid_map_data_1(%src: memref<24x32xf32>) {
423431
func.func @tensor_desc_scatter_invalid_map_data(%src: ui64) {
424432
%0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
425433
%1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> ->
426-
// expected-error@+1 {{cannot map over non-contiguous scattered elements}}
434+
// expected-error@+1 {{cannot map over non-contiguous scattered row elements}}
427435
!xegpu.tensor_desc<4x2xf32,
428436
#xegpu.scatter_tdesc_attr<chunk_size = 2>,
429437
#xegpu.sg_map<wi_layout = [1, 1], wi_data = [2, 1]>>
@@ -433,7 +441,7 @@ func.func @tensor_desc_scatter_invalid_map_data(%src: ui64) {
433441
// -----
434442
func.func @tensor_desc_scatter_invalid_map_data_1(%src: ui64, %offsets: vector<16xindex>) {
435443
%1 = xegpu.create_tdesc %src, %offsets : ui64, vector<16xindex> ->
436-
// expected-error@+1 {{too few contiguous elements for work item mapping}}
444+
// expected-error@+1 {{work item data mapping must match the number of contiguous elements}}
437445
!xegpu.tensor_desc<16xf32,
438446
#xegpu.scatter_tdesc_attr<chunk_size = 1>,
439447
#xegpu.sg_map<wi_layout = [1, 8], wi_data = [1, 2]>>

0 commit comments

Comments
 (0)