Skip to content

Commit b361e4d

Browse files
committed
Implement LinalgGroupedConvolutionOpInterface to unify grouped convs
1 parent 28dd55b commit b361e4d

File tree

10 files changed

+614
-6
lines changed

10 files changed

+614
-6
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ namespace mlir {
2828
namespace linalg {
2929
class IteratorTypeAttr;
3030
class LinalgOp;
31+
class ConvolutionOpInterface;
32+
class GroupedConvolutionOpInterface;
3133
class GenericOp;
3234

3335
namespace detail {
@@ -133,6 +135,38 @@ std::optional<Value> isaFillOpInterface(GenericOp genericOp);
133135

134136
namespace detail {
135137

138+
// Common implementations for ConvolutionOpInterface
139+
namespace convolution_impl {
140+
// Returns strides as a vector.
141+
SmallVector<int64_t, 2> getStrides(ConvolutionOpInterface op);
142+
// Returns dilations as a vector.
143+
SmallVector<int64_t, 2> getDilations(ConvolutionOpInterface op);
144+
// Region builder for basic convolution
145+
void regionBuilder(ImplicitLocOpBuilder &b, Block &block,
146+
ArrayRef<NamedAttribute> attrs);
147+
// Region builder for basic quantized convolution
148+
void quantizedRegionBuilder(ImplicitLocOpBuilder &b, Block &block,
149+
ArrayRef<NamedAttribute> attrs);
150+
void getEffects(
151+
Operation *op,
152+
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
153+
&effects);
154+
ParseResult parse(OpAsmParser &parser, OperationState &result,
155+
bool isQuantized = false);
156+
void print(LinalgOp op, OpAsmPrinter &p);
157+
} // namespace convolution_impl
158+
159+
// Common implementations for GroupedConvolutionOpInterface
160+
namespace grouped_convolution_impl {
161+
int64_t getSpatialRank(GroupedConvolutionOpInterface op);
162+
ArrayAttr createCommonIndexingMaps(
163+
MLIRContext *ctx, int64_t numSpatial,
164+
const SmallVector<SmallVector<utils::GroupedConvDim>> &layouts,
165+
const SmallVectorImpl<int64_t> &strides,
166+
const SmallVectorImpl<int64_t> &dilations);
167+
ArrayAttr getIteratorTypes(GroupedConvolutionOpInterface op);
168+
} // namespace grouped_convolution_impl
169+
136170
/// Returns true if the block contains a contraction of the following form:
137171
///
138172
/// %0 = <elemwise>(permutation-of(cu(block-argument-0),
@@ -189,6 +223,9 @@ LogicalResult verifyContractionInterface(Operation *op);
189223
/// Verify that `op` conforms to the ConvolutionOpInterface.
190224
LogicalResult verifyConvolutionInterface(Operation *op);
191225

226+
/// Verify that `op` conforms to the GroupedConvolutionOpInterface.
227+
LogicalResult verifyGroupedConvolutionInterface(Operation *op);
228+
192229
/// Verify that `op` conforms to the FillOpInterface.
193230
LogicalResult verifyFillInterface(Operation *op);
194231

mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,101 @@ def LinalgConvolutionOpInterface : OpInterface<"ConvolutionOpInterface"> {
175175
return $_op.getOperation()->getOperand(1);
176176
}]
177177
>,
178+
InterfaceMethod<
179+
/*desc=*/"Return the spatial rank.",
180+
/*retTy=*/"int64_t",
181+
/*methodName=*/"getSpatialRank",
182+
/*args=*/(ins),
183+
/*methodBody=*/"",
184+
/*defaultImplementation=*/[{
185+
// Most convolution's inputs have batch, channel and spatial dims
186+
return cast<ShapedType>(image().getType()).getRank() - 2;
187+
}]
188+
>
189+
];
190+
}
191+
192+
def LinalgGroupedConvolutionOpInterface : OpInterface<"GroupedConvolutionOpInterface", [
193+
LinalgConvolutionOpInterface]> {
194+
let description = [{
195+
A grouped convolution is defined in general terms:
196+
1. It is a convolution as defined by `ConvolutionOpInterface`.
197+
2. Operands have a the following distinct dimensions (excluding batch in input/output): group, channel, spatial
198+
3. `input_rank == kernel_rank == output_rank` (including batch in input/output)
199+
4. Reductions are along the input channel and spatial dimensions while group, output channel
200+
and output spatial dimensions are parallel.
201+
}];
202+
let cppNamespace = "::mlir::linalg";
203+
let verify = [{ return detail::verifyGroupedConvolutionInterface($_op); }];
204+
let methods = [
205+
InterfaceMethod<[{
206+
Returns the groups position for the input.
207+
}],
208+
"SmallVector<SmallVector<::mlir::utils::GroupedConvDim>>", "getLayoutsEnums", (ins)
209+
>,
210+
InterfaceMethod<[{
211+
Returns the groups position for the input.
212+
}],
213+
"int64_t", "getInputGroupsPosition", (ins)
214+
>,
215+
InterfaceMethod<[{
216+
Returns the channel position for the input.
217+
}],
218+
"int64_t", "getInputChannelPosition", (ins)
219+
>,
220+
InterfaceMethod<[{
221+
Returns the channel position for the output.
222+
}],
223+
"int64_t", "getOutputChannelPosition", (ins)
224+
>,
225+
InterfaceMethod<[{
226+
Get number of groups.
227+
}],
228+
"int64_t", "getNumGroups", (ins),
229+
/*methodBody=*/[{}],
230+
/*defaultImplementation=*/[{
231+
return cast<ShapedType>($_op.image().getType()).getShape()[$_op.getInputGroupsPosition() - 1];
232+
}]>,
233+
InterfaceMethod<[{
234+
Get number of input channels.
235+
}],
236+
"int64_t", "getNumInputChannels", (ins),
237+
/*methodBody=*/[{}],
238+
/*defaultImplementation=*/[{
239+
return cast<ShapedType>($_op.image().getType()).getShape()[$_op.getInputChannelPosition()];
240+
}]>,
241+
InterfaceMethod<[{
242+
Get number of output channels.
243+
}],
244+
"int64_t", "getNumOutputChannels", (ins),
245+
/*methodBody=*/[{}],
246+
/*defaultImplementation=*/[{
247+
return cast<ShapedType>($_op.getDpsInits()[0].getType()).getShape()[$_op.getOutputChannelPosition()];
248+
}]>,
249+
InterfaceMethod<[{
250+
Returns indexing maps for any spatial dimension.
251+
}],
252+
"::mlir::ArrayAttr", "getIteratorTypes", (ins),
253+
/*methodBody=*/[{}],
254+
/*defaultImplementation=*/[{
255+
return detail::grouped_convolution_impl::getIteratorTypes($_op);
256+
}]>,
257+
InterfaceMethod<[{
258+
Returns strides.
259+
}],
260+
"::llvm::SmallVector<int64_t, 2>", "getStridesVector", (ins),
261+
/*methodBody=*/[{}],
262+
/*defaultImplementation=*/[{
263+
return detail::convolution_impl::getStrides($_op);
264+
}]>,
265+
InterfaceMethod<[{
266+
Returns dilations.
267+
}],
268+
"::llvm::SmallVector<int64_t, 2>", "getDilationsVector", (ins),
269+
/*methodBody=*/[{}],
270+
/*defaultImplementation=*/[{
271+
return detail::convolution_impl::getDilations($_op);
272+
}]>
178273
];
179274
}
180275

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,147 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [
384384
let hasVerifier = 1;
385385
}
386386

387+
//===----------------------------------------------------------------------===//
388+
// GroupedConvNDOp ops.
389+
//===----------------------------------------------------------------------===//
390+
391+
def GroupedConvNDOp : LinalgStructuredBase_Op<"grouped_conv_nd",
392+
[AttrSizedOperandSegments, LinalgGroupedConvolutionOpInterface]> {
393+
394+
let summary = [{
395+
Performs N-D grouped convolution with switchable channel position; either first or last.
396+
}];
397+
let description = [{
398+
Allows any number of spatial dimensions but treats all of them as contiguous. Throughout, `S`,
399+
will represent all spatial dimensions. Operand layouts are determined by the `layouts`
400+
`StrArrayAttr` attritbute. Each element of the array is a string representing the layout of the
401+
corresponding operand and should be be mappable to a `GroupedConvDim` enum, i.e. one of
402+
n: (batch dim)
403+
g: (group dim)
404+
f: (feature or output channel dim)
405+
s: (all spatial dims)
406+
c: (input channel dim).
407+
408+
The domain will always be in the order `(N, G, F, S, C, KS)`.
409+
410+
}];
411+
412+
let arguments = (ins
413+
Variadic<TensorOrMemref>:$inputs,
414+
Variadic<TensorOrMemref>:$inits,
415+
DefaultValuedAttr<StrArrayAttr, "{\"ngcs\", \"gfcs\", \"ngfs\"}">:$layouts,
416+
OptionalAttr<I64ElementsAttr>:$strides,
417+
OptionalAttr<I64ElementsAttr>:$dilations
418+
);
419+
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
420+
let regions = (region AnyRegion:$region);
421+
422+
let skipDefaultBuilders = 1;
423+
let builders = [
424+
OpBuilder<
425+
(ins "Value":$input, "Value":$filter, "Value":$init,
426+
CArg<"ArrayRef<int64_t>", "{}">:$strides, CArg<"ArrayRef<int64_t>", "{}">:$dilations,
427+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
428+
[{
429+
int64_t numSpatialDims = cast<ShapedType>(input.getType()).getRank() - 3;
430+
if (strides.empty())
431+
strides = ::llvm::SmallVector<int64_t, 2>(numSpatialDims, 1);
432+
if (dilations.empty())
433+
dilations = ::llvm::SmallVector<int64_t, 2>(numSpatialDims, 1);
434+
$_state.addAttribute(getStridesAttrName($_state.name),
435+
::mlir::DenseElementsAttr::get(
436+
::mlir::RankedTensorType::get(numSpatialDims, $_builder.getI64Type()), strides));
437+
$_state.addAttribute(getDilationsAttrName($_state.name),
438+
::mlir::DenseElementsAttr::get(
439+
::mlir::RankedTensorType::get(numSpatialDims, $_builder.getI64Type()), dilations));
440+
buildStructuredOp($_builder, $_state, std::nullopt, {input, filter}, init,
441+
attributes, GroupedConvNDOp::getRegionBuilder());
442+
}]>,
443+
OpBuilder<
444+
(ins "TypeRange":$resultTensorTypes, "Value":$input, "Value":$filter,
445+
"Value":$init,
446+
CArg<"ArrayRef<int64_t>", "{}">:$strides, CArg<"ArrayRef<int64_t>", "{}">:$dilations,
447+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
448+
[{
449+
int64_t numSpatialDims = cast<ShapedType>(input.getType()).getRank() - 3;
450+
if (strides.empty())
451+
strides = ::llvm::SmallVector<int64_t, 2>(numSpatialDims, 1);
452+
if (dilations.empty())
453+
dilations = ::llvm::SmallVector<int64_t, 2>(numSpatialDims, 1);
454+
$_state.addAttribute(getStridesAttrName($_state.name),
455+
::mlir::DenseElementsAttr::get(
456+
::mlir::RankedTensorType::get(numSpatialDims, $_builder.getI64Type()), strides));
457+
$_state.addAttribute(getDilationsAttrName($_state.name),
458+
::mlir::DenseElementsAttr::get(
459+
::mlir::RankedTensorType::get(numSpatialDims, $_builder.getI64Type()), dilations));
460+
buildStructuredOp($_builder, $_state, resultTensorTypes,
461+
{input, filter}, init, attributes, GroupedConvNDOp::getRegionBuilder());
462+
}]>,
463+
OpBuilder<
464+
(ins "TypeRange":$resultTensorTypes, "Value":$input, "Value":$filter,
465+
"Value":$init, "Attribute":$strides, "Attribute":$dilations,
466+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
467+
[{
468+
$_state.addAttribute(getStridesAttrName($_state.name), strides);
469+
$_state.addAttribute(getDilationsAttrName($_state.name), dilations);
470+
buildStructuredOp($_builder, $_state, resultTensorTypes, {input, filter}, init,
471+
attributes, GroupedConvNDOp::getRegionBuilder());
472+
}]>
473+
];
474+
475+
// TODO: Figure out how to move this to the interface
476+
let extraClassDeclaration = structuredOpsBaseDecls # [{
477+
void print(::mlir::OpAsmPrinter &printer) {
478+
return detail::convolution_impl::print(*this, printer);
479+
}
480+
static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser,
481+
::mlir::OperationState &result) {
482+
return detail::convolution_impl::parse(parser, result);
483+
}
484+
static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
485+
mlir::ArrayRef<mlir::NamedAttribute>)>
486+
getRegionBuilder() {
487+
return detail::convolution_impl::regionBuilder;
488+
}
489+
// Implement functions necessary for DestinationStyleOpInterface.
490+
MutableOperandRange getDpsInitsMutable() { return getInitsMutable(); }
491+
492+
// Implement functions necessary for LinalgOp.
493+
ArrayAttr getIndexingMaps();
494+
495+
// Implement functions necessary for GroupedConvolutionOpInterface
496+
int64_t getSpatialRank() {
497+
return detail::grouped_convolution_impl::getSpatialRank(*this);
498+
}
499+
500+
SmallVector<SmallVector<::mlir::utils::GroupedConvDim>> getLayoutsEnums() {
501+
SmallVector<SmallVector<::mlir::utils::GroupedConvDim>> layouts;
502+
for (auto attr : (*this).getLayoutsAttr().getValue()) {
503+
std::string layoutStr = cast<StringAttr>(attr).getValue().str();
504+
SmallVector<::mlir::utils::GroupedConvDim> layout(layoutStr.size());
505+
for (size_t i = 0; i < layoutStr.size(); i++) {
506+
auto maybeDimEnum = ::mlir::utils::symbolizeGroupedConvDim(layoutStr.substr(i, 1).c_str());
507+
assert(maybeDimEnum);
508+
layout[i] = maybeDimEnum.value();
509+
}
510+
layouts.push_back(layout);
511+
}
512+
return layouts;
513+
}
514+
515+
int64_t getOutputChannelPosition() {
516+
return 2;
517+
}
518+
519+
int64_t getInputChannelPosition() {
520+
return 2;
521+
}
522+
523+
int64_t getInputGroupsPosition() {
524+
return 1;
525+
}
526+
}];
527+
}
387528

388529
//===----------------------------------------------------------------------===//
389530
// Transpose op.

mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,16 @@ def IteratorType : I32EnumAttr<"IteratorType", "Iterator type", [
2020
let cppNamespace = "::mlir::utils";
2121
}
2222

23+
def GroupedConvDim : I32EnumAttr<"GroupedConvDim", "Convolution dim",
24+
[
25+
I32EnumAttrCase<"n", 0>, // batch
26+
I32EnumAttrCase<"g", 1>, // group
27+
I32EnumAttrCase<"f", 2>, // feature (output channel)
28+
I32EnumAttrCase<"s", 3>, // spatial
29+
I32EnumAttrCase<"c", 4> // channel (input channel)
30+
]> {
31+
let genSpecializedAttr = 0;
32+
let cppNamespace = "::mlir::utils";
33+
}
34+
2335
#endif // STRUCTURED_OPS_UTILS

0 commit comments

Comments
 (0)