Skip to content

Commit 8a03658

Browse files
authored
[mlir][xegpu] Tensor descriptor type verifier (#124548)
Adds XeGPU tensor descriptor type verifier. The type verifier covers general tensor descriptor invariants w.r.t. Xe ISA semantics. Related operation verifiers are updated to account for the new descriptor checks and avoid duplication.
1 parent 479ffe8 commit 8a03658

File tree

5 files changed

+278
-44
lines changed

5 files changed

+278
-44
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
179179
}];
180180

181181
let hasCustomAssemblyFormat = true;
182-
182+
let genVerifyDecl = 1;
183183
}
184184

185185

mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp

Lines changed: 79 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -175,9 +175,10 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
175175
if (parser.parseGreater())
176176
return {};
177177

178-
return TensorDescType::get(parser.getContext(), shape, elementType,
179-
encoding.value_or(mlir::Attribute()),
180-
sg_map.value_or(mlir::Attribute()));
178+
return TensorDescType::getChecked(
179+
[&]() { return parser.emitError(parser.getNameLoc()); },
180+
parser.getContext(), shape, elementType,
181+
encoding.value_or(mlir::Attribute()), sg_map.value_or(mlir::Attribute()));
181182
}
182183

183184
void TensorDescType::print(::mlir::AsmPrinter &printer) const {
@@ -223,6 +224,81 @@ TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
223224
return Base::get(context, shape, elementType, attr, sg_map);
224225
}
225226

227+
LogicalResult TensorDescType::verify(
228+
llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
229+
llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
230+
mlir::Attribute encoding, mlir::Attribute sg_map) {
231+
size_t rank = shape.size();
232+
if (rank != 1 && rank != 2)
233+
return emitError() << "expected 1D or 2D tensor";
234+
235+
auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(encoding);
236+
if (scatterAttr) {
237+
// Expected tensor ranks for scattered data:
238+
// - 1D tensor for fully non-contiguous elements (chunk size == 1)
239+
// - 2D tensor for scattered blocks (chunk size > 1)
240+
IntegerAttr chunkAttr = scatterAttr.getChunkSize();
241+
unsigned chunkSize = chunkAttr ? chunkAttr.getInt() : 1;
242+
if (rank == 1 && chunkSize != 1)
243+
return emitError() << "expected non-contiguous elements for 1D tensor";
244+
if (rank == 2 && chunkSize < 2)
245+
return emitError() << "expected chunk blocks for 2D tensor";
246+
}
247+
248+
if (auto blockAttr =
249+
mlir::dyn_cast_if_present<BlockTensorDescAttr>(encoding)) {
250+
MemorySpaceAttr memorySpaceAttr = blockAttr.getMemorySpace();
251+
if (rank == 2 && memorySpaceAttr &&
252+
memorySpaceAttr.getValue() == MemorySpace::SLM)
253+
return emitError() << "SLM is not supported for 2D block tensor";
254+
}
255+
256+
if (auto sgMapAttr = llvm::dyn_cast_if_present<SGMapAttr>(sg_map)) {
257+
ArrayRef<uint32_t> wiLayout = sgMapAttr.getWiLayout();
258+
ArrayRef<uint32_t> wiData = sgMapAttr.getWiData();
259+
260+
if (rank == 1) {
261+
if (wiLayout[0] != 1 || wiData[0] != 1)
262+
return emitError()
263+
<< "outer layout distribution and data mapping must be 1 "
264+
"for 1D tensor";
265+
}
266+
267+
if (scatterAttr) {
268+
// Validate subgroup mapping rules for scattered tensors.
269+
// A work-item's slice of the tensor with shape [sg_size] or
270+
// [sg_size, chunk_size] will be [1] or [1, chunks_size] respectively,
271+
// the mapping should reflect that.
272+
if (wiData[0] != 1)
273+
return emitError()
274+
<< "cannot map over non-contiguous scattered row elements";
275+
276+
IntegerAttr chunkAttr = scatterAttr.getChunkSize();
277+
unsigned chunkSize = chunkAttr ? chunkAttr.getInt() : 1;
278+
if (wiData[1] != chunkSize)
279+
return emitError() << "work item data mapping must match the number of "
280+
"contiguous elements";
281+
}
282+
283+
// For 1D tensor, pad the shape with an outer unit dimension to allow common
284+
// validation logic.
285+
SmallVector<int64_t> tensorShape(shape.begin(), shape.end());
286+
if (rank == 1)
287+
tensorShape = {1, tensorShape.back()};
288+
289+
size_t dims = tensorShape.size();
290+
for (size_t i = 0; i < dims; ++i) {
291+
uint32_t numElemPerWi = wiLayout[i] * wiData[i];
292+
if (tensorShape[i] < numElemPerWi || tensorShape[i] % numElemPerWi != 0)
293+
return emitError() << "cannot distribute " << tensorShape[i] << " over "
294+
<< wiLayout[i] << " work items with " << wiData[i]
295+
<< " elements each";
296+
}
297+
}
298+
299+
return success();
300+
}
301+
226302
} // namespace xegpu
227303
} // namespace mlir
228304

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 16 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -81,24 +81,28 @@ static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
8181
// each dimension.
8282
static bool isArgShapesValid(ArrayRef<int64_t> descShape,
8383
ArrayRef<int64_t> valShape, SGMapAttr sgMap) {
84-
if (descShape == valShape) {
85-
if (!sgMap)
86-
return true;
87-
88-
// this can be relaxed if necessary by supporting non-2d shapes distribution
89-
// until the constraints are defined this lives here instead of the tensor
90-
// descriptor type.
91-
return valShape.size() == sgMap.getWiLayout().size();
92-
}
84+
// Equal shapes with no distribution - no further verification needed.
85+
if (descShape == valShape && !sgMap)
86+
return true;
9387

88+
// Unknown distribution - cannot perform operation on partial shape.
9489
if (!sgMap)
9590
return false;
9691

97-
if (valShape.size() != descShape.size())
92+
// Invalid rank or mixed rank usage.
93+
size_t descRank = descShape.size();
94+
if (descRank > 2 || valShape.size() != descRank)
9895
return false;
9996

97+
// For 1D, SG map is guaranteed to be unit size in the outer dimension.
98+
// Only take the distribution over the innermost dimension for validation.
99+
ArrayRef<uint32_t> wiLayout = sgMap.getWiLayout();
100+
SmallVector<uint32_t> mapLayout(wiLayout.begin(), wiLayout.end());
101+
if (descRank == 1)
102+
mapLayout = {wiLayout.back()};
103+
100104
for (const auto &[factor, dim, expected] :
101-
llvm::zip_equal(sgMap.getWiLayout(), valShape, descShape)) {
105+
llvm::zip_equal(mapLayout, valShape, descShape)) {
102106
if (factor * dim != expected)
103107
return false;
104108
}
@@ -227,10 +231,6 @@ LogicalResult CreateNdDescOp::verify() {
227231
if (getType().isScattered())
228232
return emitOpError("Expects a non-scattered TensorDesc.\n");
229233

230-
if (getType().getRank() == 2 &&
231-
tdescMemorySpace == static_cast<unsigned>(MemorySpace::SLM))
232-
return emitOpError("SLM is not supported for 2D Block TensorDesc.\n");
233-
234234
return success();
235235
}
236236

@@ -454,22 +454,7 @@ LogicalResult CreateDescOp::verify() {
454454
if (shape != tdescShape)
455455
return emitOpError("Incorrect TensorDesc shape. ")
456456
<< "Expected is " << makeString(shape) << "\n";
457-
if (auto sgMap = tdescTy.getSGMapAttr()) {
458-
// A work-item's slice of the TensorDesc with shape [sg_size] or
459-
// [sg_size, chunk_size] will be [1] or [1, chunks_size] respectively,
460-
// the mapping should reflect that.
461-
if (sgMap.getWiData()[0] > 1)
462-
return emitOpError("TensorDesc's SG map only supports multiple elements "
463-
"contiguous along rows.");
464-
if (chunkSize != static_cast<int>(sgMap.getWiData()[1]))
465-
return emitOpError(
466-
"TensorDesc's chunkSize must match WI's data mapping.");
467-
if (int rank = tdescTy.getRank();
468-
(sgMap.getWiLayout()[2 - rank] != tdescShape[0]))
469-
return emitOpError("Detected a conflict between SG map's work-item "
470-
"layout and TensorDesc shape. Check the index of "
471-
"`subgroup_size` in WI layout map.");
472-
}
457+
473458
return success();
474459
}
475460

mlir/test/Dialect/XeGPU/XeGPUOps.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,16 @@ gpu.func @test_load_nd_vc_3(%src: memref<24x32xf32>) {
9797
gpu.return
9898
}
9999

100+
// CHECK: func @test_load_nd_vc_4(%[[arg0:.*]]: memref<24x32xf32>) {
101+
gpu.func @test_load_nd_vc_4(%src: memref<24x32xf32>) {
102+
// CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
103+
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
104+
!xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
105+
// CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<2xf32>
106+
%2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<2xf32>
107+
gpu.return
108+
}
109+
100110
// CHECK: func @test_store_nd_vc(%[[arg0:.*]]: memref<24x32xf16>) {
101111
gpu.func @test_store_nd_vc(%dst: memref<24x32xf16>) {
102112
// CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<24x32xf16>
@@ -132,6 +142,18 @@ gpu.func @test_store_nd_vc_3(%src: memref<24x32xf16>) {
132142
gpu.return
133143
}
134144

145+
// CHECK: func @test_store_nd_vc_4(%[[arg0:.*]]: memref<24x32xf16>) {
146+
gpu.func @test_store_nd_vc_4(%src: memref<24x32xf16>) {
147+
// CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<2xf16>
148+
%1 = arith.constant dense<1.0>: vector<2xf16>
149+
// CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
150+
%2 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> ->
151+
!xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
152+
// CHECK: xegpu.store_nd %[[C]], %[[R0]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<2xf16>, !xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
153+
xegpu.store_nd %1, %2 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<2xf16>, !xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
154+
gpu.return
155+
}
156+
135157
// CHECK: gpu.func @test_create_update_nd_tdesc_vc(%[[arg0:.*]]: memref<24x32xf32>) {
136158
gpu.func @test_create_update_nd_tdesc_vc(%src: memref<24x32xf32>) {
137159
// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>

0 commit comments

Comments
 (0)