17
17
#include " mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18
18
#include " mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19
19
#include " mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
20
+ #include " mlir/Dialect/Utils/StaticValueUtils.h"
20
21
#include " mlir/Dialect/Vector/IR/VectorOps.h"
21
22
#include " mlir/IR/Attributes.h"
22
23
#include " mlir/IR/BuiltinAttributes.h"
@@ -40,22 +41,9 @@ using namespace mlir;
40
41
// / Returns the integer value from the first valid input element, assuming Value
41
42
// / inputs are defined by a constant index ops and Attribute inputs are integer
42
43
// / attributes.
43
- static uint64_t getFirstIntValue (ValueRange values) {
44
- return values[0 ].getDefiningOp <arith::ConstantIndexOp>().value ();
45
- }
46
- static uint64_t getFirstIntValue (ArrayRef<Attribute> attr) {
47
- return cast<IntegerAttr>(attr[0 ]).getInt ();
48
- }
49
44
static uint64_t getFirstIntValue (ArrayAttr attr) {
50
45
return (*attr.getAsValueRange <IntegerAttr>().begin ()).getZExtValue ();
51
46
}
52
- static uint64_t getFirstIntValue (ArrayRef<OpFoldResult> foldResults) {
53
- auto attr = foldResults[0 ].dyn_cast <Attribute>();
54
- if (attr)
55
- return getFirstIntValue (attr);
56
-
57
- return getFirstIntValue (ValueRange{foldResults[0 ].get <Value>()});
58
- }
59
47
60
48
// / Returns the number of bits for the given scalar/vector type.
61
49
static int getNumBits (Type type) {
@@ -157,9 +145,6 @@ struct VectorExtractOpConvert final
157
145
LogicalResult
158
146
matchAndRewrite (vector::ExtractOp extractOp, OpAdaptor adaptor,
159
147
ConversionPatternRewriter &rewriter) const override {
160
- if (extractOp.hasDynamicPosition ())
161
- return failure ();
162
-
163
148
Type dstType = getTypeConverter ()->convertType (extractOp.getType ());
164
149
if (!dstType)
165
150
return failure ();
@@ -169,9 +154,17 @@ struct VectorExtractOpConvert final
169
154
return success ();
170
155
}
171
156
172
- int32_t id = getFirstIntValue (extractOp.getMixedPosition ());
173
- rewriter.replaceOpWithNewOp <spirv::CompositeExtractOp>(
174
- extractOp, adaptor.getVector (), id);
157
+ std::optional<int64_t > id =
158
+ getConstantIntValue (extractOp.getMixedPosition ()[0 ]);
159
+
160
+ if (id.has_value ())
161
+ rewriter.replaceOpWithNewOp <spirv::CompositeExtractOp>(
162
+ extractOp, dstType, adaptor.getVector (),
163
+ rewriter.getI32ArrayAttr (id.value ()));
164
+ else
165
+ rewriter.replaceOpWithNewOp <spirv::VectorExtractDynamicOp>(
166
+ extractOp, dstType, adaptor.getVector (),
167
+ adaptor.getDynamicPosition ()[0 ]);
175
168
return success ();
176
169
}
177
170
};
@@ -249,9 +242,20 @@ struct VectorInsertOpConvert final
249
242
return success ();
250
243
}
251
244
252
- int32_t id = getFirstIntValue (insertOp.getMixedPosition ());
253
- rewriter.replaceOpWithNewOp <spirv::CompositeInsertOp>(
254
- insertOp, adaptor.getSource (), adaptor.getDest (), id);
245
+ std::optional<int64_t > id =
246
+ getConstantIntValue (insertOp.getMixedPosition ()[0 ]);
247
+
248
+ // rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
249
+ // insertOp, adaptor.getSource(), adaptor.getDest(), id);
250
+ // return success();
251
+
252
+ if (id.has_value ())
253
+ rewriter.replaceOpWithNewOp <spirv::CompositeInsertOp>(
254
+ insertOp, adaptor.getSource (), adaptor.getDest (), id.value ());
255
+ else
256
+ rewriter.replaceOpWithNewOp <spirv::VectorInsertDynamicOp>(
257
+ insertOp, insertOp.getDest (), adaptor.getSource (),
258
+ adaptor.getDynamicPosition ()[0 ]);
255
259
return success ();
256
260
}
257
261
};
0 commit comments