Skip to content

[MLIR][XeGPU] Add sg_map attribute to support Work Item level semanti… #110876

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 2, 2024

Conversation

kurapov-peter
Copy link
Contributor

…cs (#108864)

The PR adds an attribute (sg_map) describing the distribution of computation among work items for xegpu operations to be used in lowering passes. The map is attached to the tensor descriptor, so the constructor and the type are updated. Tests check the custom parser & printer. The attribute is optional now, so no other changes required. The complete description of the attribute can be found here.

…cs (llvm#108864)

The PR adds an attribute (sg_map) describing the distribution of
computation among work items for xegpu operations to be used in lowering
passes. The map is attached to the tensor descriptor, so the constructor
and the type are updated. Tests check the custom parser & printer. The
attribute is optional now, so no other changes required.
The complete description of the attribute can be found
[here](https://github.com/intel/mlir-extensions/blob/main/docs/rfcs/XeGPU.md#xegpu-attributes-to-support-work-item-level-semantics).
@llvmbot
Copy link
Member

llvmbot commented Oct 2, 2024

@llvm/pr-subscribers-mlir-gpu

Author: Petr Kurapov (kurapov-peter)

Changes

…cs (#108864)

The PR adds an attribute (sg_map) describing the distribution of computation among work items for xegpu operations to be used in lowering passes. The map is attached to the tensor descriptor, so the constructor and the type are updated. Tests check the custom parser & printer. The attribute is optional now, so no other changes required. The complete description of the attribute can be found here.


Full diff: https://github.com/llvm/llvm-project/pull/110876.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td (+32)
  • (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td (+15-5)
  • (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp (+98-12)
  • (modified) mlir/test/Dialect/XeGPU/XeGPUOps.mlir (+24)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 26eec0d4f2082a..2aaa7fd4221ab1 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -142,4 +142,36 @@ def XeGPU_FenceScopeAttr:
     let assemblyFormat = "$value";
 }
 
+def XeGPU_SGMapAttr : XeGPUAttr<"SGMap", "sg_map"> {
+  let summary = [{
+    Describes the mapping between work item (WI) and the 2D tensor specified by the tensor descriptor.
+  }];
+  let description = [{
+    To distribute the XeGPU operation to work items, the tensor_desc must be specified with the sg_map
+    attribute at the tensor description creation time.
+    Within the `sg_map`, `wi_layout` specifies the layout of work items,
+    describing the mapping of work items to the tensor.
+    wi_layout[0] x wi_layout[1] must be equal to the total number of work items within a subgroup.
+    `wi_data` specifies the minimum number of data elements assigned to each work item for a single distribution.
+
+    E.g., #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>
+    In this example, the subgroup has 16 work items in wi_layout=[1, 16],
+    each accessing 1 element as specified by wi_data=[1, 1].
+
+    `wi_data[0] * wi_data[1]` can be greater than 1, meaning that each work item operates on multiple elements,
+    which is eventually lowered to "SIMT-flavor" vector, like SPIR-V vector or llvm vector, or packed to a storage data type.
+    The multiple elements indicated by `wi_data` can only be from one dimension and must be contiguous in the memory along either dimension.
+  }];
+  let parameters = (ins
+    ArrayRefParameter<"uint32_t">:$wi_layout,
+    ArrayRefParameter<"uint32_t">:$wi_data);
+
+  let builders = [
+    AttrBuilder<(ins)>
+  ];
+
+  let hasCustomAssemblyFormat = 1;
+  let genVerifyDecl = 1;
+}
+
 #endif // MLIR_DIALECT_XEGPU_IR_XEGPUATTRS_TD
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 0ce1211664b5ba..d09c5c1870d506 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -63,7 +63,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
     element-type ::= float-type | integer-type | index-type
     dim-list := (static-dim-list `x`)?
     static-dim-list ::= decimal-literal `x` decimal-literal
-    attr-list = (, memory_space = value)? (, arr_len = value)? (, boundary_check = value)? (, scattered = value)?
+    attr-list = (, memory_space = value)? (, arr_len = value)? (, boundary_check = value)? (, scattered = value)? (, sg_map `<` wi_layout = value, wi_data = value `>`)?
     ```
 
     Examples:
@@ -77,12 +77,16 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
 
     // A TensorDesc with 8x16 f32 elements for a memory region in shared memory space.
     xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr<memory_space = slm>>
+
+    // A TensorDesc with a sg_map
+    xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
     ```
   }];
 
   let parameters = (ins ArrayRefParameter<"int64_t">: $shape,
                         "mlir::Type": $elementType,
-                        OptionalParameter<"mlir::Attribute">: $encoding);
+                        OptionalParameter<"mlir::Attribute">: $encoding,
+                        OptionalParameter<"mlir::Attribute">: $sg_map);
 
   let builders = [
     TypeBuilderWithInferredContext<(ins
@@ -90,14 +94,16 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
       "mlir::Type": $elementType,
       CArg<"int", "1">: $array_length,
       CArg<"bool", "true">: $boundary_check,
-      CArg<"xegpu::MemorySpace", "xegpu::MemorySpace::Global">:$memory_space)>,
+      CArg<"xegpu::MemorySpace", "xegpu::MemorySpace::Global">:$memory_space,
+      CArg<"mlir::Attribute", "mlir::Attribute()">:$sg_map)>,
     TypeBuilderWithInferredContext<(ins
       "llvm::ArrayRef<int64_t>": $shape,
       "mlir::Type": $elementType,
       CArg<"int", "1">: $chunk_size,
-      CArg<"xegpu::MemorySpace", "xegpu::MemorySpace::Global">:$memory_space)>
+      CArg<"xegpu::MemorySpace", "xegpu::MemorySpace::Global">:$memory_space,
+      CArg<"mlir::Attribute", "mlir::Attribute()">:$sg_map)>
   ];
-
+  
   let extraClassDeclaration = [{
     using TensorType::clone;
     using mlir::ShapedType::Trait<TensorDescType>::getElementTypeBitWidth;
@@ -121,6 +127,10 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
       return llvm::dyn_cast_if_present<ScatterTensorDescAttr>(getEncoding());
     }
 
+    SGMapAttr getSGMapAttr() const {
+      return llvm::dyn_cast_if_present<SGMapAttr>(getSgMap());
+    }
+
     xegpu::MemorySpace getMemorySpace() const {
       auto block_attr = getEncodingAsBlockTensorDescAttr();
       if (block_attr && block_attr.getMemorySpace())
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 1dfbaed454c193..eb01b15de75c60 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -55,6 +55,77 @@ ScatterTensorDescAttr::get(mlir::MLIRContext *context,
   return Base::get(context, scopeAttr, chunkSizeAttr);
 }
 
+//===----------------------------------------------------------------------===//
+// XeGPU_SGMapAttr
+//===----------------------------------------------------------------------===//
+namespace {
+template <typename T, unsigned N>
+LogicalResult parseIntArrayField(::mlir::AsmParser &parser,
+                                 llvm::SmallVector<T, N> &result,
+                                 llvm::StringRef fieldName) {
+  if (failed(parser.parseKeyword(fieldName))) {
+    parser.emitError(parser.getCurrentLocation(),
+                     "unexpected field name. Expected " + fieldName + ".");
+    return failure();
+  }
+
+  if (failed(parser.parseEqual())) {
+    parser.emitError(parser.getCurrentLocation(), "expected '=' sign.");
+    return failure();
+  }
+
+  auto elemParser = [&]() -> llvm::ParseResult {
+    uint32_t elem = 0;
+    auto res = parser.parseInteger(elem);
+    result.push_back(elem);
+    return res;
+  };
+
+  return parser.parseCommaSeparatedList(AsmParser::Delimiter::Square,
+                                        elemParser, fieldName);
+}
+} // namespace
+
+mlir::Attribute SGMapAttr::parse(::mlir::AsmParser &parser,
+                                 ::mlir::Type attrType) {
+  if (failed(parser.parseLess()))
+    return {};
+
+  llvm::SmallVector<uint32_t, 2> wi_layout, wi_data;
+  if (failed(parseIntArrayField(parser, wi_layout, "wi_layout")))
+    return {};
+
+  if (failed(parser.parseComma()))
+    return {};
+
+  if (failed(parseIntArrayField(parser, wi_data, "wi_data")))
+    return {};
+
+  return SGMapAttr::getChecked(
+      [&]() { return parser.emitError(parser.getNameLoc()); },
+      parser.getContext(), wi_layout, wi_data);
+}
+
+void SGMapAttr::print(::mlir::AsmPrinter &printer) const {
+  printer << "<";
+  printer.printKeywordOrString("wi_layout");
+  printer << " = [" << getWiLayout() << "], ";
+  printer.printKeywordOrString("wi_data");
+  printer << " = [" << getWiData() << "]";
+  printer << ">";
+}
+
+LogicalResult
+SGMapAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
+                  llvm::ArrayRef<uint32_t> wi_layout,
+                  llvm::ArrayRef<uint32_t> wi_data) {
+  if (wi_layout.size() != 2)
+    return emitError() << "expected wi_layout of size 2";
+  if (wi_data.size() != 2)
+    return emitError() << "expected wi_data of size 2";
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // XeGPU_TensorDescType
 //===----------------------------------------------------------------------===//
@@ -63,6 +134,7 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
   llvm::SmallVector<int64_t> shape;
   mlir::Type elementType;
   mlir::FailureOr<mlir::Attribute> encoding;
+  mlir::FailureOr<mlir::Attribute> sg_map;
 
   // Parse literal '<'
   if (parser.parseLess())
@@ -81,14 +153,22 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
   }
 
   // parse optional attributes
-  if (mlir::succeeded(parser.parseOptionalComma())) {
-    encoding = mlir::FieldParser<mlir::Attribute>::parse(parser);
-    if (mlir::failed(encoding)) {
-      parser.emitError(
-          parser.getCurrentLocation(),
-          "Failed to parse the attribute field for TensorDescType.\n");
-      return {};
+  while (mlir::succeeded(parser.parseOptionalComma())) {
+    mlir::Attribute attr;
+    ParseResult res = parser.parseAttribute(attr);
+    if (mlir::succeeded(res)) {
+      if (mlir::isa<SGMapAttr>(attr)) {
+        sg_map = attr;
+        continue;
+      }
+      if (mlir::isa<BlockTensorDescAttr, ScatterTensorDescAttr>(attr)) {
+        encoding = attr;
+        continue;
+      }
     }
+    parser.emitError(parser.getCurrentLocation(),
+                     "Failed to parse the attribute.\n");
+    return {};
   }
 
   // Parse literal '>'
@@ -96,7 +176,8 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
     return {};
 
   return TensorDescType::get(parser.getContext(), shape, elementType,
-                             encoding.value_or(mlir::Attribute()));
+                             encoding.value_or(mlir::Attribute()),
+                             sg_map.value_or(mlir::Attribute()));
 }
 
 void TensorDescType::print(::mlir::AsmPrinter &printer) const {
@@ -116,25 +197,30 @@ void TensorDescType::print(::mlir::AsmPrinter &printer) const {
   if (auto encoding = getEncoding())
     printer << ", " << encoding;
 
+  if (auto sg_map = getSgMap())
+    printer << ", " << sg_map;
+
   printer << ">";
 }
 
 TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
                                    mlir::Type elementType, int array_length,
                                    bool boundary_check,
-                                   MemorySpace memory_space) {
+                                   MemorySpace memory_space,
+                                   mlir::Attribute sg_map) {
   auto context = elementType.getContext();
   auto attr = BlockTensorDescAttr::get(context, memory_space, array_length,
                                        boundary_check);
-  return Base::get(context, shape, elementType, attr);
+  return Base::get(context, shape, elementType, attr, sg_map);
 }
 
 TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
                                    mlir::Type elementType, int chunk_size,
-                                   MemorySpace memory_space) {
+                                   MemorySpace memory_space,
+                                   mlir::Attribute sg_map) {
   auto context = elementType.getContext();
   auto attr = ScatterTensorDescAttr::get(context, memory_space, chunk_size);
-  return Base::get(context, shape, elementType, attr);
+  return Base::get(context, shape, elementType, attr, sg_map);
 }
 
 } // namespace xegpu
diff --git a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
index 6db57aad773aa8..a4587faa3345cb 100644
--- a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
+++ b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
@@ -13,6 +13,14 @@ gpu.func @test_create_nd_tdesc_vc_1(%src: memref<24x32xf32>) {
   gpu.return
 }
 
+// CHECK: gpu.func @test_create_nd_tdesc_with_sg_map(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @test_create_nd_tdesc_with_sg_map(%src: memref<24x32xf32>) {
+  // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+    !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  gpu.return
+}
+
 // CHECK: gpu.func @test_create_nd_tdesc_vc_2(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index) {
 gpu.func @test_create_nd_tdesc_vc_2(%src: ui64, %w : index, %h : index, %x : index, %y : index) {
   //CHECK: %[[C:.*]] = arith.constant 1 : index
@@ -43,6 +51,13 @@ gpu.func @test_create_nd_tdesc_vc_5(%src: memref<2x24x32xf32, 3>) {
   gpu.return
 }
 
+// CHECK: gpu.func @test_create_nd_tdesc_vc_6(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @test_create_nd_tdesc_vc_6(%src: memref<24x32xf32>) {
+  // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x16xf32, #xegpu.block_tdesc_attr<array_length = 2 : i64>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x16xf32, #xegpu.block_tdesc_attr<array_length = 2>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  gpu.return
+}
+
 // CHECK: gpu.func @test_prefetch_nd_vc(%[[arg0:.*]]: memref<24x32xf16>) {
 gpu.func @test_prefetch_nd_vc(%src: memref<24x32xf16>) {
   // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -120,6 +135,15 @@ gpu.func @test_create_tdesc_vc_1(%src: memref<?xf32, 3>) {
   gpu.return
 }
 
+// CHECK: gpu.func @test_create_tdesc_vc_with_sg_map(%[[arg0:.*]]: ui64) {
+gpu.func @test_create_tdesc_vc_with_sg_map(%src: ui64) {
+  //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
+  %1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
+  gpu.return
+}
+
 // CHECK: gpu.func @test_prefetch_vc(%[[arg0:.*]]: ui64) {
 gpu.func @test_prefetch_vc(%src: ui64) {
   //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>

@llvmbot
Copy link
Member

llvmbot commented Oct 2, 2024

@llvm/pr-subscribers-mlir

Author: Petr Kurapov (kurapov-peter)

Changes

…cs (#108864)

The PR adds an attribute (sg_map) describing the distribution of computation among work items for xegpu operations to be used in lowering passes. The map is attached to the tensor descriptor, so the constructor and the type are updated. Tests check the custom parser & printer. The attribute is optional now, so no other changes required. The complete description of the attribute can be found here.


Full diff: https://github.com/llvm/llvm-project/pull/110876.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td (+32)
  • (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td (+15-5)
  • (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp (+98-12)
  • (modified) mlir/test/Dialect/XeGPU/XeGPUOps.mlir (+24)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 26eec0d4f2082a..2aaa7fd4221ab1 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -142,4 +142,36 @@ def XeGPU_FenceScopeAttr:
     let assemblyFormat = "$value";
 }
 
+def XeGPU_SGMapAttr : XeGPUAttr<"SGMap", "sg_map"> {
+  let summary = [{
+    Describes the mapping between work item (WI) and the 2D tensor specified by the tensor descriptor.
+  }];
+  let description = [{
+    To distribute the XeGPU operation to work items, the tensor_desc must be specified with the sg_map
+    attribute at the tensor description creation time.
+    Within the `sg_map`, `wi_layout` specifies the layout of work items,
+    describing the mapping of work items to the tensor.
+    wi_layout[0] x wi_layout[1] must be equal to the total number of work items within a subgroup.
+    `wi_data` specifies the minimum number of data elements assigned to each work item for a single distribution.
+
+    E.g., #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>
+    In this example, the subgroup has 16 work items in wi_layout=[1, 16],
+    each accessing 1 element as specified by wi_data=[1, 1].
+
+    `wi_data[0] * wi_data[1]` can be greater than 1, meaning that each work item operates on multiple elements,
+    which is eventually lowered to "SIMT-flavor" vector, like SPIR-V vector or llvm vector, or packed to a storage data type.
+    The multiple elements indicated by `wi_data` can only be from one dimension and must be contiguous in the memory along either dimension.
+  }];
+  let parameters = (ins
+    ArrayRefParameter<"uint32_t">:$wi_layout,
+    ArrayRefParameter<"uint32_t">:$wi_data);
+
+  let builders = [
+    AttrBuilder<(ins)>
+  ];
+
+  let hasCustomAssemblyFormat = 1;
+  let genVerifyDecl = 1;
+}
+
 #endif // MLIR_DIALECT_XEGPU_IR_XEGPUATTRS_TD
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 0ce1211664b5ba..d09c5c1870d506 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -63,7 +63,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
     element-type ::= float-type | integer-type | index-type
     dim-list := (static-dim-list `x`)?
     static-dim-list ::= decimal-literal `x` decimal-literal
-    attr-list = (, memory_space = value)? (, arr_len = value)? (, boundary_check = value)? (, scattered = value)?
+    attr-list = (, memory_space = value)? (, arr_len = value)? (, boundary_check = value)? (, scattered = value)? (, sg_map `<` wi_layout = value, wi_data = value `>`)?
     ```
 
     Examples:
@@ -77,12 +77,16 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
 
     // A TensorDesc with 8x16 f32 elements for a memory region in shared memory space.
     xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr<memory_space = slm>>
+
+    // A TensorDesc with a sg_map
+    xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
     ```
   }];
 
   let parameters = (ins ArrayRefParameter<"int64_t">: $shape,
                         "mlir::Type": $elementType,
-                        OptionalParameter<"mlir::Attribute">: $encoding);
+                        OptionalParameter<"mlir::Attribute">: $encoding,
+                        OptionalParameter<"mlir::Attribute">: $sg_map);
 
   let builders = [
     TypeBuilderWithInferredContext<(ins
@@ -90,14 +94,16 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
       "mlir::Type": $elementType,
       CArg<"int", "1">: $array_length,
       CArg<"bool", "true">: $boundary_check,
-      CArg<"xegpu::MemorySpace", "xegpu::MemorySpace::Global">:$memory_space)>,
+      CArg<"xegpu::MemorySpace", "xegpu::MemorySpace::Global">:$memory_space,
+      CArg<"mlir::Attribute", "mlir::Attribute()">:$sg_map)>,
     TypeBuilderWithInferredContext<(ins
       "llvm::ArrayRef<int64_t>": $shape,
       "mlir::Type": $elementType,
       CArg<"int", "1">: $chunk_size,
-      CArg<"xegpu::MemorySpace", "xegpu::MemorySpace::Global">:$memory_space)>
+      CArg<"xegpu::MemorySpace", "xegpu::MemorySpace::Global">:$memory_space,
+      CArg<"mlir::Attribute", "mlir::Attribute()">:$sg_map)>
   ];
-
+  
   let extraClassDeclaration = [{
     using TensorType::clone;
     using mlir::ShapedType::Trait<TensorDescType>::getElementTypeBitWidth;
@@ -121,6 +127,10 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
       return llvm::dyn_cast_if_present<ScatterTensorDescAttr>(getEncoding());
     }
 
+    SGMapAttr getSGMapAttr() const {
+      return llvm::dyn_cast_if_present<SGMapAttr>(getSgMap());
+    }
+
     xegpu::MemorySpace getMemorySpace() const {
       auto block_attr = getEncodingAsBlockTensorDescAttr();
       if (block_attr && block_attr.getMemorySpace())
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 1dfbaed454c193..eb01b15de75c60 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -55,6 +55,77 @@ ScatterTensorDescAttr::get(mlir::MLIRContext *context,
   return Base::get(context, scopeAttr, chunkSizeAttr);
 }
 
+//===----------------------------------------------------------------------===//
+// XeGPU_SGMapAttr
+//===----------------------------------------------------------------------===//
+namespace {
+template <typename T, unsigned N>
+LogicalResult parseIntArrayField(::mlir::AsmParser &parser,
+                                 llvm::SmallVector<T, N> &result,
+                                 llvm::StringRef fieldName) {
+  if (failed(parser.parseKeyword(fieldName))) {
+    parser.emitError(parser.getCurrentLocation(),
+                     "unexpected field name. Expected " + fieldName + ".");
+    return failure();
+  }
+
+  if (failed(parser.parseEqual())) {
+    parser.emitError(parser.getCurrentLocation(), "expected '=' sign.");
+    return failure();
+  }
+
+  auto elemParser = [&]() -> llvm::ParseResult {
+    uint32_t elem = 0;
+    auto res = parser.parseInteger(elem);
+    result.push_back(elem);
+    return res;
+  };
+
+  return parser.parseCommaSeparatedList(AsmParser::Delimiter::Square,
+                                        elemParser, fieldName);
+}
+} // namespace
+
+mlir::Attribute SGMapAttr::parse(::mlir::AsmParser &parser,
+                                 ::mlir::Type attrType) {
+  if (failed(parser.parseLess()))
+    return {};
+
+  llvm::SmallVector<uint32_t, 2> wi_layout, wi_data;
+  if (failed(parseIntArrayField(parser, wi_layout, "wi_layout")))
+    return {};
+
+  if (failed(parser.parseComma()))
+    return {};
+
+  if (failed(parseIntArrayField(parser, wi_data, "wi_data")))
+    return {};
+
+  return SGMapAttr::getChecked(
+      [&]() { return parser.emitError(parser.getNameLoc()); },
+      parser.getContext(), wi_layout, wi_data);
+}
+
+void SGMapAttr::print(::mlir::AsmPrinter &printer) const {
+  printer << "<";
+  printer.printKeywordOrString("wi_layout");
+  printer << " = [" << getWiLayout() << "], ";
+  printer.printKeywordOrString("wi_data");
+  printer << " = [" << getWiData() << "]";
+  printer << ">";
+}
+
+LogicalResult
+SGMapAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
+                  llvm::ArrayRef<uint32_t> wi_layout,
+                  llvm::ArrayRef<uint32_t> wi_data) {
+  if (wi_layout.size() != 2)
+    return emitError() << "expected wi_layout of size 2";
+  if (wi_data.size() != 2)
+    return emitError() << "expected wi_data of size 2";
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // XeGPU_TensorDescType
 //===----------------------------------------------------------------------===//
@@ -63,6 +134,7 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
   llvm::SmallVector<int64_t> shape;
   mlir::Type elementType;
   mlir::FailureOr<mlir::Attribute> encoding;
+  mlir::FailureOr<mlir::Attribute> sg_map;
 
   // Parse literal '<'
   if (parser.parseLess())
@@ -81,14 +153,22 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
   }
 
   // parse optional attributes
-  if (mlir::succeeded(parser.parseOptionalComma())) {
-    encoding = mlir::FieldParser<mlir::Attribute>::parse(parser);
-    if (mlir::failed(encoding)) {
-      parser.emitError(
-          parser.getCurrentLocation(),
-          "Failed to parse the attribute field for TensorDescType.\n");
-      return {};
+  while (mlir::succeeded(parser.parseOptionalComma())) {
+    mlir::Attribute attr;
+    ParseResult res = parser.parseAttribute(attr);
+    if (mlir::succeeded(res)) {
+      if (mlir::isa<SGMapAttr>(attr)) {
+        sg_map = attr;
+        continue;
+      }
+      if (mlir::isa<BlockTensorDescAttr, ScatterTensorDescAttr>(attr)) {
+        encoding = attr;
+        continue;
+      }
     }
+    parser.emitError(parser.getCurrentLocation(),
+                     "Failed to parse the attribute.\n");
+    return {};
   }
 
   // Parse literal '>'
@@ -96,7 +176,8 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
     return {};
 
   return TensorDescType::get(parser.getContext(), shape, elementType,
-                             encoding.value_or(mlir::Attribute()));
+                             encoding.value_or(mlir::Attribute()),
+                             sg_map.value_or(mlir::Attribute()));
 }
 
 void TensorDescType::print(::mlir::AsmPrinter &printer) const {
@@ -116,25 +197,30 @@ void TensorDescType::print(::mlir::AsmPrinter &printer) const {
   if (auto encoding = getEncoding())
     printer << ", " << encoding;
 
+  if (auto sg_map = getSgMap())
+    printer << ", " << sg_map;
+
   printer << ">";
 }
 
 TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
                                    mlir::Type elementType, int array_length,
                                    bool boundary_check,
-                                   MemorySpace memory_space) {
+                                   MemorySpace memory_space,
+                                   mlir::Attribute sg_map) {
   auto context = elementType.getContext();
   auto attr = BlockTensorDescAttr::get(context, memory_space, array_length,
                                        boundary_check);
-  return Base::get(context, shape, elementType, attr);
+  return Base::get(context, shape, elementType, attr, sg_map);
 }
 
 TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
                                    mlir::Type elementType, int chunk_size,
-                                   MemorySpace memory_space) {
+                                   MemorySpace memory_space,
+                                   mlir::Attribute sg_map) {
   auto context = elementType.getContext();
   auto attr = ScatterTensorDescAttr::get(context, memory_space, chunk_size);
-  return Base::get(context, shape, elementType, attr);
+  return Base::get(context, shape, elementType, attr, sg_map);
 }
 
 } // namespace xegpu
