Skip to content

[mlir][linalg] Add runtime verification for linalg ops #89917

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

ryanpholt
Copy link
Contributor

This commit implements runtime verification for LinalgStructuredOps using the existing RuntimeVerifiableOpInterface. The verification checks that the runtime sizes of the operands match the runtime sizes inferred by composing the loop ranges with the op's indexing maps.

@llvmbot
Copy link
Member

llvmbot commented Apr 24, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-memref

Author: Ryan Holt (ryan-holt-1)

Changes

This commit implements runtime verification for LinalgStructuredOps using the existing RuntimeVerifiableOpInterface. The verification checks that the runtime sizes of the operands match the runtime sizes inferred by composing the loop ranges with the op's indexing maps.


Patch is 29.86 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/89917.diff

9 Files Affected:

  • (added) mlir/include/mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h (+21)
  • (modified) mlir/include/mlir/InitAllDialects.h (+2)
  • (modified) mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td (+6)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt (+2)
  • (added) mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp (+135)
  • (modified) mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp (+21-33)
  • (modified) mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp (+22)
  • (added) mlir/test/Dialect/Linalg/runtime-verification.mlir (+43)
  • (added) mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir (+298)
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h b/mlir/include/mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h
new file mode 100644
index 00000000000000..6c3643f7835cbe
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h
@@ -0,0 +1,21 @@
+//===- RuntimeOpVerification.h - Op Verification ----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LINALG_RUNTIMEOPVERIFICATION_H
+#define MLIR_DIALECT_LINALG_RUNTIMEOPVERIFICATION_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace linalg {
+void registerRuntimeVerifiableOpInterfaceExternalModels(
+    DialectRegistry &registry);
+} // namespace linalg
+} // namespace mlir
+
+#endif // MLIR_DIALECT_LINALG_RUNTIMEOPVERIFICATION_H
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index c4d788cf8ed316..d9db21073e15c7 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -45,6 +45,7 @@
 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Transforms/AllInterfaces.h"
+#include "mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h"
 #include "mlir/Dialect/MLProgram/IR/MLProgram.h"
 #include "mlir/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/MPI/IR/MPI.h"
@@ -161,6 +162,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
   cf::registerBufferDeallocationOpInterfaceExternalModels(registry);
   gpu::registerBufferDeallocationOpInterfaceExternalModels(registry);
   linalg::registerAllDialectInterfaceImplementations(registry);
+  linalg::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
   memref::registerAllocationOpInterfaceExternalModels(registry);
   memref::registerBufferViewFlowOpInterfaceExternalModels(registry);
   memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
diff --git a/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td b/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td
index d5f11d00cc3d2a..6fd0df59d9d2e0 100644
--- a/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td
+++ b/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td
@@ -35,6 +35,12 @@ def RuntimeVerifiableOpInterface : OpInterface<"RuntimeVerifiableOpInterface"> {
                     "::mlir::Location":$loc)
     >,
   ];
+
+  let extraClassDeclaration = [{
+    /// Generate the error message that will be printed to the user when 
+    /// verification fails.
+    static std::string generateErrorMessage(Operation *op, const std::string &msg);
+  }];
 }
 
 #endif // MLIR_INTERFACES_RUNTIMEVERIFIABLEOPINTERFACE
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index ee6e391d0cc682..3b5282a09569d7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -27,6 +27,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   NamedOpConversions.cpp
   Padding.cpp
   Promotion.cpp
+  RuntimeOpVerification.cpp
   Specialize.cpp
   Split.cpp
   SplitReduction.cpp
@@ -60,6 +61,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   MLIRFuncDialect
   MLIRFuncToLLVM
   MLIRFuncTransforms
+  MLIRIndexDialect
   MLIRInferTypeOpInterface
   MLIRIR
   MLIRMemRefDialect
diff --git a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
new file mode 100644
index 00000000000000..b30182dc84079f
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
@@ -0,0 +1,135 @@
+//===- RuntimeOpVerification.cpp - Op Verification ------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h"
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/Dialect/Index/IR/IndexAttrs.h"
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
+
+namespace mlir {
+namespace linalg {
+namespace {
+/// Verify that the runtime sizes of the operands to linalg structured ops are
+/// compatible with the runtime sizes inferred by composing the loop ranges with
+/// the linalg op's indexing maps. This is similar to the verifier except that
+/// here we insert IR to perform the verification at runtime.
+template <typename T>
+struct StructuredOpInterface
+    : public RuntimeVerifiableOpInterface::ExternalModel<
+          StructuredOpInterface<T>, T> {
+  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
+                                   Location loc) const {
+    auto linalgOp = llvm::cast<LinalgOp>(op);
+
+    SmallVector<Range> loopRanges = linalgOp.createLoopRanges(builder, loc);
+    auto [starts, ends, _] = getOffsetsSizesAndStrides(loopRanges);
+
+    auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+    auto one = builder.create<arith::ConstantIndexOp>(loc, 1);
+
+    // Subtract one from the loop ends before composing with the indexing map
+    transform(ends, ends.begin(), [&](OpFoldResult end) {
+      auto endValue = getValueOrCreateConstantIndexOp(builder, loc, end);
+      return builder.createOrFold<index::SubOp>(loc, endValue, one);
+    });
+
+    for (OpOperand &opOperand : linalgOp->getOpOperands()) {
+      AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
+      auto startIndices = affine::makeComposedFoldedMultiResultAffineApply(
+          builder, loc, indexingMap, starts);
+      auto endIndices = affine::makeComposedFoldedMultiResultAffineApply(
+          builder, loc, indexingMap, ends);
+
+      for (auto dim : llvm::seq(linalgOp.getRank(&opOperand))) {
+        auto startIndex =
+            getValueOrCreateConstantIndexOp(builder, loc, startIndices[dim]);
+        auto endIndex =
+            getValueOrCreateConstantIndexOp(builder, loc, endIndices[dim]);
+
+        // Generate:
+        //   minIndex = min(startIndex, endIndex)
+        //   assert(minIndex >= 0)
+        // To ensure we do not generate a negative index. We take the minimum of
+        // the start and end indices in order to handle reverse loops such as
+        // `affine_map<(i) -> (3 - i)>`
+        auto min =
+            builder.createOrFold<index::MinSOp>(loc, startIndex, endIndex);
+        auto cmpOp = builder.createOrFold<index::CmpOp>(
+            loc, index::IndexCmpPredicate::SGE, min, zero);
+        auto msg = RuntimeVerifiableOpInterface::generateErrorMessage(
+            linalgOp, "unexpected negative result on dimension #" +
+                          std::to_string(dim) + " of input/output operand #" +
+                          std::to_string(opOperand.getOperandNumber()));
+        builder.createOrFold<cf::AssertOp>(loc, cmpOp, msg);
+
+        // Generate:
+        //   inferredDimSize = max(startIndex, endIndex) + 1
+        //   actualDimSize = dim(operand)
+        //   assert(inferredDimSize <= actualDimSize)
+        // To ensure that we do not index past the bounds of the operands.
+        auto max =
+            builder.createOrFold<index::MaxSOp>(loc, startIndex, endIndex);
+
+        auto inferredDimSize =
+            builder.createOrFold<index::AddOp>(loc, max, one);
+
+        auto actualDimSize =
+            createOrFoldDimOp(builder, loc, opOperand.get(), dim);
+
+        // Similar to the verifier, when the affine expression in the indexing
+        // map is complicated, we just check that the inferred dimension sizes
+        // are in the boundary of the operands' size. Being more precise than
+        // that is difficult.
+        auto predicate = isa<AffineDimExpr>(indexingMap.getResult(dim))
+                             ? index::IndexCmpPredicate::EQ
+                             : index::IndexCmpPredicate::SLE;
+
+        cmpOp = builder.createOrFold<index::CmpOp>(
+            loc, predicate, inferredDimSize, actualDimSize);
+        msg = RuntimeVerifiableOpInterface::generateErrorMessage(
+            linalgOp, "dimension #" + std::to_string(dim) +
+                          " of input/output operand #" +
+                          std::to_string(opOperand.getOperandNumber()) +
+                          " is incompatible with inferred dimension size");
+        builder.createOrFold<cf::AssertOp>(loc, cmpOp, msg);
+      }
+    }
+  }
+};
+
+template <typename... OpTs>
+void attachInterface(MLIRContext *ctx) {
+  (OpTs::template attachInterface<StructuredOpInterface<OpTs>>(*ctx), ...);
+}
+} // namespace
+} // namespace linalg
+} // namespace mlir
+
+void mlir::linalg::registerRuntimeVerifiableOpInterfaceExternalModels(
+    DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, LinalgDialect *) {
+    attachInterface<
+#define GET_OP_LIST
+#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
+        >(ctx);
+
+    // Load additional dialects of which ops may get created.
+    ctx->loadDialect<affine::AffineDialect, arith::ArithDialect,
+                     cf::ControlFlowDialect, index::IndexDialect,
+                     tensor::TensorDialect>();
+  });
+}
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index 05b813a3b1e908..450bfa0cec0c7f 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -20,25 +20,6 @@
 
 using namespace mlir;
 
-/// Generate an error message string for the given op and the specified error.
-static std::string generateErrorMessage(Operation *op, const std::string &msg) {
-  std::string buffer;
-  llvm::raw_string_ostream stream(buffer);
-  OpPrintingFlags flags;
-  // We may generate a lot of error messages and so we need to ensure the
-  // printing is fast.
-  flags.elideLargeElementsAttrs();
-  flags.printGenericOpForm();
-  flags.skipRegions();
-  flags.useLocalScope();
-  stream << "ERROR: Runtime op verification failed\n";
-  op->print(stream, flags);
-  stream << "\n^ " << msg;
-  stream << "\nLocation: ";
-  op->getLoc().print(stream);
-  return stream.str();
-}
-
 namespace mlir {
 namespace memref {
 namespace {
@@ -62,8 +43,10 @@ struct CastOpInterface
           builder.create<arith::ConstantIndexOp>(loc, resultType.getRank());
       Value isSameRank = builder.create<arith::CmpIOp>(
           loc, arith::CmpIPredicate::eq, srcRank, resultRank);
-      builder.create<cf::AssertOp>(loc, isSameRank,
-                                   generateErrorMessage(op, "rank mismatch"));
+      builder.create<cf::AssertOp>(
+          loc, isSameRank,
+          RuntimeVerifiableOpInterface::generateErrorMessage(op,
+                                                             "rank mismatch"));
     }
 
     // Get source offset and strides. We do not have an op to get offsets and
@@ -101,8 +84,8 @@ struct CastOpInterface
           loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
       builder.create<cf::AssertOp>(
           loc, isSameSz,
-          generateErrorMessage(op, "size mismatch of dim " +
-                                       std::to_string(it.index())));
+          RuntimeVerifiableOpInterface::generateErrorMessage(
+              op, "size mismatch of dim " + std::to_string(it.index())));
     }
 
     // Get result offset and strides.
@@ -119,8 +102,10 @@ struct CastOpInterface
           builder.create<arith::ConstantIndexOp>(loc, resultOffset);
       Value isSameOffset = builder.create<arith::CmpIOp>(
           loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal);
-      builder.create<cf::AssertOp>(loc, isSameOffset,
-                                   generateErrorMessage(op, "offset mismatch"));
+      builder.create<cf::AssertOp>(
+          loc, isSameOffset,
+          RuntimeVerifiableOpInterface::generateErrorMessage(
+              op, "offset mismatch"));
     }
 
     // Check strides.
@@ -137,8 +122,8 @@ struct CastOpInterface
           loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal);
       builder.create<cf::AssertOp>(
           loc, isSameStride,
-          generateErrorMessage(op, "stride mismatch of dim " +
-                                       std::to_string(it.index())));
+          RuntimeVerifiableOpInterface::generateErrorMessage(
+              op, "stride mismatch of dim " + std::to_string(it.index())));
     }
   }
 };
@@ -178,7 +163,9 @@ struct LoadStoreOpInterface
                 : andOp;
     }
     builder.create<cf::AssertOp>(
-        loc, assertCond, generateErrorMessage(op, "out-of-bounds access"));
+        loc, assertCond,
+        RuntimeVerifiableOpInterface::generateErrorMessage(
+            op, "out-of-bounds access"));
   }
 };
 
@@ -248,7 +235,7 @@ struct ReinterpretCastOpInterface
 
     builder.create<cf::AssertOp>(
         loc, assertCond,
-        generateErrorMessage(
+        RuntimeVerifiableOpInterface::generateErrorMessage(
             op,
             "result of reinterpret_cast is out-of-bounds of the base memref"));
   }
@@ -293,8 +280,8 @@ struct SubViewOpInterface
 
     builder.create<cf::AssertOp>(
         loc, assertCond,
-        generateErrorMessage(op,
-                             "subview is out-of-bounds of the base memref"));
+        RuntimeVerifiableOpInterface::generateErrorMessage(
+            op, "subview is out-of-bounds of the base memref"));
   }
 };
 
@@ -334,8 +321,9 @@ struct ExpandShapeOpInterface
           builder.create<arith::ConstantIndexOp>(loc, 0));
       builder.create<cf::AssertOp>(
           loc, isModZero,
-          generateErrorMessage(op, "static result dims in reassoc group do not "
-                                   "divide src dim evenly"));
+          RuntimeVerifiableOpInterface::generateErrorMessage(
+              op, "static result dims in reassoc group do not "
+                  "divide src dim evenly"));
     }
   }
 };
diff --git a/mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp b/mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp
index 9205d8d8c34a29..e823b5df179c50 100644
--- a/mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp
+++ b/mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp
@@ -11,6 +11,28 @@
 namespace mlir {
 class Location;
 class OpBuilder;
+
+/// Generate an error message string for the given op and the specified error.
+std::string
+RuntimeVerifiableOpInterface::generateErrorMessage(Operation *op,
+                                                   const std::string &msg) {
+  std::string buffer;
+  llvm::raw_string_ostream stream(buffer);
+  OpPrintingFlags flags;
+  // We may generate a lot of error messages and so we need to ensure the
+  // printing is fast.
+  flags.elideLargeElementsAttrs();
+  flags.printGenericOpForm();
+  flags.skipRegions();
+  flags.useLocalScope();
+  stream << "ERROR: Runtime op verification failed\n";
+  op->print(stream, flags);
+  stream << "\n^ " << msg;
+  stream << "\nLocation: ";
+  op->getLoc().print(stream);
+  return stream.str();
+}
+
 } // namespace mlir
 
 /// Include the definitions of the interface.
diff --git a/mlir/test/Dialect/Linalg/runtime-verification.mlir b/mlir/test/Dialect/Linalg/runtime-verification.mlir
new file mode 100644
index 00000000000000..a4f29d8457e589
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/runtime-verification.mlir
@@ -0,0 +1,43 @@
+// RUN: mlir-opt %s -generate-runtime-verification | FileCheck %s
+
+// Most of the tests for linalg runtime-verification are implemented as integration tests.
+
+#identity = affine_map<(d0) -> (d0)>
+
+// CHECK-LABEL: @static_dims
+func.func @static_dims(%arg0: tensor<5xf32>, %arg1: tensor<5xf32>) -> (tensor<5xf32>) {
+    // CHECK: %[[TRUE:.*]] = index.bool.constant true
+    // CHECK: cf.assert %[[TRUE]]
+    %result = tensor.empty() : tensor<5xf32> 
+    %0 = linalg.generic {
+      indexing_maps = [#identity, #identity, #identity],
+      iterator_types = ["parallel"]
+    } ins(%arg0, %arg1 : tensor<5xf32>, tensor<5xf32>)
+      outs(%result : tensor<5xf32>) {
+      ^bb0(%gen_arg1: f32, %gen_arg2: f32, %out: f32) :
+        %tmp1 = arith.addf %gen_arg1, %gen_arg2 : f32
+        linalg.yield %tmp1 : f32
+    } -> tensor<5xf32>
+    return %0 : tensor<5xf32>
+}
+
+// -----
+
+#map = affine_map<() -> ()>
+
+// CHECK-LABEL: @scalars
+func.func @scalars(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>) {
+    // No runtime checks are required if the operands are all scalars
+    // CHECK-NOT: cf.assert
+    %result = tensor.empty() : tensor<f32> 
+    %0 = linalg.generic {
+      indexing_maps = [#map, #map, #map],
+      iterator_types = []
+    } ins(%arg0, %arg1 : tensor<f32>, tensor<f32>)
+      outs(%result : tensor<f32>) {
+      ^bb0(%gen_arg1: f32, %gen_arg2: f32, %out: f32) :
+        %tmp1 = arith.addf %gen_arg1, %gen_arg2 : f32
+        linalg.yield %tmp1 : f32
+    } -> tensor<f32>
+    return %0 : tensor<f32>
+}
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir
new file mode 100644
index 00000000000000..b05ef9422e5967
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir
@@ -0,0 +1,298 @@
+// RUN: mlir-opt %s -generate-runtime-verification \
+// RUN: -one-shot-bufferize="bufferize-function-boundaries" \
+// RUN: -convert-linalg-to-loops \
+// RUN: -expand-strided-metadata \
+// RUN: -lower-affine \
+// RUN: -convert-scf-to-cf \
+// RUN: -test-cf-assert \
+// RUN: -convert-index-to-llvm \
+// RUN: -finalize-memref-to-llvm \
+// RUN: -convert-func-to-llvm \
+// RUN: -reconcile-unrealized-casts | \
+// RUN: mlir-cpu-runner -e main -entry-point-result=void \
+// RUN:     -shared-libs=%mlir_runner_utils \
+// RUN:     -shared-libs=%mlir_c_runner_utils 2>&1 | \
+// RUN: FileCheck %s
+
+func.func @main() {
+  %c5x = arith.constant dense<0.0> : tensor<5xf32>
+  %c4x = arith.constant dense<0.0> : tensor<4xf32>
+  %d5x = tensor.cast %c5x : tensor<5xf32> to tensor<?xf32>
+  %d4x = tensor.cast %c4x : tensor<4xf32> to tensor<?xf32>
+
+  // CHECK-NOT: ERROR: Runtime op verification failed
+  func.call @simple_add(%d5x, %d5x) : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>)
+
+  // CHECK: ERROR: Runtime op verification failed
+  // CHECK: linalg.generic
+  // CHECK: ^ dimension #0 of input/output operand #1 is incompatible with inferred dimension size
+  func.call @simple_add(%d5x, %d4x) : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>)
+
+  // CHECK: ERROR: Runtime op verification failed
+  // CHECK: linalg.generic
+  // CHECK: ^ dimension #0 of input/output operand #1 is incompatible with inferred dimension size
+  func.call @simple_add(%d4x, %d5x) : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>)
+
+  %c1x1 = arith.constant dense<0.0> : tensor<1x1xf32>
+  %c1x4 = arith.constant dense<0.0> : tensor<1x4xf32>
+  %c4x4 = arith.constant dense<0.0> : tensor<4x4xf32>
+  %c4x5 = arith.constant dense<0.0> : tensor<4x5xf32>
+  %c5x4 = arith.constant dense<0.0> : tensor<5x4xf32>
+  %d1x1 = tensor.cast %c1x1 : tensor<1x1xf32> to tensor<?x?xf32>
+  %d1x4 = tensor.cast %c1x4 : tensor<1x4xf32> to tensor<?x?xf32>
+  %d4x4 = tensor.cast %c4x4 : tensor<4x4xf32> to tensor<?x?xf32>
+  %d4x5 = tensor.cast %c4x5 : tensor<4x5xf32> to tensor<?x?xf32>
+  %d5x4 = tensor.cast %c5x4 : tensor<5x4xf32> to tensor<?x?xf32>
+
+  // CHECK-NOT: ERROR: Runtime op verification failed
+  func.call @broadcast_add(%d1x1, %d1x1) : (tensor<?x?xf32>, tenso...
[truncated]

@ryanpholt
Copy link
Contributor Author

ryanpholt commented Apr 24, 2024

#89342 was reverted in #89780 due to a build failure. This PR re-commits the original change with the build fix (a missing dependency on the index dialect).

This commit implements runtime verification for LinalgStructuredOps
using the existing `RuntimeVerifiableOpInterface`. The verification
checks that the runtime sizes of the operands match the runtime sizes
inferred by composing the loop ranges with the op's indexing maps.
@ryanpholt ryanpholt force-pushed the reland-linalg-runtime-verification branch from 4561111 to ade076e Compare April 25, 2024 12:03
@ryanpholt
Copy link
Contributor Author

@sabauma Can you approve/merge this PR when you get a chance? It is identical to #89342 except that I added a missing dependency on the index dialect to fix the build bot.

@matthias-springer matthias-springer merged commit d94aeb5 into llvm:main Apr 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants