Skip to content

Commit 085075a

Browse files
[mlir][transform] Check for invalidated iterators on payload values (#66472)
Same as #66369 but for payload values. (#66369 added checks only for payload operations.) It was necessary to change the signature of `getPayloadValues` to return an iterator. This is now similar to payload operations. Fixes an issue in #66369 where the `LLVM_ENABLE_ABI_BREAKING_CHECKS` check was inverted.
1 parent 702608f commit 085075a

File tree

6 files changed

+85
-43
lines changed

6 files changed

+85
-43
lines changed

mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,15 +95,15 @@ class SingleValueMatcherOpTrait
9595
TransformResults &results,
9696
TransformState &state) {
9797
Value operandHandle = cast<OpTy>(this->getOperation()).getOperandHandle();
98-
ValueRange payload = state.getPayloadValues(operandHandle);
99-
if (payload.size() != 1) {
98+
auto payload = state.getPayloadValues(operandHandle);
99+
if (!llvm::hasSingleElement(payload)) {
100100
return emitDefiniteFailure(this->getOperation()->getLoc())
101101
<< "SingleValueMatchOpTrait requires the value handle to point to "
102102
"a single payload value";
103103
}
104104

105105
return cast<OpTy>(this->getOperation())
106-
.matchValue(payload[0], results, state);
106+
.matchValue(*payload.begin(), results, state);
107107
}
108108

109109
void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {

mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h

Lines changed: 55 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ class TransformState {
170170
/// should be emitted when the value is used.
171171
using InvalidatedHandleMap = DenseMap<Value, std::function<void(Location)>>;
172172

173-
#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
173+
#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
174174
/// Debug only: A timestamp is associated with each transform IR value, so
175175
/// that invalid iterator usage can be detected more reliably.
176176
using TransformIRTimestampMapping = DenseMap<Value, int64_t>;
@@ -185,7 +185,7 @@ class TransformState {
185185
ValueMapping values;
186186
ValueMapping reverseValues;
187187

188-
#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
188+
#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
189189
TransformIRTimestampMapping timestamps;
190190
void incrementTimestamp(Value value) { ++timestamps[value]; }
191191
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
@@ -220,7 +220,7 @@ class TransformState {
220220
auto getPayloadOps(Value value) const {
221221
ArrayRef<Operation *> view = getPayloadOpsView(value);
222222

223-
#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
223+
#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
224224
// Memorize the current timestamp and make sure that it has not changed
225225
// when incrementing or dereferencing the iterator returned by this
226226
// function. The timestamp is incremented when the "direct" mapping is
@@ -231,7 +231,7 @@ class TransformState {
231231
// When ops are replaced/erased, they are replaced with nullptr (until
232232
// the data structure is compacted). Do not enumerate these ops.
233233
return llvm::make_filter_range(view, [=](Operation *op) {
234-
#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
234+
#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
235235
bool sameTimestamp =
236236
currentTimestamp == this->getMapping(value).timestamps.lookup(value);
237237
assert(sameTimestamp && "iterator was invalidated during iteration");
@@ -244,9 +244,29 @@ class TransformState {
244244
/// corresponds to.
245245
ArrayRef<Attribute> getParams(Value value) const;
246246

247-
/// Returns the list of payload IR values that the given transform IR value
248-
/// corresponds to.
249-
ArrayRef<Value> getPayloadValues(Value handleValue) const;
247+
/// Returns an iterator that enumerates all payload IR values that the given
248+
/// transform IR value corresponds to.
249+
auto getPayloadValues(Value handleValue) const {
250+
ArrayRef<Value> view = getPayloadValuesView(handleValue);
251+
252+
#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
253+
// Memorize the current timestamp and make sure that it has not changed
254+
// when incrementing or dereferencing the iterator returned by this
255+
// function. The timestamp is incremented when the "values" mapping is
256+
// resized; this would invalidate the iterator returned by this function.
257+
int64_t currentTimestamp =
258+
getMapping(handleValue).timestamps.lookup(handleValue);
259+
return llvm::make_filter_range(view, [=](Value v) {
260+
bool sameTimestamp =
261+
currentTimestamp ==
262+
this->getMapping(handleValue).timestamps.lookup(handleValue);
263+
assert(sameTimestamp && "iterator was invalidated during iteration");
264+
return true;
265+
});
266+
#else
267+
return llvm::make_range(view.begin(), view.end());
268+
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
269+
}
250270

251271
/// Populates `handles` with all handles pointing to the given Payload IR op.
252272
/// Returns success if such handles exist, failure otherwise.
@@ -501,12 +521,15 @@ class TransformState {
501521
LogicalResult updateStateFromResults(const TransformResults &results,
502522
ResultRange opResults);
503523

504-
/// Returns a list of all ops that the given transform IR value corresponds to
505-
/// at the time when this function is called. In case an op was erased, the
506-
/// returned list contains nullptr. This function is helpful for
507-
/// transformations that apply to a particular handle.
524+
/// Returns a list of all ops that the given transform IR value corresponds
525+
/// to. In case an op was erased, the returned list contains nullptr. This
526+
/// function is helpful for transformations that apply to a particular handle.
508527
ArrayRef<Operation *> getPayloadOpsView(Value value) const;
509528

529+
/// Returns a list of payload IR values that the given transform IR value
530+
/// corresponds to.
531+
ArrayRef<Value> getPayloadValuesView(Value handleValue) const;
532+
510533
/// Sets the payload IR ops associated with the given transform IR value
511534
/// (handle). A payload op may be associated multiple handles as long as
512535
/// at most one of them gets consumed by further transformations.
@@ -806,7 +829,27 @@ class TransformResults {
806829
/// set by the transformation exactly once in case of transformation
807830
/// succeeding. The value must have a type implementing
808831
/// TransformValueHandleTypeInterface.
809-
void setValues(OpResult handle, ValueRange values);
832+
template <typename Range>
833+
void setValues(OpResult handle, Range &&values) {
834+
int64_t position = handle.getResultNumber();
835+
assert(position < static_cast<int64_t>(this->values.size()) &&
836+
"setting values for a non-existent handle");
837+
assert(this->values[position].data() == nullptr && "values already set");
838+
assert(operations[position].data() == nullptr &&
839+
"another kind of results already set");
840+
assert(params[position].data() == nullptr &&
841+
"another kind of results already set");
842+
this->values.replace(position, std::forward<Range>(values));
843+
}
844+
845+
/// Indicates that the result of the transform IR op at the given position
846+
/// corresponds to the given range of payload IR values. Each result must be
847+
/// set by the transformation exactly once in case of transformation
848+
/// succeeding. The value must have a type implementing
849+
/// TransformValueHandleTypeInterface.
850+
void setValues(OpResult handle, std::initializer_list<Value> values) {
851+
setValues(handle, ArrayRef<Value>(values));
852+
}
810853

811854
/// Indicates that the result of the transform IR op at the given position
812855
/// corresponds to the given range of mapped values. All mapped values are

mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -728,7 +728,7 @@ DiagnosedSilenceableFailure transform::MatchStructuredResultOp::matchOperation(
728728

729729
Value result = linalgOp.getTiedOpResult(linalgOp.getDpsInitOperand(position));
730730
if (isa<TransformValueHandleTypeInterface>(getResult().getType())) {
731-
results.setValues(cast<OpResult>(getResult()), result);
731+
results.setValues(cast<OpResult>(getResult()), {result});
732732
return DiagnosedSilenceableFailure::success();
733733
}
734734

mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ ArrayRef<Attribute> transform::TransformState::getParams(Value value) const {
7575
}
7676

7777
ArrayRef<Value>
78-
transform::TransformState::getPayloadValues(Value handleValue) const {
78+
transform::TransformState::getPayloadValuesView(Value handleValue) const {
7979
const ValueMapping &mapping = getMapping(handleValue).values;
8080
auto iter = mapping.find(handleValue);
8181
assert(iter != mapping.end() && "cannot find mapping for value handle "
@@ -310,7 +310,7 @@ void transform::TransformState::forgetMapping(Value opHandle,
310310
for (Operation *op : mappings.direct[opHandle])
311311
dropMappingEntry(mappings.reverse, op, opHandle);
312312
mappings.direct.erase(opHandle);
313-
#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
313+
#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
314314
// Payload IR is removed from the mapping. This invalidates the respective
315315
// iterators.
316316
mappings.incrementTimestamp(opHandle);
@@ -322,6 +322,11 @@ void transform::TransformState::forgetMapping(Value opHandle,
322322
for (Value resultHandle : resultHandles) {
323323
Mappings &localMappings = getMapping(resultHandle);
324324
dropMappingEntry(localMappings.values, resultHandle, opResult);
325+
#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
326+
// Payload IR is removed from the mapping. This invalidates the respective
327+
// iterators.
328+
mappings.incrementTimestamp(resultHandle);
329+
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
325330
dropMappingEntry(localMappings.reverseValues, opResult, resultHandle);
326331
}
327332
}
@@ -333,6 +338,11 @@ void transform::TransformState::forgetValueMapping(
333338
for (Value payloadValue : mappings.reverseValues[valueHandle])
334339
dropMappingEntry(mappings.reverseValues, payloadValue, valueHandle);
335340
mappings.values.erase(valueHandle);
341+
#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
342+
// Payload IR is removed from the mapping. This invalidates the respective
343+
// iterators.
344+
mappings.incrementTimestamp(valueHandle);
345+
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
336346

337347
for (Operation *payloadOp : payloadOperations) {
338348
SmallVector<Value> opHandles;
@@ -342,7 +352,7 @@ void transform::TransformState::forgetValueMapping(
342352
dropMappingEntry(localMappings.direct, opHandle, payloadOp);
343353
dropMappingEntry(localMappings.reverse, payloadOp, opHandle);
344354

345-
#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
355+
#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
346356
// Payload IR is removed from the mapping. This invalidates the respective
347357
// iterators.
348358
localMappings.incrementTimestamp(opHandle);
@@ -439,6 +449,11 @@ transform::TransformState::replacePayloadValue(Value value, Value replacement) {
439449
// between the handles and the IR objects
440450
if (!replacement) {
441451
dropMappingEntry(mappings.values, handle, value);
452+
#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
453+
// Payload IR is removed from the mapping. This invalidates the respective
454+
// iterators.
455+
mappings.incrementTimestamp(handle);
456+
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
442457
} else {
443458
auto it = mappings.values.find(handle);
444459
if (it == mappings.values.end())
@@ -647,7 +662,7 @@ void transform::TransformState::recordValueHandleInvalidation(
647662
OpOperand &valueHandle,
648663
transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
649664
// Invalidate other handles to the same value.
650-
for (Value payloadValue : getPayloadValues(valueHandle.get())) {
665+
for (Value payloadValue : getPayloadValuesView(valueHandle.get())) {
651666
SmallVector<Value> otherValueHandles;
652667
(void)getHandlesForPayloadValue(payloadValue, otherValueHandles);
653668
for (Value otherHandle : otherValueHandles) {
@@ -785,7 +800,7 @@ checkRepeatedConsumptionInOperand(ArrayRef<T> payload,
785800
void transform::TransformState::compactOpHandles() {
786801
for (Value handle : opHandlesToCompact) {
787802
Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
788-
#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
803+
#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
789804
if (llvm::find(mappings.direct[handle], nullptr) !=
790805
mappings.direct[handle].end())
791806
// Payload IR is removed from the mapping. This invalidates the respective
@@ -846,7 +861,7 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
846861
FULL_LDBG("--checkRepeatedConsumptionInOperand For Value\n");
847862
DiagnosedSilenceableFailure check =
848863
checkRepeatedConsumptionInOperand<Value>(
849-
getPayloadValues(operand.get()), transform,
864+
getPayloadValuesView(operand.get()), transform,
850865
operand.getOperandNumber());
851866
if (!check.succeeded()) {
852867
FULL_LDBG("----FAILED\n");
@@ -912,7 +927,7 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
912927
continue;
913928
}
914929
if (llvm::isa<TransformValueHandleTypeInterface>(operand.getType())) {
915-
for (Value payloadValue : getPayloadValues(operand)) {
930+
for (Value payloadValue : getPayloadValuesView(operand)) {
916931
if (llvm::isa<OpResult>(payloadValue)) {
917932
origAssociatedOps.push_back(payloadValue.getDefiningOp());
918933
continue;
@@ -1170,19 +1185,6 @@ void transform::TransformResults::setParams(
11701185
this->params.replace(position, params);
11711186
}
11721187

1173-
void transform::TransformResults::setValues(OpResult handle,
1174-
ValueRange values) {
1175-
int64_t position = handle.getResultNumber();
1176-
assert(position < static_cast<int64_t>(this->values.size()) &&
1177-
"setting values for a non-existent handle");
1178-
assert(this->values[position].data() == nullptr && "values already set");
1179-
assert(operations[position].data() == nullptr &&
1180-
"another kind of results already set");
1181-
assert(params[position].data() == nullptr &&
1182-
"another kind of results already set");
1183-
this->values.replace(position, values);
1184-
}
1185-
11861188
void transform::TransformResults::setMappedValues(
11871189
OpResult handle, ArrayRef<MappedValue> values) {
11881190
DiagnosedSilenceableFailure diag = dispatchMappedValues(

mlir/lib/Dialect/Transform/IR/TransformOps.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1379,9 +1379,7 @@ transform::GetTypeOp::apply(transform::TransformRewriter &rewriter,
13791379
transform::TransformResults &results,
13801380
transform::TransformState &state) {
13811381
SmallVector<Attribute> params;
1382-
ArrayRef<Value> values = state.getPayloadValues(getValue());
1383-
params.reserve(values.size());
1384-
for (Value value : values) {
1382+
for (Value value : state.getPayloadValues(getValue())) {
13851383
Type type = value.getType();
13861384
if (getElemental()) {
13871385
if (auto shaped = dyn_cast<ShapedType>(type)) {

mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ DiagnosedSilenceableFailure
136136
mlir::test::TestProduceValueHandleToSelfOperand::apply(
137137
transform::TransformRewriter &rewriter,
138138
transform::TransformResults &results, transform::TransformState &state) {
139-
results.setValues(llvm::cast<OpResult>(getOut()), getIn());
139+
results.setValues(llvm::cast<OpResult>(getOut()), {getIn()});
140140
return DiagnosedSilenceableFailure::success();
141141
}
142142

@@ -265,8 +265,7 @@ void mlir::test::TestPrintRemarkAtOperandOp::getEffects(
265265
DiagnosedSilenceableFailure mlir::test::TestPrintRemarkAtOperandValue::apply(
266266
transform::TransformRewriter &rewriter,
267267
transform::TransformResults &results, transform::TransformState &state) {
268-
ArrayRef<Value> values = state.getPayloadValues(getIn());
269-
for (Value value : values) {
268+
for (Value value : state.getPayloadValues(getIn())) {
270269
std::string note;
271270
llvm::raw_string_ostream os(note);
272271
if (auto arg = llvm::dyn_cast<BlockArgument>(value)) {
@@ -712,7 +711,7 @@ void mlir::test::TestProduceNullValueOp::getEffects(
712711
DiagnosedSilenceableFailure mlir::test::TestProduceNullValueOp::apply(
713712
transform::TransformRewriter &rewriter,
714713
transform::TransformResults &results, transform::TransformState &state) {
715-
results.setValues(llvm::cast<OpResult>(getOut()), Value());
714+
results.setValues(llvm::cast<OpResult>(getOut()), {Value()});
716715
return DiagnosedSilenceableFailure::success();
717716
}
718717

0 commit comments

Comments
 (0)