Skip to content

[Transform] Add basic onednn_graph dialect lowering #61

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 14, 2024

Conversation

LongshengDu
Copy link
Contributor

@LongshengDu LongshengDu commented May 10, 2024

Added basic onednn_graph dialect lowering for matmul, relu, add.
Tracking: #28
Depending on: #43

The lowering of elementwise binary ops with constant operand will use named ops, same as the other elementwise ops for consistency. And we will develop and explore its strength/limitation:

%cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32>
%relued = linalg.max 
    ins(%biased, %cst : tensor<128x256xf32>, tensor<128x256xf32>) 
    outs(%output : tensor<128x256xf32>) -> tensor<128x256xf32>

We can explore more options in the future if named op cannot comply with our design, e.g.:

%c0f = arith.constant 0.0 : f32
%relued = linalg.elemwise_binary { fun = #linalg.binary_fn<max_signed> }
    ins(%biased, %c0f : tensor<512x512xf32>, f32)
    outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32>

@LongshengDu LongshengDu force-pushed the longsheng/add_onednn_graph branch from 4e629e5 to cd43162 Compare May 13, 2024 03:16
@LongshengDu LongshengDu force-pushed the longsheng/onednn_lower branch from 5ee1278 to c06f84d Compare May 13, 2024 08:04
Copy link
Contributor

@kurapov-peter kurapov-peter left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a couple minor comments, otherwise looks fine. I'd probably also add some failure test cases (e.g., for unexpected broadcasting) as a nice-to-have thing; not necessary in this PR though.

Comment on lines 219 to 223
int64_t rank = resultTy.getRank();
SmallVector<int64_t> permutation(rank);
std::iota(std::begin(permutation), std::end(permutation), 0);
permutation[rank - 2] = rank - 1;
permutation[rank - 1] = rank - 2;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this already assumes it's 2d, can it be static?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, for 2D matmul it should be static

}
}

typedef Value (*GetOperandFn)(Operation *, PatternRewriter &, TensorType);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The definition is a bit confusing since there's mlir::Value (even though the result is still the same). If there's any better option - go for it. Otherwise it's fine.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a type define for a function that returns mlir::Value.

Comment on lines 75 to 101
struct OriginalOperand {
template <unsigned I>
static Value getIdx(Operation *op, PatternRewriter &b, TensorType ty) {
if (I >= op->getNumOperands()) {
op->emitError("Index exceeds operand num.\n");
return nullptr;
}
return createBroadcastOperand(op->getLoc(), b, ty, op->getOperand(I));
}
};

struct ConstantOperand {
template <int64_t I>
static Value getConst(Operation *op, PatternRewriter &b, TensorType ty) {
const auto loc = op->getLoc();
if (llvm::isa<IntegerType>(ty.getElementType())) {
return b.create<arith::ConstantOp>( //
loc, DenseElementsAttr::get(ty, int64_t(I)));
} else if (llvm::isa<FloatType>(ty.getElementType())) {
return b.create<arith::ConstantOp>( //
loc, DenseElementsAttr::get(ty, float(I)));
} else {
op->emitError("Not a supported element type for constant.\n");
return nullptr;
}
}
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These methods are stateless. What are they wrapped into a struct for? Is it just for logical grouping expecting more methods?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for logical grouping, will add more, e.g. get constant operand from attrs like onednn_graph.pow

Copy link
Contributor

@kurapov-peter kurapov-peter left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like a reasonable starting point now

@kurapov-peter kurapov-peter merged commit eddff2e into longsheng/add_onednn_graph May 14, 2024
LongshengDu pushed a commit that referenced this pull request May 15, 2024
[Transform] Add basic onednn_graph dialect lowering (#61)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants