Skip to content

Commit f15d21e

Browse files
committed
[mlir][VectorToSPIRV] Add conversion for vector.extract with dynamic indices
1 parent 13b5899 commit f15d21e

File tree

2 files changed

+68
-22
lines changed

2 files changed

+68
-22
lines changed

mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
1818
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
1919
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
20+
#include "mlir/Dialect/Utils/StaticValueUtils.h"
2021
#include "mlir/Dialect/Vector/IR/VectorOps.h"
2122
#include "mlir/IR/Attributes.h"
2223
#include "mlir/IR/BuiltinAttributes.h"
@@ -40,22 +41,9 @@ using namespace mlir;
4041
/// Returns the integer value from the first valid input element, assuming Value
4142
/// inputs are defined by a constant index ops and Attribute inputs are integer
4243
/// attributes.
43-
static uint64_t getFirstIntValue(ValueRange values) {
44-
return values[0].getDefiningOp<arith::ConstantIndexOp>().value();
45-
}
46-
static uint64_t getFirstIntValue(ArrayRef<Attribute> attr) {
47-
return cast<IntegerAttr>(attr[0]).getInt();
48-
}
4944
static uint64_t getFirstIntValue(ArrayAttr attr) {
5045
return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
5146
}
52-
static uint64_t getFirstIntValue(ArrayRef<OpFoldResult> foldResults) {
53-
auto attr = foldResults[0].dyn_cast<Attribute>();
54-
if (attr)
55-
return getFirstIntValue(attr);
56-
57-
return getFirstIntValue(ValueRange{foldResults[0].get<Value>()});
58-
}
5947

6048
/// Returns the number of bits for the given scalar/vector type.
6149
static int getNumBits(Type type) {
@@ -157,9 +145,6 @@ struct VectorExtractOpConvert final
157145
LogicalResult
158146
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
159147
ConversionPatternRewriter &rewriter) const override {
160-
if (extractOp.hasDynamicPosition())
161-
return failure();
162-
163148
Type dstType = getTypeConverter()->convertType(extractOp.getType());
164149
if (!dstType)
165150
return failure();
@@ -169,9 +154,17 @@ struct VectorExtractOpConvert final
169154
return success();
170155
}
171156

172-
int32_t id = getFirstIntValue(extractOp.getMixedPosition());
173-
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
174-
extractOp, adaptor.getVector(), id);
157+
std::optional<int64_t> id =
158+
getConstantIntValue(extractOp.getMixedPosition()[0]);
159+
160+
if (id.has_value())
161+
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
162+
extractOp, dstType, adaptor.getVector(),
163+
rewriter.getI32ArrayAttr(id.value()));
164+
else
165+
rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
166+
extractOp, dstType, adaptor.getVector(),
167+
adaptor.getDynamicPosition()[0]);
175168
return success();
176169
}
177170
};
@@ -249,9 +242,20 @@ struct VectorInsertOpConvert final
249242
return success();
250243
}
251244

252-
int32_t id = getFirstIntValue(insertOp.getMixedPosition());
253-
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
254-
insertOp, adaptor.getSource(), adaptor.getDest(), id);
245+
std::optional<int64_t> id =
246+
getConstantIntValue(insertOp.getMixedPosition()[0]);
247+
248+
// rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
249+
// insertOp, adaptor.getSource(), adaptor.getDest(), id);
250+
// return success();
251+
252+
if (id.has_value())
253+
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
254+
insertOp, adaptor.getSource(), adaptor.getDest(), id.value());
255+
else
256+
rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
257+
insertOp, insertOp.getDest(), adaptor.getSource(),
258+
adaptor.getDynamicPosition()[0]);
255259
return success();
256260
}
257261
};

mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,26 @@ func.func @extract_size1_vector(%arg0 : vector<1xf32>) -> f32 {
186186

187187
// -----
188188

189+
// CHECK-LABEL: @extract_dynamic
190+
// CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ARG1:.*]]: index
191+
// CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i32
192+
// CHECK: spirv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32
193+
func.func @extract_dynamic(%arg0 : vector<4xf32>, %id : index) -> f32 {
194+
%0 = vector.extract %arg0[%id] : f32 from vector<4xf32>
195+
return %0: f32
196+
}
197+
198+
// CHECK-LABEL: @extract_dynamic_cst
199+
// CHECK-SAME: %[[V:.*]]: vector<4xf32>
200+
// CHECK: spirv.CompositeExtract %[[V]][1 : i32] : vector<4xf32>
201+
func.func @extract_dynamic_cst(%arg0 : vector<4xf32>) -> f32 {
202+
%idx = arith.constant 1 : index
203+
%0 = vector.extract %arg0[%idx] : f32 from vector<4xf32>
204+
return %0: f32
205+
}
206+
207+
// -----
208+
189209
// CHECK-LABEL: @insert
190210
// CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[S:.*]]: f32
191211
// CHECK: spirv.CompositeInsert %[[S]], %[[V]][2 : i32] : f32 into vector<4xf32>
@@ -216,6 +236,28 @@ func.func @insert_size1_vector(%arg0 : vector<1xf32>, %arg1: f32) -> vector<1xf3
216236

217237
// -----
218238

239+
// CHECK-LABEL: @insert_dynamic
240+
// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[ARG2:.*]]: index
241+
// CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : index to i32
242+
// CHECK: spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[ID]]] : vector<4xf32>, i32
243+
func.func @insert_dynamic(%val: f32, %arg0 : vector<4xf32>, %id : index) -> vector<4xf32> {
244+
%0 = vector.insert %val, %arg0[%id] : f32 into vector<4xf32>
245+
return %0: vector<4xf32>
246+
}
247+
248+
// -----
249+
250+
// CHECK-LABEL: @insert_dynamic_cst
251+
// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>
252+
// CHECK: spirv.CompositeInsert %[[VAL]], %[[V]][2 : i32] : f32 into vector<4xf32>
253+
func.func @insert_dynamic_cst(%val: f32, %arg0 : vector<4xf32>) -> vector<4xf32> {
254+
%idx = arith.constant 2 : index
255+
%0 = vector.insert %val, %arg0[%idx] : f32 into vector<4xf32>
256+
return %0: vector<4xf32>
257+
}
258+
259+
// -----
260+
219261
// CHECK-LABEL: @extract_element
220262
// CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32
221263
// CHECK: spirv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32

0 commit comments

Comments
 (0)