Skip to content

Commit 5df62bd

Browse files
authored
[mlir][spirv] Support poison index when converting vector.insert/extract (#125560)
This modifies the conversion patterns so that, in the case where the index is known statically to be poison, the insertion/extraction is replaced by an arbitrary junk constant value, and in the dynamic case, the index is sanitized at runtime. This avoids triggering a UB in both cases. The dynamic case is definitely a pessimisation of the generated code, but the use of dynamic indexes is expected to be very rare and already slow on real-world GPU compilers ingesting SPIR-V, so the impact should be negligible. Resolves #124162.
1 parent 814db6c commit 5df62bd

File tree

2 files changed

+107
-17
lines changed

2 files changed

+107
-17
lines changed

mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp

Lines changed: 64 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,33 @@ struct VectorBroadcastConvert final
137137
}
138138
};
139139

140+
// SPIR-V does not have a concept of a poison index for certain instructions,
141+
// which creates a UB hazard when lowering from otherwise equivalent Vector
142+
// dialect instructions, because this index will be considered out-of-bounds.
143+
// To avoid this, this function implements a dynamic sanitization that returns
144+
// some arbitrary safe index. For power-of-two vector sizes, this uses a bitmask
145+
// (presumably more efficient), and otherwise index 0 (always in-bounds).
146+
static Value sanitizeDynamicIndex(ConversionPatternRewriter &rewriter,
147+
Location loc, Value dynamicIndex,
148+
int64_t kPoisonIndex, unsigned vectorSize) {
149+
if (llvm::isPowerOf2_32(vectorSize)) {
150+
Value inBoundsMask = rewriter.create<spirv::ConstantOp>(
151+
loc, dynamicIndex.getType(),
152+
rewriter.getIntegerAttr(dynamicIndex.getType(), vectorSize - 1));
153+
return rewriter.create<spirv::BitwiseAndOp>(loc, dynamicIndex,
154+
inBoundsMask);
155+
}
156+
Value poisonIndex = rewriter.create<spirv::ConstantOp>(
157+
loc, dynamicIndex.getType(),
158+
rewriter.getIntegerAttr(dynamicIndex.getType(), kPoisonIndex));
159+
Value cmpResult =
160+
rewriter.create<spirv::IEqualOp>(loc, dynamicIndex, poisonIndex);
161+
return rewriter.create<spirv::SelectOp>(
162+
loc, cmpResult,
163+
spirv::ConstantOp::getZero(dynamicIndex.getType(), loc, rewriter),
164+
dynamicIndex);
165+
}
166+
140167
struct VectorExtractOpConvert final
141168
: public OpConversionPattern<vector::ExtractOp> {
142169
using OpConversionPattern::OpConversionPattern;
@@ -154,14 +181,26 @@ struct VectorExtractOpConvert final
154181
}
155182

156183
if (std::optional<int64_t> id =
157-
getConstantIntValue(extractOp.getMixedPosition()[0]))
158-
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
159-
extractOp, dstType, adaptor.getVector(),
160-
rewriter.getI32ArrayAttr(id.value()));
161-
else
184+
getConstantIntValue(extractOp.getMixedPosition()[0])) {
185+
// TODO: ExtractOp::fold() already can fold a static poison index to
186+
// ub.poison; remove this once ub.poison can be converted to SPIR-V.
187+
if (id == vector::ExtractOp::kPoisonIndex) {
188+
// Arbitrary choice of poison result, intended to stick out.
189+
Value zero =
190+
spirv::ConstantOp::getZero(dstType, extractOp.getLoc(), rewriter);
191+
rewriter.replaceOp(extractOp, zero);
192+
} else
193+
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
194+
extractOp, dstType, adaptor.getVector(),
195+
rewriter.getI32ArrayAttr(id.value()));
196+
} else {
197+
Value sanitizedIndex = sanitizeDynamicIndex(
198+
rewriter, extractOp.getLoc(), adaptor.getDynamicPosition()[0],
199+
vector::ExtractOp::kPoisonIndex,
200+
extractOp.getSourceVectorType().getNumElements());
162201
rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
163-
extractOp, dstType, adaptor.getVector(),
164-
adaptor.getDynamicPosition()[0]);
202+
extractOp, dstType, adaptor.getVector(), sanitizedIndex);
203+
}
165204
return success();
166205
}
167206
};
@@ -266,13 +305,25 @@ struct VectorInsertOpConvert final
266305
}
267306

268307
if (std::optional<int64_t> id =
269-
getConstantIntValue(insertOp.getMixedPosition()[0]))
270-
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
271-
insertOp, adaptor.getSource(), adaptor.getDest(), id.value());
272-
else
308+
getConstantIntValue(insertOp.getMixedPosition()[0])) {
309+
// TODO: ExtractOp::fold() already can fold a static poison index to
310+
// ub.poison; remove this once ub.poison can be converted to SPIR-V.
311+
if (id == vector::InsertOp::kPoisonIndex) {
312+
// Arbitrary choice of poison result, intended to stick out.
313+
Value zero = spirv::ConstantOp::getZero(insertOp.getDestVectorType(),
314+
insertOp.getLoc(), rewriter);
315+
rewriter.replaceOp(insertOp, zero);
316+
} else
317+
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
318+
insertOp, adaptor.getSource(), adaptor.getDest(), id.value());
319+
} else {
320+
Value sanitizedIndex = sanitizeDynamicIndex(
321+
rewriter, insertOp.getLoc(), adaptor.getDynamicPosition()[0],
322+
vector::InsertOp::kPoisonIndex,
323+
insertOp.getDestVectorType().getNumElements());
273324
rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
274-
insertOp, insertOp.getDest(), adaptor.getSource(),
275-
adaptor.getDynamicPosition()[0]);
325+
insertOp, insertOp.getDest(), adaptor.getSource(), sanitizedIndex);
326+
}
276327
return success();
277328
}
278329
};

mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,8 @@ func.func @extract(%arg0 : vector<2xf32>) -> (vector<1xf32>, f32) {
176176
// -----
177177

178178
func.func @extract_poison_idx(%arg0 : vector<4xf32>) -> f32 {
179-
// expected-error@+1 {{index -1 out of bounds for 'vector<4xf32>'}}
179+
// CHECK: %[[ZERO:.+]] = spirv.Constant 0.000000e+00
180+
// CHECK: return %[[ZERO]]
180181
%0 = vector.extract %arg0[-1] : f32 from vector<4xf32>
181182
return %0: f32
182183
}
@@ -208,12 +209,31 @@ func.func @extract_size1_vector_dynamic(%arg0 : vector<1xf32>, %id : index) -> f
208209
// CHECK-LABEL: @extract_dynamic
209210
// CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ARG1:.*]]: index
210211
// CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i32
211-
// CHECK: spirv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32
212+
// CHECK: %[[MASK:.+]] = spirv.Constant 3 :
213+
// CHECK: %[[MASKED:.+]] = spirv.BitwiseAnd %[[ID]], %[[MASK]] :
214+
// CHECK: spirv.VectorExtractDynamic %[[V]][%[[MASKED]]] : vector<4xf32>, i32
212215
func.func @extract_dynamic(%arg0 : vector<4xf32>, %id : index) -> f32 {
213216
%0 = vector.extract %arg0[%id] : f32 from vector<4xf32>
214217
return %0: f32
215218
}
216219

220+
// -----
221+
222+
// CHECK-LABEL: @extract_dynamic_non_pow2
223+
// CHECK-SAME: %[[V:.*]]: vector<3xf32>, %[[ARG1:.*]]: index
224+
// CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i32
225+
// CHECK: %[[POISON:.+]] = spirv.Constant -1 :
226+
// CHECK: %[[CMP:.+]] = spirv.IEqual %[[ID]], %[[POISON]]
227+
// CHECK: %[[ZERO:.+]] = spirv.Constant 0 :
228+
// CHECK: %[[SELECT:.+]] = spirv.Select %[[CMP]], %[[ZERO]], %[[ID]] :
229+
// CHECK: spirv.VectorExtractDynamic %[[V]][%[[SELECT]]] : vector<3xf32>, i32
230+
func.func @extract_dynamic_non_pow2(%arg0 : vector<3xf32>, %id : index) -> f32 {
231+
%0 = vector.extract %arg0[%id] : f32 from vector<3xf32>
232+
return %0: f32
233+
}
234+
235+
// -----
236+
217237
// CHECK-LABEL: @extract_dynamic_cst
218238
// CHECK-SAME: %[[V:.*]]: vector<4xf32>
219239
// CHECK: spirv.CompositeExtract %[[V]][1 : i32] : vector<4xf32>
@@ -264,8 +284,10 @@ func.func @insert(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
264284

265285
// -----
266286

287+
// CHECK-LABEL: @insert_poison_idx
288+
// CHECK: %[[ZERO:.+]] = spirv.Constant dense<0.000000e+00>
289+
// CHECK: return %[[ZERO]]
267290
func.func @insert_poison_idx(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
268-
// expected-error@+1 {{index -1 out of bounds for 'vector<4xf32>'}}
269291
%1 = vector.insert %arg1, %arg0[-1] : f32 into vector<4xf32>
270292
return %1: vector<4xf32>
271293
}
@@ -306,14 +328,31 @@ func.func @insert_size1_vector_dynamic(%arg0 : vector<1xf32>, %arg1: f32, %id :
306328
// CHECK-LABEL: @insert_dynamic
307329
// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[ARG2:.*]]: index
308330
// CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : index to i32
309-
// CHECK: spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[ID]]] : vector<4xf32>, i32
331+
// CHECK: %[[MASK:.+]] = spirv.Constant 3 :
332+
// CHECK: %[[MASKED:.+]] = spirv.BitwiseAnd %[[ID]], %[[MASK]] :
333+
// CHECK: spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[MASKED]]] : vector<4xf32>, i32
310334
func.func @insert_dynamic(%val: f32, %arg0 : vector<4xf32>, %id : index) -> vector<4xf32> {
311335
%0 = vector.insert %val, %arg0[%id] : f32 into vector<4xf32>
312336
return %0: vector<4xf32>
313337
}
314338

315339
// -----
316340

341+
// CHECK-LABEL: @insert_dynamic_non_pow2
342+
// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<3xf32>, %[[ARG2:.*]]: index
343+
// CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : index to i32
344+
// CHECK: %[[POISON:.+]] = spirv.Constant -1 :
345+
// CHECK: %[[CMP:.+]] = spirv.IEqual %[[ID]], %[[POISON]]
346+
// CHECK: %[[ZERO:.+]] = spirv.Constant 0 :
347+
// CHECK: %[[SELECT:.+]] = spirv.Select %[[CMP]], %[[ZERO]], %[[ID]] :
348+
// CHECK: spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[SELECT]]] : vector<3xf32>, i32
349+
func.func @insert_dynamic_non_pow2(%val: f32, %arg0 : vector<3xf32>, %id : index) -> vector<3xf32> {
350+
%0 = vector.insert %val, %arg0[%id] : f32 into vector<3xf32>
351+
return %0: vector<3xf32>
352+
}
353+
354+
// -----
355+
317356
// CHECK-LABEL: @insert_dynamic_cst
318357
// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>
319358
// CHECK: spirv.CompositeInsert %[[VAL]], %[[V]][2 : i32] : f32 into vector<4xf32>

0 commit comments

Comments
 (0)