Skip to content

[mlir][xegpu] add support for structure control flow ops in workgroup to subgroup distribution #142618

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jun 13, 2025

Conversation

chencha3
Copy link
Contributor

@chencha3 chencha3 commented Jun 3, 2025

This PR introduces support for scf::ForOp, scf::WhileOp, scf::If, and scf::Condition within the workgroup-subgroup-distribution pass. Due to the current TypeConverter's lack of context-aware conversion for VectorType, it leverages xegpu::doSCFStructuralTypeConversionWithTensorType to convert VectorType operands, arguments, and results of these operations. The conversion process involves initially transforming them into RankedTensorType values, allowing the layout attribute to be preserved within the encoding attribute. Subsequently, the RankedTensorType is converted back to VectorType, guided by the layout attribute encoded in the encoding attribute.

Copy link

github-actions bot commented Jun 3, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@chencha3 chencha3 requested review from nbpatel and charithaintc June 4, 2025 17:34
@chencha3 chencha3 marked this pull request as ready for review June 4, 2025 17:34
@llvmbot
Copy link
Member

llvmbot commented Jun 4, 2025

@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-libc
@llvm/pr-subscribers-debuginfo
@llvm/pr-subscribers-mc
@llvm/pr-subscribers-lld-elf
@llvm/pr-subscribers-lld-macho
@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Chao Chen (chencha3)

Changes

Patch is 22.24 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/142618.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h (+3)
  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp (+136-24)
  • (modified) mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp (+3-3)
  • (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir (+25-1)
  • (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir (+67-3)
diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index f9327d63869c0..6fea10185402a 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -26,6 +26,9 @@ class TensorDescType;
 
 namespace xegpu {
 
+/// Flatten a set of ValueRange into a single SmallVector<Value>
+SmallVector<Value> flattenValues(ArrayRef<ValueRange> values);
+
 /// If tensor descriptor has a layout attribute it is used in SIMT mode.
 /// In this mode, the distributed vector shape is determined as follows:
 /// Definitions:
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 3bf76af674ba0..e29da76898c58 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
 #include "mlir/Transforms/DialectConversion.h"
 
 namespace mlir {
@@ -29,6 +30,29 @@ using namespace mlir;
 
 namespace {
 
+static std::pair<SmallVector<int64_t>, int>
+getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) {
+  int count = 1;
+  SmallVector<int64_t> sgShape(shape);
+
+  if (layout && layout.isWgLayout()) {
+    DenseI32ArrayAttr sgLayoutAttr = layout.getSgLayout();
+    auto sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
+    if (DenseI32ArrayAttr sgDataAttr = layout.getSgData())
+      sgShape = llvm::to_vector_of<int64_t>(sgDataAttr.asArrayRef());
+    else
+      sgShape = computeShapeRatio(shape, sgLayout).value_or(sgShape);
+    SmallVector<int64_t> distUnit = computeElementwiseMul(sgLayout, sgShape);
+    // Clamp distUnit to the original shape to handle cases where data is
+    // shared among subgroups, which may cause distUnit to exceed the original
+    // shape.
+    for (size_t i = 0; i < distUnit.size(); ++i)
+      distUnit[i] = std::min(shape[i], distUnit[i]);
+    count = computeProduct(shape) / computeProduct(distUnit);
+  }
+  return std::make_pair(sgShape, count);
+};
+
 /// This pattern transforms the CreateNdDescOp to create a subgroup descriptor
 /// from a workgroup descriptor. It replaces the offsets and sizes with
 /// appropriate values for the subgroup.
@@ -129,18 +153,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
       return rewriter.notifyMatchFailure(
           op, "sgLayout attribute is required in layout");
 
-    SmallVector<int64_t> sgShape;
-    if (auto sgDataAttr = layout.getSgData()) {
-      sgShape = llvm::to_vector_of<int64_t>(sgDataAttr.asArrayRef());
-    } else {
-      assert(wgShape.size() == sgLayout.size() &&
-             "sgLayout and wgShape must have the same rank");
-      sgShape.reserve(wgShape.size());
-      for (size_t i = 0; i < wgShape.size(); ++i) {
-        assert(sgLayout[i] != 0 && "sgLayout elements must be non-zero");
-        sgShape.push_back(wgShape[i] / sgLayout[i]);
-      }
-    }
+    SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
 
     // TODO : Handle order attribute
     // Get the subgroup ID
@@ -266,15 +279,15 @@ struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
     if (resultTy.getRank() != 2)
       return failure();
 
-    auto originalLayout =
-        llvm::dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr("layout"));
+    auto originalLayout = xegpu::getLayoutAttr(op.getResult());
     if (!originalLayout)
       return failure();
 
-    SmallVector<Value> newDpasOps;
     size_t i = 0;
+    SmallVector<Value> newDpasOps;
     for (auto aVec : adaptor.getLhs()) {
       for (auto bVec : adaptor.getRhs()) {
+
         llvm::SmallVector<Value> operands({aVec, bVec});
         Value tmpC;
         if (op.getAcc()) {
@@ -288,10 +301,10 @@ struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
             llvm::cast<VectorType>(bVec.getType()).getShape();
         VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]},
                                            resultTy.getElementType());
-        tmpC = rewriter.create<xegpu::DpasOp>(
-            loc, resTy, operands,
-            llvm::ArrayRef<NamedAttribute>(
-                {"layout_result_0", originalLayout.dropSgLayoutAndData()}));
+        tmpC = rewriter.create<xegpu::DpasOp>(loc, resTy, operands);
+        xegpu::setLayoutAttr(cast<OpResult>(tmpC),
+                             originalLayout.dropSgLayoutAndData());
+
         newDpasOps.push_back(tmpC);
       }
     }
@@ -314,14 +327,69 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
   }
 };
 
