Skip to content

[mlir][linalg] Add runtime verification for linalg ops #89917

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
//===- RuntimeOpVerification.h - Op Verification ----------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_LINALG_RUNTIMEOPVERIFICATION_H
#define MLIR_DIALECT_LINALG_RUNTIMEOPVERIFICATION_H

namespace mlir {
class DialectRegistry;

namespace linalg {
void registerRuntimeVerifiableOpInterfaceExternalModels(
DialectRegistry &registry);
} // namespace linalg
} // namespace mlir

#endif // MLIR_DIALECT_LINALG_RUNTIMEOPVERIFICATION_H
2 changes: 2 additions & 0 deletions mlir/include/mlir/InitAllDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/AllInterfaces.h"
#include "mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h"
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/MPI/IR/MPI.h"
Expand Down Expand Up @@ -161,6 +162,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
cf::registerBufferDeallocationOpInterfaceExternalModels(registry);
gpu::registerBufferDeallocationOpInterfaceExternalModels(registry);
linalg::registerAllDialectInterfaceImplementations(registry);
linalg::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
memref::registerAllocationOpInterfaceExternalModels(registry);
memref::registerBufferViewFlowOpInterfaceExternalModels(registry);
memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ def RuntimeVerifiableOpInterface : OpInterface<"RuntimeVerifiableOpInterface"> {
"::mlir::Location":$loc)
>,
];

let extraClassDeclaration = [{
/// Generate the error message that will be printed to the user when
/// verification fails.
static std::string generateErrorMessage(Operation *op, const std::string &msg);
}];
}

#endif // MLIR_INTERFACES_RUNTIMEVERIFIABLEOPINTERFACE
2 changes: 2 additions & 0 deletions mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
NamedOpConversions.cpp
Padding.cpp
Promotion.cpp
RuntimeOpVerification.cpp
Specialize.cpp
Split.cpp
SplitReduction.cpp
Expand Down Expand Up @@ -60,6 +61,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
MLIRFuncDialect
MLIRFuncToLLVM
MLIRFuncTransforms
MLIRIndexDialect
MLIRInferTypeOpInterface
MLIRIR
MLIRMemRefDialect
Expand Down
135 changes: 135 additions & 0 deletions mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
//===- RuntimeOpVerification.cpp - Op Verification ------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h"

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Index/IR/IndexAttrs.h"
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"

namespace mlir {
namespace linalg {
namespace {
/// Verify that the runtime sizes of the operands to linalg structured ops are
/// compatible with the runtime sizes inferred by composing the loop ranges with
/// the linalg op's indexing maps. This is similar to the verifier except that
/// here we insert IR to perform the verification at runtime.
template <typename T>
struct StructuredOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<
StructuredOpInterface<T>, T> {
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
Location loc) const {
auto linalgOp = llvm::cast<LinalgOp>(op);

SmallVector<Range> loopRanges = linalgOp.createLoopRanges(builder, loc);
auto [starts, ends, _] = getOffsetsSizesAndStrides(loopRanges);

auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
auto one = builder.create<arith::ConstantIndexOp>(loc, 1);

// Subtract one from the loop ends before composing with the indexing map
transform(ends, ends.begin(), [&](OpFoldResult end) {
auto endValue = getValueOrCreateConstantIndexOp(builder, loc, end);
return builder.createOrFold<index::SubOp>(loc, endValue, one);
});

for (OpOperand &opOperand : linalgOp->getOpOperands()) {
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
auto startIndices = affine::makeComposedFoldedMultiResultAffineApply(
builder, loc, indexingMap, starts);
auto endIndices = affine::makeComposedFoldedMultiResultAffineApply(
builder, loc, indexingMap, ends);

for (auto dim : llvm::seq(linalgOp.getRank(&opOperand))) {
auto startIndex =
getValueOrCreateConstantIndexOp(builder, loc, startIndices[dim]);
auto endIndex =
getValueOrCreateConstantIndexOp(builder, loc, endIndices[dim]);

// Generate:
// minIndex = min(startIndex, endIndex)
// assert(minIndex >= 0)
// To ensure we do not generate a negative index. We take the minimum of
// the start and end indices in order to handle reverse loops such as
// `affine_map<(i) -> (3 - i)>`
auto min =
builder.createOrFold<index::MinSOp>(loc, startIndex, endIndex);
auto cmpOp = builder.createOrFold<index::CmpOp>(
loc, index::IndexCmpPredicate::SGE, min, zero);
auto msg = RuntimeVerifiableOpInterface::generateErrorMessage(
linalgOp, "unexpected negative result on dimension #" +
std::to_string(dim) + " of input/output operand #" +
std::to_string(opOperand.getOperandNumber()));
builder.createOrFold<cf::AssertOp>(loc, cmpOp, msg);

// Generate:
// inferredDimSize = max(startIndex, endIndex) + 1
// actualDimSize = dim(operand)
// assert(inferredDimSize <= actualDimSize)
// To ensure that we do not index past the bounds of the operands.
auto max =
builder.createOrFold<index::MaxSOp>(loc, startIndex, endIndex);

auto inferredDimSize =
builder.createOrFold<index::AddOp>(loc, max, one);

auto actualDimSize =
createOrFoldDimOp(builder, loc, opOperand.get(), dim);

// Similar to the verifier, when the affine expression in the indexing
// map is complicated, we just check that the inferred dimension sizes
// are in the boundary of the operands' size. Being more precise than
// that is difficult.
auto predicate = isa<AffineDimExpr>(indexingMap.getResult(dim))
? index::IndexCmpPredicate::EQ
: index::IndexCmpPredicate::SLE;

cmpOp = builder.createOrFold<index::CmpOp>(
loc, predicate, inferredDimSize, actualDimSize);
msg = RuntimeVerifiableOpInterface::generateErrorMessage(
linalgOp, "dimension #" + std::to_string(dim) +
" of input/output operand #" +
std::to_string(opOperand.getOperandNumber()) +
" is incompatible with inferred dimension size");
builder.createOrFold<cf::AssertOp>(loc, cmpOp, msg);
}
}
}
};

template <typename... OpTs>
void attachInterface(MLIRContext *ctx) {
(OpTs::template attachInterface<StructuredOpInterface<OpTs>>(*ctx), ...);
}
} // namespace
} // namespace linalg
} // namespace mlir

void mlir::linalg::registerRuntimeVerifiableOpInterfaceExternalModels(
DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, LinalgDialect *) {
attachInterface<
#define GET_OP_LIST
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
>(ctx);

// Load additional dialects of which ops may get created.
ctx->loadDialect<affine::AffineDialect, arith::ArithDialect,
cf::ControlFlowDialect, index::IndexDialect,
tensor::TensorDialect>();
});
}
54 changes: 21 additions & 33 deletions mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,6 @@

using namespace mlir;

/// Generate an error message string for the given op and the specified error.
static std::string generateErrorMessage(Operation *op, const std::string &msg) {
std::string buffer;
llvm::raw_string_ostream stream(buffer);
OpPrintingFlags flags;
// We may generate a lot of error messages and so we need to ensure the
// printing is fast.
flags.elideLargeElementsAttrs();
flags.printGenericOpForm();
flags.skipRegions();
flags.useLocalScope();
stream << "ERROR: Runtime op verification failed\n";
op->print(stream, flags);
stream << "\n^ " << msg;
stream << "\nLocation: ";
op->getLoc().print(stream);
return stream.str();
}

namespace mlir {
namespace memref {
namespace {
Expand All @@ -62,8 +43,10 @@ struct CastOpInterface
builder.create<arith::ConstantIndexOp>(loc, resultType.getRank());
Value isSameRank = builder.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, srcRank, resultRank);
builder.create<cf::AssertOp>(loc, isSameRank,
generateErrorMessage(op, "rank mismatch"));
builder.create<cf::AssertOp>(
loc, isSameRank,
RuntimeVerifiableOpInterface::generateErrorMessage(op,
"rank mismatch"));
}

// Get source offset and strides. We do not have an op to get offsets and
Expand Down Expand Up @@ -101,8 +84,8 @@ struct CastOpInterface
loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
builder.create<cf::AssertOp>(
loc, isSameSz,
generateErrorMessage(op, "size mismatch of dim " +
std::to_string(it.index())));
RuntimeVerifiableOpInterface::generateErrorMessage(
op, "size mismatch of dim " + std::to_string(it.index())));
}

// Get result offset and strides.
Expand All @@ -119,8 +102,10 @@ struct CastOpInterface
builder.create<arith::ConstantIndexOp>(loc, resultOffset);
Value isSameOffset = builder.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal);
builder.create<cf::AssertOp>(loc, isSameOffset,
generateErrorMessage(op, "offset mismatch"));
builder.create<cf::AssertOp>(
loc, isSameOffset,
RuntimeVerifiableOpInterface::generateErrorMessage(
op, "offset mismatch"));
}

