Skip to content

[mlir][Transforms] Add a PadTilingInterface transformation and hook i… #144991

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 20, 2025

Conversation

nicolasvasilache
Copy link
Contributor

@nicolasvasilache nicolasvasilache commented Jun 20, 2025

…t up to the transform dialect

This revision revisits the padding transformation from first principles and prepares it to work more generally with TilingInterface.

Compared to structured.transform.pad it has the following differences:

  • no support for nofold, copy-back, transpose and hoisting: these have been carried by the padding op in the very early days of StructuredOps and have since then been separated out as independent transformations that compose.
  • no conflated static bounding box derivation attempts: pad_tiling_interface composes more naturally with or without tiling.
  • properly derives padding size on outputs where multiple dimensions contribute: this is not supported in structured.transform.pad
  • geared towards supporting TilingInterface once the proper control mechanisms are supported through a PadSizeComputationFunction (supports LinalgOp by default)

This will gradually replace structured.transform.pad as it is fleshed out and tested more comprehensively.

In the future this should be moved out of a specific Linalg implementation file and into a more general "Structured" file.

…t up to the transform dialect

This revision revisits the padding transformation from first principles and prepares it to work more generally with TilingInterface.

Compared to structured.transform.pad it has the following differences:
  - no support for nofold, copy-back, transpose and hoisting: these have been carried by the padding op in the very early days of StructuredOps
    and have since then been separated out as independent transformations that compose.
  - no conflated static bounding box derivation attempts: pad_tiling_interface composes more naturally with or without tiling.
  - properly derives padding size on outputs where multiple dimensions contribute: this is not supported in structured.transform.pad
  - geared towards supporting TilingInterface once the proper control mechanisms are supported through a PadSizeComputationFunction (supports LinalgOp by default)

This will gradually replace structured.transform.pad as it is fleshed out and tested more comprehensively.
@llvmbot
Copy link
Member

llvmbot commented Jun 20, 2025

@llvm/pr-subscribers-mlir

Author: Nicolas Vasilache (nicolasvasilache)

Changes

…t up to the transform dialect

This revision revisits the padding transformation from first principles and prepares it to work more generally with TilingInterface.

Compared to structured.transform.pad it has the following differences:

  • no support for nofold, copy-back, transpose and hoisting: these have been carried by the padding op in the very early days of StructuredOps and have since then been separated out as independent transformations that compose.
  • no conflated static bounding box derivation attempts: pad_tiling_interface composes more naturally with or without tiling.
  • properly derives padding size on outputs where multiple dimensions contribute: this is not supported in structured.transform.pad
  • geared towards supporting TilingInterface once the proper control mechanisms are supported through a PadSizeComputationFunction (supports LinalgOp by default)

This will gradually replace structured.transform.pad as it is fleshed out and tested more comprehensively.


Patch is 35.27 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/144991.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+76)
  • (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+76-3)
  • (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+161)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp (+322)
  • (added) mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir (+73)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 6f6df350f1ba6..c00151dea54a6 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1186,6 +1186,82 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// PadTilingInterfaceOp
+//===----------------------------------------------------------------------===//
+
+def PadTilingInterfaceOp : Op<Transform_Dialect, "structured.pad_tiling_interface",
+    [FunctionalStyleTransformOpTrait, DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+     TransformOpInterface,
+     ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    Pads the operations pointed to by the target handle using the options
+    provided as operation attributes. The operation returns a handle to the
+    padded operation and to the padding operation ("tensor.pad").
+
+    #### Return modes
+
+    This operation ignores non-Linalg ops and drops them in the return.
+    In the future, this operation will support all TilingInterfaceOps.
+
+    This operation may produce a definite failure if the padding fails for any
+    reason.
+
+    If all the operations referred to by the `target` handle pad properly, the
+    transform succeeds. Otherwise the transform produces a silenceable failure.
+    The return handle points to only the subset of successfully produced
+    padded operations, which can be empty.
+  }];
+
+  let arguments =
+    (ins TransformHandleTypeInterface:$target,
+         DefaultValuedAttr<ArrayAttr, "{}">:$padding_values,
+         DefaultValuedAttr<I64ArrayAttr, "{}">:$padding_dimensions,
+         Variadic<TransformAnyParamTypeOrAnyHandle>:$padding_sizes,
+         DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:
+            $static_padding_sizes,
+         DefaultValuedAttr<UnitAttr, "false">:$pad_to_multiple_of);
+  let results = (outs TransformHandleTypeInterface:$padded,
+                      TransformHandleTypeInterface:$pad);
+
+  let assemblyFormat = [{
+    $target
+    `to`
+    (`padding_sizes` custom<DynamicIndexList>($padding_sizes, $static_padding_sizes)^)?
+    (`pad_to_multiple_of` $pad_to_multiple_of^)?
+    attr-dict
+    `:` functional-type(operands, results)
+  }];
+
+  let hasVerifier = 1;
+
+  let builders = [
+    // Builder for a transform::PadOp with automatic inference of padding
+    // value. Warning: this will set the value 0 for the inferred elemental
+    // type without taking the op into account and thus only work for the
+    // add/mul ring at the moment.
+    // TODO: support other operations (e.g. min, max etc).
+    OpBuilder<(ins "Value":$target,
+                   "ArrayRef<int64_t>":$paddingDimensions,
+                   CArg<"ArrayRef<int64_t>", "{}">:$staticPaddingSizes,
+                   CArg<"bool", "false">:$padToMultipleOf)>,
+    OpBuilder<(ins "Value":$target,
+                   "ArrayRef<int64_t>":$paddingDimensions,
+                   "ArrayRef<OpFoldResult>":$mixedPadPaddingSizes,
+                   CArg<"bool", "false">:$usePrescribedTensorShapes)>
+  ];
+
+  let extraClassDeclaration = [{
+    /// Returns a mix of dynamic `padding_sizes` and static `static_padding_sizes`.
+    SmallVector<OpFoldResult> getMixedPaddingSizes();
+
+    ::mlir::DiagnosedSilenceableFailure apply(
+      ::mlir::transform::TransformRewriter &rewriter,
+      ::mlir::transform::TransformResults &results,
+      ::mlir::transform::TransformState &state);
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // HoistPadOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 147a2907f52e4..59b7fdeef10b3 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -20,6 +20,7 @@
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 #include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/TilingInterface.h"
 #include "mlir/Transforms/DialectConversion.h"
@@ -347,6 +348,34 @@ struct LinalgPaddingOptions {
   }
 };
 
+struct PadTilingInterfaceOptions {
+  /// A padding value for every operand.
+  SmallVector<Attribute> paddingValues;
+  PadTilingInterfaceOptions &setPaddingValues(ArrayRef<Attribute> pv) {
+    paddingValues.assign(pv.begin(), pv.end());
+    return *this;
+  }
+  /// A list of iterator dimensions to pad.
+  SmallVector<int64_t> paddingDimensions;
+  PadTilingInterfaceOptions &setPaddingDimensions(ArrayRef<int64_t> pd) {
+    paddingDimensions.assign(pd.begin(), pd.end());
+    return *this;
+  }
+  /// A list of iterator dimensions sizes to pad to.
+  SmallVector<OpFoldResult> paddingSizes;
+  PadTilingInterfaceOptions &setPaddingSizes(ArrayRef<OpFoldResult> m) {
+    paddingSizes.assign(m.begin(), m.end());
+    return *this;
+  }
+  /// Pad iterator `paddingDimension[i]` to next multiple of `paddingSizes[i]`
+  /// if true. Otherwise pad to `paddingSizes[i]`.
+  bool padToMultipleOf;
+  PadTilingInterfaceOptions &setPadToMultipleOf(bool b) {
+    padToMultipleOf = b;
+    return *this;
+  }
+};
+
 /// Callback function type used to perform the allocation for the promoted
 /// `subView`. In `boundingSubViewsize` a best attempt is made to find the
 /// smallest constant value for the size of the buffer needed for each
@@ -542,9 +571,9 @@ SmallVector<Value> peelLoop(RewriterBase &rewriter, Operation *op);
 /// where relevant.
 void peelLoops(RewriterBase &rewriter, ArrayRef<scf::ForOp> loops);
 
-/// Pad the iterator dimensions `paddingDimensions` of all `opToPad` operands
-/// to a static bounding box. The original `opToPad` is cloned and operates on
-/// the padded tensors.
+/// Pad the iterator dimensions `options.paddingDimensions` of all `opToPad`
+/// operands to a static bounding box. The original `opToPad` is cloned and
+/// operates on the padded tensors.
 ///
 /// * "options.padToMultipleOf" indicates that each padding dimension should be
 ///   padded to the specified multiple.
@@ -561,6 +590,50 @@ LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
                                 SmallVector<Value> &replacements,
                                 SmallVector<tensor::PadOp> &padOps);
 
+/// Helper function to compute the padded shape of the given value `v` of
+/// `RankedTensorType` given:
+///   - the `indexingSizes` as a list of OpFoldResult.
+///   - an `indexingMap` that encodes how the padded shape varies with
+///     increases in `indexingSizes`.
+/// The implementation iteratively combines increases from contributing using
+/// affine.apply operations.
+/// The `indexingMap` + `indexingSizes` encoding suits StructuredOps and
+/// provides a gentle portability path for Linalg-like ops with affine maps.
+/// In the future, more general interfaces can be devised to encode similar
+/// shape evolutions and map between an op and its operands.
+SmallVector<OpFoldResult>
+computePaddedShape(RewriterBase &rewriter, TypedValue<RankedTensorType> v,
+                   AffineMap indexingMap, ArrayRef<OpFoldResult> indexingSizes,
+                   const PadTilingInterfaceOptions &options);
+
+using PadSizeComputationFunction =
+    std::function<FailureOr<SmallVector<OpFoldResult>>(
+        RewriterBase &, OpOperand &, ArrayRef<Range>,
+        const PadTilingInterfaceOptions &)>;
+
+/// Specific helper for Linalg ops.
+FailureOr<SmallVector<OpFoldResult>>
+computeLinalgPaddedShape(RewriterBase &rewriter, OpOperand &operandToPad,
+                         ArrayRef<Range> iterationDomain,
+                         const PadTilingInterfaceOptions &options);
+
+/// Pad the iterator dimensions `options.paddingDimensions` of `opToPad`.
+///
+/// * "options.paddingSizes" indicates that each padding dimension should be
+///   padded to the specified padding size.
+/// * "options.padToMultipleOf" indicates that the paddingSizes should be
+//    interpreted as the bounding box (dynamic) value to pad to.
+/// * Use "options.paddingValues" to set the padding value of the created
+//    tensor::PadOp.
+/// * The tensor::PadOp is returned on success.
+
+FailureOr<TilingInterface>
+rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad,
+                  const PadTilingInterfaceOptions &constOptions,
+                  SmallVector<tensor::PadOp> &padOps,
+                  PadSizeComputationFunction computePaddingSizeFun =
+                      &computeLinalgPaddedShape);
+
 namespace detail {
 
 /// Helper struct to hold the results of building a packing loop nest.
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index d78c8847f8843..bc8ff8e55bf0f 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -45,6 +45,7 @@
 #include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/LogicalResult.h"
 #include <type_traits>
 
 using namespace mlir;
@@ -2155,6 +2156,166 @@ LogicalResult transform::PadOp::verify() {
   return success();
 }
 
+//===---------------------------------------------------------------------===//
+// PadTilingInterfaceOp
+//===---------------------------------------------------------------------===//
+
+void transform::PadTilingInterfaceOp::build(OpBuilder &b,
+                                            OperationState &result,
+                                            Value target,
+                                            ArrayRef<int64_t> paddingDimensions,
+                                            ArrayRef<int64_t> paddingSizes,
+                                            bool padToMultipleOf) {
+  auto resultType = transform::AnyOpType::get(b.getContext());
+  return build(/*builder=*/b,
+               /*result=*/result,
+               /*types=*/TypeRange{resultType, resultType},
+               /*target=*/target,
+               /*paddingValues=*/ArrayAttr(), // let inference handle this
+               /*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions),
+               /*paddingSizes=*/ValueRange{},
+               /*paddingSizes=*/
+               (paddingSizes.empty() ? DenseI64ArrayAttr()
+                                     : b.getDenseI64ArrayAttr(paddingSizes)),
+               /*padToMultipleOf=*/
+               padToMultipleOf ? b.getUnitAttr() : nullptr);
+}
+
+void transform::PadTilingInterfaceOp::build(
+    OpBuilder &b, OperationState &result, Value target,
+    ArrayRef<int64_t> paddingDimensions,
+    ArrayRef<OpFoldResult> mixedPaddingSizes, bool padToMultipleOf) {
+  auto resultType = transform::AnyOpType::get(b.getContext());
+  SmallVector<int64_t> staticPaddingSizes;
+  SmallVector<Value> dynamicPaddingSizes;
+  dispatchIndexOpFoldResults(mixedPaddingSizes, dynamicPaddingSizes,
+                             staticPaddingSizes);
+  return build(/*builder=*/b,
+               /*result=*/result,
+               /*types=*/TypeRange{resultType, resultType},
+               /*target=*/target,
+               /*paddingValues=*/ArrayAttr(), // let inference handle this
+               /*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions),
+               /*paddingSizes=*/dynamicPaddingSizes,
+               /*paddingSizes=*/staticPaddingSizes,
+               /*usePrescribedTensorShapes=*/padToMultipleOf);
+}
+
+void transform::PadTilingInterfaceOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  consumesHandle(getTargetMutable(), effects);
+  onlyReadsHandle(getPaddingSizesMutable(), effects);
+  producesHandle(getOperation()->getOpResults(), effects);
+  modifiesPayload(effects);
+}
+
+SmallVector<OpFoldResult>
+transform::PadTilingInterfaceOp::getMixedPaddingSizes() {
+  Builder b(getContext());
+  return getMixedValues(getStaticPaddingSizes(), getPaddingSizes(), b);
+}
+
+DiagnosedSilenceableFailure
+transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter,
+                                       transform::TransformResults &results,
+                                       transform::TransformState &state) {
+  SmallVector<Operation *> paddedOps, padOps;
+
+  for (Operation *target : state.getPayloadOps(getTarget())) {
+    auto targetOp = dyn_cast<TilingInterface>(target);
+    if (!targetOp) {
+      auto diag = emitSilenceableError() << "expected TilingInterface target";
+      diag.attachNote(target->getLoc()) << "target op";
+      return diag;
+    }
+
+    // Only Linalg ops for now, until TilingInterface exposes a loopsToOperand
+    // map / C++ APIs to compute the effect of padding on operands.
+    if (!isa<LinalgOp>(targetOp.getOperation())) {
+      auto diag = emitSilenceableError() << "only LinalgOp supported atm";
+      diag.attachNote(target->getLoc()) << "target op";
+      return diag;
+    }
+
+    // Convert the padding values to attributes.
+    SmallVector<Attribute> paddingValues;
+    for (auto const &it :
+         llvm::zip(getPaddingValues(), targetOp->getOperandTypes())) {
+      auto attr = dyn_cast<TypedAttr>(std::get<0>(it));
+      if (!attr) {
+        emitOpError("expects padding values to be typed attributes");
+        return DiagnosedSilenceableFailure::definiteFailure();
+      }
+      Type elementType = getElementTypeOrSelf(std::get<1>(it));
+      // Try to parse string attributes to obtain an attribute of element type.
+      if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
+        auto parsedAttr = dyn_cast_if_present<TypedAttr>(parseAttribute(
+            stringAttr, getContext(), elementType,
+            /*numRead=*/nullptr, /*isKnownNullTerminated=*/true));
+        if (!parsedAttr || parsedAttr.getType() != elementType) {
+          auto diag = this->emitOpError("expects a padding that parses to ")
+                      << elementType << ", got " << std::get<0>(it);
+          diag.attachNote(targetOp.getLoc()) << "when applied to this op";
+          return DiagnosedSilenceableFailure::definiteFailure();
+        }
+        paddingValues.push_back(parsedAttr);
+        continue;
+      }
+      // Otherwise, add the attribute directly.
+      if (attr.getType() != elementType) {
+        auto diag = this->emitOpError("expects a padding value of type ")
+                    << elementType << ", got " << attr;
+        diag.attachNote(targetOp.getLoc()) << "when applied to this op";
+        return DiagnosedSilenceableFailure::definiteFailure();
+      }
+      paddingValues.push_back(attr);
+    }
+
+    // Set options.
+    TilingInterface paddedOp;
+    PadTilingInterfaceOptions options;
+    options.setPaddingValues(paddingValues)
+        .setPaddingDimensions(
+            extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions()))
+        .setPaddingSizes(getMixedPaddingSizes())
+        .setPadToMultipleOf(getPadToMultipleOf());
+
+    // Apply padding.
+    SmallVector<tensor::PadOp> newPadOps;
+    FailureOr<TilingInterface> maybePaddedOp = rewriteAsPaddedOp(
+        rewriter, cast<TilingInterface>(targetOp.getOperation()), options,
+        newPadOps);
+    if (failed(maybePaddedOp)) {
+      auto diag = emitSilenceableError() << "failed to pad op";
+      diag.attachNote(target->getLoc()) << "target op";
+      return diag;
+    }
+
+    // Set transform results.
+    paddedOps.push_back(cast<TilingInterface>(maybePaddedOp->getOperation()));
+    padOps.append(newPadOps.begin(), newPadOps.end());
+  }
+
+  results.set(cast<OpResult>(getPadded()), paddedOps);
+  results.set(cast<OpResult>(getPad()), padOps);
+  return DiagnosedSilenceableFailure::success();
+}
+
+LogicalResult transform::PadTilingInterfaceOp::verify() {
+  SmallVector<int64_t> paddingDimensions =
+      extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
+  if (any_of(paddingDimensions,
+             [](int64_t paddingDimension) { return paddingDimension < 0; })) {
+    return emitOpError() << "expects padding_dimensions to contain positive "
+                            "integers, found "
+                         << getPaddingDimensions();
+  }
+  if (getMixedPaddingSizes().size() != paddingDimensions.size()) {
+    return emitOpError() << "expects as many multiples as padding_dimensions";
+  }
+  return success();
+}
+
 //===---------------------------------------------------------------------===//
 // HoistPadOp
 //===---------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 881d9fcb4f52e..69e6fdabf9a58 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -29,6 +29,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   BlockPackMatmul.cpp
   PackAndUnpackPatterns.cpp
   Padding.cpp
+  PadTilingInterface.cpp
   Promotion.cpp
   RuntimeOpVerification.cpp
   Specialize.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
new file mode 100644
index 0000000000000..a9d7bc64f2a6b
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
@@ -0,0 +1,322 @@
+//===- PaddingTilingInterface.cpp - Padding of TilingInterface ops --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Interfaces/TilingInterface.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/Casting.h"
+
+#define DEBUG_TYPE "pad-tiling-interface"
+
+using namespace mlir;
+using namespace mlir::linalg;
+using namespace mlir::tensor;
+
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
+#define DBGSNL() (llvm::dbgs() << "\n")
+
+/// Form a "full-rank" padding specification so that the application is easy.
+static llvm::SmallDenseMap<int64_t, OpFoldResult>
+getDimsToSize(Builder &b, ArrayRef<OpFoldResult> indexingSizes,
+              const PadTilingInterfaceOptions &options) {
+  llvm::SmallDenseMap<int64_t, OpFoldResult> dimsToSize;
+  for (const auto &[paddingDim, paddingSize] :
+       llvm::zip_equal(options.paddingDimensions, options.paddingSizes)) {
+    dimsToSize[paddingDim] = paddingSize;
+  }
+  // Complete the padding specification to specify all dimensions.
+  for (int64_t idx = 0, e = indexingSizes.size(); idx != e; ++idx) {
+    if (dimsToSize.find(idx) != dimsToSize.end())
+      continue;
+    // If a dimension is not specified, either complete with:
+    //   - 1 if we are padding to the next multiple of.
+    //   - indexingSizes[idx] otherwise
+    dimsToSize[idx] =
+        options.padToMultipleOf ? b.getIndexAttr(1) : indexingSizes[idx];
+  }
+  for (int64_t idx = 0, e = indexingSizes.size(); idx != e; ++idx) {
+    LLVM_DEBUG(DBGS() << "----idx: " << idx << " : " << dimsToSize[idx]
+                      << "\n");
+  }
+  return dimsToSize;
+}
+
+/// Compute the padded shape of the given value `v` of `RankedTensorType` given
+///   - `indexingSizes` a list of OpFoldResult.
+///   - an `indexingMap` that encodes how the shape of varies with increases
+///     in `indexingSizes`.
+/// The `indexingMap` encodes how the shape of varies w...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jun 20, 2025

@llvm/pr-subscribers-mlir-linalg

Author: Nicolas Vasilache (nicolasvasilache)

Changes

…t up to the transform dialect

This revision revisits the padding transformation from first principles and prepares it to work more generally with TilingInterface.

Compared to structured.transform.pad it has the following differences:

  • no support for nofold, copy-back, transpose and hoisting: these have been carried by the padding op in the very early days of StructuredOps and have since then been separated out as independent transformations that compose.
  • no conflated static bounding box derivation attempts: pad_tiling_interface composes more naturally with or without tiling.
  • properly derives padding size on outputs where multiple dimensions contribute: this is not supported in structured.transform.pad
  • geared towards supporting TilingInterface once the proper control mechanisms are supported through a PadSizeComputationFunction (supports LinalgOp by default)

This will gradually replace structured.transform.pad as it is fleshed out and tested more comprehensively.


Patch is 35.27 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/144991.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+76)
  • (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+76-3)
  • (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+161)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp (+322)
  • (added) mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir (+73)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 6f6df350f1ba6..c00151dea54a6 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1186,6 +1186,82 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// PadTilingInterfaceOp
+//===----------------------------------------------------------------------===//
+
+def PadTilingInterfaceOp : Op<Transform_Dialect, "structured.pad_tiling_interface",
+    [FunctionalStyleTransformOpTrait, DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+     TransformOpInterface,
+     ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    Pads the operations pointed to by the target handle using the options
+    provided as operation attributes. The operation returns a handle to the
+    padded operation and to the padding operation ("tensor.pad").
+
+    #### Return modes
+
+    This operation ignores non-Linalg ops and drops them in the return.
+    In the future, this operation will support all TilingInterfaceOps.
+
+    This operation may produce a definite failure if the padding fails for any
+    reason.
+
+    If all the operations referred to by the `target` handle pad properly, the
+    transform succeeds. Otherwise the transform produces a silenceable failure.
+    The return handle points to only the subset of successfully produced
+    padded operations, which can be empty.
+  }];
+
+  let arguments =
+    (ins TransformHandleTypeInterface:$target,
+         DefaultValuedAttr<ArrayAttr, "{}">:$padding_values,
+         DefaultValuedAttr<I64ArrayAttr, "{}">:$padding_dimensions,
+         Variadic<TransformAnyParamTypeOrAnyHandle>:$padding_sizes,
+         DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:
+            $static_padding_sizes,
+         DefaultValuedAttr<UnitAttr, "false">:$pad_to_multiple_of);
+  let results = (outs TransformHandleTypeInterface:$padded,
+                      TransformHandleTypeInterface:$pad);
+
+  let assemblyFormat = [{
+    $target
+    `to`
+    (`padding_sizes` custom<DynamicIndexList>($padding_sizes, $static_padding_sizes)^)?
+    (`pad_to_multiple_of` $pad_to_multiple_of^)?
+    attr-dict
+    `:` functional-type(operands, results)
+  }];
+
+  let hasVerifier = 1;
+
+  let builders = [
+    // Builder for a transform::PadOp with automatic inference of padding
+    // value. Warning: this will set the value 0 for the inferred elemental
+    // type without taking the op into account and thus only work for the
+    // add/mul ring at the moment.
+    // TODO: support other operations (e.g. min, max etc).
+    OpBuilder<(ins "Value":$target,
+                   "ArrayRef<int64_t>":$paddingDimensions,
+                   CArg<"ArrayRef<int64_t>", "{}">:$staticPaddingSizes,
+                   CArg<"bool", "false">:$padToMultipleOf)>,
+    OpBuilder<(ins "Value":$target,
+                   "ArrayRef<int64_t>":$paddingDimensions,
+                   "ArrayRef<OpFoldResult>":$mixedPadPaddingSizes,
+                   CArg<"bool", "false">:$usePrescribedTensorShapes)>
+  ];
+
+  let extraClassDeclaration = [{
+    /// Returns a mix of dynamic `padding_sizes` and static `static_padding_sizes`.
+    SmallVector<OpFoldResult> getMixedPaddingSizes();
+
+    ::mlir::DiagnosedSilenceableFailure apply(
+      ::mlir::transform::TransformRewriter &rewriter,
+      ::mlir::transform::TransformResults &results,
+      ::mlir::transform::TransformState &state);
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // HoistPadOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 147a2907f52e4..59b7fdeef10b3 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -20,6 +20,7 @@
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 #include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/TilingInterface.h"
 #include "mlir/Transforms/DialectConversion.h"
@@ -347,6 +348,34 @@ struct LinalgPaddingOptions {
   }
 };
 
+struct PadTilingInterfaceOptions {
+  /// A padding value for every operand.
+  SmallVector<Attribute> paddingValues;
+  PadTilingInterfaceOptions &setPaddingValues(ArrayRef<Attribute> pv) {
+    paddingValues.assign(pv.begin(), pv.end());
+    return *this;
+  }
+  /// A list of iterator dimensions to pad.
+  SmallVector<int64_t> paddingDimensions;
+  PadTilingInterfaceOptions &setPaddingDimensions(ArrayRef<int64_t> pd) {
+    paddingDimensions.assign(pd.begin(), pd.end());
+    return *this;
+  }
+  /// A list of iterator dimensions sizes to pad to.
+  SmallVector<OpFoldResult> paddingSizes;
+  PadTilingInterfaceOptions &setPaddingSizes(ArrayRef<OpFoldResult> m) {
+    paddingSizes.assign(m.begin(), m.end());
+    return *this;
+  }
+  /// Pad iterator `paddingDimension[i]` to next multiple of `paddingSizes[i]`
+  /// if true. Otherwise pad to `paddingSizes[i]`.
+  bool padToMultipleOf;
+  PadTilingInterfaceOptions &setPadToMultipleOf(bool b) {
+    padToMultipleOf = b;
+    return *this;
+  }
+};
+
 /// Callback function type used to perform the allocation for the promoted
 /// `subView`. In `boundingSubViewsize` a best attempt is made to find the
 /// smallest constant value for the size of the buffer needed for each
@@ -542,9 +571,9 @@ SmallVector<Value> peelLoop(RewriterBase &rewriter, Operation *op);
 /// where relevant.
 void peelLoops(RewriterBase &rewriter, ArrayRef<scf::ForOp> loops);
 
-/// Pad the iterator dimensions `paddingDimensions` of all `opToPad` operands
-/// to a static bounding box. The original `opToPad` is cloned and operates on
-/// the padded tensors.
+/// Pad the iterator dimensions `options.paddingDimensions` of all `opToPad`
+/// operands to a static bounding box. The original `opToPad` is cloned and
+/// operates on the padded tensors.
 ///
 /// * "options.padToMultipleOf" indicates that each padding dimension should be
 ///   padded to the specified multiple.
@@ -561,6 +590,50 @@ LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
                                 SmallVector<Value> &replacements,
                                 SmallVector<tensor::PadOp> &padOps);
 
+/// Helper function to compute the padded shape of the given value `v` of
+/// `RankedTensorType` given:
+///   - the `indexingSizes` as a list of OpFoldResult.
+///   - an `indexingMap` that encodes how the padded shape varies with
+///     increases in `indexingSizes`.
+/// The implementation iteratively combines increases from contributing using
+/// affine.apply operations.
+/// The `indexingMap` + `indexingSizes` encoding suits StructuredOps and
+/// provides a gentle portability path for Linalg-like ops with affine maps.
+/// In the future, more general interfaces can be devised to encode similar
+/// shape evolutions and map between an op and its operands.
+SmallVector<OpFoldResult>
+computePaddedShape(RewriterBase &rewriter, TypedValue<RankedTensorType> v,
+                   AffineMap indexingMap, ArrayRef<OpFoldResult> indexingSizes,
+                   const PadTilingInterfaceOptions &options);
+
+using PadSizeComputationFunction =
+    std::function<FailureOr<SmallVector<OpFoldResult>>(
+        RewriterBase &, OpOperand &, ArrayRef<Range>,
+        const PadTilingInterfaceOptions &)>;
+
+/// Specific helper for Linalg ops.
+FailureOr<SmallVector<OpFoldResult>>
+computeLinalgPaddedShape(RewriterBase &rewriter, OpOperand &operandToPad,
+                         ArrayRef<Range> iterationDomain,
+                         const PadTilingInterfaceOptions &options);
+
+/// Pad the iterator dimensions `options.paddingDimensions` of `opToPad`.
+///
+/// * "options.paddingSizes" indicates that each padding dimension should be
+///   padded to the specified padding size.
+/// * "options.padToMultipleOf" indicates that the paddingSizes should be
+//    interpreted as the bounding box (dynamic) value to pad to.
+/// * Use "options.paddingValues" to set the padding value of the created
+//    tensor::PadOp.
+/// * The tensor::PadOp is returned on success.
+
+FailureOr<TilingInterface>
+rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad,
+                  const PadTilingInterfaceOptions &constOptions,
+                  SmallVector<tensor::PadOp> &padOps,
+                  PadSizeComputationFunction computePaddingSizeFun =
+                      &computeLinalgPaddedShape);
+
 namespace detail {
 
 /// Helper struct to hold the results of building a packing loop nest.
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index d78c8847f8843..bc8ff8e55bf0f 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -45,6 +45,7 @@
 #include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/LogicalResult.h"
 #include <type_traits>
 
 using namespace mlir;
@@ -2155,6 +2156,166 @@ LogicalResult transform::PadOp::verify() {
   return success();
 }
 
+//===---------------------------------------------------------------------===//
+// PadTilingInterfaceOp
+//===---------------------------------------------------------------------===//
+
+void transform::PadTilingInterfaceOp::build(OpBuilder &b,
+                                            OperationState &result,
+                                            Value target,
+                                            ArrayRef<int64_t> paddingDimensions,
+                                            ArrayRef<int64_t> paddingSizes,
+                                            bool padToMultipleOf) {
+  auto resultType = transform::AnyOpType::get(b.getContext());
+  return build(/*builder=*/b,
+               /*result=*/result,
+               /*types=*/TypeRange{resultType, resultType},
+               /*target=*/target,
+               /*paddingValues=*/ArrayAttr(), // let inference handle this
+               /*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions),
+               /*paddingSizes=*/ValueRange{},
+               /*paddingSizes=*/
+               (paddingSizes.empty() ? DenseI64ArrayAttr()
+                                     : b.getDenseI64ArrayAttr(paddingSizes)),
+               /*padToMultipleOf=*/
+               padToMultipleOf ? b.getUnitAttr() : nullptr);
+}
+
+void transform::PadTilingInterfaceOp::build(
+    OpBuilder &b, OperationState &result, Value target,
+    ArrayRef<int64_t> paddingDimensions,
+    ArrayRef<OpFoldResult> mixedPaddingSizes, bool padToMultipleOf) {
+  auto resultType = transform::AnyOpType::get(b.getContext());
+  SmallVector<int64_t> staticPaddingSizes;
+  SmallVector<Value> dynamicPaddingSizes;
+  dispatchIndexOpFoldResults(mixedPaddingSizes, dynamicPaddingSizes,
+                             staticPaddingSizes);
+  return build(/*builder=*/b,
+               /*result=*/result,
+               /*types=*/TypeRange{resultType, resultType},
+               /*target=*/target,
+               /*paddingValues=*/ArrayAttr(), // let inference handle this
+               /*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions),
+               /*paddingSizes=*/dynamicPaddingSizes,
+               /*paddingSizes=*/staticPaddingSizes,
+               /*usePrescribedTensorShapes=*/padToMultipleOf);
+}
+
+void transform::PadTilingInterfaceOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  consumesHandle(getTargetMutable(), effects);
+  onlyReadsHandle(getPaddingSizesMutable(), effects);
+  producesHandle(getOperation()->getOpResults(), effects);
+  modifiesPayload(effects);
+}
+
+SmallVector<OpFoldResult>
+transform::PadTilingInterfaceOp::getMixedPaddingSizes() {
+  Builder b(getContext());
+  return getMixedValues(getStaticPaddingSizes(), getPaddingSizes(), b);
+}
+
+DiagnosedSilenceableFailure
+transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter,
+                                       transform::TransformResults &results,
+                                       transform::TransformState &state) {
+  SmallVector<Operation *> paddedOps, padOps;
+
+  for (Operation *target : state.getPayloadOps(getTarget())) {
+    auto targetOp = dyn_cast<TilingInterface>(target);
+    if (!targetOp) {
+      auto diag = emitSilenceableError() << "expected TilingInterface target";
+      diag.attachNote(target->getLoc()) << "target op";
+      return diag;
+    }
+
+    // Only Linalg ops for now, until TilingInterface exposes a loopsToOperand
+    // map / C++ APIs to compute the effect of padding on operands.
+    if (!isa<LinalgOp>(targetOp.getOperation())) {
+      auto diag = emitSilenceableError() << "only LinalgOp supported atm";
+      diag.attachNote(target->getLoc()) << "target op";
+      return diag;
+    }
+
+    // Convert the padding values to attributes.
+    SmallVector<Attribute> paddingValues;
+    for (auto const &it :
+         llvm::zip(getPaddingValues(), targetOp->getOperandTypes())) {
+      auto attr = dyn_cast<TypedAttr>(std::get<0>(it));
+      if (!attr) {
+        emitOpError("expects padding values to be typed attributes");
+        return DiagnosedSilenceableFailure::definiteFailure();
+      }
+      Type elementType = getElementTypeOrSelf(std::get<1>(it));
+      // Try to parse string attributes to obtain an attribute of element type.
+      if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
+        auto parsedAttr = dyn_cast_if_present<TypedAttr>(parseAttribute(
+            stringAttr, getContext(), elementType,
+            /*numRead=*/nullptr, /*isKnownNullTerminated=*/true));
+        if (!parsedAttr || parsedAttr.getType() != elementType) {
+          auto diag = this->emitOpError("expects a padding that parses to ")
+                      << elementType << ", got " << std::get<0>(it);
+          diag.attachNote(targetOp.getLoc()) << "when applied to this op";
+          return DiagnosedSilenceableFailure::definiteFailure();
+        }
+        paddingValues.push_back(parsedAttr);
+        continue;
+      }
+      // Otherwise, add the attribute directly.
+      if (attr.getType() != elementType) {
+        auto diag = this->emitOpError("expects a padding value of type ")
+                    << elementType << ", got " << attr;
+        diag.attachNote(targetOp.getLoc()) << "when applied to this op";
+        return DiagnosedSilenceableFailure::definiteFailure();
+      }
+      paddingValues.push_back(attr);
+    }
+
+    // Set options.
+    TilingInterface paddedOp;
+    PadTilingInterfaceOptions options;
+    options.setPaddingValues(paddingValues)
+        .setPaddingDimensions(
+            extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions()))
+        .setPaddingSizes(getMixedPaddingSizes())
+        .setPadToMultipleOf(getPadToMultipleOf());
+
+    // Apply padding.
+    SmallVector<tensor::PadOp> newPadOps;
+    FailureOr<TilingInterface> maybePaddedOp = rewriteAsPaddedOp(
+        rewriter, cast<TilingInterface>(targetOp.getOperation()), options,
+        newPadOps);
+    if (failed(maybePaddedOp)) {
+      auto diag = emitSilenceableError() << "failed to pad op";
+      diag.attachNote(target->getLoc()) << "target op";
+      return diag;
+    }
+
+    // Set transform results.
+    paddedOps.push_back(cast<TilingInterface>(maybePaddedOp->getOperation()));
+    padOps.append(newPadOps.begin(), newPadOps.end());
+  }
+
+  results.set(cast<OpResult>(getPadded()), paddedOps);
+  results.set(cast<OpResult>(getPad()), padOps);
+  return DiagnosedSilenceableFailure::success();
+}
+
+LogicalResult transform::PadTilingInterfaceOp::verify() {
+  SmallVector<int64_t> paddingDimensions =
+      extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
+  if (any_of(paddingDimensions,
+             [](int64_t paddingDimension) { return paddingDimension < 0; })) {
+    return emitOpError() << "expects padding_dimensions to contain positive "
+                            "integers, found "
+                         << getPaddingDimensions();
+  }
+  if (getMixedPaddingSizes().size() != paddingDimensions.size()) {
+    return emitOpError() << "expects as many multiples as padding_dimensions";
+  }
+  return success();
+}
+
 //===---------------------------------------------------------------------===//
 // HoistPadOp
 //===---------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 881d9fcb4f52e..69e6fdabf9a58 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -29,6 +29,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   BlockPackMatmul.cpp
   PackAndUnpackPatterns.cpp
   Padding.cpp
+  PadTilingInterface.cpp
   Promotion.cpp
   RuntimeOpVerification.cpp
   Specialize.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
new file mode 100644
index 0000000000000..a9d7bc64f2a6b
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
@@ -0,0 +1,322 @@
+//===- PaddingTilingInterface.cpp - Padding of TilingInterface ops --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Interfaces/TilingInterface.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/Casting.h"
+
+#define DEBUG_TYPE "pad-tiling-interface"
+
+using namespace mlir;
+using namespace mlir::linalg;
+using namespace mlir::tensor;
+
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
+#define DBGSNL() (llvm::dbgs() << "\n")
+
+/// Form a "full-rank" padding specification so that the application is easy.
+static llvm::SmallDenseMap<int64_t, OpFoldResult>
+getDimsToSize(Builder &b, ArrayRef<OpFoldResult> indexingSizes,
+              const PadTilingInterfaceOptions &options) {
+  llvm::SmallDenseMap<int64_t, OpFoldResult> dimsToSize;
+  for (const auto &[paddingDim, paddingSize] :
+       llvm::zip_equal(options.paddingDimensions, options.paddingSizes)) {
+    dimsToSize[paddingDim] = paddingSize;
+  }
+  // Complete the padding specification to specify all dimensions.
+  for (int64_t idx = 0, e = indexingSizes.size(); idx != e; ++idx) {
+    if (dimsToSize.find(idx) != dimsToSize.end())
+      continue;
+    // If a dimension is not specified, either complete with:
+    //   - 1 if we are padding to the next multiple of.
+    //   - indexingSizes[idx] otherwise
+    dimsToSize[idx] =
+        options.padToMultipleOf ? b.getIndexAttr(1) : indexingSizes[idx];
+  }
+  for (int64_t idx = 0, e = indexingSizes.size(); idx != e; ++idx) {
+    LLVM_DEBUG(DBGS() << "----idx: " << idx << " : " << dimsToSize[idx]
+                      << "\n");
+  }
+  return dimsToSize;
+}
+
+/// Compute the padded shape of the given value `v` of `RankedTensorType` given
+///   - `indexingSizes` a list of OpFoldResult.
+///   - an `indexingMap` that encodes how the shape of varies with increases
+///     in `indexingSizes`.
+/// The `indexingMap` encodes how the shape of varies w...
[truncated]

Copy link
Member

@Groverkss Groverkss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, much cleaner than the old padding transform!

Can we add a note somewhere that eventually this transform should be moved out of linalg? For now I know it can only work on linalg (and sets the roots for it to work on TilingInterface) so it's okay.

@nicolasvasilache nicolasvasilache merged commit 00c18d0 into main Jun 20, 2025
5 of 7 checks passed
@nicolasvasilache nicolasvasilache deleted the users/nico/revisit-pad branch June 20, 2025 10:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants