Skip to content

[Dialect] [OneDNNGraph] Add ops lowering for llama2 mlp #107

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 39 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/gc/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ class OpenMPDialect;
namespace linalg {
class LinalgDialect;
}
namespace linalgx {
class LinalgxDialect;
}

namespace MemRef {
class MemRefDialect;
Expand Down
12 changes: 7 additions & 5 deletions include/gc/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@ def ConvertOneDNNGraphToLinalg : Pass<"convert-onednn-graph-to-linalg"> {
Lowers the `onednn_graph` ops to `linalg` ops.
}];
let dependentDialects = [
"func::FuncDialect",
"math::MathDialect",
"arith::ArithDialect",
"tensor::TensorDialect",
"linalg::LinalgDialect"
"func::FuncDialect",
"math::MathDialect",
"arith::ArithDialect",
"tensor::TensorDialect",
"linalg::LinalgDialect",
"linalgx::LinalgxDialect"
];
}

Expand All @@ -37,6 +38,7 @@ def GCCPUPipeline: Pass<"gc-cpu-pipeline"> {
"tensor::TensorDialect",
"memref::MemRefDialect",
"linalg::LinalgDialect",
"linalgx::LinalgxDialect",
"LLVM::LLVMDialect",
"scf::SCFDialect",
"bufferization::BufferizationDialect",
Expand Down
141 changes: 80 additions & 61 deletions lib/gc/Dialect/OneDNNGraph/OneDNNGraphOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,21 +59,44 @@ SmallVector<int64_t> canonicalizeReduceAxes(ArrayRef<int64_t> axes,
return ret;
}

SmallVector<int64_t> getReducedShape(ShapeAdaptor operandShape,
ArrayRef<int64_t> axes, bool keep_dims) {
SmallVector<int64_t> outputShape;
SmallVector<int64_t> canonicalizeKeepAxes(ArrayRef<int64_t> axes, int64_t rank,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice to have a short note on what is considered canonical for future reference

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added comment

bool canonicalized = false) {
// get canonicalized reduce axes
auto newCanonicalized = canonicalized ? SmallVector<int64_t>{}
: canonicalizeReduceAxes(axes, rank);
auto reduceAxes = canonicalized ? axes : ArrayRef<int64_t>(newCanonicalized);
// get kept axes
SmallVector<int64_t> keepAxes;
for (int64_t dim = 0, idx = 0; dim < rank; dim++) {
if (idx < (int64_t)reduceAxes.size() && reduceAxes[idx] == dim) {
idx++;
continue;
}
keepAxes.push_back(dim);
}
return keepAxes;
}

SmallVector<int64_t> inferReducedShape(ShapedType operandShape,
ArrayRef<int64_t> axes, bool keepDims,
bool canonicalized = false) {
// get canonicalized reduce axes
auto rank = operandShape.getRank();
auto newCanonicalized = canonicalized ? SmallVector<int64_t>{}
: canonicalizeReduceAxes(axes, rank);
auto reduceAxes = canonicalized ? axes : ArrayRef<int64_t>(newCanonicalized);
// get reduce axis one by one
size_t index = 0;
auto getNextReduceAxis = [&]() {
return (index >= axes.size()) ? -1 : axes[index++];
return (index >= reduceAxes.size()) ? -1 : reduceAxes[index++];
};
// get reduced shape
auto rank = operandShape.getRank();
SmallVector<int64_t> outputShape;
auto axis = getNextReduceAxis();
for (int64_t idx = 0; idx < rank; idx++) {
if (idx == axis) {
axis = getNextReduceAxis();
if (keep_dims) {
if (keepDims) {
outputShape.push_back(1);
}
} else {
Expand All @@ -84,16 +107,16 @@ SmallVector<int64_t> getReducedShape(ShapeAdaptor operandShape,
}

static LogicalResult InferReduceReturnTypes(
ShapeAdaptor operandShape, Type elemType, ArrayRef<int64_t> axes,
bool keep_dims,
ShapedType operandTy, ArrayRef<int64_t> axes, bool keepDims,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
// no reduce axes
if (axes.empty()) {
inferredReturnShapes.push_back(ShapedTypeComponents(operandShape));
inferredReturnShapes.push_back(ShapedTypeComponents(operandTy));
return success();
}
inferredReturnShapes.push_back(ShapedTypeComponents(
getReducedShape(operandShape, axes, keep_dims), elemType));
inferredReturnShapes.push_back(
ShapedTypeComponents(inferReducedShape(operandTy, axes, keepDims),
operandTy.getElementType()));
return success();
}

Expand All @@ -119,10 +142,10 @@ struct CanonicalizeReduceOp : public OpRewritePattern<ReduceOp> {
return failure();
}
// canonicalize the reduce axes
auto new_axes = canonicalizeReduceAxes(op.getAxes(), rank);
auto new_op = rewriter.create<ReduceOp>(
op.getLoc(), op.getType(), op.getOperand(), new_axes, op.getKeepDims());
rewriter.replaceOp(op, new_op);
auto newAxes = canonicalizeReduceAxes(op.getAxes(), rank);
auto newOp = rewriter.create<ReduceOp>(
op.getLoc(), op.getType(), op.getOperand(), newAxes, op.getKeepDims());
rewriter.replaceOp(op, newOp);
// NOLINTEND
return success();
}
Expand All @@ -142,12 +165,9 @@ struct CanonicalizeReduceOp : public OpRewritePattern<ReduceOp> {
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
llvm::SmallVector<int64_t> outShape; \
auto operandTy = dyn_cast<ShapedType>(adaptor.getOperand().getType()); \
auto rank = operandTy.getRank(); \
ShapeAdaptor inputShape(operandTy); \
return InferReduceReturnTypes( \
inputShape, operandTy.getElementType(), \
canonicalizeReduceAxes(adaptor.getAxes(), rank), \
adaptor.getKeepDims(), inferredReturnShapes); \
return InferReduceReturnTypes(operandTy, adaptor.getAxes(), \
adaptor.getKeepDims(), \
inferredReturnShapes); \
}

#define REDUCE_OP_VERIFY(OP) \
Expand Down Expand Up @@ -181,22 +201,36 @@ LogicalResult onednn_graph::MatMulOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
MatMulOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
// get batch dims from shape
auto extractBatch = [](const ShapeAdaptor &lhsShape,
const ShapeAdaptor &rhsShape, int64_t range,
int64_t diff, SmallVector<int64_t> &outDims) {
for (int64_t i = 0; i < range; i++) {
// TODO(longsheng): add OpTrait::util::getBroadcastedShape for batch
int64_t idx = i - diff;
if ((idx >= 0) && (lhsShape.getDimSize(i) != rhsShape.getDimSize(idx))) {
return failure();
}
outDims.push_back(lhsShape.getDimSize(i));
// get batch dims from 1 multi-batch mat shape
auto extractBatch = [](ShapedType shape, SmallVector<int64_t> &outDims) {
// assuming last 2 input dims are row and col
assert(shape.getRank() >= 2);
for (int64_t i = 0; i < shape.getRank() - 2; i++) {
outDims.push_back(shape.getDimSize(i));
}
return success();
};
// get broadcasted batch dims from 2 multi-batch mat shape,
auto extractBroadcastBatch = [](ShapedType lhsType, ShapedType rhsType,
SmallVector<int64_t> &outDims) {
SmallVector<int64_t> lhsShape(lhsType.getShape());
SmallVector<int64_t> rhsShape(rhsType.getShape());
// assuming last 2 input dims are row and col
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What guarantees it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In else if (lRank > 1 && rRank > 1), it checks for both input rank >= 2, meaning 2 inputs are all matrix and may have batch dims.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean, what would happen with a transposed matrix that has batch dimension at the last position for whatever reason? I guess what you are saying is that we will treat it as if the last two are not batch dimensions and if they are the shape/layout was just wrong. Correct?

Copy link
Contributor Author

@LongshengDu LongshengDu Jun 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Transpose attr only Controls whether to transpose the last two dimensions, so batch dims always before last 2 dims according to the onednn spec. If last two dimensions somehow contain a batch dim, it is definitely wrong.

// 0xFF is just a random number > 1, replacing the row and col dims
// so that getBroadcastedShape can match, will be removed after
lhsShape[lhsShape.size() - 1] = 0xFF;
lhsShape[lhsShape.size() - 2] = 0xFF;
rhsShape[rhsShape.size() - 1] = 0xFF;
rhsShape[rhsShape.size() - 2] = 0xFF;
bool ret = OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, outDims);
// remove 0xFF
assert(outDims.size() >= 2);
outDims.pop_back();
outDims.pop_back();
return LogicalResult::success(ret);
};
// get row col of 2d matrix according to transpose info
auto getMatRowCol = [](const ShapeAdaptor &shape, bool transpose) {
auto getMatRowCol = [](ShapedType shape, bool transpose) {
using pairRowCol = std::pair<int64_t, int64_t>;
auto rank = shape.getRank();
assert(rank > 1);
Expand All @@ -205,8 +239,8 @@ LogicalResult onednn_graph::MatMulOp::inferReturnTypeComponents(
: pairRowCol(shape.getDimSize(rank - 2),
shape.getDimSize(rank - 1));
};
ShapeAdaptor lhsShape(adaptor.getInputA().getType());
ShapeAdaptor rhsShape(adaptor.getInputB().getType());
auto lhsShape = cast<ShapedType>(adaptor.getInputA().getType());
auto rhsShape = cast<ShapedType>(adaptor.getInputB().getType());
bool transposeA = adaptor.getTransposeA();
bool transposeB = adaptor.getTransposeB();
int64_t lRank = lhsShape.getRank();
Expand All @@ -223,36 +257,24 @@ LogicalResult onednn_graph::MatMulOp::inferReturnTypeComponents(
} else if (lRank == 1 && rRank > 1) {
// 1D x ND
auto rMatRowCol = getMatRowCol(rhsShape, transposeB);
status = extractBatch(rhsShape, rhsShape, rRank - 2, 0, outShape);
status = extractBatch(rhsShape, outShape);
if (lhsShape.getDimSize(0) != rMatRowCol.first) {
return failure();
}
outShape.push_back(rhsShape.getDimSize(rMatRowCol.second));
} else if (lRank > 1 && rRank == 1) {
// ND x 1D
auto lMatRowCol = getMatRowCol(lhsShape, transposeA);
status = extractBatch(lhsShape, lhsShape, lRank - 2, 0, outShape);
status = extractBatch(lhsShape, outShape);
if (lMatRowCol.second != rhsShape.getDimSize(0)) {
return failure();
}
outShape.push_back(lhsShape.getDimSize(lMatRowCol.first));
} else if (lRank > 1 && rRank > 1) {
if (lRank == rRank) {
// ND x ND
auto range = lRank - 2;
status = extractBatch(lhsShape, rhsShape, range, 0, outShape);
} else if (lRank > rRank) {
// MD x ND (M > N)
auto range = lRank - 2;
auto diff = lRank - rRank;
status = extractBatch(lhsShape, rhsShape, range, diff, outShape);
} else {
// ND x MD (M > N)
auto range = rRank - 2;
auto diff = rRank - lRank;
status = extractBatch(rhsShape, lhsShape, range, diff, outShape);
}
//
// ND x ND
// MD x ND (M > N)
// ND x MD (M > N)
status = extractBroadcastBatch(lhsShape, rhsShape, outShape);
auto lMatRowCol = getMatRowCol(lhsShape, transposeA);
auto rMatRowCol = getMatRowCol(rhsShape, transposeB);
if (failed(status) || (lMatRowCol.second != rMatRowCol.first)) {
Expand All @@ -269,16 +291,13 @@ LogicalResult onednn_graph::MatMulOp::inferReturnTypeComponents(
inferredReturnShapes.push_back(retShape);
// check for bias broadcasting
if (adaptor.getBias()) {
auto biasType = adaptor.getBias().getType();
ShapeAdaptor biasShape(biasType);

bool biasRankMatch = biasShape.getRank() == 1 ||
biasShape.getRank() == (int64_t)outShape.size();
auto biasType = dyn_cast<ShapedType>(adaptor.getBias().getType());
bool biasRankMatch = biasType.getRank() == 1 ||
biasType.getRank() == (int64_t)outShape.size();
SmallVector<int64_t> resultShape;
if (!biasRankMatch ||
!OpTrait::util::getBroadcastedShape(
retShape.getDims(), dyn_cast<ShapedType>(biasType).getShape(),
resultShape)) {
!OpTrait::util::getBroadcastedShape(retShape.getDims(),
biasType.getShape(), resultShape)) {
return failure();
}
}
Expand Down
Loading