// Check strides.
Expand All @@ -137,8 +122,8 @@ struct CastOpInterface
loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal);
builder.create<cf::AssertOp>(
loc, isSameStride,
generateErrorMessage(op, "stride mismatch of dim " +
std::to_string(it.index())));
RuntimeVerifiableOpInterface::generateErrorMessage(
op, "stride mismatch of dim " + std::to_string(it.index())));
}
}
};
Expand Down Expand Up @@ -178,7 +163,9 @@ struct LoadStoreOpInterface
: andOp;
}
builder.create<cf::AssertOp>(
loc, assertCond, generateErrorMessage(op, "out-of-bounds access"));
loc, assertCond,
RuntimeVerifiableOpInterface::generateErrorMessage(
op, "out-of-bounds access"));
}
};

Expand Down Expand Up @@ -248,7 +235,7 @@ struct ReinterpretCastOpInterface

builder.create<cf::AssertOp>(
loc, assertCond,
generateErrorMessage(
RuntimeVerifiableOpInterface::generateErrorMessage(
op,
"result of reinterpret_cast is out-of-bounds of the base memref"));
}
Expand Down Expand Up @@ -293,8 +280,8 @@ struct SubViewOpInterface

builder.create<cf::AssertOp>(
loc, assertCond,
generateErrorMessage(op,
"subview is out-of-bounds of the base memref"));
RuntimeVerifiableOpInterface::generateErrorMessage(
op, "subview is out-of-bounds of the base memref"));
}
};

Expand Down Expand Up @@ -334,8 +321,9 @@ struct ExpandShapeOpInterface
builder.create<arith::ConstantIndexOp>(loc, 0));
builder.create<cf::AssertOp>(
loc, isModZero,
generateErrorMessage(op, "static result dims in reassoc group do not "
"divide src dim evenly"));
RuntimeVerifiableOpInterface::generateErrorMessage(
op, "static result dims in reassoc group do not "
"divide src dim evenly"));
}
}
};
Expand Down
21 changes: 21 additions & 0 deletions mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,27 @@
namespace mlir {
class Location;
class OpBuilder;

/// Generate an error message string for the given op and the specified error.
std::string
RuntimeVerifiableOpInterface::generateErrorMessage(Operation *op,
const std::string &msg) {
std::string buffer;
llvm::raw_string_ostream stream(buffer);
OpPrintingFlags flags;
// We may generate a lot of error messages and so we need to ensure the
// printing is fast.
flags.elideLargeElementsAttrs();
flags.printGenericOpForm();
flags.skipRegions();
flags.useLocalScope();
stream << "ERROR: Runtime op verification failed\n";
op->print(stream, flags);
stream << "\n^ " << msg;
stream << "\nLocation: ";
op->getLoc().print(stream);
return stream.str();
}
} // namespace mlir

/// Include the definitions of the interface.
Expand Down
43 changes: 43 additions & 0 deletions mlir/test/Dialect/Linalg/runtime-verification.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// RUN: mlir-opt %s -generate-runtime-verification | FileCheck %s

// Most of the tests for linalg runtime-verification are implemented as integration tests.

#identity = affine_map<(d0) -> (d0)>

// CHECK-LABEL: @static_dims
func.func @static_dims(%arg0: tensor<5xf32>, %arg1: tensor<5xf32>) -> (tensor<5xf32>) {
// CHECK: %[[TRUE:.*]] = index.bool.constant true
// CHECK: cf.assert %[[TRUE]]
%result = tensor.empty() : tensor<5xf32>
%0 = linalg.generic {
indexing_maps = [#identity, #identity, #identity],
iterator_types = ["parallel"]
} ins(%arg0, %arg1 : tensor<5xf32>, tensor<5xf32>)
outs(%result : tensor<5xf32>) {
^bb0(%gen_arg1: f32, %gen_arg2: f32, %out: f32) :
%tmp1 = arith.addf %gen_arg1, %gen_arg2 : f32
linalg.yield %tmp1 : f32
} -> tensor<5xf32>
return %0 : tensor<5xf32>
}

// -----

#map = affine_map<() -> ()>

// CHECK-LABEL: @scalars
func.func @scalars(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>) {
// No runtime checks are required if the operands are all scalars
// CHECK-NOT: cf.assert
%result = tensor.empty() : tensor<f32>
%0 = linalg.generic {
indexing_maps = [#map, #map, #map],
iterator_types = []
} ins(%arg0, %arg1 : tensor<f32>, tensor<f32>)
outs(%result : tensor<f32>) {
^bb0(%gen_arg1: f32, %gen_arg2: f32, %out: f32) :
%tmp1 = arith.addf %gen_arg1, %gen_arg2 : f32
linalg.yield %tmp1 : f32
} -> tensor<f32>
return %0 : tensor<f32>
}
Loading