Skip to content

[Dialect] [OneDNNGraph] Add onednn_graph ops for llama2 mlp #92

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 135 additions & 13 deletions include/gc/Dialect/OneDNNGraph/OneDNNGraphOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,38 @@ include "gc/Dialect/OneDNNGraph/OneDNNGraphTypes.td"
class OneDNNGraph_Op<string mnemonic, list<Trait> traits = []> :
Op<OneDNNGraphDialect, mnemonic, traits>;

class OneDNNGraph_ElemwiseUnaryOp<string mnemonic, list<Trait> traits = []> :
OneDNNGraph_Op<mnemonic, traits #
[SameOperandsAndResultType]> {
let arguments = (ins OneDNNGraph_FloatTensor:$operand);
let results = (outs OneDNNGraph_FloatTensor:$result);

let assemblyFormat =
"operands attr-dict `:` functional-type(operands, results)";
}

class OneDNNGraph_ElemwiseBinaryOp<string mnemonic, list<Trait> traits = []> :
OneDNNGraph_Op<mnemonic, traits # [SameOperandsAndResultElementType, InferTensorType,
ResultsBroadcastableShape]> {
let arguments = (ins OneDNNGraph_FloatTensor:$input_0,
OneDNNGraph_FloatTensor:$input_1);
OneDNNGraph_Op<mnemonic, traits #
[SameOperandsAndResultElementType, InferTensorTypeAdaptor, ResultsBroadcastableShape]> {
let arguments = (ins OneDNNGraph_FloatTensor:$input_a,
OneDNNGraph_FloatTensor:$input_b,
DefaultValuedOptionalAttr<BoolAttr, "true">:$auto_broadcast);
let results = (outs OneDNNGraph_FloatTensor:$result);

let assemblyFormat =
"operands attr-dict `:` functional-type(operands, results)";
}

class OneDNNGraph_ElemwiseUnaryOp<string mnemonic, list<Trait> traits = []> :
OneDNNGraph_Op<mnemonic, traits # [SameOperandsAndResultType]> {
let arguments = (ins OneDNNGraph_FloatTensor:$operand);
class OneDNNGraph_ReduceOp<string mnemonic, list<Trait> traits = []> :
OneDNNGraph_Op<mnemonic, traits #
[SameOperandsAndResultElementType, InferTensorTypeAdaptor]> {
let arguments = (ins OneDNNGraph_FloatTensor:$operand,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$axes,
DefaultValuedOptionalAttr<BoolAttr, "false">:$keep_dims);
let results = (outs OneDNNGraph_FloatTensor:$result);

let hasVerifier = 1;
let hasCanonicalizer = 1;
let assemblyFormat =
"operands attr-dict `:` functional-type(operands, results)";
}
Expand All @@ -48,36 +64,142 @@ class OneDNNGraph_ElemwiseUnaryOp<string mnemonic, list<Trait> traits = []> :
// OneDNNGraph op definitions
//===----------------------------------------------------------------------===//

// Matmul

def OneDNNGraph_MatMulOp :
OneDNNGraph_Op<"matmul", [SameOperandsAndResultElementType, InferTensorTypeAdaptor]> {
let summary = "Generalized matrix multiplication";
OneDNNGraph_Op<"matmul",
[SameOperandsAndResultElementType, InferTensorTypeAdaptor]> {
let summary = [{
MatMul operation computes the product of two tensors with optional bias addition.
}];
let description = [{
`https://oneapi-src.github.io/oneDNN/dev_guide_op_matmul.html`
}];

let arguments = (ins OneDNNGraph_FloatTensor:$input_a,
OneDNNGraph_FloatTensor:$input_b,
Optional<OneDNNGraph_LogicalTensor>:$bias,
DefaultValuedAttr<BoolAttr, "false">:$transpose_a,
DefaultValuedAttr<BoolAttr, "false">:$transpose_b);
Optional<OneDNNGraph_FloatTensor>:$bias,
DefaultValuedOptionalAttr<BoolAttr, "false">:$transpose_a,
DefaultValuedOptionalAttr<BoolAttr, "false">:$transpose_b);
let results = (outs OneDNNGraph_FloatTensor:$result);

let assemblyFormat =
"operands attr-dict `:` functional-type(operands, results)";
}

// Common Unary

def OneDNNGraph_ReLUOp : OneDNNGraph_ElemwiseUnaryOp<"relu"> {
let summary = "element-wise relu";
let description = [{
`https://oneapi-src.github.io/oneDNN/dev_guide_op_relu.html`
}];
}

def OneDNNGraph_SigmoidOp : OneDNNGraph_ElemwiseUnaryOp<"sigmoid"> {
let summary = "element-wise sigmoid";
let description = [{
`https://oneapi-src.github.io/oneDNN/dev_guide_op_sigmoid.html`
}];
}

// Special Unary

def OneDNNGraph_TypeCastOp : OneDNNGraph_Op<"type_cast", [SameOperandsAndResultShape]> {
let summary = [{
TypeCast operation performs element-wise cast from input data type
to the data type given by output tensor.
}];
let description = [{
`https://oneapi-src.github.io/oneDNN/dev_guide_op_typecast.html`
}];

let arguments = (ins OneDNNGraph_FloatTensor:$operand);
let results = (outs OneDNNGraph_FloatTensor:$result);

let assemblyFormat =
"operands attr-dict `:` functional-type(operands, results)";
}

def OneDNNGraph_PowOp : OneDNNGraph_Op<"pow", [SameOperandsAndResultType]> {
let summary = [{
Pow operation performs an element-wise power operation on a given input
tensor with a single value attribute beta as its exponent.
}];
let description = [{
`https://oneapi-src.github.io/oneDNN/dev_guide_op_pow.html`
}];

let arguments = (ins OneDNNGraph_FloatTensor:$operand,
F32Attr:$beta);
let results = (outs OneDNNGraph_FloatTensor:$result);

let assemblyFormat =
"operands attr-dict `:` functional-type(operands, results)";
}

// Common Binary

def OneDNNGraph_AddOp : OneDNNGraph_ElemwiseBinaryOp<"add", [Commutative]> {
let summary = "element-wise addition with multi-directional broadcast";
let summary = [{
Add operation performs element-wise addition operation with two
given tensors applying multi-directional broadcast rules.
}];
let description = [{
`https://oneapi-src.github.io/oneDNN/dev_guide_op_add.html`
}];
}

def OneDNNGraph_MulOp : OneDNNGraph_ElemwiseBinaryOp<"mul", [Commutative]> {
let summary = [{
Multiply operation performs element-wise multiply operation with two
given tensors applying multi-directional broadcast rules.
}];
let description = [{
`https://oneapi-src.github.io/oneDNN/dev_guide_op_multiply.html`
}];
}

def OneDNNGraph_SubOp : OneDNNGraph_ElemwiseBinaryOp<"sub"> {
let summary = [{
Subtract operation performs element-wise subtraction operation with
two given tensors applying multi-directional broadcast rules.
}];
let description = [{
`https://oneapi-src.github.io/oneDNN/dev_guide_op_subtract.html`
}];
}

def OneDNNGraph_DivOp : OneDNNGraph_ElemwiseBinaryOp<"div"> {
let summary = [{
Divide operation performs element-wise division operation with two
given tensors applying multi-directional broadcast rules.
}];
let description = [{
`https://oneapi-src.github.io/oneDNN/dev_guide_op_divide.html`
}];
}

// Common Reduce

def OneDNNGraph_ReduceSumOp : OneDNNGraph_ReduceOp<"reduce_sum"> {
let summary = [{
ReduceSum operation performs the reduction with addition on a given
src data along dimensions specified by axes.
}];
let description = [{
`https://oneapi-src.github.io/oneDNN/dev_guide_op_reducesum.html`
}];
}

def OneDNNGraph_ReduceMeanOp : OneDNNGraph_ReduceOp<"reduce_mean"> {
let summary = [{
ReduceMean operation performs the reduction with finding the arithmetic
mean on a given src data along dimensions specified by axes.
}];
let description = [{
`https://oneapi-src.github.io/oneDNN/dev_guide_op_reducemean.html`
}];
}

#endif // ONEDNNGRAPH_OPS
173 changes: 158 additions & 15 deletions lib/gc/Dialect/OneDNNGraph/OneDNNGraphOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,166 @@
namespace mlir {
namespace onednn_graph {

LogicalResult onednn_graph::AddOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
llvm::SmallVector<int64_t> outShape;
auto resultTy = dyn_cast<ShapedType>(operands.front().getType());
auto getShapeIdx = [&operands](size_t i) {
return operands.getTypes()[i].dyn_cast<ShapedType>().getShape();
//===----------------------------------------------------------------------===//
// Binary ops shape infer
//===----------------------------------------------------------------------===//

#define BINARY_OP_SHAPE_INFER(OP) \
LogicalResult OP::inferReturnTypeComponents( \
MLIRContext *context, ::std::optional<Location> location, \
OP::Adaptor adaptor, \
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
auto inputTy0 = dyn_cast<ShapedType>(adaptor.getInputA().getType()); \
auto inputTy1 = dyn_cast<ShapedType>(adaptor.getInputB().getType()); \
if (!adaptor.getAutoBroadcast() && (inputTy0 != inputTy1)) { \
return failure(); \
} \
llvm::SmallVector<int64_t> outShape; \
auto ret = OpTrait::util::getBroadcastedShape( \
inputTy0.getShape(), inputTy1.getShape(), outShape); \
inferredReturnShapes.push_back( \
ShapedTypeComponents(outShape, inputTy0.getElementType())); \
return LogicalResult::success(ret); \
}

BINARY_OP_SHAPE_INFER(onednn_graph::AddOp)
BINARY_OP_SHAPE_INFER(onednn_graph::MulOp)
BINARY_OP_SHAPE_INFER(onednn_graph::SubOp)
BINARY_OP_SHAPE_INFER(onednn_graph::DivOp)

//===----------------------------------------------------------------------===//
// Reduce ops shape infer
//===----------------------------------------------------------------------===//

SmallVector<int64_t> canonicalizeReduceAxes(ArrayRef<int64_t> axes,
int64_t rank) {
SmallVector<int64_t> ret(axes.size());
for (size_t i = 0; i < axes.size(); i++) {
ret[i] = axes[i] < 0 ? axes[i] + rank : axes[i];
}
llvm::sort(ret);
ret.erase(std::unique(ret.begin(), ret.end()), ret.end());
return ret;
}

SmallVector<int64_t> getReducedShape(ShapeAdaptor operandShape,
ArrayRef<int64_t> axes, bool keep_dims) {
SmallVector<int64_t> outputShape;
// get reduce axis one by one
size_t index = 0;
auto getNextReduceAxis = [&]() {
return (index >= axes.size()) ? -1 : axes[index++];
};
// get reduced shape
auto rank = operandShape.getRank();
auto axis = getNextReduceAxis();
for (int64_t idx = 0; idx < rank; idx++) {
if (idx == axis) {
axis = getNextReduceAxis();
if (keep_dims) {
outputShape.push_back(1);
}
} else {
outputShape.push_back(operandShape.getDimSize(idx));
}
}
return outputShape;
}

auto ret = OpTrait::util::getBroadcastedShape(getShapeIdx(0), getShapeIdx(1),
outShape);
inferredReturnShapes.push_back(
ShapedTypeComponents(outShape, resultTy.getElementType()));
return LogicalResult::success(ret);
static LogicalResult InferReduceReturnTypes(
ShapeAdaptor operandShape, Type elemType, ArrayRef<int64_t> axes,
bool keep_dims,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
// no reduce axes
if (axes.empty()) {
inferredReturnShapes.push_back(ShapedTypeComponents(operandShape));
return success();
}
inferredReturnShapes.push_back(ShapedTypeComponents(
getReducedShape(operandShape, axes, keep_dims), elemType));
return success();
}

template <typename ReduceOp>
struct CanonicalizeReduceOp : public OpRewritePattern<ReduceOp> {
using OpRewritePattern<ReduceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ReduceOp op,
PatternRewriter &rewriter) const override {
auto rank = dyn_cast<ShapedType>(op.getOperand().getType()).getRank();
// consider canonicalized if all axes are non-negative in ascending order
// Note: disable tidy here due to dangling reference in OperationState
// NOLINTBEGIN
bool canonicalized = true;
int64_t last = -1;
for (const auto axis : op.getAxes()) {
if (axis <= last) {
canonicalized = false;
break;
}
last = axis;
}
if (canonicalized) {
return failure();
}
// canonicalize the reduce axes
auto new_axes = canonicalizeReduceAxes(op.getAxes(), rank);
auto new_op = rewriter.create<ReduceOp>(
op.getLoc(), op.getType(), op.getOperand(), new_axes, op.getKeepDims());
rewriter.replaceOp(op, new_op);
// NOLINTEND
return success();
}
};

#define REDUCE_OP_SHAPE_CANONICALIZE(OP) \
void OP::getCanonicalizationPatterns(RewritePatternSet &results, \
MLIRContext *context) { \
using CanonicalizeOp = CanonicalizeReduceOp<OP>; \
results.add<CanonicalizeOp>(context); \
}

#define REDUCE_OP_SHAPE_INFER(OP) \
LogicalResult OP::inferReturnTypeComponents( \
MLIRContext *context, ::std::optional<Location> location, \
OP::Adaptor adaptor, \
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
llvm::SmallVector<int64_t> outShape; \
auto operandTy = dyn_cast<ShapedType>(adaptor.getOperand().getType()); \
auto rank = operandTy.getRank(); \
ShapeAdaptor inputShape(operandTy); \
return InferReduceReturnTypes( \
inputShape, operandTy.getElementType(), \
canonicalizeReduceAxes(adaptor.getAxes(), rank), \
adaptor.getKeepDims(), inferredReturnShapes); \
}

#define REDUCE_OP_VERIFY(OP) \
LogicalResult OP::verify() { \
auto operandTy = dyn_cast<ShapedType>(getOperand().getType()); \
if (!operandTy.hasRank()) { \
return emitOpError("Invalid operand shape!\n"); \
} \
int64_t rank = operandTy.getRank(); \
for (const auto axis : canonicalizeReduceAxes(getAxes(), rank)) { \
if (axis >= rank || axis < 0) { \
return emitOpError("Reduce axis not valid!\n"); \
} \
} \
return success(); \
}

#define REDUCE_OP_DEFINE(OP) \
REDUCE_OP_SHAPE_CANONICALIZE(OP) \
REDUCE_OP_SHAPE_INFER(OP) \
REDUCE_OP_VERIFY(OP)

REDUCE_OP_DEFINE(onednn_graph::ReduceSumOp)
REDUCE_OP_DEFINE(onednn_graph::ReduceMeanOp)

//===----------------------------------------------------------------------===//
// Matmul ops shape infer
//===----------------------------------------------------------------------===//

LogicalResult onednn_graph::MatMulOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
MatMulOp::Adaptor adaptor,
Expand All @@ -44,6 +186,7 @@ LogicalResult onednn_graph::MatMulOp::inferReturnTypeComponents(
const ShapeAdaptor &rhsShape, int64_t range,
int64_t diff, SmallVector<int64_t> &outDims) {
for (int64_t i = 0; i < range; i++) {
// TODO(longsheng): add OpTrait::util::getBroadcastedShape for batch
int64_t idx = i - diff;
if ((idx >= 0) && (lhsShape.getDimSize(i) != rhsShape.getDimSize(idx))) {
return failure();
Expand Down Expand Up @@ -134,7 +277,7 @@ LogicalResult onednn_graph::MatMulOp::inferReturnTypeComponents(
SmallVector<int64_t> resultShape;
if (!biasRankMatch ||
!OpTrait::util::getBroadcastedShape(
retShape.getDims(), biasType.dyn_cast<ShapedType>().getShape(),
retShape.getDims(), dyn_cast<ShapedType>(biasType).getShape(),
resultShape)) {
return failure();
}
Expand Down
Loading