-
Notifications
You must be signed in to change notification settings - Fork 17
[OneDNN Graph Dialect] Use Broadcast Trait and organize data types #81
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
auto ret = | ||
inferBroadcastShape<ValueShapeRange>(operands, outShape, getShapeIdx); | ||
llvm::SmallVector<int64_t> input1, input2; | ||
getShapeIdx(operands, 0).getDims(input1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use something like getShape
to get the ArrayRef of shape for performance?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks~ Optimized the code~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, maybe dyn_cast<ShapedType>()
should be better, in case we need to support other shaped types in the future.
// final shape | ||
auto retShape = ShapedTypeComponents(outShape, lhsShape.getElementType()); | ||
inferredReturnShapes.push_back(retShape); | ||
// check for bias broadcasting | ||
if (adaptor.getBias()) { | ||
ShapeAdaptor biasShape(adaptor.getBias().getType()); | ||
ShapeAdaptor matShape(retShape); | ||
llvm::SmallVector<int64_t> matmulShape; | ||
matShape.getDims(matmulShape); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use something like getShape
to get the ArrayRef of shape for performance?
// Floating-point types. | ||
//===----------------------------------------------------------------------===// | ||
def OneDNNGraph_Float : AnyTypeOf<[F32, | ||
F16, | ||
BF16]>; | ||
|
||
//===----------------------------------------------------------------------===// | ||
// Integer types. | ||
//===----------------------------------------------------------------------===// | ||
|
||
def OneDNNGraph_Int : AnyTypeOf<[SI<8>, | ||
UI<8>]>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the rationale for having separate types? Also, why do we only have fp16 and int8?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Easier to construct Quantize/Dequantize ops later. https://oneapi-src.github.io/oneDNN/dev_guide_op_quantize.html#supported-data-types
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For ops like matmul, only f32/bf16/f16 are supported: https://oneapi-src.github.io/oneDNN/dev_guide_op_matmul.html#supported-data-types
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary: having separate types just follows the spec and helps us automatically check typing. The change associated with potential op semantic change (e.g., quantize int->int) is estimated as minor.
This will still need to change when fp8eX types are added. |
Some small changes: