Skip to content

[mlir][transform] Check for invalidated iterators on payload values #66472

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,15 @@ class SingleValueMatcherOpTrait
TransformResults &results,
TransformState &state) {
Value operandHandle = cast<OpTy>(this->getOperation()).getOperandHandle();
ValueRange payload = state.getPayloadValues(operandHandle);
if (payload.size() != 1) {
auto payload = state.getPayloadValues(operandHandle);
if (!llvm::hasSingleElement(payload)) {
return emitDefiniteFailure(this->getOperation()->getLoc())
<< "SingleValueMatchOpTrait requires the value handle to point to "
"a single payload value";
}

return cast<OpTy>(this->getOperation())
.matchValue(payload[0], results, state);
.matchValue(*payload.begin(), results, state);
}

void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
Expand Down
70 changes: 57 additions & 13 deletions mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ class TransformState {
/// should be emitted when the value is used.
using InvalidatedHandleMap = DenseMap<Value, std::function<void(Location)>>;

#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
/// Debug only: A timestamp is associated with each transform IR value, so
/// that invalid iterator usage can be detected more reliably.
using TransformIRTimestampMapping = DenseMap<Value, int64_t>;
Expand All @@ -185,7 +185,7 @@ class TransformState {
ValueMapping values;
ValueMapping reverseValues;

#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
TransformIRTimestampMapping timestamps;
void incrementTimestamp(Value value) { ++timestamps[value]; }
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
Expand Down Expand Up @@ -220,7 +220,7 @@ class TransformState {
auto getPayloadOps(Value value) const {
ArrayRef<Operation *> view = getPayloadOpsView(value);

#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
// Memorize the current timestamp and make sure that it has not changed
// when incrementing or dereferencing the iterator returned by this
// function. The timestamp is incremented when the "direct" mapping is
Expand All @@ -231,7 +231,7 @@ class TransformState {
// When ops are replaced/erased, they are replaced with nullptr (until
// the data structure is compacted). Do not enumerate these ops.
return llvm::make_filter_range(view, [=](Operation *op) {
#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
bool sameTimestamp =
currentTimestamp == this->getMapping(value).timestamps.lookup(value);
assert(sameTimestamp && "iterator was invalidated during iteration");
Expand All @@ -244,9 +244,29 @@ class TransformState {
/// corresponds to.
ArrayRef<Attribute> getParams(Value value) const;

/// Returns the list of payload IR values that the given transform IR value
/// corresponds to.
ArrayRef<Value> getPayloadValues(Value handleValue) const;
/// Returns an iterator that enumerates all payload IR values that the given
/// transform IR value corresponds to.
auto getPayloadValues(Value handleValue) const {
ArrayRef<Value> view = getPayloadValuesView(handleValue);

#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
// Memorize the current timestamp and make sure that it has not changed
// when incrementing or dereferencing the iterator returned by this
// function. The timestamp is incremented when the "values" mapping is
// resized; this would invalidate the iterator returned by this function.
int64_t currentTimestamp =
getMapping(handleValue).timestamps.lookup(handleValue);
return llvm::make_filter_range(view, [=](Value v) {
bool sameTimestamp =
currentTimestamp ==
this->getMapping(handleValue).timestamps.lookup(handleValue);
assert(sameTimestamp && "iterator was invalidated during iteration");
return true;
});
#else
return llvm::make_range(view.begin(), view.end());
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
}

/// Populates `handles` with all handles pointing to the given Payload IR op.
/// Returns success if such handles exist, failure otherwise.
Expand Down Expand Up @@ -501,12 +521,15 @@ class TransformState {
LogicalResult updateStateFromResults(const TransformResults &results,
ResultRange opResults);

/// Returns a list of all ops that the given transform IR value corresponds to
/// at the time when this function is called. In case an op was erased, the
/// returned list contains nullptr. This function is helpful for
/// transformations that apply to a particular handle.
/// Returns a list of all ops that the given transform IR value corresponds
/// to. In case an op was erased, the returned list contains nullptr. This
/// function is helpful for transformations that apply to a particular handle.
ArrayRef<Operation *> getPayloadOpsView(Value value) const;

/// Returns a list of payload IR values that the given transform IR value
/// corresponds to.
ArrayRef<Value> getPayloadValuesView(Value handleValue) const;

/// Sets the payload IR ops associated with the given transform IR value
/// (handle). A payload op may be associated multiple handles as long as
/// at most one of them gets consumed by further transformations.
Expand Down Expand Up @@ -774,7 +797,8 @@ class TransformResults {
/// corresponds to the given list of payload IR ops. Each result must be set
/// by the transformation exactly once in case of transformation succeeding.
/// The value must have a type implementing TransformHandleTypeInterface.
template <typename Range> void set(OpResult value, Range &&ops) {
template <typename Range>
void set(OpResult value, Range &&ops) {
int64_t position = value.getResultNumber();
assert(position < static_cast<int64_t>(operations.size()) &&
"setting results for a non-existent handle");
Expand Down Expand Up @@ -805,7 +829,27 @@ class TransformResults {
/// set by the transformation exactly once in case of transformation
/// succeeding. The value must have a type implementing
/// TransformValueHandleTypeInterface.
void setValues(OpResult handle, ValueRange values);
template <typename Range>
void setValues(OpResult handle, Range &&values) {
int64_t position = handle.getResultNumber();
assert(position < static_cast<int64_t>(this->values.size()) &&
"setting values for a non-existent handle");
assert(this->values[position].data() == nullptr && "values already set");
assert(operations[position].data() == nullptr &&
"another kind of results already set");
assert(params[position].data() == nullptr &&
"another kind of results already set");
this->values.replace(position, std::forward<Range>(values));
}

/// Indicates that the result of the transform IR op at the given position
/// corresponds to the given range of payload IR values. Each result must be
/// set by the transformation exactly once in case of transformation
/// succeeding. The value must have a type implementing
/// TransformValueHandleTypeInterface.
void setValues(OpResult handle, std::initializer_list<Value> values) {
setValues(handle, ArrayRef<Value>(values));
}

/// Indicates that the result of the transform IR op at the given position
/// corresponds to the given range of mapped values. All mapped values are
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,7 @@ DiagnosedSilenceableFailure transform::MatchStructuredResultOp::matchOperation(

Value result = linalgOp.getTiedOpResult(linalgOp.getDpsInitOperand(position));
if (isa<TransformValueHandleTypeInterface>(getResult().getType())) {
results.setValues(cast<OpResult>(getResult()), result);
results.setValues(cast<OpResult>(getResult()), {result});
return DiagnosedSilenceableFailure::success();
}

Expand Down
42 changes: 22 additions & 20 deletions mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ ArrayRef<Attribute> transform::TransformState::getParams(Value value) const {
}

ArrayRef<Value>
transform::TransformState::getPayloadValues(Value handleValue) const {
transform::TransformState::getPayloadValuesView(Value handleValue) const {
const ValueMapping &mapping = getMapping(handleValue).values;
auto iter = mapping.find(handleValue);
assert(iter != mapping.end() && "cannot find mapping for value handle "
Expand Down Expand Up @@ -310,7 +310,7 @@ void transform::TransformState::forgetMapping(Value opHandle,
for (Operation *op : mappings.direct[opHandle])
dropMappingEntry(mappings.reverse, op, opHandle);
mappings.direct.erase(opHandle);
#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
// Payload IR is removed from the mapping. This invalidates the respective
// iterators.
mappings.incrementTimestamp(opHandle);
Expand All @@ -322,6 +322,11 @@ void transform::TransformState::forgetMapping(Value opHandle,
for (Value resultHandle : resultHandles) {
Mappings &localMappings = getMapping(resultHandle);
dropMappingEntry(localMappings.values, resultHandle, opResult);
#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
// Payload IR is removed from the mapping. This invalidates the respective
// iterators.
mappings.incrementTimestamp(resultHandle);
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
dropMappingEntry(localMappings.reverseValues, opResult, resultHandle);
}
}
Expand All @@ -333,6 +338,11 @@ void transform::TransformState::forgetValueMapping(
for (Value payloadValue : mappings.reverseValues[valueHandle])
dropMappingEntry(mappings.reverseValues, payloadValue, valueHandle);
mappings.values.erase(valueHandle);
#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
// Payload IR is removed from the mapping. This invalidates the respective
// iterators.
mappings.incrementTimestamp(valueHandle);
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS

for (Operation *payloadOp : payloadOperations) {
SmallVector<Value> opHandles;
Expand All @@ -342,7 +352,7 @@ void transform::TransformState::forgetValueMapping(
dropMappingEntry(localMappings.direct, opHandle, payloadOp);
dropMappingEntry(localMappings.reverse, payloadOp, opHandle);

#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
// Payload IR is removed from the mapping. This invalidates the respective
// iterators.
localMappings.incrementTimestamp(opHandle);
Expand Down Expand Up @@ -439,6 +449,11 @@ transform::TransformState::replacePayloadValue(Value value, Value replacement) {
// between the handles and the IR objects
if (!replacement) {
dropMappingEntry(mappings.values, handle, value);
#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
// Payload IR is removed from the mapping. This invalidates the respective
// iterators.
mappings.incrementTimestamp(handle);
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
} else {
auto it = mappings.values.find(handle);
if (it == mappings.values.end())
Expand Down Expand Up @@ -647,7 +662,7 @@ void transform::TransformState::recordValueHandleInvalidation(
OpOperand &valueHandle,
transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
// Invalidate other handles to the same value.
for (Value payloadValue : getPayloadValues(valueHandle.get())) {
for (Value payloadValue : getPayloadValuesView(valueHandle.get())) {
SmallVector<Value> otherValueHandles;
(void)getHandlesForPayloadValue(payloadValue, otherValueHandles);
for (Value otherHandle : otherValueHandles) {
Expand Down Expand Up @@ -785,7 +800,7 @@ checkRepeatedConsumptionInOperand(ArrayRef<T> payload,
void transform::TransformState::compactOpHandles() {
for (Value handle : opHandlesToCompact) {
Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
if (llvm::find(mappings.direct[handle], nullptr) !=
mappings.direct[handle].end())
// Payload IR is removed from the mapping. This invalidates the respective
Expand Down Expand Up @@ -846,7 +861,7 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
FULL_LDBG("--checkRepeatedConsumptionInOperand For Value\n");
DiagnosedSilenceableFailure check =
checkRepeatedConsumptionInOperand<Value>(
getPayloadValues(operand.get()), transform,
getPayloadValuesView(operand.get()), transform,
operand.getOperandNumber());
if (!check.succeeded()) {
FULL_LDBG("----FAILED\n");
Expand Down Expand Up @@ -912,7 +927,7 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
continue;
}
if (llvm::isa<TransformValueHandleTypeInterface>(operand.getType())) {
for (Value payloadValue : getPayloadValues(operand)) {
for (Value payloadValue : getPayloadValuesView(operand)) {
if (llvm::isa<OpResult>(payloadValue)) {
origAssociatedOps.push_back(payloadValue.getDefiningOp());
continue;
Expand Down Expand Up @@ -1170,19 +1185,6 @@ void transform::TransformResults::setParams(
this->params.replace(position, params);
}

void transform::TransformResults::setValues(OpResult handle,
ValueRange values) {
int64_t position = handle.getResultNumber();
assert(position < static_cast<int64_t>(this->values.size()) &&
"setting values for a non-existent handle");
assert(this->values[position].data() == nullptr && "values already set");
assert(operations[position].data() == nullptr &&
"another kind of results already set");
assert(params[position].data() == nullptr &&
"another kind of results already set");
this->values.replace(position, values);
}

void transform::TransformResults::setMappedValues(
OpResult handle, ArrayRef<MappedValue> values) {
DiagnosedSilenceableFailure diag = dispatchMappedValues(
Expand Down
4 changes: 1 addition & 3 deletions mlir/lib/Dialect/Transform/IR/TransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1378,9 +1378,7 @@ transform::GetTypeOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
SmallVector<Attribute> params;
ArrayRef<Value> values = state.getPayloadValues(getValue());
params.reserve(values.size());
for (Value value : values) {
for (Value value : state.getPayloadValues(getValue())) {
Type type = value.getType();
if (getElemental()) {
if (auto shaped = dyn_cast<ShapedType>(type)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ DiagnosedSilenceableFailure
mlir::test::TestProduceValueHandleToSelfOperand::apply(
transform::TransformRewriter &rewriter,
transform::TransformResults &results, transform::TransformState &state) {
results.setValues(llvm::cast<OpResult>(getOut()), getIn());
results.setValues(llvm::cast<OpResult>(getOut()), {getIn()});
return DiagnosedSilenceableFailure::success();
}

Expand Down Expand Up @@ -265,8 +265,7 @@ void mlir::test::TestPrintRemarkAtOperandOp::getEffects(
DiagnosedSilenceableFailure mlir::test::TestPrintRemarkAtOperandValue::apply(
transform::TransformRewriter &rewriter,
transform::TransformResults &results, transform::TransformState &state) {
ArrayRef<Value> values = state.getPayloadValues(getIn());
for (Value value : values) {
for (Value value : state.getPayloadValues(getIn())) {
std::string note;
llvm::raw_string_ostream os(note);
if (auto arg = llvm::dyn_cast<BlockArgument>(value)) {
Expand Down Expand Up @@ -712,7 +711,7 @@ void mlir::test::TestProduceNullValueOp::getEffects(
DiagnosedSilenceableFailure mlir::test::TestProduceNullValueOp::apply(
transform::TransformRewriter &rewriter,
transform::TransformResults &results, transform::TransformState &state) {
results.setValues(llvm::cast<OpResult>(getOut()), Value());
results.setValues(llvm::cast<OpResult>(getOut()), {Value()});
return DiagnosedSilenceableFailure::success();
}

Expand Down