-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][Vector] Add narrow type emulation pattern for vector.maskedload #68443
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome! Just a comment for now about the passthru value.
Would you also mind adding some description to the commit message? That would help. Thanks!
// (3xi4) from the memref, we need to set the second half of the last i8 | ||
// that was effectively loaded (i.e. the second i8) to 0. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Trying to understand if this is really needed. Given that we are only emulating the vector loads/stores and bitcasting them back to the i4 type (see vector load case), it should be ok to just do an arith.select
using the original mask, the loaded i4 vector and the original passthru value, right? If that is the case, all the code below would be really simplified!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure I see how to do this. Do you mean having a loop of arith.select
based on the index?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm probably the one missing something :) I meant:
%i8ld = vector.maskedload %0[%linear_index], %new_mask, %new_pass_thru ...
%i4ld = vector.bitcast %i8ld : vector<3xi8> to vector<6xi4>
%final_ld = arith.select %mask, %i4ld, %pass_thru : vector<6xi1>, vector<6xi4>
Would this work? The generated asm may not be ideal but it should definitely reduce the complexity of this lowering.
✅ With the latest revision this PR passed the C/C++ code formatter. |
As for vector.load and vector.transfer_read ops, vector.maskedload op with narrow bitwidth (e.g., i4) can now be converted to supported wider bitwidth (e.g., i8 or i32).
… mask is an extraction
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG! If you can remove the commented code that should be it! Thanks a lot for looking into this!
…vector.maskedload arith.select should be used instead of a series of manual mask manipulating ops (arith.and/or/extsi)
… emulation pattern
Thanks for the review! |
func.func @fcst_maskedload(%A: memref<?xi4>, %passthru: vector<6xi4>) -> vector<6xi4> { | ||
%c0 = arith.constant 0: index | ||
%mask = vector.constant_mask [3] : vector<6xi1> | ||
%1 = vector.maskedload %A[%c0], %mask, %passthru : | ||
memref<?xi4>, vector<6xi1>, vector<6xi4> into vector<6xi4> | ||
return %1 : vector<6xi4> | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just noticed that this is not tested because the patterns are not populated..
llvm-project/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
Lines 162 to 165 in a0bd636
void transform::ApplyRewriteNarrowTypePatternsOp::populatePatterns( | |
RewritePatternSet &patterns) { | |
populateVectorNarrowTypeRewritePatterns(patterns); | |
} |
To test it, we need to set up a type converter and populate those patterns.
No description provided.