Skip to content

[mlir][linalg] Implement LinalgGroupedConvolutionOpInterface to unify grouped convs #94796

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ namespace mlir {
namespace linalg {
class IteratorTypeAttr;
class LinalgOp;
class ConvolutionOpInterface;
class GroupedConvolutionOpInterface;
class GenericOp;

namespace detail {
Expand Down Expand Up @@ -147,6 +149,34 @@ std::optional<Value> isaFillOpInterface(GenericOp genericOp);

namespace detail {

// Common implementations for ConvolutionOpInterface
namespace convolution_impl {
// Returns strides as a vector.
SmallVector<int64_t, 2> getStrides(ConvolutionOpInterface op);
// Returns dilations as a vector.
SmallVector<int64_t, 2> getDilations(ConvolutionOpInterface op);
// Region builder for basic convolution
void regionBuilder(ImplicitLocOpBuilder &b, Block &block,
ArrayRef<NamedAttribute> attrs);
// Region builder for basic quantized convolution
void quantizedRegionBuilder(ImplicitLocOpBuilder &b, Block &block,
ArrayRef<NamedAttribute> attrs);
ParseResult parse(OpAsmParser &parser, OperationState &result,
bool isQuantized = false);
void print(LinalgOp op, OpAsmPrinter &p);
} // namespace convolution_impl

// Common implementations for GroupedConvolutionOpInterface
namespace grouped_convolution_impl {
int64_t getSpatialRank(GroupedConvolutionOpInterface op);
ArrayAttr createCommonIndexingMaps(
MLIRContext *ctx, int64_t numSpatial,
const SmallVector<SmallVector<utils::GroupedConvDim>> &layouts,
const SmallVectorImpl<int64_t> &strides,
const SmallVectorImpl<int64_t> &dilations);
ArrayAttr getIteratorTypes(GroupedConvolutionOpInterface op);
} // namespace grouped_convolution_impl

/// Returns true if the block contains a contraction of the following form:
///
/// %0 = <elemwise>(permutation-of(cu(block-argument-0),
Expand Down Expand Up @@ -206,6 +236,9 @@ LogicalResult verifyContractionInterface(Operation *op);
/// Verify that `op` conforms to the ConvolutionOpInterface.
LogicalResult verifyConvolutionInterface(Operation *op);

/// Verify that `op` conforms to the GroupedConvolutionOpInterface.
LogicalResult verifyGroupedConvolutionInterface(Operation *op);

/// Verify that `op` conforms to the FillOpInterface.
LogicalResult verifyFillInterface(Operation *op);

Expand Down
80 changes: 80 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,87 @@ def LinalgConvolutionOpInterface : OpInterface<"ConvolutionOpInterface"> {
return $_op.getOperation()->getOperand(1);
}]
>,
InterfaceMethod<
/*desc=*/"Return the spatial rank.",
/*retTy=*/"int64_t",
/*methodName=*/"getSpatialRank",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
// Most convolution's inputs have batch, channel and spatial dims
return cast<ShapedType>(image().getType()).getRank() - 2;
}]
>
];
}

def LinalgGroupedConvolutionOpInterface : OpInterface<"GroupedConvolutionOpInterface", [
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's probably no need for this interface if GroupedConvNDOp is already enough to cover this category without additional ops

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i remember why i made an interface now. it was to also support implementing quantized versions of grouped conv, since at the time i didn't think we should cram that into the same op definition.

LinalgConvolutionOpInterface]> {
let description = [{
A grouped convolution is defined in general terms:
1. It is a convolution as defined by `ConvolutionOpInterface`.
2. Operands have a the following distinct dimensions (excluding batch in input/output): group, channel, spatial
3. `input_rank == kernel_rank == output_rank` (including batch in input/output)
4. Reductions are along the input channel and spatial dimensions while group, output channel
and output spatial dimensions are parallel.
}];
let cppNamespace = "::mlir::linalg";
let verify = [{ return detail::verifyGroupedConvolutionInterface($_op); }];
let methods = [
InterfaceMethod<[{
Returns the layouts of each operand (image, kernel, init). Each layout is represented
by a vector of `GroupedConvDim`s.
}],
"SmallVector<SmallVector<::mlir::utils::GroupedConvDim>>", "getOperandConvDims", (ins)
>,
InterfaceMethod<[{
Returns the groups position for the input.
}],
"int64_t", "getInputGroupsPosition", (ins)
>,
InterfaceMethod<[{
Returns the channel position for the input.
}],
"int64_t", "getInputChannelPosition", (ins)
>,
InterfaceMethod<[{
Returns the channel position for the output.
}],
"int64_t", "getOutputChannelPosition", (ins)
>,
];

let extraSharedClassDeclaration = [{
// Get number of groups.
int64_t getNumGroups() {
return cast<ShapedType>(
cast<::mlir::linalg::ConvolutionOpInterface>(
$_op.getOperation()).image().getType())
.getShape()[$_op.getInputGroupsPosition()];
}
// Get number of input channels.
int64_t getNumInputChannels() {
return cast<ShapedType>(
cast<::mlir::linalg::ConvolutionOpInterface>(
$_op.getOperation()).image().getType()).getShape()[$_op.getInputChannelPosition()];
}
// Get number of output channels.
int64_t getNumOutputChannels() {
return cast<ShapedType>($_op->getOperand(2).getType()).getShape()[$_op.getOutputChannelPosition()];
}
// Returns iterator tyes.
::mlir::ArrayAttr getIteratorTypes() {
return detail::grouped_convolution_impl::getIteratorTypes($_op);
}
// Returns strides.
::llvm::SmallVector<int64_t, 2> getStridesVector() {
return detail::convolution_impl::getStrides($_op);
}
// Returns dilations.
::llvm::SmallVector<int64_t, 2> getDilationsVector() {
return detail::convolution_impl::getDilations($_op);
}
}];
}

def LinalgFillOpInterface : OpInterface<"FillOpInterface"> {
Expand Down
163 changes: 163 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,169 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// GroupedConvNDOp ops.
//===----------------------------------------------------------------------===//

def GroupedConvNDOp : LinalgStructuredBase_Op<"grouped_conv_nd",
[AttrSizedOperandSegments, LinalgGroupedConvolutionOpInterface]> {

let summary = [{
Performs N-D grouped convolution with parametrizable operand layouts.
}];
let description = [{
Allows any number of spatial dimensions but treats all of them as contiguous. Throughout, `S`,
will represent all spatial dimensions. Operand layouts are determined by the `layouts`
`StrArrayAttr` attritbute. Each element of the array is a string representing the layout of the
corresponding operand and should be be mappable to a `GroupedConvDim` enum, i.e. one of
n: (batch dim)
g: (group dim)
f: (feature or output channel dim)
s: (all spatial dims)
c: (input channel dim).

The domain will always be in the order `(N, G, F, S, C, KS)`.

}];

let arguments = (ins
Variadic<TensorOrMemref>:$inputs,
Variadic<TensorOrMemref>:$inits,
DefaultValuedAttr<StrArrayAttr, "{\"ngcs\", \"gfcs\", \"ngfs\"}">:$layouts,
OptionalAttr<I64ElementsAttr>:$strides,
OptionalAttr<I64ElementsAttr>:$dilations
);
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
let regions = (region AnyRegion:$region);

let skipDefaultBuilders = 1;
let builders = [
OpBuilder<
(ins "Value":$input, "Value":$filter, "Value":$init,
CArg<"ArrayRef<int64_t>", "{}">:$strides, CArg<"ArrayRef<int64_t>", "{}">:$dilations,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
[{
int64_t numSpatialDims = cast<ShapedType>(input.getType()).getRank() - 3;
if (strides.empty())
strides = ::llvm::SmallVector<int64_t, 2>(numSpatialDims, 1);
if (dilations.empty())
dilations = ::llvm::SmallVector<int64_t, 2>(numSpatialDims, 1);
$_state.addAttribute(getStridesAttrName($_state.name),
::mlir::DenseElementsAttr::get(
::mlir::RankedTensorType::get(numSpatialDims, $_builder.getI64Type()), strides));
$_state.addAttribute(getDilationsAttrName($_state.name),
::mlir::DenseElementsAttr::get(
::mlir::RankedTensorType::get(numSpatialDims, $_builder.getI64Type()), dilations));
buildStructuredOp($_builder, $_state, std::nullopt, {input, filter}, init,
attributes, GroupedConvNDOp::getRegionBuilder());
}]>,
OpBuilder<
(ins "TypeRange":$resultTensorTypes, "Value":$input, "Value":$filter,
"Value":$init,
CArg<"ArrayRef<int64_t>", "{}">:$strides, CArg<"ArrayRef<int64_t>", "{}">:$dilations,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
[{
int64_t numSpatialDims = cast<ShapedType>(input.getType()).getRank() - 3;
if (strides.empty())
strides = ::llvm::SmallVector<int64_t, 2>(numSpatialDims, 1);
if (dilations.empty())
dilations = ::llvm::SmallVector<int64_t, 2>(numSpatialDims, 1);
$_state.addAttribute(getStridesAttrName($_state.name),
::mlir::DenseElementsAttr::get(
::mlir::RankedTensorType::get(numSpatialDims, $_builder.getI64Type()), strides));
$_state.addAttribute(getDilationsAttrName($_state.name),
::mlir::DenseElementsAttr::get(
::mlir::RankedTensorType::get(numSpatialDims, $_builder.getI64Type()), dilations));
buildStructuredOp($_builder, $_state, resultTensorTypes,
{input, filter}, init, attributes, GroupedConvNDOp::getRegionBuilder());
}]>,
OpBuilder<
(ins "TypeRange":$resultTensorTypes, "Value":$input, "Value":$filter,
"Value":$init, "Attribute":$strides, "Attribute":$dilations,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
[{
$_state.addAttribute(getStridesAttrName($_state.name), strides);
$_state.addAttribute(getDilationsAttrName($_state.name), dilations);
buildStructuredOp($_builder, $_state, resultTensorTypes, {input, filter}, init,
attributes, GroupedConvNDOp::getRegionBuilder());
}]>
];

// TODO: Figure out how to move this to the interface
let extraClassDeclaration = structuredOpsBaseDecls # [{
void print(::mlir::OpAsmPrinter &printer) {
return detail::convolution_impl::print(*this, printer);
}
static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser,
::mlir::OperationState &result) {
return detail::convolution_impl::parse(parser, result);
}
static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
mlir::ArrayRef<mlir::NamedAttribute>)>
getRegionBuilder() {
return detail::convolution_impl::regionBuilder;
}
// Implement functions necessary for DestinationStyleOpInterface.
MutableOperandRange getDpsInitsMutable() { return getInitsMutable(); }

// Implement functions necessary for LinalgOp.
::mlir::ArrayAttr getIndexingMaps() {
::mlir::ArrayAttr cached = (*this)->getAttrOfType<::mlir::ArrayAttr>(
LinalgDialect::kMemoizedIndexingMapsAttrName);
if (cached)
return cached;

cached = detail::grouped_convolution_impl::createCommonIndexingMaps(
getContext(), getSpatialRank(), getOperandConvDims(), getStridesVector(),
getDilationsVector());

(*this)->setAttr(LinalgDialect::kMemoizedIndexingMapsAttrName, cached);
return cached;
}


// Implement functions necessary for GroupedConvolutionOpInterface
int64_t getSpatialRank() {
return detail::grouped_convolution_impl::getSpatialRank(*this);
}

SmallVector<SmallVector<::mlir::utils::GroupedConvDim>> getOperandConvDims() {
SmallVector<SmallVector<::mlir::utils::GroupedConvDim>> layouts;
for (auto attr : (*this).getLayoutsAttr().getValue()) {
std::string layoutStr = cast<StringAttr>(attr).getValue().str();
SmallVector<::mlir::utils::GroupedConvDim> layout(layoutStr.size());
for (size_t i = 0; i < layoutStr.size(); i++) {
auto maybeDimEnum = ::mlir::utils::symbolizeGroupedConvDim(layoutStr.substr(i, 1).c_str());
assert(maybeDimEnum);
layout[i] = maybeDimEnum.value();
}
layouts.push_back(layout);
}
return layouts;
}

int64_t getOutputChannelPosition() {
std::string layoutStr = cast<StringAttr>((*this).getLayoutsAttr().getValue()[2]).getValue().str();
size_t pos = layoutStr.find("f");
assert(pos != ::std::string::npos);
return pos;
}

int64_t getInputChannelPosition() {
std::string layoutStr = cast<StringAttr>((*this).getLayoutsAttr().getValue()[0]).getValue().str();
size_t pos = layoutStr.find("c");
assert(pos != ::std::string::npos);
return pos;
}

int64_t getInputGroupsPosition() {
std::string layoutStr = cast<StringAttr>((*this).getLayoutsAttr().getValue()[0]).getValue().str();
size_t pos = layoutStr.find("g");
assert(pos != ::std::string::npos);
return pos;
}
}];
}

//===----------------------------------------------------------------------===//
// Transpose op.
Expand Down
12 changes: 12 additions & 0 deletions mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.td
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,16 @@ def IteratorType : I32EnumAttr<"IteratorType", "Iterator type", [
let cppNamespace = "::mlir::utils";
}

def GroupedConvDim : I32EnumAttr<"GroupedConvDim", "Convolution dim",
[
I32EnumAttrCase<"n", 0>, // batch
I32EnumAttrCase<"g", 1>, // group
I32EnumAttrCase<"f", 2>, // feature (output channel)
I32EnumAttrCase<"s", 3>, // spatial
I32EnumAttrCase<"c", 4> // channel (input channel)
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::utils";
}

#endif // STRUCTURED_OPS_UTILS
Loading