Skip to content

Commit b83074f

Browse files
committed
port fixes from local llvm
1 parent 103fa32 commit b83074f

File tree

8 files changed

+70
-36
lines changed

8 files changed

+70
-36
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,9 @@ struct BufferizationOptions {
257257
/// Parameters: Value, memory space, bufferization options
258258
using UnknownTypeConverterFn = std::function<BaseMemRefType(
259259
Value, Attribute memorySpace, const BufferizationOptions &)>;
260+
// Produce a MemorySpace attribute from a tensor type
261+
using GetMemorySpaceFn =
262+
std::function<std::optional<Attribute>(TensorType t)>;
260263

261264
BufferizationOptions();
262265

@@ -351,6 +354,16 @@ struct BufferizationOptions {
351354
/// used.
352355
UnknownTypeConverterFn unknownTypeConverterFn = nullptr;
353356

357+
// Use during type conversion to determine the memory space for memref based
358+
// on the originanl tensor type
359+
GetMemorySpaceFn getMemorySpaceFn = nullptr;
360+
361+
std::optional<Attribute> getMemorySpace(TensorType t) const {
362+
if (getMemorySpaceFn)
363+
return getMemorySpaceFn(t);
364+
return defaultMemorySpace;
365+
}
366+
354367
/// Seed for the analysis fuzzer. If set to `0`, the fuzzer is deactivated.
355368
/// Should be used only with `testAnalysisOnly = true`.
356369
unsigned analysisFuzzerSeed = 0;

mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,18 @@ struct ConstantOpInterface
2626
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
2727
const BufferizationOptions &options) const {
2828
auto constantOp = cast<arith::ConstantOp>(op);
29+
auto type = constantOp.getType().dyn_cast<RankedTensorType>();
30+
31+
// Only ranked tensors are supported.
32+
if (!type)
33+
return failure();
2934

3035
Attribute memorySpace;
31-
if (options.defaultMemorySpace.has_value())
32-
memorySpace = *options.defaultMemorySpace;
36+
if (options.getMemorySpace(type))
37+
memorySpace = *options.getMemorySpace(type);
3338
else
3439
return constantOp->emitError("could not infer memory space");
3540

36-
// Only ranked tensors are supported.
37-
if (!isa<RankedTensorType>(constantOp.getType()))
38-
return failure();
39-
4041
// Only constants inside a module are supported.
4142
auto moduleOp = constantOp->getParentOfType<ModuleOp>();
4243
if (!moduleOp)

mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -682,11 +682,11 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
682682
return bufferizableOp.getBufferType(value, options, invocationStack);
683683

684684
// Op is not bufferizable.
685-
if (!options.defaultMemorySpace.has_value())
685+
auto memSpace = options.getMemorySpace(value.getType().cast<TensorType>());
686+
if (!memSpace.has_value())
686687
return op->emitError("could not infer memory space");
687688

688-
return getMemRefType(value, options, /*layout=*/{},
689-
*options.defaultMemorySpace);
689+
return getMemRefType(value, options, /*layout=*/{}, *memSpace);
690690
}
691691

692692
bool bufferization::hasTensorSemantics(Operation *op) {
@@ -943,11 +943,11 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
943943

944944
// If we do not know the memory space and there is no default memory space,
945945
// report a failure.
946-
if (!options.defaultMemorySpace.has_value())
946+
auto memSpace = options.getMemorySpace(value.getType().cast<TensorType>());
947+
if (!memSpace.has_value())
947948
return op->emitError("could not infer memory space");
948949

949-
return getMemRefType(value, options, /*layout=*/{},
950-
*options.defaultMemorySpace);
950+
return getMemRefType(value, options, /*layout=*/{}, *memSpace);
951951
}
952952

953953
bool bufferization::detail::defaultIsRepetitiveRegion(

mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,8 +234,8 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
234234
if (failed(copyBufferType))
235235
return failure();
236236
memorySpace = copyBufferType->getMemorySpace();
237-
} else if (options.defaultMemorySpace.has_value()) {
238-
memorySpace = *options.defaultMemorySpace;
237+
} else if (auto x = options.getMemorySpace(getType()); x.has_value()) {
238+
memorySpace = *x;
239239
} else {
240240
return getOperation()->emitError("could not infer memory space");
241241
}

mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
6666
assert(tensorType && "expected TensorType");
6767

6868
BaseMemRefType memrefType = options.functionArgTypeConverterFn(
69-
tensorType, *options.defaultMemorySpace, funcOp, options);
69+
tensorType, *options.getMemorySpace(tensorType), funcOp, options);
7070

7171
auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>(
7272
index, BufferizationDialect::kBufferLayoutAttrName);
@@ -443,7 +443,7 @@ struct FuncOpInterface
443443
// Note: If `inferFunctionResultLayout = true`, cast are later folded
444444
// away.
445445
BaseMemRefType resultType = options.functionArgTypeConverterFn(
446-
tensorType, *options.defaultMemorySpace, funcOp, options);
446+
tensorType, *options.getMemorySpace(tensorType), funcOp, options);
447447
Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
448448
loc, resultType, returnVal);
449449
returnValues.push_back(toMemrefOp);

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "mlir/IR/IRMapping.h"
2222
#include "mlir/IR/Matchers.h"
2323
#include "mlir/IR/OpDefinition.h"
24+
#include "mlir/IR/TensorEncoding.h"
2425
#include "mlir/IR/TypeUtilities.h"
2526
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
2627
#include "mlir/Interfaces/LoopLikeInterface.h"
@@ -1622,7 +1623,20 @@ CollapseShapeOp::inferCollapsedType(RankedTensorType type,
16221623
currentDim += dim;
16231624
}
16241625

1625-
return RankedTensorType::get(newShape, type.getElementType());
1626+
auto encoding = type.getEncoding();
1627+
if (auto v = encoding.dyn_cast_or_null<VerifiableTensorEncoding>()) {
1628+
auto ignoreError = [&] {
1629+
auto emitter = mlir::emitError(UnknownLoc::get(type.getContext()));
1630+
emitter.abandon();
1631+
return emitter;
1632+
};
1633+
if (failed(
1634+
v.verifyEncoding(newShape, type.getElementType(), ignoreError))) {
1635+
// strip the encoding if it is not valid for the new shape.
1636+
encoding = Attribute();
1637+
}
1638+
}
1639+
return RankedTensorType::get(newShape, type.getElementType(), encoding);
16261640
}
16271641

16281642
void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
@@ -1902,7 +1916,8 @@ RankedTensorType ExtractSliceOp::inferResultType(
19021916
assert(static_cast<int64_t>(staticSizes.size()) ==
19031917
sourceTensorType.getRank() &&
19041918
"unexpected staticSizes not equal to rank of source");
1905-
return RankedTensorType::get(staticSizes, sourceTensorType.getElementType());
1919+
return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
1920+
sourceTensorType.getEncoding());
19061921
}
19071922

19081923
RankedTensorType ExtractSliceOp::inferResultType(
@@ -1943,7 +1958,8 @@ RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
19431958
if (!dimsToProject.test(pos))
19441959
projectedShape.push_back(shape[pos]);
19451960
inferredType =
1946-
RankedTensorType::get(projectedShape, inferredType.getElementType());
1961+
RankedTensorType::get(projectedShape, inferredType.getElementType(),
1962+
inferredType.getEncoding());
19471963
}
19481964
return inferredType;
19491965
}
@@ -2663,8 +2679,8 @@ struct InsertSliceOpSourceCastInserter final
26632679
if (!hasValidSizesOffsets(newSrcShape))
26642680
return failure();
26652681

2666-
RankedTensorType newSrcType =
2667-
RankedTensorType::get(newSrcShape, srcType.getElementType());
2682+
RankedTensorType newSrcType = RankedTensorType::get(
2683+
newSrcShape, srcType.getElementType(), srcType.getEncoding());
26682684
if (srcType == newSrcType ||
26692685
!preservesStaticInformation(srcType, newSrcType) ||
26702686
!tensor::CastOp::areCastCompatible(srcType, newSrcType))
@@ -2815,7 +2831,8 @@ RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
28152831
}
28162832
}
28172833

2818-
return RankedTensorType::get(inferredShape, sourceType.getElementType());
2834+
return RankedTensorType::get(inferredShape, sourceType.getElementType(),
2835+
sourceType.getEncoding());
28192836
}
28202837

28212838
void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
@@ -3597,9 +3614,9 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
35973614
"tiling factors must equal the number of dimensions to tile");
35983615
}
35993616

3600-
ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
3601-
? packOrUnPack.getDestType()
3602-
: packOrUnPack.getSourceType();
3617+
RankedTensorType packedType = (std::is_same<OpTy, PackOp>::value)
3618+
? packOrUnPack.getDestType()
3619+
: packOrUnPack.getSourceType();
36033620
size_t packedRank = packedType.getRank();
36043621
// Require output rank to match input rank + number of blocking factors.
36053622
if (unpackedRank + mixedTiles.size() != packedRank) {
@@ -3866,7 +3883,8 @@ RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
38663883
ArrayRef<int64_t> outerDimsPerm) {
38673884
SmallVector<int64_t> resultShape = getPackOpResultTypeShape(
38683885
sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
3869-
return RankedTensorType::get(resultShape, sourceType.getElementType());
3886+
return RankedTensorType::get(resultShape, sourceType.getElementType(),
3887+
sourceType.getEncoding());
38703888
}
38713889

38723890
Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,

mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -473,14 +473,14 @@ struct FromElementsOpInterface
473473
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
474474
const BufferizationOptions &options) const {
475475
auto fromElementsOp = cast<tensor::FromElementsOp>(op);
476+
auto tensorType = cast<RankedTensorType>(fromElementsOp.getType());
476477

477478
// TODO: Implement memory space for this op.
478-
if (options.defaultMemorySpace != Attribute())
479+
if (options.getMemorySpace(tensorType) != Attribute())
479480
return op->emitError("memory space not implemented yet");
480481

481482
// Allocate a buffer for the result.
482483
Location loc = op->getLoc();
483-
auto tensorType = cast<RankedTensorType>(fromElementsOp.getType());
484484
auto shape = tensorType.getShape();
485485
// TODO: Create alloc_tensor ops during TensorCopyInsertion.
486486
FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
@@ -588,8 +588,10 @@ struct GenerateOpInterface
588588
const BufferizationOptions &options) const {
589589
auto generateOp = cast<tensor::GenerateOp>(op);
590590

591+
auto type = generateOp.getResult().getType();
592+
591593
// TODO: Implement memory space for this op.
592-
if (options.defaultMemorySpace != Attribute())
594+
if (options.getMemorySpace(type) != Attribute())
593595
return op->emitError("memory space not implemented yet");
594596

595597
// Allocate memory.

mlir/test/Dialect/Linalg/collapse-dim.mlir

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -122,13 +122,13 @@ func.func @uncollapsable_strided_memref(%arg0: memref<2x6x24x48xi32>, %arg1: mem
122122
// CHECK-LABEL: func.func @linalg_copy(
123123
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x2x3x4x5xf32, 1 : i64>,
124124
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2x3x4x5xf32, 3 : i64>) -> tensor<1x2x3x4x5xf32, 3 : i64> {
125-
// CHECK: %[[VAL_2:.*]] = tensor.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 1 : i64> into tensor<1x2x12x5xf32>
126-
// CHECK: %[[VAL_3:.*]] = tensor.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 3 : i64> into tensor<1x2x12x5xf32>
127-
// CHECK: %[[VAL_4:.*]] = tensor.collapse_shape %[[VAL_2]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32> into tensor<1x2x60xf32>
128-
// CHECK: %[[VAL_5:.*]] = tensor.collapse_shape %[[VAL_3]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32> into tensor<1x2x60xf32>
129-
// CHECK: %[[VAL_6:.*]] = linalg.copy ins(%[[VAL_4]] : tensor<1x2x60xf32>) outs(%[[VAL_5]] : tensor<1x2x60xf32>) -> tensor<1x2x60xf32>
130-
// CHECK: %[[VAL_7:.*]] = tensor.expand_shape %[[VAL_6]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x60xf32> into tensor<1x2x12x5xf32>
131-
// CHECK: %[[VAL_8:.*]] = tensor.expand_shape %[[VAL_7]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x12x5xf32> into tensor<1x2x3x4x5xf32, 3 : i64>
125+
// CHECK: %[[VAL_2:.*]] = tensor.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 1 : i64> into tensor<1x2x12x5xf32, 1 : i64>
126+
// CHECK: %[[VAL_3:.*]] = tensor.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 3 : i64> into tensor<1x2x12x5xf32, 3 : i64>
127+
// CHECK: %[[VAL_4:.*]] = tensor.collapse_shape %[[VAL_2]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32, 1 : i64> into tensor<1x2x60xf32, 1 : i64>
128+
// CHECK: %[[VAL_5:.*]] = tensor.collapse_shape %[[VAL_3]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32, 3 : i64> into tensor<1x2x60xf32, 3 : i64>
129+
// CHECK: %[[VAL_6:.*]] = linalg.copy ins(%[[VAL_4]] : tensor<1x2x60xf32, 1 : i64>) outs(%[[VAL_5]] : tensor<1x2x60xf32, 3 : i64>) -> tensor<1x2x60xf32, 3 : i64>
130+
// CHECK: %[[VAL_7:.*]] = tensor.expand_shape %[[VAL_6]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x60xf32, 3 : i64> into tensor<1x2x12x5xf32, 3 : i64>
131+
// CHECK: %[[VAL_8:.*]] = tensor.expand_shape %[[VAL_7]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x12x5xf32, 3 : i64> into tensor<1x2x3x4x5xf32, 3 : i64>
132132
// CHECK: return %[[VAL_8]] : tensor<1x2x3x4x5xf32, 3 : i64>
133133
// CHECK: }
134134

0 commit comments

Comments
 (0)