Skip to content

[mlir][vector] Add scalable lowering for transfer_write(transpose) #101353

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Aug 12, 2024

Conversation

MacDue
Copy link
Member

@MacDue MacDue commented Jul 31, 2024

This specifically handles the case of a transpose from a vector type like vector<8x[4]xf32> to vector<[4]x8xf32>. Such transposes occur fairly frequently when scalably vectorizing linalg.generics. There is no direct lowering for these (as types like vector<[4]x8xf32> cannot be represented in LLVM-IR). However, if the only use of the transpose is a write, then it is possible to lower the transfer_write(transpose) as a VLA loop.

Example:

%transpose = vector.transpose %vec, [1, 0]
   : vector<4x[4]xf32> to vector<[4]x4xf32>
vector.transfer_write %transpose, %dest[%i, %j] {in_bounds = [true, true]}
   : vector<[4]x4xf32>,  memref<?x?xf32>

Becomes:

%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
%c0 = arith.constant 0 : index
%0 = vector.extract %arg0[0] : vector<[4]xf32> from vector<4x[4]xf32>
%1 = vector.extract %arg0[1] : vector<[4]xf32> from vector<4x[4]xf32>
%2 = vector.extract %arg0[2] : vector<[4]xf32> from vector<4x[4]xf32>
%3 = vector.extract %arg0[3] : vector<[4]xf32> from vector<4x[4]xf32>
%vscale = vector.vscale
%c4_vscale = arith.muli %vscale, %c4 : index
scf.for %idx = %c0 to %c4_vscale step %c1 {
  %4 = vector.extract %0[%idx] : f32 from vector<[4]xf32>
  %5 = vector.extract %1[%idx] : f32 from vector<[4]xf32>
  %6 = vector.extract %2[%idx] : f32 from vector<[4]xf32>
  %7 = vector.extract %3[%idx] : f32 from vector<[4]xf32>
  %slice_i = affine.apply #map(%idx)[%i]
  %slice = vector.from_elements %4, %5, %6, %7 : vector<4xf32>
  vector.transfer_write %slice, %arg1[%slice_i, %j] {in_bounds = [true]}
    : vector<4xf32>, memref<?x?xf32>
}

@llvmbot
Copy link
Member

llvmbot commented Jul 31, 2024

@llvm/pr-subscribers-mlir

Author: Benjamin Maxwell (MacDue)

Changes

This specifically handles the case of a transpose from a vector type like vector&lt;8x[4]xf32&gt; to vector&lt;[4]x8xf32&gt;. Such transposes occur fairly frequently when scalably vectorizing linalg.generics. There is no direct lowering for these (as types like vector&lt;[4]x8xf32&gt; cannot be represented in LLVM-IR). However, if the only use of the transpose is a write, then it is possible to lower the transfer_write(transpose) as a VLA loop.

Example:

%transpose = vector.transpose %vec, [1, 0]
   : vector&lt;4x[4]xf32&gt; to vector&lt;[4]x4xf32&gt;
vector.transfer_write %transpose, %dest[%i, %j] {in_bounds = [true, true]}
   : vector&lt;[4]x4xf32&gt;,  memref&lt;?x?xf32&gt;

Becomes:

%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
%c0 = arith.constant 0 : index
%0 = vector.extract %arg0[0] : vector&lt;[4]xf32&gt; from vector&lt;4x[4]xf32&gt;
%1 = vector.extract %arg0[1] : vector&lt;[4]xf32&gt; from vector&lt;4x[4]xf32&gt;
%2 = vector.extract %arg0[2] : vector&lt;[4]xf32&gt; from vector&lt;4x[4]xf32&gt;
%3 = vector.extract %arg0[3] : vector&lt;[4]xf32&gt; from vector&lt;4x[4]xf32&gt;
%vscale = vector.vscale
%c4_vscale = arith.muli %vscale, %c4 : index
scf.for %idx = %c0 to %c4_vscale step %c1 {
  %4 = vector.extract %0[%idx] : f32 from vector&lt;[4]xf32&gt;
  %5 = vector.extract %1[%idx] : f32 from vector&lt;[4]xf32&gt;
  %6 = vector.extract %2[%idx] : f32 from vector&lt;[4]xf32&gt;
  %7 = vector.extract %3[%idx] : f32 from vector&lt;[4]xf32&gt;
  %slice_i = affine.apply #map(%idx)[%i]
  %slice = vector.from_elements %4, %5, %6, %7 : vector&lt;4xf32&gt;
  vector.transfer_write %slice, %arg1[%slice_i, %j] {in_bounds = [true]}
    : vector&lt;4xf32&gt;, memref&lt;?x?xf32&gt;
}

Patch is 23.12 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/101353.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Conversion/Passes.td (+3-1)
  • (modified) mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h (+8)
  • (modified) mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp (+186-1)
  • (modified) mlir/test/Conversion/VectorToSCF/tensor-transfer-ops.mlir (+14-1)
  • (modified) mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir (+117-9)
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index b5bb2f42f2961..7bde9e490e4f4 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1300,7 +1300,9 @@ def ConvertVectorToSCF : Pass<"convert-vector-to-scf"> {
     Option<"targetRank", "target-rank", "unsigned", /*default=*/"1",
            "Target vector rank to which transfer ops should be lowered">,
     Option<"lowerTensors", "lower-tensors", "bool", /*default=*/"false",
-           "Lower transfer ops that operate on tensors">
+           "Lower transfer ops that operate on tensors">,
+    Option<"lowerScalable", "lower-scalable", "bool", /*default=*/"false",
+           "Add scalable vector specific lowerings (that introduce loops)">
   ];
 }
 
diff --git a/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h b/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h
index 1c834b6c69083..e0ef67c39a101 100644
--- a/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h
+++ b/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h
@@ -69,6 +69,14 @@ struct VectorTransferToSCFOptions {
     unroll = u;
     return *this;
   }
+  /// Enable scalable vector specific lowerings (which introduce loops). These
+  /// work alongside fullUnroll (which unrolls until the first scalable
+  /// dimension).
+  bool lowerScalable = false;
+  VectorTransferToSCFOptions enableLowerScalable(bool enable = true) {
+    lowerScalable = enable;
+    return *this;
+  }
 };
 
 /// Collect a set of patterns to convert from the Vector dialect to SCF + func.
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 19f02297bfbb7..c2fb8b6161ec3 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -24,6 +24,7 @@
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/Pass/Pass.h"
@@ -987,6 +988,185 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
   }
 };
 
+/// Retrieves the dimensions sizes of a mask. Currently supports CreateMaskOp
+/// and ConstantMaskOp.
+template <typename VscaleConstantBuilder>
+static FailureOr<SmallVector<OpFoldResult>>
+getMaskDimSizes(Value mask, VscaleConstantBuilder &createVscaleMultiple) {
+  if (!mask)
+    return SmallVector<OpFoldResult>{};
+  if (auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>()) {
+    return llvm::map_to_vector(createMaskOp.getOperands(), [](Value dimSize) {
+      return OpFoldResult(dimSize);
+    });
+  }
+  if (auto constantMask = mask.getDefiningOp<vector::ConstantMaskOp>()) {
+    int dimIdx = 0;
+    VectorType maskType = constantMask.getVectorType();
+    auto indexType = IndexType::get(mask.getContext());
+    return llvm::map_to_vector(
+        constantMask.getMaskDimSizes(), [&](int64_t dimSize) {
+          // A scalable dim in a constant_mask means vscale x dimSize.
+          if (maskType.getScalableDims()[dimIdx++])
+            return OpFoldResult(createVscaleMultiple(dimSize));
+          return OpFoldResult(IntegerAttr::get(indexType, dimSize));
+        });
+  }
+  return failure();
+}
+
+/// Scalable vector lowering of transfer_write(transpose). This lowering only
+/// supports rank 2 (scalable) vectors, but can be used in in conjunction with
+/// `UnrollTransferWriteConversion` to support n-D cases. The unroll conversion
+/// unrolls until the first scalable dimension.
+///
+/// Example:
+///
+/// BEFORE:
+/// ```mlir
+/// %transpose = vector.transpose %vec, [1, 0]
+///    : vector<4x[4]xf32> to vector<[4]x4xf32>
+/// vector.transfer_write %transpose, %dest[%i, %j] {in_bounds = [true, true]}
+///    : vector<[4]x4xf32>,  memref<?x?xf32>
+/// ```
+///
+/// AFTER:
+/// ```mlir
+/// %c1 = arith.constant 1 : index
+/// %c4 = arith.constant 4 : index
+/// %c0 = arith.constant 0 : index
+/// %0 = vector.extract %arg0[0] : vector<[4]xf32> from vector<4x[4]xf32>
+/// %1 = vector.extract %arg0[1] : vector<[4]xf32> from vector<4x[4]xf32>
+/// %2 = vector.extract %arg0[2] : vector<[4]xf32> from vector<4x[4]xf32>
+/// %3 = vector.extract %arg0[3] : vector<[4]xf32> from vector<4x[4]xf32>
+/// %vscale = vector.vscale
+/// %c4_vscale = arith.muli %vscale, %c4 : index
+/// scf.for %idx = %c0 to %c4_vscale step %c1 {
+///   %4 = vector.extract %0[%idx] : f32 from vector<[4]xf32>
+///   %5 = vector.extract %1[%idx] : f32 from vector<[4]xf32>
+///   %6 = vector.extract %2[%idx] : f32 from vector<[4]xf32>
+///   %7 = vector.extract %3[%idx] : f32 from vector<[4]xf32>
+///   %slice_i = affine.apply #map(%idx)[%i]
+///   %slice = vector.from_elements %4, %5, %6, %7 : vector<4xf32>
+///   vector.transfer_write %slice, %arg1[%slice_i, %j] {in_bounds = [true]}
+///     : vector<4xf32>, memref<?x?xf32>
+/// }
+/// ```
+struct ScalableTransposeTransferWriteConversion
+    : VectorToSCFPattern<vector::TransferWriteOp> {
+  using VectorToSCFPattern::VectorToSCFPattern;
+
+  LogicalResult matchAndRewrite(TransferWriteOp writeOp,
+                                PatternRewriter &rewriter) const override {
+    if (isTensorOp(writeOp) && !options.lowerTensors) {
+      return rewriter.notifyMatchFailure(
+          writeOp, "lowering tensor transfers is disabled");
+    }
+
+    auto vector = writeOp.getVector();
+    auto vectorType = vector.getType();
+    auto scalableFlags = vectorType.getScalableDims();
+    if (scalableFlags != ArrayRef<bool>{true, false}) {
+      return rewriter.notifyMatchFailure(
+          writeOp, "expected vector of form vector<[*]x*xty>");
+    }
+
+    auto permutationMap = writeOp.getPermutationMap();
+    if (!permutationMap.isIdentity()) {
+      return rewriter.notifyMatchFailure(
+          writeOp, "non-identity permutations are unsupported (lower first)");
+    }
+
+    if (!writeOp.isDimInBounds(0)) {
+      return rewriter.notifyMatchFailure(
+          writeOp, "out-of-bounds dims are unsupported (use masking)");
+    }
+
+    auto transposeOp = vector.getDefiningOp<vector::TransposeOp>();
+    if (!transposeOp ||
+        transposeOp.getPermutation() != ArrayRef<int64_t>{1, 0}) {
+      return rewriter.notifyMatchFailure(writeOp, "source not transpose");
+    }
+
+    auto loc = writeOp.getLoc();
+    auto createVscaleMultiple =
+        vector::makeVscaleConstantBuilder(rewriter, loc);
+
+    auto maskDims = getMaskDimSizes(writeOp.getMask(), createVscaleMultiple);
+    if (failed(maskDims)) {
+      return rewriter.notifyMatchFailure(writeOp,
+                                         "failed to resolve mask dims");
+    }
+
+    int64_t fixedDimSize = vectorType.getDimSize(1);
+    auto fixedDimOffsets = llvm::seq(fixedDimSize);
+
+    // Extract all slices from the source of the transpose.
+    auto transposeSource = transposeOp.getVector();
+    SmallVector<Value> transposeSourceSlices =
+        llvm::map_to_vector(fixedDimOffsets, [&](int64_t idx) -> Value {
+          return rewriter.create<vector::ExtractOp>(loc, transposeSource, idx);
+        });
+
+    // Loop bounds and step.
+    auto lb = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    auto ub =
+        maskDims->empty()
+            ? Value(createVscaleMultiple(vectorType.getDimSize(0)))
+            : vector::getAsValues(rewriter, loc, maskDims->front()).front();
+    auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+
+    // Generate a new mask for the slice.
+    VectorType sliceType = VectorType::Builder(vectorType).dropDim(0);
+    Value sliceMask = nullptr;
+    if (!maskDims->empty()) {
+      sliceMask = rewriter.create<vector::CreateMaskOp>(
+          loc, sliceType.clone(rewriter.getI1Type()),
+          ArrayRef<OpFoldResult>(*maskDims).drop_front());
+    }
+
+    ValueRange initLoopArgs =
+        isTensorOp(writeOp) ? writeOp.getSource() : ValueRange{};
+    auto result = rewriter.create<scf::ForOp>(
+        loc, lb, ub, step, initLoopArgs,
+        [&](OpBuilder &b, Location loc, Value iv, ValueRange loopIterArgs) {
+          // Indices for the new transfer op.
+          SmallVector<Value, 8> xferIndices;
+          getXferIndices(b, writeOp, iv, xferIndices);
+
+          // Extract a transposed slice from the source vector.
+          SmallVector<Value> transposeElements =
+              llvm::map_to_vector(fixedDimOffsets, [&](int64_t idx) -> Value {
+                return b.create<vector::ExtractOp>(
+                    loc, transposeSourceSlices[idx], iv);
+              });
+          auto sliceVec = b.create<vector::FromElementsOp>(loc, sliceType,
+                                                           transposeElements);
+
+          // Create the transfer_write for the slice.
+          Value dest =
+              loopIterArgs.empty() ? writeOp.getSource() : loopIterArgs.front();
+          auto newWriteOp = b.create<vector::TransferWriteOp>(
+              loc, sliceVec, dest, xferIndices,
+              ArrayRef<bool>(writeOp.getInBoundsValues()).drop_front());
+          if (sliceMask)
+            newWriteOp.getMaskMutable().assign(sliceMask);
+
+          // Yield from the loop.
+          b.create<scf::YieldOp>(loc, loopIterArgs.empty()
+                                          ? ValueRange{}
+                                          : newWriteOp.getResult());
+        });
+
+    if (isTensorOp(writeOp))
+      rewriter.replaceOp(writeOp, result);
+    else
+      rewriter.eraseOp(writeOp);
+
+    return success();
+  }
+};
+
 } // namespace lowering_n_d
 
 namespace lowering_n_d_unrolled {
@@ -1503,7 +1683,10 @@ void mlir::populateVectorToSCFConversionPatterns(
                  lowering_n_d::TransferOpConversion<TransferWriteOp>>(
         patterns.getContext(), options);
   }
-
+  if (options.lowerScalable) {
+    patterns.add<lowering_n_d::ScalableTransposeTransferWriteConversion>(
+        patterns.getContext(), options);
+  }
   if (options.targetRank == 1) {
     patterns.add<lowering_1_d::TransferOp1dConversion<TransferReadOp>,
                  lowering_1_d::TransferOp1dConversion<TransferWriteOp>>(
@@ -1522,6 +1705,7 @@ struct ConvertVectorToSCFPass
     this->fullUnroll = options.unroll;
     this->targetRank = options.targetRank;
     this->lowerTensors = options.lowerTensors;
+    this->lowerScalable = options.lowerScalable;
   }
 
   void runOnOperation() override {
@@ -1529,6 +1713,7 @@ struct ConvertVectorToSCFPass
     options.unroll = fullUnroll;
     options.targetRank = targetRank;
     options.lowerTensors = lowerTensors;
+    options.lowerScalable = lowerScalable;
 
     // Lower permutation maps first.
     RewritePatternSet lowerTransferPatterns(&getContext());
diff --git a/mlir/test/Conversion/VectorToSCF/tensor-transfer-ops.mlir b/mlir/test/Conversion/VectorToSCF/tensor-transfer-ops.mlir
index dac8e018f845f..c542b79d5c80e 100644
--- a/mlir/test/Conversion/VectorToSCF/tensor-transfer-ops.mlir
+++ b/mlir/test/Conversion/VectorToSCF/tensor-transfer-ops.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(convert-vector-to-scf{lower-tensors=true}))" -split-input-file -allow-unregistered-dialect | FileCheck %s
+// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(convert-vector-to-scf{lower-tensors=true lower-scalable=true}))" -split-input-file -allow-unregistered-dialect | FileCheck %s
 
 // CHECK-LABEL: func @transfer_read_2d(
 //       CHECK: %[[ALLOC:.*]] = memref.alloca() : memref<vector<4x9xf32>>
@@ -36,3 +36,16 @@ func.func @transfer_write_2d(%A : tensor<?x?xf32>, %vec : vector<2x3xf32>,
   return %t : tensor<?x?xf32>
 }
 
+// -----
+
+// CHECK-LABEL: func @scalable_transpose_store
+//  CHECK-SAME: %[[TENSOR:[a-z0-9]+]]: tensor<?x?xf32>
+//       CHECK: %[[RESULT:.*]] = scf.for {{.*}} iter_args(%[[ITER_ARG:.*]] = %[[TENSOR]]) -> (tensor<?x?xf32>)
+//       CHECK:   %[[WRITE_SLICE:.*]] = vector.transfer_write %{{.*}} %[[ITER_ARG]]
+//       CHECK:   scf.yield %[[WRITE_SLICE]]
+//       CHECK: return %[[RESULT]]
+func.func @scalable_transpose_store(%vec: vector<4x[4]xf32>, %dest: tensor<?x?xf32>, %i: index, %j: index) -> tensor<?x?xf32> {
+  %transpose = vector.transpose %vec, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
+  %result = vector.transfer_write %transpose, %dest[%i, %j] {in_bounds = [true, true]} : vector<[4]x4xf32>,  tensor<?x?xf32>
+  return %result : tensor<?x?xf32>
+}
diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
index 3f4e70a6835af..f7ded02d83fe5 100644
--- a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
+++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(convert-vector-to-scf))" -split-input-file -allow-unregistered-dialect | FileCheck %s
-// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(convert-vector-to-scf{full-unroll=true}))" -split-input-file -allow-unregistered-dialect | FileCheck %s --check-prefix=FULL-UNROLL
+// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(convert-vector-to-scf{full-unroll=true lower-scalable=true}))" -split-input-file -allow-unregistered-dialect | FileCheck %s --check-prefix=FULL-UNROLL
 // RUN: mlir-opt %s "-convert-vector-to-scf=full-unroll target-rank=0" -split-input-file -allow-unregistered-dialect | FileCheck %s --check-prefix=TARGET-RANK-ZERO
 
 // CHECK-LABEL: func @vector_transfer_ops_0d(
@@ -661,10 +661,10 @@ func.func @transfer_read_array_of_scalable(%arg0: memref<3x?xf32>) -> vector<3x[
 // CHECK:           memref.store %[[MASK]], %[[ALLOCA_MASK]][] : memref<vector<3x[4]xi1>>
 // CHECK:           %[[UNPACK_VECTOR:.*]] = vector.type_cast %[[ALLOCA_VEC]] : memref<vector<3x[4]xf32>> to memref<3xvector<[4]xf32>>
 // CHECK:           %[[UNPACK_MASK:.*]] = vector.type_cast %[[ALLOCA_MASK]] : memref<vector<3x[4]xi1>> to memref<3xvector<[4]xi1>>
-// CHECK:           scf.for %[[VAL_11:.*]] = %[[C0]] to %[[C3]] step %[[C1]] {
-// CHECK:             %[[MASK_SLICE:.*]] = memref.load %[[UNPACK_MASK]]{{\[}}%[[VAL_11]]] : memref<3xvector<[4]xi1>>
-// CHECK:             %[[READ_SLICE:.*]] = vector.transfer_read %[[ARG]]{{\[}}%[[VAL_11]], %[[C0]]], %[[PADDING]], %[[MASK_SLICE]] {in_bounds = [true]} : memref<3x?xf32>, vector<[4]xf32>
-// CHECK:             memref.store %[[READ_SLICE]], %[[UNPACK_VECTOR]]{{\[}}%[[VAL_11]]] : memref<3xvector<[4]xf32>>
+// CHECK:           scf.for %[[VSCALE:.*]] = %[[C0]] to %[[C3]] step %[[C1]] {
+// CHECK:             %[[MASK_SLICE:.*]] = memref.load %[[UNPACK_MASK]]{{\[}}%[[VSCALE]]] : memref<3xvector<[4]xi1>>
+// CHECK:             %[[READ_SLICE:.*]] = vector.transfer_read %[[ARG]]{{\[}}%[[VSCALE]], %[[C0]]], %[[PADDING]], %[[MASK_SLICE]] {in_bounds = [true]} : memref<3x?xf32>, vector<[4]xf32>
+// CHECK:             memref.store %[[READ_SLICE]], %[[UNPACK_VECTOR]]{{\[}}%[[VSCALE]]] : memref<3xvector<[4]xf32>>
 // CHECK:           }
 // CHECK:           %[[RESULT:.*]] = memref.load %[[ALLOCA_VEC]][] : memref<vector<3x[4]xf32>>
 // CHECK:           return %[[RESULT]] : vector<3x[4]xf32>
@@ -695,10 +695,10 @@ func.func @transfer_write_array_of_scalable(%vec: vector<3x[4]xf32>, %arg0: memr
 // CHECK:           memref.store %[[VEC]], %[[ALLOCA_VEC]][] : memref<vector<3x[4]xf32>>
 // CHECK:           %[[UNPACK_VECTOR:.*]] = vector.type_cast %[[ALLOCA_VEC]] : memref<vector<3x[4]xf32>> to memref<3xvector<[4]xf32>>
 // CHECK:           %[[UNPACK_MASK:.*]] = vector.type_cast %[[ALLOCA_MASK]] : memref<vector<3x[4]xi1>> to memref<3xvector<[4]xi1>>
-// CHECK:           scf.for %[[VAL_11:.*]] = %[[C0]] to %[[C3]] step %[[C1]] {
-// CHECK:             %[[MASK_SLICE:.*]] = memref.load %[[UNPACK_VECTOR]]{{\[}}%[[VAL_11]]] : memref<3xvector<[4]xf32>>
-// CHECK:             %[[VECTOR_SLICE:.*]] = memref.load %[[UNPACK_MASK]]{{\[}}%[[VAL_11]]] : memref<3xvector<[4]xi1>>
-// CHECK:             vector.transfer_write %[[MASK_SLICE]], %[[MEMREF]]{{\[}}%[[VAL_11]], %[[C0]]], %[[VECTOR_SLICE]] {in_bounds = [true]} : vector<[4]xf32>, memref<3x?xf32>
+// CHECK:           scf.for %[[VSCALE:.*]] = %[[C0]] to %[[C3]] step %[[C1]] {
+// CHECK:             %[[MASK_SLICE:.*]] = memref.load %[[UNPACK_VECTOR]]{{\[}}%[[VSCALE]]] : memref<3xvector<[4]xf32>>
+// CHECK:             %[[VECTOR_SLICE:.*]] = memref.load %[[UNPACK_MASK]]{{\[}}%[[VSCALE]]] : memref<3xvector<[4]xi1>>
+// CHECK:             vector.transfer_write %[[MASK_SLICE]], %[[MEMREF]]{{\[}}%[[VSCALE]], %[[C0]]], %[[VECTOR_SLICE]] {in_bounds = [true]} : vector<[4]xf32>, memref<3x?xf32>
 // CHECK:           }
 // CHECK:           return
 // CHECK:         }
@@ -803,3 +803,111 @@ func.func @unroll_transfer_write_target_rank_zero(%vec : vector<2xi32>) {
 // TARGET-RANK-ZERO: %[[EXTRACTED2:.*]] = vector.extract {{.*}} : i32 from vector<2xi32>
 // TARGET-RANK-ZERO: %[[BROADCASTED2:.*]] = vector.broadcast %[[EXTRACTED2]] : i32 to vector<i32>
 // TARGET-RANK-ZERO: vector.transfer_write %[[BROADCASTED2]], %[[ALLOC]]{{.*}} : vector<i32>, memref<4xi32>
+
+// -----
+
+func.func @scalable_transpose_store_unmasked(%vec: vector<4x[4]xf32>, %dest: memref<?x?xf32>, %i: index, %j: index) {
+  %transpose = vector.transpose %vec, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
+  vector.transfer_write %transpose, %dest[%i, %j] {in_bounds = [true, true]} : vector<[4]x4xf32>,  memref<?x?xf32>
+  return
+}
+// FULL-UNROLL: #[[$SLICE_MAP:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
+// FULL-UNROLL-LABEL:   func.func @scalable_transpose_store_unmasked(
+// FULL-UNROLL-SAME:                                                 %[[VEC:.*]]: vector<4x[4]xf32>,
+// FULL-UNROLL-SAME:                                                 %[[DEST:.*]]: memref<?x?xf32>,
+// FULL-UNROLL-SAME:                                                 %[[I:.*]]: index,
+// FULL-UNROLL-SAME:                                                 %[[J:.*]]: index)
+// FULL-UNROLL-DAG:       %[[C0:.*]] = arith.constant 0 : index
+// FULL-UNROLL-DAG:       %[[C1:.*]] = arith.constant 1 : index
+// FULL-UNROLL-DAG:       %[[C4:.*]] = arith.constant 4 : index
+// FULL-UNROLL:           %[[SLICE_0:.*]] = vector.extract %[[VEC]][0] : vector<[4]xf32> from vector<4x[4]xf32>
+// FULL-UNROLL:           %[[SLICE_1:.*]] = vector.extract %[[VEC]][1] : vector<[4]xf32> from vector<4x[4]xf32>
+// FULL-UNROLL:           %[[SLICE_2:.*]] = vector.extract %[[VEC]][2] : vector<[4]xf32> from vector<4x[4]xf32>
+// FULL-UNROLL:           %[[SLICE_3:.*]] = vector.extract %[[VEC]][3] : vector<[4]xf32> from vector<4x[4]xf32>
+// FULL-UNROLL:           %[[VSCALE:.*]] = vector.vscale
+// FULL-UNROLL:           %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
+// FULL-UNROLL:           scf.for %[[VAL_13:.*]] = %[[C0]] to %[[C4_VSCALE]] step %[[C1]] {
+// FULL-UNROLL:             %[[SLICE_I:.*]] = affine.apply #[[$SLICE_MAP]](%[[VAL_13]]){{\[}}%[[I]]]
+// FULL-UNROLL:             %[[ELEM_0:.*]] = vector.extract %[[SLICE_0]]{{\[}}%[[VAL_13]]] : f32 from vector<[4]xf32>
+// FULL-UNROLL:             %[[ELEM_1:.*]] = vector.extract %[[SLICE_1]]{{\[}}%[[VAL_13]]] : f32 from vector<[4]xf32>
+// FULL-UNROLL:             %[[ELEM_2:.*]] = vector.extract %[[SLICE_2]]{{\[}}%[[VAL_13]]] : f32 from vector<[4]xf32>
+// FULL-UNROLL:             %[[ELEM_3:.*]] = vector.extract %[[SLICE_3]]{{\[}}%[[VAL_13]]] : f32 from vector<[4]xf32>
+// FULL-UNROLL:             %[[TRANSPOSE_SLICE:.*]] = vector.from_elements %[[ELEM_0]], %[[ELEM_1]], %[[ELEM_2]], %[[ELEM_3]] : vector<4xf32>
+// FULL-UNROLL:             vector.transfer_write %[[TRANSPOSE_SLICE]], %[[DEST]]{{\[}}%[[SLICE_I]], %[[J]]] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
+// FULL-UNROLL:           }
+// FULL-UNROLL:           return
+
+// -----
+
+func.func @scalable_transpose_store_dynamic_mask(%vec: vector<4x[4]xf32>, %dest: memref<?x?xf32>, %i: index, %j: index, %a: index, %b: index) {
+  %transpose = vector.transpose %vec, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
+  %mask = vector.create_mask %a, %b : vector<[4]x4xi1>
+  vector.transfer_write %transpose, %dest[%i, %j], %mask {in_bounds = [true, true]} : vector<[4]x4xf32>,  memref<?x?xf32>
+  return
+}
+// FULL-UNROLL-LABEL:   func.func @scalable_transpose_store_dynamic_mask(
+// FULL-UNR...
[truncated]

Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left couple of minor comments but otherwise LGTM cheers

@MacDue MacDue force-pushed the vector_to_scf_scalable_transpose_store branch 2 times, most recently from 6ef6af3 to c83f20c Compare August 7, 2024 09:43
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG % some minor suggestions

@MacDue MacDue force-pushed the vector_to_scf_scalable_transpose_store branch from ae9adad to 3359eb3 Compare August 8, 2024 10:10
MacDue added a commit to MacDue/iree that referenced this pull request Aug 9, 2024
This enables a general scalable lowering for `transfer_write(transpose)`
when ArmSME is _not_ available. The ArmSME dialect already had its own
(more specific) lowerings for cases like this, which is why these
lowerings are disabled when SME is available.

Depends on: llvm/llvm-project#101353
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks! Just a couple of minor nits.

MacDue added a commit to MacDue/iree that referenced this pull request Aug 9, 2024
This enables a general scalable lowering for `transfer_write(transpose)`
when ArmSME is _not_ available. The ArmSME dialect already had its own
(more specific) lowerings for cases like this, which is why these
lowerings are disabled when SME is available.

Depends on: llvm/llvm-project#101353

Signed-off-by: Benjamin Maxwell <[email protected]>
@MacDue MacDue force-pushed the vector_to_scf_scalable_transpose_store branch from 3359eb3 to 8545557 Compare August 9, 2024 14:54
MacDue added 6 commits August 12, 2024 10:12
This specifically handles the case of a transpose from a vector type
like `vector<8x[4]xf32>` to `vector<[4]x8xf32>`. Such transposes occur
fairly frequently when scalably vectorizing `linalg.generic`s. There is
no direct lowering for these (as types like `vector<[4]x8xf32>` cannot
be represented in LLVM-IR). However, if the only use of the transpose is
a write, then it is possible to lower the `transfer_write(transpose)`
as a VLA loop.

Example:

```mlir
%transpose = vector.transpose %vec, [1, 0]
   : vector<4x[4]xf32> to vector<[4]x4xf32>
vector.transfer_write %transpose, %dest[%i, %j] {in_bounds = [true, true]}
   : vector<[4]x4xf32>,  memref<?x?xf32>
```

Becomes:

```mlir
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
%c0 = arith.constant 0 : index
%0 = vector.extract %arg0[0] : vector<[4]xf32> from vector<4x[4]xf32>
%1 = vector.extract %arg0[1] : vector<[4]xf32> from vector<4x[4]xf32>
%2 = vector.extract %arg0[2] : vector<[4]xf32> from vector<4x[4]xf32>
%3 = vector.extract %arg0[3] : vector<[4]xf32> from vector<4x[4]xf32>
%vscale = vector.vscale
%c4_vscale = arith.muli %vscale, %c4 : index
scf.for %idx = %c0 to %c4_vscale step %c1 {
  %4 = vector.extract %0[%idx] : f32 from vector<[4]xf32>
  %5 = vector.extract %1[%idx] : f32 from vector<[4]xf32>
  %6 = vector.extract %2[%idx] : f32 from vector<[4]xf32>
  %7 = vector.extract %3[%idx] : f32 from vector<[4]xf32>
  %slice_i = affine.apply #map(%idx)[%i]
  %slice = vector.from_elements %4, %5, %6, %7 : vector<4xf32>
  vector.transfer_write %slice, %arg1[%slice_i, %j] {in_bounds = [true]}
    : vector<4xf32>, memref<?x?xf32>
}
```
@MacDue MacDue force-pushed the vector_to_scf_scalable_transpose_store branch from 8545557 to 7b50608 Compare August 12, 2024 12:35
@MacDue MacDue merged commit 27a713f into llvm:main Aug 12, 2024
8 checks passed
@MacDue MacDue deleted the vector_to_scf_scalable_transpose_store branch August 12, 2024 15:31
MacDue added a commit to MacDue/iree that referenced this pull request Aug 13, 2024
This enables a general scalable lowering for `transfer_write(transpose)`
when ArmSME is _not_ available. The ArmSME dialect already had its own
(more specific) lowerings for cases like this, which is why these
lowerings are disabled when SME is available.

Depends on: llvm/llvm-project#101353

Signed-off-by: Benjamin Maxwell <[email protected]>
MacDue added a commit to MacDue/iree that referenced this pull request Aug 14, 2024
This enables a general scalable lowering for `transfer_write(transpose)`
when ArmSME is _not_ available. The ArmSME dialect already had its own
(more specific) lowerings for cases like this, which is why these
lowerings are disabled when SME is available.

Depends on: llvm/llvm-project#101353

Signed-off-by: Benjamin Maxwell <[email protected]>
MacDue added a commit to MacDue/iree that referenced this pull request Aug 15, 2024
This enables a general scalable lowering for `transfer_write(transpose)`
when ArmSME is _not_ available. The ArmSME dialect already had its own
(more specific) lowerings for cases like this, which is why these
lowerings are disabled when SME is available.

Depends on: llvm/llvm-project#101353

Signed-off-by: Benjamin Maxwell <[email protected]>
c-rhodes pushed a commit to iree-org/iree that referenced this pull request Aug 16, 2024
This enables a general scalable lowering for `transfer_write(transpose)`
when ArmSME is _not_ available. The ArmSME dialect already had its own
(more specific) lowerings for cases like this, which is why these
lowerings are disabled when SME is available.

Depends on: llvm/llvm-project#101353

---------

Signed-off-by: Benjamin Maxwell <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants