Skip to content

Commit f0678a4

Browse files
committed
small refactor
1 parent b361e4d commit f0678a4

File tree

4 files changed

+64
-71
lines changed

4 files changed

+64
-71
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td

Lines changed: 33 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -203,9 +203,10 @@ def LinalgGroupedConvolutionOpInterface : OpInterface<"GroupedConvolutionOpInter
203203
let verify = [{ return detail::verifyGroupedConvolutionInterface($_op); }];
204204
let methods = [
205205
InterfaceMethod<[{
206-
Returns the groups position for the input.
206+
Returns the layouts of each operand (image, kernel, init). Each layout is represented
207+
by a vector of `GroupedConvDim`s.
207208
}],
208-
"SmallVector<SmallVector<::mlir::utils::GroupedConvDim>>", "getLayoutsEnums", (ins)
209+
"SmallVector<SmallVector<::mlir::utils::GroupedConvDim>>", "getOperandConvDims", (ins)
209210
>,
210211
InterfaceMethod<[{
211212
Returns the groups position for the input.
@@ -222,55 +223,39 @@ def LinalgGroupedConvolutionOpInterface : OpInterface<"GroupedConvolutionOpInter
222223
}],
223224
"int64_t", "getOutputChannelPosition", (ins)
224225
>,
225-
InterfaceMethod<[{
226-
Get number of groups.
227-
}],
228-
"int64_t", "getNumGroups", (ins),
229-
/*methodBody=*/[{}],
230-
/*defaultImplementation=*/[{
231-
return cast<ShapedType>($_op.image().getType()).getShape()[$_op.getInputGroupsPosition() - 1];
232-
}]>,
233-
InterfaceMethod<[{
234-
Get number of input channels.
235-
}],
236-
"int64_t", "getNumInputChannels", (ins),
237-
/*methodBody=*/[{}],
238-
/*defaultImplementation=*/[{
239-
return cast<ShapedType>($_op.image().getType()).getShape()[$_op.getInputChannelPosition()];
240-
}]>,
241-
InterfaceMethod<[{
242-
Get number of output channels.
243-
}],
244-
"int64_t", "getNumOutputChannels", (ins),
245-
/*methodBody=*/[{}],
246-
/*defaultImplementation=*/[{
247-
return cast<ShapedType>($_op.getDpsInits()[0].getType()).getShape()[$_op.getOutputChannelPosition()];
248-
}]>,
249-
InterfaceMethod<[{
250-
Returns indexing maps for any spatial dimension.
251-
}],
252-
"::mlir::ArrayAttr", "getIteratorTypes", (ins),
253-
/*methodBody=*/[{}],
254-
/*defaultImplementation=*/[{
226+
];
227+
228+
let extraSharedClassDeclaration = [{
229+
// Get number of groups.
230+
int64_t getNumGroups() {
231+
return cast<ShapedType>(
232+
cast<::mlir::linalg::ConvolutionOpInterface>(
233+
$_op.getOperation()).image().getType())
234+
.getShape()[$_op.getInputGroupsPosition()];
235+
}
236+
// Get number of input channels.
237+
int64_t getNumInputChannels() {
238+
return cast<ShapedType>(
239+
cast<::mlir::linalg::ConvolutionOpInterface>(
240+
$_op.getOperation()).image().getType()).getShape()[$_op.getInputChannelPosition()];
241+
}
242+
// Get number of output channels.
243+
int64_t getNumOutputChannels() {
244+
return cast<ShapedType>($_op->getOperand(2).getType()).getShape()[$_op.getOutputChannelPosition()];
245+
}
246+
// Returns iterator tyes.
247+
::mlir::ArrayAttr getIteratorTypes() {
255248
return detail::grouped_convolution_impl::getIteratorTypes($_op);
256-
}]>,
257-
InterfaceMethod<[{
258-
Returns strides.
259-
}],
260-
"::llvm::SmallVector<int64_t, 2>", "getStridesVector", (ins),
261-
/*methodBody=*/[{}],
262-
/*defaultImplementation=*/[{
249+
}
250+
// Returns strides.
251+
::llvm::SmallVector<int64_t, 2> getStridesVector() {
263252
return detail::convolution_impl::getStrides($_op);
264-
}]>,
265-
InterfaceMethod<[{
266-
Returns dilations.
267-
}],
268-
"::llvm::SmallVector<int64_t, 2>", "getDilationsVector", (ins),
269-
/*methodBody=*/[{}],
270-
/*defaultImplementation=*/[{
253+
}
254+
// Returns dilations.
255+
::llvm::SmallVector<int64_t, 2> getDilationsVector() {
271256
return detail::convolution_impl::getDilations($_op);
272-
}]>
273-
];
257+
}
258+
}];
274259
}
275260

276261
def LinalgFillOpInterface : OpInterface<"FillOpInterface"> {

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,7 @@ def GroupedConvNDOp : LinalgStructuredBase_Op<"grouped_conv_nd",
392392
[AttrSizedOperandSegments, LinalgGroupedConvolutionOpInterface]> {
393393

394394
let summary = [{
395-
Performs N-D grouped convolution with switchable channel position; either first or last.
395+
Performs N-D grouped convolution with parametrizable operand layouts.
396396
}];
397397
let description = [{
398398
Allows any number of spatial dimensions but treats all of them as contiguous. Throughout, `S`,
@@ -490,14 +490,27 @@ def GroupedConvNDOp : LinalgStructuredBase_Op<"grouped_conv_nd",
490490
MutableOperandRange getDpsInitsMutable() { return getInitsMutable(); }
491491

492492
// Implement functions necessary for LinalgOp.
493-
ArrayAttr getIndexingMaps();
493+
::mlir::ArrayAttr getIndexingMaps() {
494+
::mlir::ArrayAttr cached = (*this)->getAttrOfType<::mlir::ArrayAttr>(
495+
LinalgDialect::kMemoizedIndexingMapsAttrName);
496+
if (cached)
497+
return cached;
498+
499+
cached = detail::grouped_convolution_impl::createCommonIndexingMaps(
500+
getContext(), getSpatialRank(), getOperandConvDims(), getStridesVector(),
501+
getDilationsVector());
502+
503+
(*this)->setAttr(LinalgDialect::kMemoizedIndexingMapsAttrName, cached);
504+
return cached;
505+
}
506+
494507

495508
// Implement functions necessary for GroupedConvolutionOpInterface
496509
int64_t getSpatialRank() {
497510
return detail::grouped_convolution_impl::getSpatialRank(*this);
498511
}
499512

500-
SmallVector<SmallVector<::mlir::utils::GroupedConvDim>> getLayoutsEnums() {
513+
SmallVector<SmallVector<::mlir::utils::GroupedConvDim>> getOperandConvDims() {
501514
SmallVector<SmallVector<::mlir::utils::GroupedConvDim>> layouts;
502515
for (auto attr : (*this).getLayoutsAttr().getValue()) {
503516
std::string layoutStr = cast<StringAttr>(attr).getValue().str();
@@ -513,15 +526,24 @@ def GroupedConvNDOp : LinalgStructuredBase_Op<"grouped_conv_nd",
513526
}
514527

515528
int64_t getOutputChannelPosition() {
516-
return 2;
529+
std::string layoutStr = cast<StringAttr>((*this).getLayoutsAttr().getValue()[2]).getValue().str();
530+
size_t pos = layoutStr.find("f");
531+
assert(pos != ::std::string::npos);
532+
return pos;
517533
}
518534

519535
int64_t getInputChannelPosition() {
520-
return 2;
536+
std::string layoutStr = cast<StringAttr>((*this).getLayoutsAttr().getValue()[0]).getValue().str();
537+
size_t pos = layoutStr.find("c");
538+
assert(pos != ::std::string::npos);
539+
return pos;
521540
}
522541

523542
int64_t getInputGroupsPosition() {
524-
return 1;
543+
std::string layoutStr = cast<StringAttr>((*this).getLayoutsAttr().getValue()[0]).getValue().str();
544+
size_t pos = layoutStr.find("g");
545+
assert(pos != ::std::string::npos);
546+
return pos;
525547
}
526548
}];
527549
}

mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -880,12 +880,12 @@ mlir::linalg::detail::verifyGroupedConvolutionInterface(Operation *op) {
880880
return failure();
881881
if (GroupedConvolutionOpInterface conv =
882882
dyn_cast<GroupedConvolutionOpInterface>(op)) {
883-
const auto imageType = conv.image().getType().dyn_cast<ShapedType>();
883+
const auto imageType = cast<ShapedType>(conv.image().getType());
884884
const auto imageRank = imageType.getRank();
885885
const auto kernelRank =
886-
conv.filter().getType().cast<ShapedType>().getRank();
886+
cast<ShapedType>(conv.filter().getType()).getRank();
887887
const auto initType =
888-
cast<LinalgOp>(op).getDpsInits()[0].getType().dyn_cast<ShapedType>();
888+
cast<ShapedType>(cast<LinalgOp>(op).getDpsInits()[0].getType());
889889
const auto initRank = initType.getRank();
890890
if (imageRank != kernelRank || imageRank != initRank)
891891
return op->emitError(

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1825,20 +1825,6 @@ void GroupedConvNDOp::getEffects(
18251825
return detail::convolution_impl::getEffects(*this, effects);
18261826
}
18271827

1828-
ArrayAttr GroupedConvNDOp::getIndexingMaps() {
1829-
ArrayAttr cached = (*this)->getAttrOfType<ArrayAttr>(
1830-
LinalgDialect::kMemoizedIndexingMapsAttrName);
1831-
if (cached)
1832-
return cached;
1833-
1834-
cached = detail::grouped_convolution_impl::createCommonIndexingMaps(
1835-
getContext(), getSpatialRank(), getLayoutsEnums(), getStridesVector(),
1836-
getDilationsVector());
1837-
1838-
(*this)->setAttr(LinalgDialect::kMemoizedIndexingMapsAttrName, cached);
1839-
return cached;
1840-
}
1841-
18421828
//===----------------------------------------------------------------------===//
18431829
// TransposeOp
18441830
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)