+// Handles UnrealizedConversionCastOp generated during
+// SCFStructuralTypeConversions (step 1). This op may appear as either a
+// target or source materialization for Vector or TensorDesc, e.g.:
+// 1. unrealized_conversion_cast %1 : tensor_desc<16xf16> to
+// tensor_desc<128xf16, ...>
+// 2. unrealized_conversion_cast %1 : vector<256xf32> to vector<16xf32>, ...
+// 3. unrealized_conversion_cast %1 : vector<16xf32>, ... to vector<256xf32>
+// In all cases, the pattern simply forwards the inputs to the outputs with
+// one-to-one or one-to-n patterns.
+struct UnrealizedConversionCastOpPattern
+    : public OpConversionPattern<mlir::UnrealizedConversionCastOp> {
+  using OpConversionPattern<
+      mlir::UnrealizedConversionCastOp>::OpConversionPattern;
+
+  mlir::LogicalResult
+  matchAndRewrite(mlir::UnrealizedConversionCastOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    SmallVector<Value> inputs = xegpu::flattenValues(adaptor.getInputs());
+
+    // Handles the case where cast %1 : tensor_desc<16xf16> to
+    // tensor_desc<128xf16, ...> The input values provided by the adaptor should
+    // already be distributed.
+    if (op.getNumOperands() == 1 && op.getNumResults() == 1 &&
+        isa<xegpu::TensorDescType>(op->getOperand(0).getType()) &&
+        isa<xegpu::TensorDescType>(op->getResult(0).getType())) {
+      rewriter.replaceOp(op, inputs);
+      return success();
+    }
+
+    // Handles the case where cast %1 : vector<256xf32> to vector<16xf32>, ...
+    // the input values provided by the adaptor should already be distributed,
+    // and their types should correspond exactly to the result types of the
+    // operation.
+    if (op.getNumOperands() == 1 &&
+        llvm::equal(ValueRange(inputs).getTypes(), op->getResultTypes())) {
+      rewriter.replaceOp(op, inputs);
+      return success();
+    }
+
+    // Handles the case where cast %1 : vector<16xf32>, ... to vector<256xf32>.
+    // All input values must have the same vector type, and their shape must be
+    // evenly divisible by the output vector's shape.
+    auto inputTy = dyn_cast<VectorType>(inputs[0].getType());
+    auto outputTy = dyn_cast<VectorType>(op->getOpResult(0).getType());
+    if (op.getNumResults() == 1 && inputTy && outputTy &&
+        llvm::all_equal(ValueRange(inputs).getTypes()) &&
+        computeShapeRatio(outputTy.getShape(), inputTy.getShape())) {
+      rewriter.replaceOpWithMultiple(op, {inputs});
+      return success();
+    }
+
+    return mlir::failure();
+  }
+};
+
 } // namespace
 
 namespace mlir {
 namespace xegpu {
 void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
   patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
-               WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp>(
-      patterns.getContext());
+               WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
+               UnrealizedConversionCastOpPattern>(patterns.getContext());
 }
 } // namespace xegpu
 } // namespace mlir
@@ -334,6 +402,47 @@ struct XeGPUWgToSgDistributePass
 } // namespace
 
 void XeGPUWgToSgDistributePass::runOnOperation() {
+  TypeConverter converter;
+  converter.addConversion([&](Type type) -> Type { return type; });
+  converter.addConversion(
+      [&](RankedTensorType type,
+          SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
+        Type elemTy = type.getElementType();
+        ArrayRef<int64_t> shape = type.getShape();
+
+        int count;
+        SmallVector<int64_t> subShape;
+        std::tie(subShape, count) = getSgShapeAndCount(
+            shape, dyn_cast<xegpu::LayoutAttr>(type.getEncoding()));
+
+        auto newTy = VectorType::get(subShape, elemTy);
+        result.append(count, newTy);
+        return success();
+      });
+
+  converter.addConversion(
+      [&](xegpu::TensorDescType type,
+          SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
+        Type elemTy = type.getElementType();
+        ArrayRef<int64_t> shape = type.getShape();
+
+        int count;
+        SmallVector<int64_t> subShape;
+        xegpu::LayoutAttr layout = type.getLayoutAttr();
+        std::tie(subShape, count) = getSgShapeAndCount(shape, layout);
+
+        if (layout)
+          layout = layout.dropSgLayoutAndData();
+
+        auto newTy = xegpu::TensorDescType::get(
+            type.getContext(), subShape, elemTy, type.getEncoding(), layout);
+        result.append(count, newTy);
+        return success();
+      });
+
+  // step1: perform SCFStructuralTypeConversions on SCF ops
+  xegpu::doSCFStructuralTypeConversionWithTensorType(getOperation(), converter);
+
   MLIRContext *ctx = &getContext();
   RewritePatternSet patterns(ctx);
   ConversionTarget target(*ctx);
@@ -353,24 +462,27 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
   };
 
   auto isLegal = [&](xegpu::LayoutAttr layout) -> bool {
-    return !layout || layout.getSgLayout() == nullptr;
+    return !layout || !layout.isWgLayout();
   };
 
   target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::LoadNdOp,
                                xegpu::StoreNdOp, xegpu::UpdateNdOffsetOp,
                                xegpu::PrefetchNdOp>([=](Operation *op) -> bool {
     auto tdescTy = getTensorDescType(op);
-    auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(tdescTy.getLayout());
+    auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(tdescTy.getLayout());
     return isLegal(layout);
   });
 
   target.addDynamicallyLegalOp<xegpu::DpasOp>([=](xegpu::DpasOp op) -> bool {
-    auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr("layout"));
+    auto layout = xegpu::getLayoutAttr(op.getResult());
     return isLegal(layout);
   });
 
+  target.addIllegalOp<UnrealizedConversionCastOp>();
+
   target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
 
+  // step2: Perform for workgroup to subgroup distribution for rest ops
   xegpu::populateXeGPUWgToSgDistributePatterns(patterns);
   if (failed(
           applyPartialConversion(getOperation(), target, std::move(patterns))))
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index dcaf4e85a82c5..6b85a66a8bd36 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -27,7 +27,7 @@
 using namespace mlir;
 
 /// convert ArrayRef<ValueRange> into SmallVector<Value>
-static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
+SmallVector<Value> xegpu::flattenValues(ArrayRef<ValueRange> values) {
   SmallVector<Value> result;
   for (const auto &vals : values)
     llvm::append_range(result, vals);
@@ -271,7 +271,7 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(
       auto resultTy = dyn_cast<RankedTensorType>(result.getType());
 
       // Only look at ops casting from VectorType to RankedTensorType
-      if (!isa<VectorType>(inputTy) || !isa<RankedTensorType>(resultTy))
+      if (!inputTy || !resultTy)
         return WalkResult::skip();
 
       xegpu::LayoutAttr layout = xegpu::getLayoutAttr(input);
@@ -342,7 +342,7 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(
         }
 
         if (isa<RankedTensorType>(inputTy) && isa<VectorType>(outputTy)) {
-          SmallVector<Value> values = flattenValues(adaptor.getInputs());
+          SmallVector<Value> values = xegpu::flattenValues(adaptor.getInputs());
           auto newOp = rewriter.create<UnrealizedConversionCastOp>(
               op.getLoc(), outputTy, values);
           rewriter.replaceOp(op, newOp);
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
index bee026eb2084d..ff86e65300bb8 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -85,7 +85,7 @@ gpu.module @test_round_robin_assignment {
     %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<8x8xf32>
       -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
     %dpas = xegpu.dpas %load_a, %load_b
-      {layout =  #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+      {layout_result_0 =  #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
       : vector<8x8xf32>, vector<8x8xf32> -> vector<8x8xf32>
     gpu.return
   }
@@ -102,4 +102,28 @@ gpu.module @test_round_robin_assignment {
       : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
     gpu.return
   }
+
+  gpu.func @test_scf_while_and_condition(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) {
+    %c1_i32 = arith.constant 1 : i32
+    %c10_i32 = arith.constant 10 : i32
+    %c0_i32 = arith.constant 0 : i32
+    %0 = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
+    %1 = xegpu.load_nd %0  : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>> -> vector<256xf32>
+    %2 = xegpu.create_nd_tdesc %arg1[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
+    //CHECK: scf.while ({{.*}}) : (vector<16xf32>, vector<16xf32>, i32) -> (vector<16xf32>, vector<16xf32>, i32)
+    %3:2 = scf.while (%arg2 = %1, %arg3 = %c0_i32) : (vector<256xf32>, i32) -> (vector<256xf32>, i32) {
+      %4 = arith.cmpi slt, %arg3, %c10_i32 : i32
+      //CHECK: scf.condition{{.*}} : vector<16xf32>, vector<16xf32>, i32
+      scf.condition(%4) %arg2, %arg3 : vector<256xf32>, i32
+    } do {
+    // CHECK: ([[arg2:%.+]]: vector<16xf32>, [[arg3:%.+]]: vector<16xf32>, [[arg4:%.+]]: i32)
+    ^bb0(%arg2: vector<256xf32>, %arg3: i32):
+      xegpu.store_nd %arg2, %2  : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
+      %4 = arith.addi %arg3, %c1_i32 : i32
+      %5 = xegpu.update_nd_offset %0, [256] : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
+      %6 = xegpu.load_nd %5  : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>> -> vector<256xf32>
+      scf.yield %6, %4 : vector<256xf32>, i32
+    }
+    gpu.return
+  }
 }
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index 7e89ada934071..d016d3a30a339 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -5,7 +5,7 @@
 gpu.module @test_1_1_assignment {
   // CHECK-LABEL: test_create_nd_tdesc
   // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
-  gpu.func @test_create_nd_tdesc(%src: memref<24x32xf32>) {  
+  gpu.func @test_create_nd_tdesc(%src: memref<24x32xf32>) {
   // CHECK: %[[SGID:.*]] = gpu.subgroup_id
   // CHECK: %[[C12:.*]] = arith.constant 12 : index
   // CHECK: %[[C4:.*]] = arith.constant 4 : index
@@ -108,7 +108,7 @@ gpu.func @test_dpas(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
       : !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [8, 12], lane_layout = [8, 2], lane_data = [1, 1]>>
       -> vector<32x24xf32>
     %dpas = xegpu.dpas %load_a, %load_b
-      {layout =  #xegpu.layout<sg_layout = [2, 2], sg_data = [12, 12], lane_layout = [2, 2], lane_data = [1, 1]>}
+      {layout_result_0 =  #xegpu.layout<sg_layout = [2, 2], sg_data = [12, 12], lane_layout = [2, 2], lane_data = [1, 1]>}
       : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32>
     gpu.return
   }
@@ -142,7 +142,7 @@ gpu.func @test_dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
       : !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], lane_layout = [8, 2], lane_data = [1, 1]>>
       -> vector<32x24xf32>
     %dpas = xegpu.dpas %load_a, %load_b
-      {layout =  #xegpu.layout<sg_layout = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+      {layout_result_0 =  #xegpu.layout<sg_layout = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
       : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32>
     gpu.return
   }
@@ -169,4 +169,68 @@ gpu.func @test_dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
       : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32>
     gpu.return
   }
+
+  gpu.func @test_scf_for(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) {
+    //CHECK: [[c0:%.+]] = arith.constant 0 : index
+    //CHECK: [[c128:%.+]] = arith.constant 128 : index
+    //CHECK: [[c1024:%.+]] = arith.constant 1024 : index
+    %c0 = arith.constant 0 : index
+    %c128 = arith.constant 128 : index
+    %c1024 = arith.constant 1024 : index
+    %block_id_x = gpu.block_id  x
+    %block_id_y = gpu.block_id  y
+    %0 = arith.muli %block_id_x, %c128 : index
+    %1 = arith.muli %block_id_y, %c128 : index
+    %2 = xegpu.create_nd_tdesc %arg2[%0, %1] : memref<1024x1024xf32> -> !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>>
+    %3 = xegpu.load_nd %2 : !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>> -> vector<128x128xf32>
+    %4 = xegpu.create_nd_tdesc %arg0[%0, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>>
+    %5 = xegpu.create_nd_tdesc %arg1[%c0, %1] : memref<1024x1024xf16> -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>>
+
+    //CHECK: [[scf:%.+]]:3 = scf.for [[arg3:%.+]] = [[c0]] to [[c1024]] step [[c128]] iter_args([[arg4:%.+]] = {{.*}}, [[arg5:%.+]] = {{.*}}, [[arg6:%.+]] = {{.*}}) -> (!xegpu.tensor_desc<16x128xf16>, !xegpu.tensor_desc<128x16xf16>, vector<16x16xf32>)
+    //CHECK: [[a:%.+]] = xegpu.load_nd [[arg4]] : !xegpu.tensor_desc<16x128xf16> -> vector<16x128xf16>
+    //CHECK: [[b:%.+]] = xegpu.load_nd [[arg5]] : !xegpu.tensor_desc<128x16xf16> -> vector<128x16xf16>
+    //CHECK: [[c:%.+]] = xegpu.dpas [[a]], [[b]], [[arg6]] : vector<16x128xf16>, vector<128x16xf16>, vector<16x16xf32> -> vector<16x16xf32>
+    //CHECK: [[at:%.+]] = xegpu.update_nd_offset [[arg4]], [[[c0]], [[c128]]] : !xegpu.tensor_desc<16x128xf16>
+    //CHECK: [[bt:%.+]] = xegpu.update_nd_offset [[arg5]], [[[c128]], [[c0]]] : !xegpu.tensor_desc<128x16xf16>
+    //CHECK: scf.yield [[at]], [[bt]], [[c]] : !xegpu.tensor_desc<16x128xf16>, !xegpu.tensor_desc<128x16xf16>, vector<16x16xf32>
+    %6:3 = scf.for %arg3 = %c0 to %c1024 step %c128 iter_args(%arg4 = %4, %arg5 = %5, %arg6 = %3) -> (!xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>>, !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>>, vector<128x128xf32>) {
+      %8 = xegpu.load_nd %arg4  : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>> -> vector<128x128xf16>
+      %9 = xegpu.load_nd %arg5  : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>> -> vector<128x128xf16>
+      %10 = xegpu.dpas %8, %9, %arg6 {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>} : vector<128x128xf16>, vector<128x128xf16>, vector<128x128xf32> -> vector<128x128xf32>
+      %11 = xegpu.update_nd_offset %arg4, [%c0, %c128] : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>>
+      %12 = xegpu.upd...
[truncated]

int count = 1;
SmallVector<int64_t> sgShape(shape);

if (layout && layout.isWgLayout()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isWgLayout seems confusing, I think it should be called isSgLayout since it describes how the subgroups are laid out

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an interface defined in a previous PR. I think we can create a small fix PR if we plan to change it.

// VectorType operands. This first converts such operands to RankedTensorType,
// propagates the layout attribute into the encoding attribute, and finally
// converts the RankedTensorType to VectorType based on the encoding.
xegpu::doSCFStructuralTypeConversionWithTensorType(getOperation(), converter);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: do we need the "do"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel it needs a verb-prefix. I am open to change it to something else instead of do.

%4 = xegpu.create_nd_tdesc %arg0[%0, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>>
%5 = xegpu.create_nd_tdesc %arg1[%c0, %1] : memref<1024x1024xf16> -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>>

//CHECK: [[scf:%.+]]:3 = scf.for [[arg3:%.+]] = [[c0]] to [[c1024]] step [[c128]] iter_args([[arg4:%.+]] = {{.*}}, [[arg5:%.+]] = {{.*}}, [[arg6:%.+]] = {{.*}}) -> (!xegpu.tensor_desc<16x128xf16>, !xegpu.tensor_desc<128x16xf16>, vector<16x16xf32>)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

break line, for both CHECKS and test case

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

gpu.return
}


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we should do one test with nested control flow?

scf.for {
scf.if {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found we need to do it later, since elementwise or constant is not supported yet. It is hard to have a meaningful test without them.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok we can add it later

return failure();

// Handles the case where cast %1 : vector<256xf32> to vector<16xf32>, ...
// the input values provided by the adaptor should already be distributed,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you clarify who distributed the input? is it done by SCFStructuralTypeConversions? that means the input maybe coming from some structural op.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added an example for the pattern, hope it can help to understand. For arguments and results (N:1 case), they are generated by SCFStructuralTypeConversions, for 1:N case, they are generated by patterns of, e.g., create_nd etc.

// VectorType operands. This first converts such operands to RankedTensorType,
// propagates the layout attribute into the encoding attribute, and finally
// converts the RankedTensorType to VectorType based on the encoding.
xegpu::doSCFStructuralTypeConversionWithTensorType(getOperation(), converter);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after this point how does the IR looks like? all SCF operations are distributed and there is no ranked tensor type in the IR. is that correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. All VectorType operands/arguments/Results of SCF::If, SCF::For, SCF::While and SCF::Condition ops will be converted.

if (auto layout = op->getAttrOfType<xegpu::LayoutAttr>(name)) {
op->removeAttr(name);
if (!isa<scf::IfOp, scf::ForOp, scf::WhileOp, scf::ConditionOp>(op))
op->setAttr(name, layout.dropInstData());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why drop inst data?

Copy link
Contributor Author

@chencha3 chencha3 Jun 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a bug. Fixed

Copy link
Contributor

@nbpatel nbpatel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@charithaintc charithaintc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@chencha3 chencha3 merged commit 5578bcb into llvm:main Jun 13, 2025
7 checks passed
tomtor pushed a commit to tomtor/llvm-project that referenced this pull request Jun 14, 2025
… to subgroup distribution (llvm#142618)

This PR introduces support for `scf::ForOp`, `scf::WhileOp`, `scf::If`,
and `scf::Condition` within the workgroup-subgroup-distribution pass,
leveraging the `SCFStructuralTypeConversionsAndLegality`.
akuhlens pushed a commit to akuhlens/llvm-project that referenced this pull request Jun 24, 2025
… to subgroup distribution (llvm#142618)

This PR introduces support for `scf::ForOp`, `scf::WhileOp`, `scf::If`,
and `scf::Condition` within the workgroup-subgroup-distribution pass,
leveraging the `SCFStructuralTypeConversionsAndLegality`.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants