Skip to content

[Dialect] [OneDNNGraph] Add ops lowering for llama2 mlp #107

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 39 commits into from
Jun 7, 2024

Conversation

LongshengDu
Copy link
Contributor

@LongshengDu LongshengDu commented May 29, 2024

Add llama2 mlp ops lowering, update matmul support for batch broadcast, matmul lowering for 3dx2d flatten. A llama2 mlp graph has been placed in onednn-graph-llama2.mlir for codegen testing.

Note: if using conda, first export LD_PRELOAD=path/to/libomp.so, then gc-opt %s --gc-cpu-pipeline | gc-cpu-runner -e main -entry-point-result=void to execute generated binary.

Tracking: #117

@LongshengDu LongshengDu added the WIP work in progress label May 29, 2024
@LongshengDu LongshengDu requested a review from kurapov-peter June 3, 2024 04:22
@LongshengDu LongshengDu removed the WIP work in progress label Jun 3, 2024
@LongshengDu LongshengDu requested a review from ciyongch June 3, 2024 04:22
Longsheng Du added 3 commits June 3, 2024 12:27
@LongshengDu LongshengDu requested a review from xurui1995 June 3, 2024 06:32
Longsheng Du added 2 commits June 4, 2024 16:08
Copy link
Contributor

@kurapov-peter kurapov-peter left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One concern I had while reading is the semantics of the shapes folding. I wasn't able to formulate any concrete examples and there might be none, but it still makes sense to give it yet another thought (maybe add complex cases where problems may arise).

SmallVector<int64_t> getReducedShape(ShapeAdaptor operandShape,
ArrayRef<int64_t> axes, bool keep_dims) {
SmallVector<int64_t> outputShape;
SmallVector<int64_t> canonicalizeKeepAxes(ArrayRef<int64_t> axes, int64_t rank,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice to have a short note on what is considered canonical for future reference

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added comment

Comment on lines 90 to 99
Value newVal = op;
if (collapseShape.size() < opShape.size()) {
assert(collapseShape.size() + bcastDims.size() == bcastShape.size());
auto reassociation =
computeReassociationByAnchor(keepDims, opTy.getRank());
ShapedType collapseTy =
RankedTensorType::get(collapseShape, opTy.getElementType());
newVal = rewriter.create<tensor::CollapseShapeOp>(loc, collapseTy, newVal,
reassociation);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please elaborate on what's going on? And what does an anchor mean in this context?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider kept dims [a, b,..., c] as anchor, so reassociation = [[...a...], [...b...], ..., [...c...]]
e.g. for shape [16, 1, 32, 1, 64], rank = 5, kept dims = [0, 2, 4]
[16, 1, 32, 1, 64] --collapse-> [16, 32, 64]
reassociation = [[0], [1, 2], [3, 4]]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added comment

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Feels like this can still be better in the following sense. So first, we introduce an "anchor" which is the result of the collapse transformation. Kept dims is an array of indices of the shape. Finally, reassociation would be better off with a clear definition. All these are not obvious and make you read the code to understand. For example, what would happen to [16, 1, 1, 32], rank 4, kept dims [0, 3]? One would expect a [16, 32]. What should the reassociation look like?
So, for anchor - do we need another term at all? It overloads the term for fusion as well. Kept dims are clear I think. The reassociation we should also clarify.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for [16, 1, 1, 32] -> [16, 32], rank = 4, kept dims = [0, 3]
reassociation = [[0], [1, 2, 3]]
anchor_dims is just a internal term/var name for this file, to represent which dims to collapse to/expand from, I don't think it will conflict with the term for fusion.
reassociation is not introduced by us, it is the attribute of tensor.collapse_shape and tensor.expand_shape from MLIR tensor dialect

SmallVector<int64_t> lhsShape(lhsType.getShape());
SmallVector<int64_t> rhsShape(rhsType.getShape());
assert(lhsShape.size() >= 2 && rhsShape.size() >= 2);
// assuming last 2 input dims are row and col
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What guarantees it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In else if (lRank > 1 && rRank > 1), it checks for both input rank >= 2, meaning 2 inputs are all matrix and may have batch dims.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean, what would happen with a transposed matrix that has batch dimension at the last position for whatever reason? I guess what you are saying is that we will treat it as if the last two are not batch dimensions and if they are the shape/layout was just wrong. Correct?

Copy link
Contributor Author

@LongshengDu LongshengDu Jun 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Transpose attr only Controls whether to transpose the last two dimensions, so batch dims always before last 2 dims according to the onednn spec. If last two dimensions somehow contain a batch dim, it is definitely wrong.

@LongshengDu
Copy link
Contributor Author

LongshengDu commented Jun 6, 2024

One concern I had while reading is the semantics of the shapes folding. I wasn't able to formulate any concrete examples and there might be none, but it still makes sense to give it yet another thought (maybe add complex cases where problems may arise).

@kurapov-peter Can you specify what consider as semantics of the shapes folding? We can look into it.

Copy link
Contributor

@xurui1995 xurui1995 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@@ -143,6 +144,8 @@ class GCCPUPipeline : public impl::GCCPUPipelineBase<GCCPUPipeline> {
auto op = getOperation();
PassManager pm{op->getContext()};
populateCPUPipeline(pm);
// TODO(longsheng): add a option to
// disable threading and enable pm.enableIRPrinting();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What was this TODO comment for previously ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There were no code here previously, I use to add pm.enableIRPrinting() for debug internal passes, but this function also requires disabling threading, we can figure out later.

@kurapov-peter kurapov-peter merged commit d7c3c0b into main Jun 7, 2024
4 checks passed
@LongshengDu LongshengDu deleted the longsheng/llma2_onednn_lower branch June 7, 2024 15:55
@LongshengDu LongshengDu linked an issue Jul 3, 2024 that may be closed by this pull request
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add support for Llama 2 MLP OPs on oneDNN Graph dialect
5 participants