-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][xegpu] Improve scatter attribute definition #126540
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-gpu Author: Adam Siemieniuk (adam-smnk) ChangesRefactors XeGPU scatter attribute introducing following:
The attribute's getter now provide default values simplifying their usage and scattered tensor descriptor handling. Full diff: https://github.com/llvm/llvm-project/pull/126540.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 4841f94de75f4aa..0136b18ccfa9461 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -59,19 +59,29 @@ def XeGPU_BlockTensorDescAttr: XeGPU_TensorDescAttr<"BlockTensorDesc", "block_td
def XeGPU_ScatterTensorDescAttr: XeGPU_TensorDescAttr<"ScatterTensorDesc", "scatter_tdesc_attr"> {
let summary = [{a composite attribute for `TensorDescType`}];
- let description = [{`ScatterTensorDesc` (or `scatter_tdesc_attr`) is a composite
- attribute defined for `TensorDescType` for describing following
- properties of a `TensorDesc`.
+ let description = [{
+ `ScatterTensorDesc` is a composite attribute defined for `TensorDescType`
+ for describing following properties of a `TensorDesc`:
+
1. `memory_space`: It describes where the data block described by the
TensorDesc is located, `Global` device memory or `Shared` local memory.
It is default to `Global`.
- 2. `chunk_size`: indicates number of continious elements accessed for each
+
+ 2. `chunk_size`: indicates number of contiguous elements accessed for each
offset, default is 1. It is used with `scattered` attr only.
}];
let parameters = (ins
- OptionalParameter<"MemorySpaceAttr">: $memory_space,
- OptionalParameter<"IntegerAttr", "1">: $chunk_size
+ DefaultValuedParameter<
+ "MemorySpaceAttr",
+ "MemorySpaceAttr::get($_ctxt, xegpu::MemorySpace::Global)",
+ "Data memory location"
+ >: $memory_space,
+ DefaultValuedParameter<
+ "IntegerAttr",
+ "IntegerAttr::get(IntegerType::get($_ctxt, 64), 1)",
+ "Number of contiguous elements"
+ >: $chunk_size
);
let builders = [
@@ -80,6 +90,8 @@ def XeGPU_ScatterTensorDescAttr: XeGPU_TensorDescAttr<"ScatterTensorDesc", "scat
CArg<"int", "1">: $chunk_size
)>
];
+
+ let genVerifyDecl = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 494f11f041b71ff..cc2e93fb19a7048 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -172,7 +172,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
auto attr = getEncoding();
auto scatter_attr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(attr);
assert((!attr || scatter_attr) && "invalid on non ScatterTensorDescAttr.");
- if (scatter_attr && scatter_attr.getChunkSize())
+ if (scatter_attr)
return scatter_attr.getChunkSize().getInt();
return 1;
}
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index becc32d1226973d..06fd03f3af3ad5a 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -55,6 +55,18 @@ ScatterTensorDescAttr::get(mlir::MLIRContext *context,
return Base::get(context, scopeAttr, chunkSizeAttr);
}
+LogicalResult ScatterTensorDescAttr::verify(
+ llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
+ MemorySpaceAttr memory_space, IntegerAttr chunk_size) {
+ int64_t chunkSize = chunk_size.getInt();
+ SmallVector<int64_t> supportedChunkSizes = {1, 2, 3, 4, 8,
+ 16, 32, 64, 128, 256};
+ if (!llvm::is_contained(supportedChunkSizes, chunkSize))
+ return emitError() << "invalid chunk size";
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// XeGPU_SGMapAttr
//===----------------------------------------------------------------------===//
@@ -166,8 +178,6 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
continue;
}
}
- parser.emitError(parser.getCurrentLocation(),
- "Failed to parse the attribute.\n");
return {};
}
@@ -237,8 +247,7 @@ LogicalResult TensorDescType::verify(
// Expected tensor ranks for scattered data:
// - 1D tensor for fully non-contiguous elements (chunk size == 1)
// - 2D tensor for scattered blocks (chunk size > 1)
- IntegerAttr chunkAttr = scatterAttr.getChunkSize();
- unsigned chunkSize = chunkAttr ? chunkAttr.getInt() : 1;
+ unsigned chunkSize = scatterAttr.getChunkSize().getInt();
if (rank == 1 && chunkSize != 1)
return emitError() << "expected non-contiguous elements for 1D tensor";
if (rank == 2 && chunkSize < 2)
@@ -273,8 +282,7 @@ LogicalResult TensorDescType::verify(
return emitError()
<< "cannot map over non-contiguous scattered row elements";
- IntegerAttr chunkAttr = scatterAttr.getChunkSize();
- unsigned chunkSize = chunkAttr ? chunkAttr.getInt() : 1;
+ unsigned chunkSize = scatterAttr.getChunkSize().getInt();
if (wiData[1] != chunkSize)
return emitError() << "work item data mapping must match the number of "
"contiguous elements";
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index e06d99ac20bb736..25dc1f22f043264 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -419,16 +419,8 @@ LogicalResult CreateDescOp::verify() {
<< " Source: " << srcMemorySpace
<< ", TensorDesc: " << tdescMemorySpace;
- auto chunkSize = tdescTy.getChunkSize();
-
- // check chunk_size
- llvm::SmallVector<int64_t> supportedChunkSizes = {1, 2, 3, 4, 8,
- 16, 32, 64, 128, 256};
- if (!llvm::is_contained(supportedChunkSizes, chunkSize))
- return emitOpError("Invalid chunk_size. Supported values are 1, 2, 3, 4, "
- "8, 16, 32, 64, 128, or 256.");
-
// check total size
+ auto chunkSize = tdescTy.getChunkSize();
auto elemBits = tdescTy.getElementType().getIntOrFloatBitWidth();
auto bitsPerLane = elemBits * chunkSize;
if (chunkSize > 1 && bitsPerLane % 32) {
diff --git a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
index 8af1b600ad0a4e2..472176af72b191f 100644
--- a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
+++ b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
@@ -181,6 +181,15 @@ gpu.func @test_create_tdesc_vc_1(%src: memref<?xf32, 3>) {
gpu.return
}
+// CHECK: gpu.func @test_create_tdesc_vc_2(%[[arg0:.*]]: memref<?xf32>) {
+gpu.func @test_create_tdesc_vc_2(%src: memref<?xf32>) {
+ //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : memref<?xf32>, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>
+ %1 = xegpu.create_tdesc %src, %0 : memref<?xf32>, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<chunk_size = 1>>
+ gpu.return
+}
+
// CHECK: gpu.func @test_create_tdesc_vc_with_sg_map(%[[arg0:.*]]: ui64) {
gpu.func @test_create_tdesc_vc_with_sg_map(%src: ui64) {
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 9162e0012f6d56d..86356e09de57cef 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -190,7 +190,7 @@ func.func @test_create_tdesc_vc_2(%src: ui64) {
}
// -----
-func.func @test_create_tdesc_vc_1(%src: memref<?xf32>) {
+func.func @test_create_tdesc_vc_3(%src: memref<?xf32>) {
%0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
// expected-error@+1 {{Memory space mismatch}}
%1 = xegpu.create_tdesc %src, %0 : memref<?xf32>, vector<4xindex>
@@ -198,6 +198,15 @@ func.func @test_create_tdesc_vc_1(%src: memref<?xf32>) {
return
}
+// -----
+func.func @test_create_tdesc_vc_4(%src: memref<?xf32>) {
+ %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ %1 = xegpu.create_tdesc %src, %0 : memref<?xf32>, vector<4xindex>
+ // expected-error@+1 {{invalid chunk size}}
+ -> !xegpu.tensor_desc<4x5xf32, #xegpu.scatter_tdesc_attr<chunk_size = 5>>
+ return
+}
+
// -----
func.func @test_prefetch_vc_1(%src: memref<24x32xf16>) {
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<24x32xf16>
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks.
Refactors XeGPU scatter attribute introducing following: - improved docs formatting - default initialized parameters - invariant checks in attribute verifier - removal of additional parsing error The attribute's getters now provide default values simplifying their usage and scattered tensor descriptor handling. Related descriptor verifier is updated to avoid check duplication.
Refactors XeGPU scatter attribute introducing following: - improved docs formatting - default initialized parameters - invariant checks in attribute verifier - removal of additional parsing error The attribute's getters now provide default values simplifying their usage and scattered tensor descriptor handling. Related descriptor verifier is updated to avoid check duplication.
Refactors XeGPU scatter attribute introducing following: - improved docs formatting - default initialized parameters - invariant checks in attribute verifier - removal of additional parsing error The attribute's getters now provide default values simplifying their usage and scattered tensor descriptor handling. Related descriptor verifier is updated to avoid check duplication.
Refactors XeGPU scatter attribute introducing following:
The attribute's getter now provide default values simplifying their usage and scattered tensor descriptor handling.
Related descriptor verifier is updated to avoid check duplication.