-
Notifications
You must be signed in to change notification settings - Fork 17
[Dialect] [OneDNNGraph] Add ops lowering for llama2 mlp #107
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One concern I had while reading is the semantics of the shapes folding. I wasn't able to formulate any concrete examples and there might be none, but it still makes sense to give it yet another thought (maybe add complex cases where problems may arise).
SmallVector<int64_t> getReducedShape(ShapeAdaptor operandShape, | ||
ArrayRef<int64_t> axes, bool keep_dims) { | ||
SmallVector<int64_t> outputShape; | ||
SmallVector<int64_t> canonicalizeKeepAxes(ArrayRef<int64_t> axes, int64_t rank, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be nice to have a short note on what is considered canonical for future reference
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added comment
Value newVal = op; | ||
if (collapseShape.size() < opShape.size()) { | ||
assert(collapseShape.size() + bcastDims.size() == bcastShape.size()); | ||
auto reassociation = | ||
computeReassociationByAnchor(keepDims, opTy.getRank()); | ||
ShapedType collapseTy = | ||
RankedTensorType::get(collapseShape, opTy.getElementType()); | ||
newVal = rewriter.create<tensor::CollapseShapeOp>(loc, collapseTy, newVal, | ||
reassociation); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please elaborate on what's going on? And what does an anchor mean in this context?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
consider kept dims [a, b,..., c] as anchor, so reassociation = [[...a...], [...b...], ..., [...c...]]
e.g. for shape [16, 1, 32, 1, 64], rank = 5, kept dims = [0, 2, 4]
[16, 1, 32, 1, 64] --collapse-> [16, 32, 64]
reassociation = [[0], [1, 2], [3, 4]]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Feels like this can still be better in the following sense. So first, we introduce an "anchor" which is the result of the collapse transformation. Kept dims is an array of indices of the shape. Finally, reassociation would be better off with a clear definition. All these are not obvious and make you read the code to understand. For example, what would happen to [16, 1, 1, 32]
, rank 4
, kept dims [0, 3]
? One would expect a [16, 32]
. What should the reassociation look like?
So, for anchor - do we need another term at all? It overloads the term for fusion as well. Kept dims are clear I think. The reassociation we should also clarify.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for [16, 1, 1, 32] -> [16, 32], rank = 4, kept dims = [0, 3]
reassociation = [[0], [1, 2, 3]]
anchor_dims
is just a internal term/var name for this file, to represent which dims to collapse to/expand from, I don't think it will conflict with the term for fusion.
reassociation
is not introduced by us, it is the attribute of tensor.collapse_shape
and tensor.expand_shape
from MLIR tensor dialect
SmallVector<int64_t> lhsShape(lhsType.getShape()); | ||
SmallVector<int64_t> rhsShape(rhsType.getShape()); | ||
assert(lhsShape.size() >= 2 && rhsShape.size() >= 2); | ||
// assuming last 2 input dims are row and col |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What guarantees it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In else if (lRank > 1 && rRank > 1)
, it checks for both input rank >= 2, meaning 2 inputs are all matrix and may have batch dims.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean, what would happen with a transposed matrix that has batch dimension at the last position for whatever reason? I guess what you are saying is that we will treat it as if the last two are not batch dimensions and if they are the shape/layout was just wrong. Correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Transpose attr only Controls whether to transpose the last two dimensions
, so batch dims always before last 2 dims according to the onednn spec. If last two dimensions somehow contain a batch dim, it is definitely wrong.
@kurapov-peter Can you specify what consider as semantics of the shapes folding? We can look into it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@@ -143,6 +144,8 @@ class GCCPUPipeline : public impl::GCCPUPipelineBase<GCCPUPipeline> { | |||
auto op = getOperation(); | |||
PassManager pm{op->getContext()}; | |||
populateCPUPipeline(pm); | |||
// TODO(longsheng): add a option to | |||
// disable threading and enable pm.enableIRPrinting(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What was this TODO comment for previously ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There were no code here previously, I use to add pm.enableIRPrinting() for debug internal passes, but this function also requires disabling threading, we can figure out later.
Add llama2 mlp ops lowering, update matmul support for batch broadcast, matmul lowering for 3dx2d flatten. A llama2 mlp graph has been placed in onednn-graph-llama2.mlir for codegen testing.
Note: if using conda, first
export LD_PRELOAD=path/to/libomp.so
, thengc-opt %s --gc-cpu-pipeline | gc-cpu-runner -e main -entry-point-result=void
to execute generated binary.Tracking: #117