Skip to content

Fix some AffineOps to properly declare their inherent affinemap Attribute #66050

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ class AffineDmaStartOp
/// Returns the affine map used to access the source memref.
AffineMap getSrcMap() { return getSrcMapAttr().getValue(); }
AffineMapAttr getSrcMapAttr() {
return cast<AffineMapAttr>((*this)->getAttr(getSrcMapAttrStrName()));
return cast<AffineMapAttr>(*(*this)->getInherentAttr(getSrcMapAttrStrName()));
}

/// Returns the source memref affine map indices for this DMA operation.
Expand Down Expand Up @@ -156,7 +156,7 @@ class AffineDmaStartOp
/// Returns the affine map used to access the destination memref.
AffineMap getDstMap() { return getDstMapAttr().getValue(); }
AffineMapAttr getDstMapAttr() {
return cast<AffineMapAttr>((*this)->getAttr(getDstMapAttrStrName()));
return cast<AffineMapAttr>(*(*this)->getInherentAttr(getDstMapAttrStrName()));
}

/// Returns the destination memref indices for this DMA operation.
Expand Down Expand Up @@ -185,7 +185,7 @@ class AffineDmaStartOp
/// Returns the affine map used to access the tag memref.
AffineMap getTagMap() { return getTagMapAttr().getValue(); }
AffineMapAttr getTagMapAttr() {
return cast<AffineMapAttr>((*this)->getAttr(getTagMapAttrStrName()));
return cast<AffineMapAttr>(*(*this)->getInherentAttr(getTagMapAttrStrName()));
}

/// Returns the tag memref indices for this DMA operation.
Expand Down Expand Up @@ -307,7 +307,7 @@ class AffineDmaWaitOp
/// Returns the affine map used to access the tag memref.
AffineMap getTagMap() { return getTagMapAttr().getValue(); }
AffineMapAttr getTagMapAttr() {
return cast<AffineMapAttr>((*this)->getAttr(getTagMapAttrStrName()));
return cast<AffineMapAttr>(*(*this)->getInherentAttr(getTagMapAttrStrName()));
}

/// Returns the tag memref index for this DMA operation.
Expand Down
35 changes: 20 additions & 15 deletions mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -284,18 +284,18 @@ def AffineForOp : Affine_Op<"for",

/// Returns loop step.
int64_t getStep() {
return ::llvm::cast<IntegerAttr>((*this)->getAttr(getStepAttrStrName())).getInt();
return ::llvm::cast<IntegerAttr>(*(*this)->getInherentAttr(getStepAttrStrName())).getInt();
}

/// Returns affine map for the lower bound.
AffineMap getLowerBoundMap() { return getLowerBoundMapAttr().getValue(); }
AffineMapAttr getLowerBoundMapAttr() {
return ::llvm::cast<AffineMapAttr>((*this)->getAttr(getLowerBoundAttrStrName()));
return ::llvm::cast<AffineMapAttr>(*(*this)->getInherentAttr(getLowerBoundAttrStrName()));
}
/// Returns affine map for the upper bound. The upper bound is exclusive.
AffineMap getUpperBoundMap() { return getUpperBoundMapAttr().getValue(); }
AffineMapAttr getUpperBoundMapAttr() {
return ::llvm::cast<AffineMapAttr>((*this)->getAttr(getUpperBoundAttrStrName()));
return ::llvm::cast<AffineMapAttr>(*(*this)->getInherentAttr(getUpperBoundAttrStrName()));
}

/// Set lower bound. The new bound must have the same number of operands as
Expand Down Expand Up @@ -497,7 +497,8 @@ class AffineLoadOpBase<string mnemonic, list<Trait> traits = []> :
MemRefsNormalizable])> {
let arguments = (ins Arg<AnyMemRef, "the reference to load from",
[MemRead]>:$memref,
Variadic<Index>:$indices);
Variadic<Index>:$indices,
AffineMapAttr:$map);

code extraClassDeclarationBase = [{
/// Returns the operand index of the memref.
Expand All @@ -507,7 +508,7 @@ class AffineLoadOpBase<string mnemonic, list<Trait> traits = []> :

/// Returns the affine map used to index the memref for this operation.
AffineMapAttr getAffineMapAttr() {
return ::llvm::cast<AffineMapAttr>((*this)->getAttr(getMapAttrStrName()));
return getProperties().map;
}

static StringRef getMapAttrStrName() { return "map"; }
Expand Down Expand Up @@ -801,7 +802,8 @@ def AffinePrefetchOp : Affine_Op<"prefetch",
BoolAttr:$isWrite,
ConfinedAttr<I32Attr, [IntMinValue<0>,
IntMaxValue<3>]>:$localityHint,
BoolAttr:$isDataCache);
BoolAttr:$isDataCache,
AffineMapAttr:$map);

let builders = [
OpBuilder<(ins "Value":$memref, "AffineMap":$map,
Expand All @@ -814,11 +816,12 @@ def AffinePrefetchOp : Affine_Op<"prefetch",
auto isWriteAttr = $_builder.getBoolAttr(isWrite);
auto isDataCacheAttr = $_builder.getBoolAttr(isDataCache);
$_state.addOperands(memref);
$_state.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
$_state.addOperands(mapOperands);
$_state.addAttribute(getLocalityHintAttrStrName(), localityHintAttr);
$_state.addAttribute(getIsWriteAttrStrName(), isWriteAttr);
$_state.addAttribute(getIsDataCacheAttrStrName(), isDataCacheAttr);
Properties &prop = $_state.getOrAddProperties<Properties>();
prop.map = AffineMapAttr::get(map);
prop.localityHint = localityHintAttr;
prop.isWrite = isWriteAttr;
prop.isDataCache = isDataCacheAttr;
}]>];

let extraClassDeclaration = [{
Expand All @@ -829,10 +832,10 @@ def AffinePrefetchOp : Affine_Op<"prefetch",
/// Returns the affine map used to index the memref for this operation.
AffineMap getAffineMap() { return getAffineMapAttr().getValue(); }
AffineMapAttr getAffineMapAttr() {
return ::llvm::cast<AffineMapAttr>((*this)->getAttr(getMapAttrStrName()));
return getProperties().map;
}

/// Impelements the AffineMapAccessInterface.
/// Implements the AffineMapAccessInterface.
/// Returns the AffineMapAttr associated with 'memref'.
NamedAttribute getAffineMapAttrForMemRef(Value mref) {
assert(mref == getMemref() &&
Expand Down Expand Up @@ -874,7 +877,7 @@ class AffineStoreOpBase<string mnemonic, list<Trait> traits = []> :

/// Returns the affine map used to index the memref for this operation.
AffineMapAttr getAffineMapAttr() {
return ::llvm::cast<AffineMapAttr>((*this)->getAttr(getMapAttrStrName()));
return getProperties().map;
}

static StringRef getMapAttrStrName() { return "map"; }
Expand Down Expand Up @@ -912,7 +915,8 @@ def AffineStoreOp : AffineStoreOpBase<"store"> {
let arguments = (ins AnyType:$value,
Arg<AnyMemRef, "the reference to store to",
[MemWrite]>:$memref,
Variadic<Index>:$indices);
Variadic<Index>:$indices,
AffineMapAttr:$map);

let skipDefaultBuilders = 1;
let builders = [
Expand Down Expand Up @@ -1065,7 +1069,8 @@ def AffineVectorStoreOp : AffineStoreOpBase<"vector_store"> {
let arguments = (ins AnyVector:$value,
Arg<AnyMemRef, "the reference to store to",
[MemWrite]>:$memref,
Variadic<Index>:$indices);
Variadic<Index>:$indices,
AffineMapAttr:$map);

let skipDefaultBuilders = 1;
let builders = [
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/IR/SymbolInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def Symbol : OpInterface<"SymbolOpInterface"> {
// If this is an optional symbol, bail out early if possible.
auto concreteOp = cast<ConcreteOp>($_op);
if (concreteOp.isOptionalSymbol()) {
if(!concreteOp->getAttr(::mlir::SymbolTable::getSymbolAttrName()))
if(!concreteOp->getInherentAttr(::mlir::SymbolTable::getSymbolAttrName()).value_or(Attribute{}))
return success();
}
if (::mlir::failed(::mlir::detail::verifySymbol($_op)))
Expand Down
8 changes: 1 addition & 7 deletions mlir/lib/Dialect/Affine/IR/AffineOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3135,17 +3135,11 @@ static LogicalResult
verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr,
Operation::operand_range mapOperands,
MemRefType memrefType, unsigned numIndexOperands) {
if (mapAttr) {
AffineMap map = mapAttr.getValue();
if (map.getNumResults() != memrefType.getRank())
return op->emitOpError("affine map num results must equal memref rank");
if (map.getNumInputs() != numIndexOperands)
return op->emitOpError("expects as many subscripts as affine map inputs");
} else {
if (memrefType.getRank() != numIndexOperands)
return op->emitOpError(
"expects the number of subscripts to be equal to memref rank");
}

Region *scope = getAffineScope(op);
for (auto idx : mapOperands) {
Expand Down Expand Up @@ -3224,7 +3218,7 @@ void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
result.addOperands(valueToStore);
result.addOperands(memref);
result.addOperands(mapOperands);
result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
result.getOrAddProperties<Properties>().map = AffineMapAttr::get(map);
}

// Use identity map.
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/Affine/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ func.func @affine_store_missing_l_square(%C: memref<4096x4096xf32>) {
func.func @affine_store_wrong_value_type(%C: memref<f32>) {
%c0 = arith.constant 0 : i32
// expected-error@+1 {{value to store must have the same type as memref element type}}
"affine.store"(%c0, %C) : (i32, memref<f32>) -> ()
"affine.store"(%c0, %C) <{map = affine_map<(i) -> (i)>}> : (i32, memref<f32>) -> ()
return
}

Expand Down
29 changes: 0 additions & 29 deletions mlir/test/Dialect/Affine/load-store-invalid.mlir
Original file line number Diff line number Diff line change
@@ -1,12 +1,5 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics

func.func @load_too_many_subscripts(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index, %arg3: index) {
// expected-error@+1 {{expects the number of subscripts to be equal to memref rank}}
"affine.load"(%arg0, %arg1, %arg2, %arg3) : (memref<?x?xf32>, index, index, index) -> f32
}

// -----

func.func @load_too_many_subscripts_map(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index, %arg3: index) {
// expected-error@+1 {{op expects as many subscripts as affine map inputs}}
"affine.load"(%arg0, %arg1, %arg2, %arg3)
Expand All @@ -15,13 +8,6 @@ func.func @load_too_many_subscripts_map(%arg0: memref<?x?xf32>, %arg1: index, %a

// -----

func.func @load_too_few_subscripts(%arg0: memref<?x?xf32>, %arg1: index) {
// expected-error@+1 {{expects the number of subscripts to be equal to memref rank}}
"affine.load"(%arg0, %arg1) : (memref<?x?xf32>, index) -> f32
}

// -----

func.func @load_too_few_subscripts_map(%arg0: memref<?x?xf32>, %arg1: index) {
// expected-error@+1 {{op expects as many subscripts as affine map inputs}}
"affine.load"(%arg0, %arg1)
Expand All @@ -30,14 +16,6 @@ func.func @load_too_few_subscripts_map(%arg0: memref<?x?xf32>, %arg1: index) {

// -----

func.func @store_too_many_subscripts(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index,
%arg3: index, %val: f32) {
// expected-error@+1 {{expects the number of subscripts to be equal to memref rank}}
"affine.store"(%val, %arg0, %arg1, %arg2, %arg3) : (f32, memref<?x?xf32>, index, index, index) -> ()
}

// -----

func.func @store_too_many_subscripts_map(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index,
%arg3: index, %val: f32) {
// expected-error@+1 {{op expects as many subscripts as affine map inputs}}
Expand All @@ -47,13 +25,6 @@ func.func @store_too_many_subscripts_map(%arg0: memref<?x?xf32>, %arg1: index, %

// -----

func.func @store_too_few_subscripts(%arg0: memref<?x?xf32>, %arg1: index, %val: f32) {
// expected-error@+1 {{expects the number of subscripts to be equal to memref rank}}
"affine.store"(%val, %arg0, %arg1) : (f32, memref<?x?xf32>, index) -> ()
}

// -----

func.func @store_too_few_subscripts_map(%arg0: memref<?x?xf32>, %arg1: index, %val: f32) {
// expected-error@+1 {{op expects as many subscripts as affine map inputs}}
"affine.store"(%val, %arg0, %arg1)
Expand Down