Skip to content

Commit c96a85a

Browse files
authored
[mlir][VectorToSPIRV] Add conversion for vector.extract with dynamic indices (#114137)
1 parent 7a5b040 commit c96a85a

File tree

2 files changed

+82
-22
lines changed

2 files changed

+82
-22
lines changed

mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
1818
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
1919
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
20+
#include "mlir/Dialect/Utils/StaticValueUtils.h"
2021
#include "mlir/Dialect/Vector/IR/VectorOps.h"
2122
#include "mlir/IR/Attributes.h"
2223
#include "mlir/IR/BuiltinAttributes.h"
@@ -40,22 +41,9 @@ using namespace mlir;
4041
/// Returns the integer value from the first valid input element, assuming Value
4142
/// inputs are defined by a constant index ops and Attribute inputs are integer
4243
/// attributes.
43-
static uint64_t getFirstIntValue(ValueRange values) {
44-
return values[0].getDefiningOp<arith::ConstantIndexOp>().value();
45-
}
46-
static uint64_t getFirstIntValue(ArrayRef<Attribute> attr) {
47-
return cast<IntegerAttr>(attr[0]).getInt();
48-
}
4944
static uint64_t getFirstIntValue(ArrayAttr attr) {
5045
return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
5146
}
52-
static uint64_t getFirstIntValue(ArrayRef<OpFoldResult> foldResults) {
53-
auto attr = foldResults[0].dyn_cast<Attribute>();
54-
if (attr)
55-
return getFirstIntValue(attr);
56-
57-
return getFirstIntValue(ValueRange{foldResults[0].get<Value>()});
58-
}
5947

6048
/// Returns the number of bits for the given scalar/vector type.
6149
static int getNumBits(Type type) {
@@ -157,9 +145,6 @@ struct VectorExtractOpConvert final
157145
LogicalResult
158146
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
159147
ConversionPatternRewriter &rewriter) const override {
160-
if (extractOp.hasDynamicPosition())
161-
return failure();
162-
163148
Type dstType = getTypeConverter()->convertType(extractOp.getType());
164149
if (!dstType)
165150
return failure();
@@ -169,9 +154,15 @@ struct VectorExtractOpConvert final
169154
return success();
170155
}
171156

172-
int32_t id = getFirstIntValue(extractOp.getMixedPosition());
173-
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
174-
extractOp, adaptor.getVector(), id);
157+
if (std::optional<int64_t> id =
158+
getConstantIntValue(extractOp.getMixedPosition()[0]))
159+
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
160+
extractOp, dstType, adaptor.getVector(),
161+
rewriter.getI32ArrayAttr(id.value()));
162+
else
163+
rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
164+
extractOp, dstType, adaptor.getVector(),
165+
adaptor.getDynamicPosition()[0]);
175166
return success();
176167
}
177168
};
@@ -249,9 +240,14 @@ struct VectorInsertOpConvert final
249240
return success();
250241
}
251242

252-
int32_t id = getFirstIntValue(insertOp.getMixedPosition());
253-
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
254-
insertOp, adaptor.getSource(), adaptor.getDest(), id);
243+
if (std::optional<int64_t> id =
244+
getConstantIntValue(insertOp.getMixedPosition()[0]))
245+
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
246+
insertOp, adaptor.getSource(), adaptor.getDest(), id.value());
247+
else
248+
rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
249+
insertOp, insertOp.getDest(), adaptor.getSource(),
250+
adaptor.getDynamicPosition()[0]);
255251
return success();
256252
}
257253
};

mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,37 @@ func.func @extract_size1_vector(%arg0 : vector<1xf32>) -> f32 {
186186

187187
// -----
188188

189+
// CHECK-LABEL: @extract_size1_vector_dynamic
190+
// CHECK-SAME: %[[ARG0:.+]]: vector<1xf32>
191+
// CHECK: %[[R:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
192+
// CHECK: return %[[R]]
193+
func.func @extract_size1_vector_dynamic(%arg0 : vector<1xf32>, %id : index) -> f32 {
194+
%0 = vector.extract %arg0[%id] : f32 from vector<1xf32>
195+
return %0: f32
196+
}
197+
198+
// -----
199+
200+
// CHECK-LABEL: @extract_dynamic
201+
// CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ARG1:.*]]: index
202+
// CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i32
203+
// CHECK: spirv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32
204+
func.func @extract_dynamic(%arg0 : vector<4xf32>, %id : index) -> f32 {
205+
%0 = vector.extract %arg0[%id] : f32 from vector<4xf32>
206+
return %0: f32
207+
}
208+
209+
// CHECK-LABEL: @extract_dynamic_cst
210+
// CHECK-SAME: %[[V:.*]]: vector<4xf32>
211+
// CHECK: spirv.CompositeExtract %[[V]][1 : i32] : vector<4xf32>
212+
func.func @extract_dynamic_cst(%arg0 : vector<4xf32>) -> f32 {
213+
%idx = arith.constant 1 : index
214+
%0 = vector.extract %arg0[%idx] : f32 from vector<4xf32>
215+
return %0: f32
216+
}
217+
218+
// -----
219+
189220
// CHECK-LABEL: @insert
190221
// CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[S:.*]]: f32
191222
// CHECK: spirv.CompositeInsert %[[S]], %[[V]][2 : i32] : f32 into vector<4xf32>
@@ -216,6 +247,39 @@ func.func @insert_size1_vector(%arg0 : vector<1xf32>, %arg1: f32) -> vector<1xf3
216247

217248
// -----
218249

250+
// CHECK-LABEL: @insert_size1_vector_dynamic
251+
// CHECK-SAME: %[[V:.*]]: vector<1xf32>, %[[S:.*]]: f32
252+
// CHECK: %[[R:.+]] = builtin.unrealized_conversion_cast %[[S]]
253+
// CHECK: return %[[R]]
254+
func.func @insert_size1_vector_dynamic(%arg0 : vector<1xf32>, %arg1: f32, %id : index) -> vector<1xf32> {
255+
%1 = vector.insert %arg1, %arg0[%id] : f32 into vector<1xf32>
256+
return %1 : vector<1xf32>
257+
}
258+
259+
// -----
260+
261+
// CHECK-LABEL: @insert_dynamic
262+
// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[ARG2:.*]]: index
263+
// CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : index to i32
264+
// CHECK: spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[ID]]] : vector<4xf32>, i32
265+
func.func @insert_dynamic(%val: f32, %arg0 : vector<4xf32>, %id : index) -> vector<4xf32> {
266+
%0 = vector.insert %val, %arg0[%id] : f32 into vector<4xf32>
267+
return %0: vector<4xf32>
268+
}
269+
270+
// -----
271+
272+
// CHECK-LABEL: @insert_dynamic_cst
273+
// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>
274+
// CHECK: spirv.CompositeInsert %[[VAL]], %[[V]][2 : i32] : f32 into vector<4xf32>
275+
func.func @insert_dynamic_cst(%val: f32, %arg0 : vector<4xf32>) -> vector<4xf32> {
276+
%idx = arith.constant 2 : index
277+
%0 = vector.insert %val, %arg0[%idx] : f32 into vector<4xf32>
278+
return %0: vector<4xf32>
279+
}
280+
281+
// -----
282+
219283
// CHECK-LABEL: @extract_element
220284
// CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32
221285
// CHECK: spirv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32

0 commit comments

Comments
 (0)