Skip to content

Commit 8317d36

Browse files
authored
[mlir][linalg] Add runtime verification for linalg ops (#89342)
This commit implements runtime verification for LinalgStructuredOps using the existing `RuntimeVerifiableOpInterface`. The verification checks that the runtime sizes of the operands match the runtime sizes inferred by composing the loop ranges with the op's indexing maps.
1 parent c108653 commit 8317d36

File tree

9 files changed

+549
-33
lines changed

9 files changed

+549
-33
lines changed
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
//===- RuntimeOpVerification.h - Op Verification ----------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_LINALG_RUNTIMEOPVERIFICATION_H
10+
#define MLIR_DIALECT_LINALG_RUNTIMEOPVERIFICATION_H
11+
12+
namespace mlir {
13+
class DialectRegistry;
14+
15+
namespace linalg {
16+
void registerRuntimeVerifiableOpInterfaceExternalModels(
17+
DialectRegistry &registry);
18+
} // namespace linalg
19+
} // namespace mlir
20+
21+
#endif // MLIR_DIALECT_LINALG_RUNTIMEOPVERIFICATION_H

mlir/include/mlir/InitAllDialects.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
4646
#include "mlir/Dialect/Linalg/IR/Linalg.h"
4747
#include "mlir/Dialect/Linalg/Transforms/AllInterfaces.h"
48+
#include "mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h"
4849
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
4950
#include "mlir/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.h"
5051
#include "mlir/Dialect/MPI/IR/MPI.h"
@@ -161,6 +162,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
161162
cf::registerBufferDeallocationOpInterfaceExternalModels(registry);
162163
gpu::registerBufferDeallocationOpInterfaceExternalModels(registry);
163164
linalg::registerAllDialectInterfaceImplementations(registry);
165+
linalg::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
164166
memref::registerAllocationOpInterfaceExternalModels(registry);
165167
memref::registerBufferViewFlowOpInterfaceExternalModels(registry);
166168
memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);

mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@ def RuntimeVerifiableOpInterface : OpInterface<"RuntimeVerifiableOpInterface"> {
3535
"::mlir::Location":$loc)
3636
>,
3737
];
38+
39+
let extraClassDeclaration = [{
40+
/// Generate the error message that will be printed to the user when
41+
/// verification fails.
42+
static std::string generateErrorMessage(Operation *op, const std::string &msg);
43+
}];
3844
}
3945

4046
#endif // MLIR_INTERFACES_RUNTIMEVERIFIABLEOPINTERFACE

mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
2727
NamedOpConversions.cpp
2828
Padding.cpp
2929
Promotion.cpp
30+
RuntimeOpVerification.cpp
3031
Specialize.cpp
3132
Split.cpp
3233
SplitReduction.cpp
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
//===- RuntimeOpVerification.cpp - Op Verification ------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h"
10+
11+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
12+
#include "mlir/Dialect/Arith/IR/Arith.h"
13+
#include "mlir/Dialect/Arith/Utils/Utils.h"
14+
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
15+
#include "mlir/Dialect/Index/IR/IndexAttrs.h"
16+
#include "mlir/Dialect/Index/IR/IndexDialect.h"
17+
#include "mlir/Dialect/Index/IR/IndexOps.h"
18+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
19+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
20+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
21+
#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
22+
23+
namespace mlir {
24+
namespace linalg {
25+
namespace {
26+
/// Verify that the runtime sizes of the operands to linalg structured ops are
27+
/// compatible with the runtime sizes inferred by composing the loop ranges with
28+
/// the linalg op's indexing maps. This is similar to the verifier except that
29+
/// here we insert IR to perform the verification at runtime.
30+
template <typename T>
31+
struct StructuredOpInterface
32+
: public RuntimeVerifiableOpInterface::ExternalModel<
33+
StructuredOpInterface<T>, T> {
34+
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
35+
Location loc) const {
36+
auto linalgOp = llvm::cast<LinalgOp>(op);
37+
38+
SmallVector<Range> loopRanges = linalgOp.createLoopRanges(builder, loc);
39+
auto [starts, ends, _] = getOffsetsSizesAndStrides(loopRanges);
40+
41+
auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
42+
auto one = builder.create<arith::ConstantIndexOp>(loc, 1);
43+
44+
// Subtract one from the loop ends before composing with the indexing map
45+
transform(ends, ends.begin(), [&](OpFoldResult end) {
46+
auto endValue = getValueOrCreateConstantIndexOp(builder, loc, end);
47+
return builder.createOrFold<index::SubOp>(loc, endValue, one);
48+
});
49+
50+
for (OpOperand &opOperand : linalgOp->getOpOperands()) {
51+
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
52+
auto startIndices = affine::makeComposedFoldedMultiResultAffineApply(
53+
builder, loc, indexingMap, starts);
54+
auto endIndices = affine::makeComposedFoldedMultiResultAffineApply(
55+
builder, loc, indexingMap, ends);
56+
57+
for (auto dim : llvm::seq(linalgOp.getRank(&opOperand))) {
58+
auto startIndex =
59+
getValueOrCreateConstantIndexOp(builder, loc, startIndices[dim]);
60+
auto endIndex =
61+
getValueOrCreateConstantIndexOp(builder, loc, endIndices[dim]);
62+
63+
// Generate:
64+
// minIndex = min(startIndex, endIndex)
65+
// assert(minIndex >= 0)
66+
// To ensure we do not generate a negative index. We take the minimum of
67+
// the start and end indices in order to handle reverse loops such as
68+
// `affine_map<(i) -> (3 - i)>`
69+
auto min =
70+
builder.createOrFold<index::MinSOp>(loc, startIndex, endIndex);
71+
auto cmpOp = builder.createOrFold<index::CmpOp>(
72+
loc, index::IndexCmpPredicate::SGE, min, zero);
73+
auto msg = RuntimeVerifiableOpInterface::generateErrorMessage(
74+
linalgOp, "unexpected negative result on dimension #" +
75+
std::to_string(dim) + " of input/output operand #" +
76+
std::to_string(opOperand.getOperandNumber()));
77+
builder.createOrFold<cf::AssertOp>(loc, cmpOp, msg);
78+
79+
// Generate:
80+
// inferredDimSize = max(startIndex, endIndex) + 1
81+
// actualDimSize = dim(operand)
82+
// assert(inferredDimSize <= actualDimSize)
83+
// To ensure that we do not index past the bounds of the operands.
84+
auto max =
85+
builder.createOrFold<index::MaxSOp>(loc, startIndex, endIndex);
86+
87+
auto inferredDimSize =
88+
builder.createOrFold<index::AddOp>(loc, max, one);
89+
90+
auto actualDimSize =
91+
createOrFoldDimOp(builder, loc, opOperand.get(), dim);
92+
93+
// Similar to the verifier, when the affine expression in the indexing
94+
// map is complicated, we just check that the inferred dimension sizes
95+
// are in the boundary of the operands' size. Being more precise than
96+
// that is difficult.
97+
auto predicate = isa<AffineDimExpr>(indexingMap.getResult(dim))
98+
? index::IndexCmpPredicate::EQ
99+
: index::IndexCmpPredicate::SLE;
100+
101+
cmpOp = builder.createOrFold<index::CmpOp>(
102+
loc, predicate, inferredDimSize, actualDimSize);
103+
msg = RuntimeVerifiableOpInterface::generateErrorMessage(
104+
linalgOp, "dimension #" + std::to_string(dim) +
105+
" of input/output operand #" +
106+
std::to_string(opOperand.getOperandNumber()) +
107+
" is incompatible with inferred dimension size");
108+
builder.createOrFold<cf::AssertOp>(loc, cmpOp, msg);
109+
}
110+
}
111+
}
112+
};
113+
114+
template <typename... OpTs>
115+
void attachInterface(MLIRContext *ctx) {
116+
(OpTs::template attachInterface<StructuredOpInterface<OpTs>>(*ctx), ...);
117+
}
118+
} // namespace
119+
} // namespace linalg
120+
} // namespace mlir
121+
122+
void mlir::linalg::registerRuntimeVerifiableOpInterfaceExternalModels(
123+
DialectRegistry &registry) {
124+
registry.addExtension(+[](MLIRContext *ctx, LinalgDialect *) {
125+
attachInterface<
126+
#define GET_OP_LIST
127+
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
128+
>(ctx);
129+
130+
// Load additional dialects of which ops may get created.
131+
ctx->loadDialect<affine::AffineDialect, arith::ArithDialect,
132+
cf::ControlFlowDialect, index::IndexDialect,
133+
tensor::TensorDialect>();
134+
});
135+
}

mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp

Lines changed: 21 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -20,25 +20,6 @@
2020

2121
using namespace mlir;
2222

23-
/// Generate an error message string for the given op and the specified error.
24-
static std::string generateErrorMessage(Operation *op, const std::string &msg) {
25-
std::string buffer;
26-
llvm::raw_string_ostream stream(buffer);
27-
OpPrintingFlags flags;
28-
// We may generate a lot of error messages and so we need to ensure the
29-
// printing is fast.
30-
flags.elideLargeElementsAttrs();
31-
flags.printGenericOpForm();
32-
flags.skipRegions();
33-
flags.useLocalScope();
34-
stream << "ERROR: Runtime op verification failed\n";
35-
op->print(stream, flags);
36-
stream << "\n^ " << msg;
37-
stream << "\nLocation: ";
38-
op->getLoc().print(stream);
39-
return stream.str();
40-
}
41-
4223
namespace mlir {
4324
namespace memref {
4425
namespace {
@@ -62,8 +43,10 @@ struct CastOpInterface
6243
builder.create<arith::ConstantIndexOp>(loc, resultType.getRank());
6344
Value isSameRank = builder.create<arith::CmpIOp>(
6445
loc, arith::CmpIPredicate::eq, srcRank, resultRank);
65-
builder.create<cf::AssertOp>(loc, isSameRank,
66-
generateErrorMessage(op, "rank mismatch"));
46+
builder.create<cf::AssertOp>(
47+
loc, isSameRank,
48+
RuntimeVerifiableOpInterface::generateErrorMessage(op,
49+
"rank mismatch"));
6750
}
6851

6952
// Get source offset and strides. We do not have an op to get offsets and
@@ -101,8 +84,8 @@ struct CastOpInterface
10184
loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
10285
builder.create<cf::AssertOp>(
10386
loc, isSameSz,
104-
generateErrorMessage(op, "size mismatch of dim " +
105-
std::to_string(it.index())));
87+
RuntimeVerifiableOpInterface::generateErrorMessage(
88+
op, "size mismatch of dim " + std::to_string(it.index())));
10689
}
10790

10891
// Get result offset and strides.
@@ -119,8 +102,10 @@ struct CastOpInterface
119102
builder.create<arith::ConstantIndexOp>(loc, resultOffset);
120103
Value isSameOffset = builder.create<arith::CmpIOp>(
121104
loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal);
122-
builder.create<cf::AssertOp>(loc, isSameOffset,
123-
generateErrorMessage(op, "offset mismatch"));
105+
builder.create<cf::AssertOp>(
106+
loc, isSameOffset,
107+
RuntimeVerifiableOpInterface::generateErrorMessage(
108+
op, "offset mismatch"));
124109
}
125110

126111
// Check strides.
@@ -137,8 +122,8 @@ struct CastOpInterface
137122
loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal);
138123
builder.create<cf::AssertOp>(
139124
loc, isSameStride,
140-
generateErrorMessage(op, "stride mismatch of dim " +
141-
std::to_string(it.index())));
125+
RuntimeVerifiableOpInterface::generateErrorMessage(
126+
op, "stride mismatch of dim " + std::to_string(it.index())));
142127
}
143128
}
144129
};
@@ -178,7 +163,9 @@ struct LoadStoreOpInterface
178163
: andOp;
179164
}
180165
builder.create<cf::AssertOp>(
181-
loc, assertCond, generateErrorMessage(op, "out-of-bounds access"));
166+
loc, assertCond,
167+
RuntimeVerifiableOpInterface::generateErrorMessage(
168+
op, "out-of-bounds access"));
182169
}
183170
};
184171

@@ -248,7 +235,7 @@ struct ReinterpretCastOpInterface
248235

249236
builder.create<cf::AssertOp>(
250237
loc, assertCond,
251-
generateErrorMessage(
238+
RuntimeVerifiableOpInterface::generateErrorMessage(
252239
op,
253240
"result of reinterpret_cast is out-of-bounds of the base memref"));
254241
}
@@ -293,8 +280,8 @@ struct SubViewOpInterface
293280

294281
builder.create<cf::AssertOp>(
295282
loc, assertCond,
296-
generateErrorMessage(op,
297-
"subview is out-of-bounds of the base memref"));
283+
RuntimeVerifiableOpInterface::generateErrorMessage(
284+
op, "subview is out-of-bounds of the base memref"));
298285
}
299286
};
300287

@@ -334,8 +321,9 @@ struct ExpandShapeOpInterface
334321
builder.create<arith::ConstantIndexOp>(loc, 0));
335322
builder.create<cf::AssertOp>(
336323
loc, isModZero,
337-
generateErrorMessage(op, "static result dims in reassoc group do not "
338-
"divide src dim evenly"));
324+
RuntimeVerifiableOpInterface::generateErrorMessage(
325+
op, "static result dims in reassoc group do not "
326+
"divide src dim evenly"));
339327
}
340328
}
341329
};

mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,28 @@
1111
namespace mlir {
1212
class Location;
1313
class OpBuilder;
14+
15+
/// Generate an error message string for the given op and the specified error.
16+
std::string
17+
RuntimeVerifiableOpInterface::generateErrorMessage(Operation *op,
18+
const std::string &msg) {
19+
std::string buffer;
20+
llvm::raw_string_ostream stream(buffer);
21+
OpPrintingFlags flags;
22+
// We may generate a lot of error messages and so we need to ensure the
23+
// printing is fast.
24+
flags.elideLargeElementsAttrs();
25+
flags.printGenericOpForm();
26+
flags.skipRegions();
27+
flags.useLocalScope();
28+
stream << "ERROR: Runtime op verification failed\n";
29+
op->print(stream, flags);
30+
stream << "\n^ " << msg;
31+
stream << "\nLocation: ";
32+
op->getLoc().print(stream);
33+
return stream.str();
34+
}
35+
1436
} // namespace mlir
1537

1638
/// Include the definitions of the interface.
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// RUN: mlir-opt %s -generate-runtime-verification | FileCheck %s
2+
3+
// Most of the tests for linalg runtime-verification are implemented as integration tests.
4+
5+
#identity = affine_map<(d0) -> (d0)>
6+
7+
// CHECK-LABEL: @static_dims
8+
func.func @static_dims(%arg0: tensor<5xf32>, %arg1: tensor<5xf32>) -> (tensor<5xf32>) {
9+
// CHECK: %[[TRUE:.*]] = index.bool.constant true
10+
// CHECK: cf.assert %[[TRUE]]
11+
%result = tensor.empty() : tensor<5xf32>
12+
%0 = linalg.generic {
13+
indexing_maps = [#identity, #identity, #identity],
14+
iterator_types = ["parallel"]
15+
} ins(%arg0, %arg1 : tensor<5xf32>, tensor<5xf32>)
16+
outs(%result : tensor<5xf32>) {
17+
^bb0(%gen_arg1: f32, %gen_arg2: f32, %out: f32) :
18+
%tmp1 = arith.addf %gen_arg1, %gen_arg2 : f32
19+
linalg.yield %tmp1 : f32
20+
} -> tensor<5xf32>
21+
return %0 : tensor<5xf32>
22+
}
23+
24+
// -----
25+
26+
#map = affine_map<() -> ()>
27+
28+
// CHECK-LABEL: @scalars
29+
func.func @scalars(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>) {
30+
// No runtime checks are required if the operands are all scalars
31+
// CHECK-NOT: cf.assert
32+
%result = tensor.empty() : tensor<f32>
33+
%0 = linalg.generic {
34+
indexing_maps = [#map, #map, #map],
35+
iterator_types = []
36+
} ins(%arg0, %arg1 : tensor<f32>, tensor<f32>)
37+
outs(%result : tensor<f32>) {
38+
^bb0(%gen_arg1: f32, %gen_arg2: f32, %out: f32) :
39+
%tmp1 = arith.addf %gen_arg1, %gen_arg2 : f32
40+
linalg.yield %tmp1 : f32
41+
} -> tensor<f32>
42+
return %0 : tensor<f32>
43+
}

0 commit comments

Comments
 (0)