-
Notifications
You must be signed in to change notification settings - Fork 17
[Transform] Add basic onednn_graph dialect lowering #61
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Transform] Add basic onednn_graph dialect lowering #61
Conversation
4e629e5
to
cd43162
Compare
5ee1278
to
c06f84d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a couple minor comments, otherwise looks fine. I'd probably also add some failure test cases (e.g., for unexpected broadcasting) as a nice-to-have thing; not necessary in this PR though.
int64_t rank = resultTy.getRank(); | ||
SmallVector<int64_t> permutation(rank); | ||
std::iota(std::begin(permutation), std::end(permutation), 0); | ||
permutation[rank - 2] = rank - 1; | ||
permutation[rank - 1] = rank - 2; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this already assumes it's 2d, can it be static?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, for 2D matmul it should be static
} | ||
} | ||
|
||
typedef Value (*GetOperandFn)(Operation *, PatternRewriter &, TensorType); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The definition is a bit confusing since there's mlir::Value
(even though the result is still the same). If there's any better option - go for it. Otherwise it's fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a type define for a function that returns mlir::Value
.
struct OriginalOperand { | ||
template <unsigned I> | ||
static Value getIdx(Operation *op, PatternRewriter &b, TensorType ty) { | ||
if (I >= op->getNumOperands()) { | ||
op->emitError("Index exceeds operand num.\n"); | ||
return nullptr; | ||
} | ||
return createBroadcastOperand(op->getLoc(), b, ty, op->getOperand(I)); | ||
} | ||
}; | ||
|
||
struct ConstantOperand { | ||
template <int64_t I> | ||
static Value getConst(Operation *op, PatternRewriter &b, TensorType ty) { | ||
const auto loc = op->getLoc(); | ||
if (llvm::isa<IntegerType>(ty.getElementType())) { | ||
return b.create<arith::ConstantOp>( // | ||
loc, DenseElementsAttr::get(ty, int64_t(I))); | ||
} else if (llvm::isa<FloatType>(ty.getElementType())) { | ||
return b.create<arith::ConstantOp>( // | ||
loc, DenseElementsAttr::get(ty, float(I))); | ||
} else { | ||
op->emitError("Not a supported element type for constant.\n"); | ||
return nullptr; | ||
} | ||
} | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These methods are stateless. What are they wrapped into a struct for? Is it just for logical grouping expecting more methods?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just for logical grouping, will add more, e.g. get constant operand from attrs like onednn_graph.pow
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like a reasonable starting point now
[Transform] Add basic onednn_graph dialect lowering (#61)
Added basic onednn_graph dialect lowering for matmul, relu, add.
Tracking: #28
Depending on: #43
The lowering of elementwise binary ops with constant operand will use named ops, same as the other elementwise ops for consistency. And we will develop and explore its strength/limitation:
We can explore more options in the future if named op cannot comply with our design, e.g.: