Skip to content

Commit a72c12f

Browse files
committed
Move memory space check to type verifier
1 parent 50c6283 commit a72c12f

File tree

3 files changed

+9
-8
lines changed

3 files changed

+9
-8
lines changed

mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -232,9 +232,6 @@ LogicalResult TensorDescType::verify(
232232
if (rank != 1 && rank != 2)
233233
return emitError() << "expected 1D or 2D tensor";
234234

235-
// Scattered attribute imposes extra restriction on tensor descriptor.
236-
// Block attribute can only be validated further against data transfer
237-
// operations.
238235
auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(encoding);
239236
if (scatterAttr) {
240237
// Expected tensor ranks for scattered data:
@@ -248,6 +245,14 @@ LogicalResult TensorDescType::verify(
248245
return emitError() << "expected chunk blocks for 2D tensor";
249246
}
250247

248+
if (auto blockAttr =
249+
mlir::dyn_cast_if_present<BlockTensorDescAttr>(encoding)) {
250+
MemorySpaceAttr memorySpaceAttr = blockAttr.getMemorySpace();
251+
if (rank == 2 && memorySpaceAttr &&
252+
memorySpaceAttr.getValue() == MemorySpace::SLM)
253+
return emitError() << "SLM is not supported for 2D block tensor";
254+
}
255+
251256
if (auto sgMapAttr = llvm::dyn_cast_if_present<SGMapAttr>(sg_map)) {
252257
ArrayRef<uint32_t> wiLayout = sgMapAttr.getWiLayout();
253258
ArrayRef<uint32_t> wiData = sgMapAttr.getWiData();

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -231,10 +231,6 @@ LogicalResult CreateNdDescOp::verify() {
231231
if (getType().isScattered())
232232
return emitOpError("Expects a non-scattered TensorDesc.\n");
233233

234-
if (getType().getRank() == 2 &&
235-
tdescMemorySpace == static_cast<unsigned>(MemorySpace::SLM))
236-
return emitOpError("SLM is not supported for 2D Block TensorDesc.\n");
237-
238234
return success();
239235
}
240236

mlir/test/Dialect/XeGPU/invalid.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ func.func @test_create_nd_tdesc_vc_2(%src: memref<24x32xf32>) {
1717

1818
// -----
1919
func.func @test_create_nd_tdesc_vc_3(%src: memref<2x24x32xf32, 3>) {
20-
// expected-error@+1 {{SLM is not supported for 2D Block TensorDesc}}
20+
// expected-error@+1 {{SLM is not supported for 2D block tensor}}
2121
%1 = xegpu.create_nd_tdesc %src[0, 0, 0] : memref<2x24x32xf32, 3> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = slm>>
2222
return
2323
}

0 commit comments

Comments
 (0)