diff --git a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
index 6db57aad773aa8..a4587faa3345cb 100644
--- a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
+++ b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
@@ -13,6 +13,14 @@ gpu.func @test_create_nd_tdesc_vc_1(%src: memref<24x32xf32>) {
   gpu.return
 }
 
+// CHECK: gpu.func @test_create_nd_tdesc_with_sg_map(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @test_create_nd_tdesc_with_sg_map(%src: memref<24x32xf32>) {
+  // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+    !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  gpu.return
+}
+
 // CHECK: gpu.func @test_create_nd_tdesc_vc_2(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index) {
 gpu.func @test_create_nd_tdesc_vc_2(%src: ui64, %w : index, %h : index, %x : index, %y : index) {
   //CHECK: %[[C:.*]] = arith.constant 1 : index
@@ -43,6 +51,13 @@ gpu.func @test_create_nd_tdesc_vc_5(%src: memref<2x24x32xf32, 3>) {
   gpu.return
 }
 
+// CHECK: gpu.func @test_create_nd_tdesc_vc_6(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @test_create_nd_tdesc_vc_6(%src: memref<24x32xf32>) {
+  // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x16xf32, #xegpu.block_tdesc_attr<array_length = 2 : i64>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x16xf32, #xegpu.block_tdesc_attr<array_length = 2>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  gpu.return
+}
+
 // CHECK: gpu.func @test_prefetch_nd_vc(%[[arg0:.*]]: memref<24x32xf16>) {
 gpu.func @test_prefetch_nd_vc(%src: memref<24x32xf16>) {
   // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -120,6 +135,15 @@ gpu.func @test_create_tdesc_vc_1(%src: memref<?xf32, 3>) {
   gpu.return
 }
 
+// CHECK: gpu.func @test_create_tdesc_vc_with_sg_map(%[[arg0:.*]]: ui64) {
+gpu.func @test_create_tdesc_vc_with_sg_map(%src: ui64) {
+  //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
+  %1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
+  gpu.return
+}
+
 // CHECK: gpu.func @test_prefetch_vc(%[[arg0:.*]]: ui64) {
 gpu.func @test_prefetch_vc(%src: ui64) {
   //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>

Copy link
Contributor

@chencha3 chencha3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

AttrBuilder<(ins)>
];

let hasCustomAssemblyFormat = 1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: does this have to use custom assembly format?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I added it to add less change to other parts. I can remove it later on.

@chencha3 chencha3 merged commit 9fa55ec into llvm:main Oct 2, 2024
11 checks passed
Sterling-Augustine pushed a commit to Sterling-Augustine/llvm-project that referenced this pull request Oct 3, 2024
@kurapov-peter kurapov-peter deleted the xegpu-sgmap branch October 3, 2024 12:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants