17
17
#include " mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18
18
#include " mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19
19
#include " mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
20
+ #include " mlir/Dialect/Utils/StaticValueUtils.h"
20
21
#include " mlir/Dialect/Vector/IR/VectorOps.h"
21
22
#include " mlir/IR/Attributes.h"
22
23
#include " mlir/IR/BuiltinAttributes.h"
@@ -40,22 +41,9 @@ using namespace mlir;
40
41
// / Returns the integer value from the first valid input element, assuming Value
41
42
// / inputs are defined by a constant index ops and Attribute inputs are integer
42
43
// / attributes.
43
- static uint64_t getFirstIntValue (ValueRange values) {
44
- return values[0 ].getDefiningOp <arith::ConstantIndexOp>().value ();
45
- }
46
- static uint64_t getFirstIntValue (ArrayRef<Attribute> attr) {
47
- return cast<IntegerAttr>(attr[0 ]).getInt ();
48
- }
49
44
static uint64_t getFirstIntValue (ArrayAttr attr) {
50
45
return (*attr.getAsValueRange <IntegerAttr>().begin ()).getZExtValue ();
51
46
}
52
- static uint64_t getFirstIntValue (ArrayRef<OpFoldResult> foldResults) {
53
- auto attr = foldResults[0 ].dyn_cast <Attribute>();
54
- if (attr)
55
- return getFirstIntValue (attr);
56
-
57
- return getFirstIntValue (ValueRange{foldResults[0 ].get <Value>()});
58
- }
59
47
60
48
// / Returns the number of bits for the given scalar/vector type.
61
49
static int getNumBits (Type type) {
@@ -157,9 +145,6 @@ struct VectorExtractOpConvert final
157
145
LogicalResult
158
146
matchAndRewrite (vector::ExtractOp extractOp, OpAdaptor adaptor,
159
147
ConversionPatternRewriter &rewriter) const override {
160
- if (extractOp.hasDynamicPosition ())
161
- return failure ();
162
-
163
148
Type dstType = getTypeConverter ()->convertType (extractOp.getType ());
164
149
if (!dstType)
165
150
return failure ();
@@ -169,9 +154,15 @@ struct VectorExtractOpConvert final
169
154
return success ();
170
155
}
171
156
172
- int32_t id = getFirstIntValue (extractOp.getMixedPosition ());
173
- rewriter.replaceOpWithNewOp <spirv::CompositeExtractOp>(
174
- extractOp, adaptor.getVector (), id);
157
+ if (std::optional<int64_t > id =
158
+ getConstantIntValue (extractOp.getMixedPosition ()[0 ]))
159
+ rewriter.replaceOpWithNewOp <spirv::CompositeExtractOp>(
160
+ extractOp, dstType, adaptor.getVector (),
161
+ rewriter.getI32ArrayAttr (id.value ()));
162
+ else
163
+ rewriter.replaceOpWithNewOp <spirv::VectorExtractDynamicOp>(
164
+ extractOp, dstType, adaptor.getVector (),
165
+ adaptor.getDynamicPosition ()[0 ]);
175
166
return success ();
176
167
}
177
168
};
@@ -249,9 +240,14 @@ struct VectorInsertOpConvert final
249
240
return success ();
250
241
}
251
242
252
- int32_t id = getFirstIntValue (insertOp.getMixedPosition ());
253
- rewriter.replaceOpWithNewOp <spirv::CompositeInsertOp>(
254
- insertOp, adaptor.getSource (), adaptor.getDest (), id);
243
+ if (std::optional<int64_t > id =
244
+ getConstantIntValue (insertOp.getMixedPosition ()[0 ]))
245
+ rewriter.replaceOpWithNewOp <spirv::CompositeInsertOp>(
246
+ insertOp, adaptor.getSource (), adaptor.getDest (), id.value ());
247
+ else
248
+ rewriter.replaceOpWithNewOp <spirv::VectorInsertDynamicOp>(
249
+ insertOp, insertOp.getDest (), adaptor.getSource (),
250
+ adaptor.getDynamicPosition ()[0 ]);
255
251
return success ();
256
252
}
257
253
};
0 commit comments