Skip to content

[mlir][xegpu] Improve scatter attribute definition #126540

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 11, 2025

Conversation

adam-smnk
Copy link
Contributor

Refactors XeGPU scatter attribute introducing following:

  • improved docs formatting
  • default initialized parameters
  • invariant checks in attribute verifier
  • removal of additional parsing error

The attribute's getter now provide default values simplifying their usage and scattered tensor descriptor handling.
Related descriptor verifier is updated to avoid check duplication.

@llvmbot
Copy link
Member

llvmbot commented Feb 10, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Author: Adam Siemieniuk (adam-smnk)

Changes

Refactors XeGPU scatter attribute introducing following:

  • improved docs formatting
  • default initialized parameters
  • invariant checks in attribute verifier
  • removal of additional parsing error

The attribute's getter now provide default values simplifying their usage and scattered tensor descriptor handling.
Related descriptor verifier is updated to avoid check duplication.


Full diff: https://github.com/llvm/llvm-project/pull/126540.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td (+18-6)
  • (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td (+1-1)
  • (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp (+14-6)
  • (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp (+1-9)
  • (modified) mlir/test/Dialect/XeGPU/XeGPUOps.mlir (+9)
  • (modified) mlir/test/Dialect/XeGPU/invalid.mlir (+10-1)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 4841f94de75f4aa..0136b18ccfa9461 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -59,19 +59,29 @@ def XeGPU_BlockTensorDescAttr: XeGPU_TensorDescAttr<"BlockTensorDesc", "block_td
 
 def XeGPU_ScatterTensorDescAttr: XeGPU_TensorDescAttr<"ScatterTensorDesc", "scatter_tdesc_attr"> {
   let summary = [{a composite attribute for `TensorDescType`}];
-  let description = [{`ScatterTensorDesc` (or `scatter_tdesc_attr`) is a composite
-    attribute defined for `TensorDescType` for describing following
-    properties of a `TensorDesc`.
+  let description = [{
+    `ScatterTensorDesc` is a composite attribute defined for `TensorDescType`
+    for describing following properties of a `TensorDesc`:
+
     1. `memory_space`: It describes where the data block described by the
         TensorDesc is located, `Global` device memory or `Shared` local memory.
         It is default to `Global`.
-    2.  `chunk_size`: indicates number of continious elements accessed for each
+
+    2.  `chunk_size`: indicates number of contiguous elements accessed for each
         offset, default is 1. It is used with `scattered` attr only.
   }];
 
   let parameters = (ins
-    OptionalParameter<"MemorySpaceAttr">: $memory_space,
-    OptionalParameter<"IntegerAttr", "1">: $chunk_size
+    DefaultValuedParameter<
+      "MemorySpaceAttr",
+      "MemorySpaceAttr::get($_ctxt, xegpu::MemorySpace::Global)",
+      "Data memory location"
+    >: $memory_space,
+    DefaultValuedParameter<
+      "IntegerAttr",
+      "IntegerAttr::get(IntegerType::get($_ctxt, 64), 1)",
+      "Number of contiguous elements"
+    >: $chunk_size
   );
 
   let builders = [
@@ -80,6 +90,8 @@ def XeGPU_ScatterTensorDescAttr: XeGPU_TensorDescAttr<"ScatterTensorDesc", "scat
       CArg<"int", "1">: $chunk_size
     )>
   ];
+
+  let genVerifyDecl = 1;
  }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 494f11f041b71ff..cc2e93fb19a7048 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -172,7 +172,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
       auto attr = getEncoding();
       auto scatter_attr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(attr);
       assert((!attr || scatter_attr) && "invalid on non ScatterTensorDescAttr.");
-      if (scatter_attr && scatter_attr.getChunkSize())
+      if (scatter_attr)
         return scatter_attr.getChunkSize().getInt();
       return 1;
     }
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index becc32d1226973d..06fd03f3af3ad5a 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -55,6 +55,18 @@ ScatterTensorDescAttr::get(mlir::MLIRContext *context,
   return Base::get(context, scopeAttr, chunkSizeAttr);
 }
 
+LogicalResult ScatterTensorDescAttr::verify(
+    llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
+    MemorySpaceAttr memory_space, IntegerAttr chunk_size) {
+  int64_t chunkSize = chunk_size.getInt();
+  SmallVector<int64_t> supportedChunkSizes = {1,  2,  3,  4,   8,
+                                              16, 32, 64, 128, 256};
+  if (!llvm::is_contained(supportedChunkSizes, chunkSize))
+    return emitError() << "invalid chunk size";
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // XeGPU_SGMapAttr
 //===----------------------------------------------------------------------===//
@@ -166,8 +178,6 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
         continue;
       }
     }
-    parser.emitError(parser.getCurrentLocation(),
-                     "Failed to parse the attribute.\n");
     return {};
   }
 
@@ -237,8 +247,7 @@ LogicalResult TensorDescType::verify(
     // Expected tensor ranks for scattered data:
     //   - 1D tensor for fully non-contiguous elements (chunk size == 1)
     //   - 2D tensor for scattered blocks (chunk size > 1)
-    IntegerAttr chunkAttr = scatterAttr.getChunkSize();
-    unsigned chunkSize = chunkAttr ? chunkAttr.getInt() : 1;
+    unsigned chunkSize = scatterAttr.getChunkSize().getInt();
     if (rank == 1 && chunkSize != 1)
       return emitError() << "expected non-contiguous elements for 1D tensor";
     if (rank == 2 && chunkSize < 2)
@@ -273,8 +282,7 @@ LogicalResult TensorDescType::verify(
         return emitError()
                << "cannot map over non-contiguous scattered row elements";
 
-      IntegerAttr chunkAttr = scatterAttr.getChunkSize();
-      unsigned chunkSize = chunkAttr ? chunkAttr.getInt() : 1;
+      unsigned chunkSize = scatterAttr.getChunkSize().getInt();
       if (wiData[1] != chunkSize)
         return emitError() << "work item data mapping must match the number of "
                               "contiguous elements";
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index e06d99ac20bb736..25dc1f22f043264 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -419,16 +419,8 @@ LogicalResult CreateDescOp::verify() {
            << " Source: " << srcMemorySpace
            << ", TensorDesc: " << tdescMemorySpace;
 
-  auto chunkSize = tdescTy.getChunkSize();
-
-  // check chunk_size
-  llvm::SmallVector<int64_t> supportedChunkSizes = {1,  2,  3,  4,   8,
-                                                    16, 32, 64, 128, 256};
-  if (!llvm::is_contained(supportedChunkSizes, chunkSize))
-    return emitOpError("Invalid chunk_size. Supported values are 1, 2, 3, 4, "
-                       "8, 16, 32, 64, 128, or 256.");
-
   // check total size
+  auto chunkSize = tdescTy.getChunkSize();
   auto elemBits = tdescTy.getElementType().getIntOrFloatBitWidth();
   auto bitsPerLane = elemBits * chunkSize;
   if (chunkSize > 1 && bitsPerLane % 32) {
diff --git a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
index 8af1b600ad0a4e2..472176af72b191f 100644
--- a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
+++ b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
@@ -181,6 +181,15 @@ gpu.func @test_create_tdesc_vc_1(%src: memref<?xf32, 3>) {
   gpu.return
 }
 
+// CHECK: gpu.func @test_create_tdesc_vc_2(%[[arg0:.*]]: memref<?xf32>) {
+gpu.func @test_create_tdesc_vc_2(%src: memref<?xf32>) {
+  //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : memref<?xf32>, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>
+  %1 = xegpu.create_tdesc %src, %0 : memref<?xf32>, vector<4xindex>  -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<chunk_size = 1>>
+  gpu.return
+}
+
 // CHECK: gpu.func @test_create_tdesc_vc_with_sg_map(%[[arg0:.*]]: ui64) {
 gpu.func @test_create_tdesc_vc_with_sg_map(%src: ui64) {
   //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 9162e0012f6d56d..86356e09de57cef 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -190,7 +190,7 @@ func.func @test_create_tdesc_vc_2(%src: ui64) {
 }
 
 // -----
-func.func @test_create_tdesc_vc_1(%src: memref<?xf32>) {
+func.func @test_create_tdesc_vc_3(%src: memref<?xf32>) {
   %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
   // expected-error@+1 {{Memory space mismatch}}
   %1 = xegpu.create_tdesc %src, %0 : memref<?xf32>, vector<4xindex>
@@ -198,6 +198,15 @@ func.func @test_create_tdesc_vc_1(%src: memref<?xf32>) {
   return
 }
 
+// -----
+func.func @test_create_tdesc_vc_4(%src: memref<?xf32>) {
+  %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  %1 = xegpu.create_tdesc %src, %0 : memref<?xf32>, vector<4xindex>
+  // expected-error@+1 {{invalid chunk size}}
+          -> !xegpu.tensor_desc<4x5xf32, #xegpu.scatter_tdesc_attr<chunk_size = 5>>
+  return
+}
+
 // -----
 func.func @test_prefetch_vc_1(%src: memref<24x32xf16>) {
   %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<24x32xf16>

Copy link
Contributor

@chencha3 chencha3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks.

@adam-smnk adam-smnk merged commit 67f59a6 into llvm:main Feb 11, 2025
11 checks passed
Icohedron pushed a commit to Icohedron/llvm-project that referenced this pull request Feb 11, 2025
Refactors XeGPU scatter attribute introducing following:
  - improved docs formatting
  - default initialized parameters
  - invariant checks in attribute verifier
  - removal of additional parsing error
 
The attribute's getters now provide default values simplifying their
usage and scattered tensor descriptor handling.
Related descriptor verifier is updated to avoid check duplication.
joaosaffran pushed a commit to joaosaffran/llvm-project that referenced this pull request Feb 14, 2025
Refactors XeGPU scatter attribute introducing following:
  - improved docs formatting
  - default initialized parameters
  - invariant checks in attribute verifier
  - removal of additional parsing error
 
The attribute's getters now provide default values simplifying their
usage and scattered tensor descriptor handling.
Related descriptor verifier is updated to avoid check duplication.
sivan-shani pushed a commit to sivan-shani/llvm-project that referenced this pull request Feb 24, 2025
Refactors XeGPU scatter attribute introducing following:
  - improved docs formatting
  - default initialized parameters
  - invariant checks in attribute verifier
  - removal of additional parsing error
 
The attribute's getters now provide default values simplifying their
usage and scattered tensor descriptor handling.
Related descriptor verifier is updated to avoid check duplication.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants