Skip to content

[OneDNN Graph Dialect] Use Broadcast Trait and organize data types #81

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/gc/Dialect/OneDNNGraph/OneDNNGraphOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Dialect/Traits.h"

#define GET_OP_CLASSES
#include "gc/Dialect/OneDNNGraph/OneDNNGraphOps.h.inc"
Expand Down
25 changes: 13 additions & 12 deletions include/gc/Dialect/OneDNNGraph/OneDNNGraphOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,20 @@ class OneDNNGraph_Op<string mnemonic, list<Trait> traits = []> :
Op<OneDNNGraphDialect, mnemonic, traits>;

class OneDNNGraph_ElemwiseBinaryOp<string mnemonic, list<Trait> traits = []> :
OneDNNGraph_Op<mnemonic, traits # [SameOperandsAndResultElementType, InferTensorType]> {
let arguments = (ins OneDNNGraph_LogicalTensor:$input_0,
OneDNNGraph_LogicalTensor:$input_1);
let results = (outs OneDNNGraph_LogicalTensor:$result);
OneDNNGraph_Op<mnemonic, traits # [SameOperandsAndResultElementType, InferTensorType,
ResultsBroadcastableShape]> {
let arguments = (ins OneDNNGraph_FloatTensor:$input_0,
OneDNNGraph_FloatTensor:$input_1);
let results = (outs OneDNNGraph_FloatTensor:$result);

let assemblyFormat =
"operands attr-dict `:` functional-type(operands, results)";
}

class OneDNNGraph_ElemwiseUnaryOp<string mnemonic, list<Trait> traits = []> :
OneDNNGraph_Op<mnemonic, traits # [SameOperandsAndResultType]> {
let arguments = (ins OneDNNGraph_LogicalTensor:$operand);
let results = (outs OneDNNGraph_LogicalTensor:$result);
let arguments = (ins OneDNNGraph_FloatTensor:$operand);
let results = (outs OneDNNGraph_FloatTensor:$result);

let assemblyFormat =
"operands attr-dict `:` functional-type(operands, results)";
Expand All @@ -51,15 +52,15 @@ def OneDNNGraph_MatMulOp :
OneDNNGraph_Op<"matmul", [SameOperandsAndResultElementType, InferTensorTypeAdaptor]> {
let summary = "Generalized matrix multiplication";
let description = [{
`https://spec.oneapi.io/onednn-graph/latest/ops/matrix/MatMul_1.html`
`https://oneapi-src.github.io/oneDNN/dev_guide_op_matmul.html`
}];

let arguments = (ins OneDNNGraph_LogicalTensor:$input_a,
OneDNNGraph_LogicalTensor:$input_b,
let arguments = (ins OneDNNGraph_FloatTensor:$input_a,
OneDNNGraph_FloatTensor:$input_b,
Optional<OneDNNGraph_LogicalTensor>:$bias,
DefaultValuedAttr<BoolAttr, "false">:$transpose_a,
DefaultValuedAttr<BoolAttr, "false">:$transpose_b);
let results = (outs OneDNNGraph_LogicalTensor:$result);
let results = (outs OneDNNGraph_FloatTensor:$result);

let assemblyFormat =
"operands attr-dict `:` functional-type(operands, results)";
Expand All @@ -68,14 +69,14 @@ def OneDNNGraph_MatMulOp :
def OneDNNGraph_ReLUOp : OneDNNGraph_ElemwiseUnaryOp<"relu"> {
let summary = "element-wise relu";
let description = [{
`https://spec.oneapi.io/onednn-graph/latest/ops/activation/ReLU_1.html`
`https://oneapi-src.github.io/oneDNN/dev_guide_op_relu.html`
}];
}

def OneDNNGraph_AddOp : OneDNNGraph_ElemwiseBinaryOp<"add", [Commutative]> {
let summary = "element-wise addition with multi-directional broadcast";
let description = [{
`https://spec.oneapi.io/onednn-graph/latest/ops/arithmetic/Add_1.html`
`https://oneapi-src.github.io/oneDNN/dev_guide_op_add.html`
}];
}

Expand Down
24 changes: 18 additions & 6 deletions include/gc/Dialect/OneDNNGraph/OneDNNGraphTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,26 @@ include "OneDNNGraphDialect.td"
// OneDNNGraph type definitions
//===----------------------------------------------------------------------===//

//===----------------------------------------------------------------------===//
// Floating-point types.
//===----------------------------------------------------------------------===//
def OneDNNGraph_Float : AnyTypeOf<[F32,
F16,
BF16]>;

//===----------------------------------------------------------------------===//
// Integer types.
//===----------------------------------------------------------------------===//

def OneDNNGraph_Int : AnyTypeOf<[SI<8>,
UI<8>]>;
Comment on lines +21 to +32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the rationale for having separate types? Also, why do we only have fp16 and int8?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For ops like matmul, only f32/bf16/f16 are supported: https://oneapi-src.github.io/oneDNN/dev_guide_op_matmul.html#supported-data-types


def OneDNNGraph_DataType : AnyTypeOf<[
F16,
BF16,
F32,
SI<32>,
SI<8>,
UI<8>]>;
OneDNNGraph_Float,
OneDNNGraph_Int
]>;

def OneDNNGraph_LogicalTensor : TensorOf<[OneDNNGraph_DataType]>;
def OneDNNGraph_FloatTensor : TensorOf<[OneDNNGraph_Float]>;

#endif // ONEDNNGRAPH_TYPES
66 changes: 14 additions & 52 deletions lib/gc/Dialect/OneDNNGraph/OneDNNGraphOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,59 +17,22 @@
namespace mlir {
namespace onednn_graph {

// https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md
template <typename ShapeRange>
static LogicalResult inferBroadcastShape(
ShapeRange operands, SmallVector<int64_t> &outShape,
const std::function<ShapeAdaptor(ShapeRange, size_t)> &getShapeIdx) {
int64_t outRank = 0;
for (size_t i = 0; i < operands.size(); i++) {
auto shape = getShapeIdx(operands, i);
if (!shape.hasRank()) {
return failure();
}
outRank = std::max(outRank, shape.getRank());
}
// Start with all 1 dim
outShape.clear();
outShape.resize(outRank, 1);
// Scan each shape for match dims
for (size_t i = 0; i < operands.size(); i++) {
auto shape = getShapeIdx(operands, i);
auto diff = outShape.size() - shape.getRank();
for (int64_t j = 0; j < shape.getRank(); j++) {
auto dim1 = outShape[diff + j];
auto dim2 = shape.getDimSize(j);
auto resolvedDim = dim1;

if (dim1 == 1) {
resolvedDim = dim2;
} else if (dim2 == 1) {
resolvedDim = dim1;
} else if (dim1 != dim2) {
return failure();
}
outShape[diff + j] = resolvedDim;
}
}
return success();
}

LogicalResult onednn_graph::AddOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
llvm::SmallVector<int64_t> outShape;
auto resultTy = dyn_cast<TensorType>(operands.front().getType());
auto getShapeIdx = [](ValueShapeRange operands, size_t i) {
return operands.getShape(i);
auto resultTy = dyn_cast<ShapedType>(operands.front().getType());
auto getShapeIdx = [&operands](size_t i) {
return operands.getTypes()[i].dyn_cast<ShapedType>().getShape();
};
auto ret =
inferBroadcastShape<ValueShapeRange>(operands, outShape, getShapeIdx);

auto ret = OpTrait::util::getBroadcastedShape(getShapeIdx(0), getShapeIdx(1),
outShape);
inferredReturnShapes.push_back(
ShapedTypeComponents(outShape, resultTy.getElementType()));
return ret;
return LogicalResult::success(ret);
}

LogicalResult onednn_graph::MatMulOp::inferReturnTypeComponents(
Expand Down Expand Up @@ -158,22 +121,21 @@ LogicalResult onednn_graph::MatMulOp::inferReturnTypeComponents(
// Not supported
return failure();
}
auto getShapeIdx = [](ArrayRef<ShapeAdaptor> operands, size_t i) {
return operands[i];
};
// final shape
auto retShape = ShapedTypeComponents(outShape, lhsShape.getElementType());
inferredReturnShapes.push_back(retShape);
// check for bias broadcasting
if (adaptor.getBias()) {
ShapeAdaptor biasShape(adaptor.getBias().getType());
ShapeAdaptor matShape(retShape);
auto biasType = adaptor.getBias().getType();
ShapeAdaptor biasShape(biasType);

bool biasRankMatch = biasShape.getRank() == 1 ||
biasShape.getRank() == (int64_t)outShape.size();
SmallVector<int64_t> bcastShape;
SmallVector<int64_t> resultShape;
if (!biasRankMatch ||
failed(inferBroadcastShape<ArrayRef<ShapeAdaptor>>(
{matShape, biasShape}, bcastShape, getShapeIdx))) {
!OpTrait::util::getBroadcastedShape(
retShape.getDims(), biasType.dyn_cast<ShapedType>().getShape(),
resultShape)) {
return failure();
}
}
Expand Down