-
Notifications
You must be signed in to change notification settings - Fork 17
[Dialect] [OneDNNGraph] Add ops lowering for llama2 mlp #107
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
0781b4f
93306ef
60807a8
9344308
eb4678f
d82cc81
bda491d
f1a8e10
1ecf8a4
45e64c9
50c7737
fe0c118
8e144c9
44619b4
dd98c65
04db05b
4877c06
f6072f0
4dc1aa4
47b6551
e86afaa
6682794
0c3ff50
72983b7
b14398c
d208ec4
fc85574
dee3d54
eb04513
7e1a4be
efb6133
20d7dd0
37782b6
fb61105
ecf95e8
d31acb3
2a1c98b
532f51f
e6e8159
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -48,6 +48,8 @@ BINARY_OP_SHAPE_INFER(onednn_graph::DivOp) | |
// Reduce ops shape infer | ||
//===----------------------------------------------------------------------===// | ||
|
||
// canonicalize reduced axes | ||
// make all axis inside reduced axes array non-negative in acceding order | ||
SmallVector<int64_t> canonicalizeReduceAxes(ArrayRef<int64_t> axes, | ||
int64_t rank) { | ||
SmallVector<int64_t> ret(axes.size()); | ||
|
@@ -59,21 +61,46 @@ SmallVector<int64_t> canonicalizeReduceAxes(ArrayRef<int64_t> axes, | |
return ret; | ||
} | ||
|
||
SmallVector<int64_t> getReducedShape(ShapeAdaptor operandShape, | ||
ArrayRef<int64_t> axes, bool keep_dims) { | ||
SmallVector<int64_t> outputShape; | ||
// canonicalize kept axes | ||
// make all axis inside kept axes array non-negative in acceding order | ||
SmallVector<int64_t> canonicalizeKeepAxes(ArrayRef<int64_t> axes, int64_t rank, | ||
bool canonicalized = false) { | ||
// get canonicalized reduce axes | ||
auto newCanonicalized = canonicalized ? SmallVector<int64_t>{} | ||
: canonicalizeReduceAxes(axes, rank); | ||
auto reduceAxes = canonicalized ? axes : ArrayRef<int64_t>(newCanonicalized); | ||
// get kept axes | ||
SmallVector<int64_t> keepAxes; | ||
for (int64_t dim = 0, idx = 0; dim < rank; dim++) { | ||
if (idx < (int64_t)reduceAxes.size() && reduceAxes[idx] == dim) { | ||
idx++; | ||
continue; | ||
} | ||
keepAxes.push_back(dim); | ||
} | ||
return keepAxes; | ||
} | ||
|
||
SmallVector<int64_t> inferReducedShape(ShapedType operandShape, | ||
ArrayRef<int64_t> axes, bool keepDims, | ||
bool canonicalized = false) { | ||
// get canonicalized reduce axes | ||
auto rank = operandShape.getRank(); | ||
auto newCanonicalized = canonicalized ? SmallVector<int64_t>{} | ||
: canonicalizeReduceAxes(axes, rank); | ||
auto reduceAxes = canonicalized ? axes : ArrayRef<int64_t>(newCanonicalized); | ||
// get reduce axis one by one | ||
size_t index = 0; | ||
auto getNextReduceAxis = [&]() { | ||
return (index >= axes.size()) ? -1 : axes[index++]; | ||
return (index >= reduceAxes.size()) ? -1 : reduceAxes[index++]; | ||
}; | ||
// get reduced shape | ||
auto rank = operandShape.getRank(); | ||
SmallVector<int64_t> outputShape; | ||
auto axis = getNextReduceAxis(); | ||
for (int64_t idx = 0; idx < rank; idx++) { | ||
if (idx == axis) { | ||
axis = getNextReduceAxis(); | ||
if (keep_dims) { | ||
if (keepDims) { | ||
outputShape.push_back(1); | ||
} | ||
} else { | ||
|
@@ -84,16 +111,16 @@ SmallVector<int64_t> getReducedShape(ShapeAdaptor operandShape, | |
} | ||
|
||
static LogicalResult InferReduceReturnTypes( | ||
ShapeAdaptor operandShape, Type elemType, ArrayRef<int64_t> axes, | ||
bool keep_dims, | ||
ShapedType operandTy, ArrayRef<int64_t> axes, bool keepDims, | ||
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { | ||
// no reduce axes | ||
if (axes.empty()) { | ||
inferredReturnShapes.push_back(ShapedTypeComponents(operandShape)); | ||
inferredReturnShapes.push_back(ShapedTypeComponents(operandTy)); | ||
return success(); | ||
} | ||
inferredReturnShapes.push_back(ShapedTypeComponents( | ||
getReducedShape(operandShape, axes, keep_dims), elemType)); | ||
inferredReturnShapes.push_back( | ||
ShapedTypeComponents(inferReducedShape(operandTy, axes, keepDims), | ||
operandTy.getElementType())); | ||
return success(); | ||
} | ||
|
||
|
@@ -119,10 +146,10 @@ struct CanonicalizeReduceOp : public OpRewritePattern<ReduceOp> { | |
return failure(); | ||
} | ||
// canonicalize the reduce axes | ||
auto new_axes = canonicalizeReduceAxes(op.getAxes(), rank); | ||
auto new_op = rewriter.create<ReduceOp>( | ||
op.getLoc(), op.getType(), op.getOperand(), new_axes, op.getKeepDims()); | ||
rewriter.replaceOp(op, new_op); | ||
auto newAxes = canonicalizeReduceAxes(op.getAxes(), rank); | ||
auto newOp = rewriter.create<ReduceOp>( | ||
op.getLoc(), op.getType(), op.getOperand(), newAxes, op.getKeepDims()); | ||
rewriter.replaceOp(op, newOp); | ||
// NOLINTEND | ||
return success(); | ||
} | ||
|
@@ -142,12 +169,9 @@ struct CanonicalizeReduceOp : public OpRewritePattern<ReduceOp> { | |
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \ | ||
llvm::SmallVector<int64_t> outShape; \ | ||
auto operandTy = dyn_cast<ShapedType>(adaptor.getOperand().getType()); \ | ||
auto rank = operandTy.getRank(); \ | ||
ShapeAdaptor inputShape(operandTy); \ | ||
return InferReduceReturnTypes( \ | ||
inputShape, operandTy.getElementType(), \ | ||
canonicalizeReduceAxes(adaptor.getAxes(), rank), \ | ||
adaptor.getKeepDims(), inferredReturnShapes); \ | ||
return InferReduceReturnTypes(operandTy, adaptor.getAxes(), \ | ||
adaptor.getKeepDims(), \ | ||
inferredReturnShapes); \ | ||
} | ||
|
||
#define REDUCE_OP_VERIFY(OP) \ | ||
|
@@ -181,22 +205,37 @@ LogicalResult onednn_graph::MatMulOp::inferReturnTypeComponents( | |
MLIRContext *context, ::std::optional<Location> location, | ||
MatMulOp::Adaptor adaptor, | ||
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { | ||
// get batch dims from shape | ||
auto extractBatch = [](const ShapeAdaptor &lhsShape, | ||
const ShapeAdaptor &rhsShape, int64_t range, | ||
int64_t diff, SmallVector<int64_t> &outDims) { | ||
for (int64_t i = 0; i < range; i++) { | ||
// TODO(longsheng): add OpTrait::util::getBroadcastedShape for batch | ||
int64_t idx = i - diff; | ||
if ((idx >= 0) && (lhsShape.getDimSize(i) != rhsShape.getDimSize(idx))) { | ||
return failure(); | ||
} | ||
outDims.push_back(lhsShape.getDimSize(i)); | ||
// get batch dims from 1 multi-batch mat shape | ||
auto extractBatch = [](ShapedType shape, SmallVector<int64_t> &outDims) { | ||
// assuming last 2 input dims are row and col | ||
assert(shape.getRank() >= 2); | ||
for (int64_t i = 0; i < shape.getRank() - 2; i++) { | ||
outDims.push_back(shape.getDimSize(i)); | ||
} | ||
return success(); | ||
}; | ||
// get broadcasted batch dims from 2 multi-batch mat shape, | ||
auto extractBroadcastBatch = [](ShapedType lhsType, ShapedType rhsType, | ||
SmallVector<int64_t> &outDims) { | ||
SmallVector<int64_t> lhsShape(lhsType.getShape()); | ||
SmallVector<int64_t> rhsShape(rhsType.getShape()); | ||
assert(lhsShape.size() >= 2 && rhsShape.size() >= 2); | ||
// assuming last 2 input dims are row and col | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What guarantees it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I mean, what would happen with a transposed matrix that has batch dimension at the last position for whatever reason? I guess what you are saying is that we will treat it as if the last two are not batch dimensions and if they are the shape/layout was just wrong. Correct? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Transpose attr only |
||
// 0xFF is just a random number > 1, replacing the row and col dims | ||
// so that getBroadcastedShape can match, will be removed after | ||
LongshengDu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
lhsShape[lhsShape.size() - 1] = 0xFF; | ||
lhsShape[lhsShape.size() - 2] = 0xFF; | ||
rhsShape[rhsShape.size() - 1] = 0xFF; | ||
rhsShape[rhsShape.size() - 2] = 0xFF; | ||
bool ret = OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, outDims); | ||
// remove 0xFF | ||
assert(outDims.size() >= 2); | ||
outDims.pop_back(); | ||
outDims.pop_back(); | ||
return LogicalResult::success(ret); | ||
}; | ||
// get row col of 2d matrix according to transpose info | ||
auto getMatRowCol = [](const ShapeAdaptor &shape, bool transpose) { | ||
auto getMatRowCol = [](ShapedType shape, bool transpose) { | ||
using pairRowCol = std::pair<int64_t, int64_t>; | ||
auto rank = shape.getRank(); | ||
assert(rank > 1); | ||
|
@@ -205,8 +244,8 @@ LogicalResult onednn_graph::MatMulOp::inferReturnTypeComponents( | |
: pairRowCol(shape.getDimSize(rank - 2), | ||
shape.getDimSize(rank - 1)); | ||
}; | ||
ShapeAdaptor lhsShape(adaptor.getInputA().getType()); | ||
ShapeAdaptor rhsShape(adaptor.getInputB().getType()); | ||
auto lhsShape = cast<ShapedType>(adaptor.getInputA().getType()); | ||
auto rhsShape = cast<ShapedType>(adaptor.getInputB().getType()); | ||
bool transposeA = adaptor.getTransposeA(); | ||
bool transposeB = adaptor.getTransposeB(); | ||
int64_t lRank = lhsShape.getRank(); | ||
|
@@ -223,36 +262,24 @@ LogicalResult onednn_graph::MatMulOp::inferReturnTypeComponents( | |
} else if (lRank == 1 && rRank > 1) { | ||
// 1D x ND | ||
auto rMatRowCol = getMatRowCol(rhsShape, transposeB); | ||
status = extractBatch(rhsShape, rhsShape, rRank - 2, 0, outShape); | ||
status = extractBatch(rhsShape, outShape); | ||
if (lhsShape.getDimSize(0) != rMatRowCol.first) { | ||
return failure(); | ||
} | ||
outShape.push_back(rhsShape.getDimSize(rMatRowCol.second)); | ||
} else if (lRank > 1 && rRank == 1) { | ||
// ND x 1D | ||
auto lMatRowCol = getMatRowCol(lhsShape, transposeA); | ||
status = extractBatch(lhsShape, lhsShape, lRank - 2, 0, outShape); | ||
status = extractBatch(lhsShape, outShape); | ||
if (lMatRowCol.second != rhsShape.getDimSize(0)) { | ||
return failure(); | ||
} | ||
outShape.push_back(lhsShape.getDimSize(lMatRowCol.first)); | ||
} else if (lRank > 1 && rRank > 1) { | ||
if (lRank == rRank) { | ||
// ND x ND | ||
auto range = lRank - 2; | ||
status = extractBatch(lhsShape, rhsShape, range, 0, outShape); | ||
} else if (lRank > rRank) { | ||
// MD x ND (M > N) | ||
auto range = lRank - 2; | ||
auto diff = lRank - rRank; | ||
status = extractBatch(lhsShape, rhsShape, range, diff, outShape); | ||
} else { | ||
// ND x MD (M > N) | ||
auto range = rRank - 2; | ||
auto diff = rRank - lRank; | ||
status = extractBatch(rhsShape, lhsShape, range, diff, outShape); | ||
} | ||
// | ||
// ND x ND | ||
// MD x ND (M > N) | ||
// ND x MD (M > N) | ||
status = extractBroadcastBatch(lhsShape, rhsShape, outShape); | ||
auto lMatRowCol = getMatRowCol(lhsShape, transposeA); | ||
auto rMatRowCol = getMatRowCol(rhsShape, transposeB); | ||
if (failed(status) || (lMatRowCol.second != rMatRowCol.first)) { | ||
|
@@ -269,16 +296,13 @@ LogicalResult onednn_graph::MatMulOp::inferReturnTypeComponents( | |
inferredReturnShapes.push_back(retShape); | ||
// check for bias broadcasting | ||
if (adaptor.getBias()) { | ||
auto biasType = adaptor.getBias().getType(); | ||
ShapeAdaptor biasShape(biasType); | ||
|
||
bool biasRankMatch = biasShape.getRank() == 1 || | ||
biasShape.getRank() == (int64_t)outShape.size(); | ||
auto biasType = dyn_cast<ShapedType>(adaptor.getBias().getType()); | ||
bool biasRankMatch = biasType.getRank() == 1 || | ||
biasType.getRank() == (int64_t)outShape.size(); | ||
SmallVector<int64_t> resultShape; | ||
if (!biasRankMatch || | ||
!OpTrait::util::getBroadcastedShape( | ||
retShape.getDims(), dyn_cast<ShapedType>(biasType).getShape(), | ||
resultShape)) { | ||
!OpTrait::util::getBroadcastedShape(retShape.getDims(), | ||
biasType.getShape(), resultShape)) { | ||
return failure(); | ||
} | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be nice to have a short note on what is considered canonical for future reference
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added comment