Skip to content

Commit a5cee3e

Browse files
committed
[mlir][linalg] Add a padding case for ComplexType
If the paddingAttr is an ArrayAttr with two values we know that the element type is a `ComplexType` and we should pad the value accordingly. Reviewed By: mravishankar Differential Revision: https://reviews.llvm.org/D154908
1 parent 4c42ab1 commit a5cee3e

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

mlir/lib/Dialect/Linalg/Transforms/Padding.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
1010

1111
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12+
#include "mlir/Dialect/Complex/IR/Complex.h"
1213
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1314
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1415
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
@@ -125,8 +126,17 @@ static FailureOr<Value> padOperandToSmallestStaticBoundingBox(
125126
return rewriter.notifyMatchFailure(opToPad, "--no padding value specified");
126127
}
127128
Attribute paddingAttr = options.paddingValues[opOperand->getOperandNumber()];
128-
Value paddingValue = rewriter.create<arith::ConstantOp>(
129-
opToPad.getLoc(), cast<TypedAttr>(paddingAttr));
129+
130+
Value paddingValue;
131+
if (auto complexTy = dyn_cast<ComplexType>(
132+
getElementTypeOrSelf(opOperand->get().getType()))) {
133+
auto complexAttr = cast<ArrayAttr>(paddingAttr);
134+
paddingValue = rewriter.create<complex::ConstantOp>(opToPad.getLoc(),
135+
complexTy, complexAttr);
136+
} else {
137+
paddingValue = rewriter.create<arith::ConstantOp>(
138+
opToPad.getLoc(), cast<TypedAttr>(paddingAttr));
139+
}
130140

131141
// Pad the operand to the bounding box defined by `paddedShape`.
132142
auto paddedTensorType = RankedTensorType::get(

0 commit comments

Comments
 (0)