Skip to content

added not-tested LinalgTilingToParallelLoops pass #353

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ std::unique_ptr<OperationPass<FuncOp>> createConvertElementwiseToLinalgPass();
std::unique_ptr<OperationPass<FuncOp>> createLinalgFoldUnitExtentDimsPass();

std::unique_ptr<Pass> createLinalgElementwiseOpFusionPass();

std::unique_ptr<Pass> createFoldReshapeOpsByLinearizationPass();

std::unique_ptr<OperationPass<FuncOp>>
Expand All @@ -33,8 +34,15 @@ std::unique_ptr<OperationPass<FuncOp>>
createLinalgTilingToTiledLoopPass(ArrayRef<int64_t> tileSizes = {},
ArrayRef<StringRef> distributionTypes = {});

std::unique_ptr<OperationPass<FuncOp>>
createMemoryFootPrintReducePass(int64_t maxFootprint = 0);

std::unique_ptr<OperationPass<FuncOp>>
createLinalgMemoryFootprintReductionPass(int64_t maxFootprint = 0);

std::unique_ptr<OperationPass<FuncOp>>
createLinalgPromotionPass(bool dynamicBuffers, bool useAlloca);

std::unique_ptr<OperationPass<FuncOp>> createLinalgPromotionPass();

std::unique_ptr<OperationPass<FuncOp>> createLinalgInlineScalarOperandsPass();
Expand Down
15 changes: 15 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,21 @@ def LinalgTiling : FunctionPass<"linalg-tile"> {
];
}

def LinalgMemoryFootprintReduction : FunctionPass<"linalg-memory-footprint-reduce"> {
let summary = "Brings the LinalgOps memory footprint below maxMemFootprint";
let constructor = "mlir::createLinalgMemoryFootprintReductionPass()";
let options = [
ListOption<"maxMemFootprint", "linalg-max-memory-footprint", "int64_t",
"Max memory footprint in bytes", "llvm::cl::Optional">
];
let dependentDialects = [
"AffineDialect",
"linalg::LinalgDialect",
"memref::MemRefDialect",
"scf::SCFDialect"
];
}

def LinalgTilingToParallelLoops
: FunctionPass<"linalg-tile-to-parallel-loops"> {
let summary = "Tile operations in the linalg dialect to parallel loops";
Expand Down
183 changes: 183 additions & 0 deletions mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Pass/PassManager.h"

#include "llvm/Support/CommandLine.h"

Expand Down Expand Up @@ -510,6 +511,15 @@ static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) {
funcOp, getLinalgTilingCanonicalizationPatterns(ctx));
}

// static void applyExtractSliceOfPadTensorSwapPattern(LinalgOp op) {
// MLIRContext *ctx = op.getContext();
// RewritePatternSet patterns(ctx);
// patterns.add<ExtractSliceOfPadTensorSwapPattern>(patterns.getContext());
// (void)applyPatternsAndFoldGreedily(op, std::move(patterns));
// (void)applyPatternsAndFoldGreedily(
// op, getLinalgTilingCanonicalizationPatterns(ctx));
// }

static void
applyTilingToLoopPatterns(LinalgTilingLoopType loopType, FuncOp funcOp,
ArrayRef<int64_t> tileSizes,
Expand All @@ -535,6 +545,127 @@ applyTilingToLoopPatterns(LinalgTilingLoopType loopType, FuncOp funcOp,
applyExtractSliceOfPadTensorSwapPattern(funcOp);
}

// static void
// applyTilingToLoopPatterns(LinalgTilingLoopType loopType, LinalgOp op,
// ArrayRef<int64_t> tileSizes,
// ArrayRef<StringRef> distributionTypes = {}) {
// auto options = LinalgTilingOptions()
// .setTileSizes(tileSizes)
// .setLoopType(loopType)
// .setDistributionTypes(distributionTypes);
// MLIRContext *ctx = op.getContext();
// RewritePatternSet patterns(ctx);
// insertTilingPatterns(patterns, options);
// scf::populateSCFForLoopCanonicalizationPatterns(patterns);
// (void)applyPatternsAndFoldGreedily(op, std::move(patterns));
// (void)applyPatternsAndFoldGreedily(
// op, getLinalgTilingCanonicalizationPatterns(ctx));
// // Drop the marker.
// op->removeAttr(LinalgTransforms::kLinalgTransformMarker);

// // Apply swap pattern after generating loop nest and running
// // canonicalizations.
// applyExtractSliceOfPadTensorSwapPattern(op);
// }

/// Calculates the size (in bytes) of a ranked tensor
static inline size_t getSizeFromShape(llvm::ArrayRef<int64_t> shape,
size_t elementBitWidth) {
assert(elementBitWidth % 8 == 0 && "BitWidth has to be divisible by 8");
size_t numBytes = 1ul;
for (auto s : shape)
numBytes *= s;
return numBytes * (elementBitWidth / 8);
}

static constexpr inline int64_t ceilDiv(int64_t Numerator,
int64_t Denominator) {
return 1 + ((Numerator - 1) / Denominator);
}

struct RankedOperands {
struct Operand {
mlir::ArrayRef<int64_t> shape;
size_t bitWidth;
};
llvm::SmallVector<Operand> ops;
bool isParallelizable;
};

template<typename Op> // TODO: add type trait to make sure Op has `getOperands()'
RankedOperands getOperands(Op &op) {
llvm::SmallVector<RankedOperands::Operand> rankedOperands;
bool isParallelizable = true;
for (auto val : op.getOperands()) {
if (auto rankedTensor = val.getType()
.template dyn_cast<mlir::RankedTensorType>()) {
rankedOperands.push_back(
{rankedTensor.getShape(), rankedTensor.getElementTypeBitWidth()});
isParallelizable = false;
} else if (auto rankedMemRef = val.getType()
.template dyn_cast<mlir::MemRefType>()) {
rankedOperands.push_back(
{rankedMemRef.getShape(), rankedMemRef.getElementTypeBitWidth()});
} else
return {{}, false}; // Nothing to tile
}
return {std::move(rankedOperands), isParallelizable};
}


static inline constexpr int64_t findNextPowerOfTwo(int64_t num) {
int64_t p = 1;
while (p <= num)
p <<= 1;
return p;
}

static inline constexpr int64_t findPreviousPowerOfTwo(int64_t num) {
int64_t p = findNextPowerOfTwo(num);
return p >>= 1;
}

static llvm::SmallVector<int64_t> findNewShape(llvm::ArrayRef<int64_t> oldShape,
size_t bitWidth,
size_t maxTensorSize) {

llvm::SmallVector<int64_t> newShape(oldShape.begin(), oldShape.end());
for (size_t i = 0, end = oldShape.size(); i < end; i++) {
auto curSize = getSizeFromShape(newShape, bitWidth);
auto tileSize = ceilDiv(curSize, maxTensorSize);
if (oldShape[i] >= tileSize) {
newShape[i] = oldShape[i] / tileSize;
break;
} else {
newShape[i] = 1;
}
}
return newShape;
}

struct TilingStrategy {
llvm::SmallVector<int64_t> tilingShape;
bool isParallelizable;
};

template<typename Op> // TODO: add type trait to make sure Op is the type we want
static TilingStrategy getTilingStrategy(Op &op,
int64_t maxMemoryFootprint) {
auto operands = getOperands(op);
TilingStrategy ts;
// We can't tile if we have unranked stuff
int64_t numRanked = operands.ops.size();
if (!numRanked)
return {{}, false};

auto desiredFootprintPerTensor =
findPreviousPowerOfTwo(maxMemoryFootprint / numRanked);
auto &rankedOp = *operands.ops.begin();
return {findNewShape(rankedOp.shape, rankedOp.bitWidth,
desiredFootprintPerTensor),
operands.isParallelizable};
}

namespace {
struct LinalgTilingPass : public LinalgTilingBase<LinalgTilingPass> {
LinalgTilingPass() = default;
Expand Down Expand Up @@ -578,6 +709,48 @@ struct LinalgTilingToTiledLoopsPass
}
};

struct LinalgMemoryFootprintReductionPass
: public LinalgMemoryFootprintReductionBase<
LinalgMemoryFootprintReductionPass> {
LinalgMemoryFootprintReductionPass() = default;
LinalgMemoryFootprintReductionPass(int64_t maxFootprint) {
maxMemFootprint = maxFootprint;
}

void runOnFunction() override {
// Apply tiling patterns for each linalg op here
}
};

class MemoryFootPrintReducePass
: public mlir::PassWrapper<MemoryFootPrintReducePass,
mlir::OperationPass<mlir::FuncOp>> {
int64_t maxFootprint;

public:
explicit MemoryFootPrintReducePass(int64_t maxFootprint)
: maxFootprint(maxFootprint) {}

void runOnOperation() override {
auto funcOp = getOperation();
auto &region = funcOp.getRegion();
auto genericOps = llvm::make_filter_range(region.getOps(), [](Operation &op) {
return isa<mlir::linalg::GenericOp>(op);
});

OpPassManager pm("builtin.func");
for (auto &op : genericOps) {
auto ts = getTilingStrategy(op, maxFootprint);
if (ts.isParallelizable) {
pm.addPass(mlir::createLinalgTilingToParallelLoopsPass(ts.tilingShape));
} else {
pm.addPass(mlir::createLinalgTilingPass(ts.tilingShape));
}
}
(void)runPipeline(pm, funcOp);
}
};

} // namespace

std::unique_ptr<OperationPass<FuncOp>>
Expand All @@ -596,3 +769,13 @@ mlir::createLinalgTilingToTiledLoopPass(ArrayRef<int64_t> tileSizes,
return std::make_unique<LinalgTilingToTiledLoopsPass>(tileSizes,
distributionTypes);
}

std::unique_ptr<OperationPass<FuncOp>>
mlir::createLinalgMemoryFootprintReductionPass(int64_t maxFootprint) {
return std::make_unique<LinalgMemoryFootprintReductionPass>(maxFootprint);
}

std::unique_ptr<OperationPass<FuncOp>>
mlir::createMemoryFootPrintReducePass(int64_t maxFootprint) {
return std::make_unique<MemoryFootPrintReducePass>(maxFootprint);
}