Skip to content

Commit 9b14fc1

Browse files
committed
[mlir][spirv] Support poison index when converting vector.insert/extract
This modifies the conversion patterns so that, in the case where the index is known statically to be poison, the insertion/extraction is replaced by an arbitrary junk constant value, and in the dynamic case, the index is sanitized at runtime. This avoids triggering a UB in both cases. The dynamic case is definitely a pessimisation of the generated code, but the use of dynamic indexes is expected to be very rare and already slow on real-world GPU compilers ingesting SPIR-V, so the impact should be negligible. Resolves #124162.
1 parent 25ae1a2 commit 9b14fc1

File tree

2 files changed

+72
-17
lines changed

2 files changed

+72
-17
lines changed

mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp

Lines changed: 57 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,26 @@ struct VectorBroadcastConvert final
137137
}
138138
};
139139

140+
// SPIR-V does not have a concept of a poison index for certain instructions,
141+
// which creates a UB hazard when lowering from otherwise equivalent Vector
142+
// dialect instructions, because this index will be considered out-of-bounds.
143+
// To avoid this, this function implements a dynamic sanitization, arbitrarily
144+
// choosing to replace the poison index with index 0 (always in-bounds).
145+
static Value sanitizeDynamicIndex(ConversionPatternRewriter &rewriter,
146+
Location loc, Value dynamicIndex,
147+
int64_t kPoisonIndex) {
148+
Value poisonIndex = rewriter.create<spirv::ConstantOp>(
149+
loc, dynamicIndex.getType(),
150+
rewriter.getIntegerAttr(dynamicIndex.getType(), kPoisonIndex));
151+
Value cmpResult =
152+
rewriter.create<spirv::IEqualOp>(loc, dynamicIndex, poisonIndex);
153+
Value sanitizedIndex = rewriter.create<spirv::SelectOp>(
154+
loc, cmpResult,
155+
spirv::ConstantOp::getZero(dynamicIndex.getType(), loc, rewriter),
156+
dynamicIndex);
157+
return sanitizedIndex;
158+
}
159+
140160
struct VectorExtractOpConvert final
141161
: public OpConversionPattern<vector::ExtractOp> {
142162
using OpConversionPattern::OpConversionPattern;
@@ -154,14 +174,26 @@ struct VectorExtractOpConvert final
154174
}
155175

156176
if (std::optional<int64_t> id =
157-
getConstantIntValue(extractOp.getMixedPosition()[0]))
158-
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
159-
extractOp, dstType, adaptor.getVector(),
160-
rewriter.getI32ArrayAttr(id.value()));
161-
else
177+
getConstantIntValue(extractOp.getMixedPosition()[0])) {
178+
// TODO: It would be better to apply the ub.poison folding for this case
179+
// unconditionally, and have a specific SPIR-V lowering for it,
180+
// rather than having to handle it here.
181+
if (id == vector::ExtractOp::kPoisonIndex) {
182+
// Arbitrary choice of poison result, intended to stick out.
183+
Value zero =
184+
spirv::ConstantOp::getZero(dstType, extractOp.getLoc(), rewriter);
185+
rewriter.replaceOp(extractOp, zero);
186+
} else
187+
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
188+
extractOp, dstType, adaptor.getVector(),
189+
rewriter.getI32ArrayAttr(id.value()));
190+
} else {
191+
Value sanitizedIndex = sanitizeDynamicIndex(
192+
rewriter, extractOp.getLoc(), adaptor.getDynamicPosition()[0],
193+
vector::ExtractOp::kPoisonIndex);
162194
rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
163-
extractOp, dstType, adaptor.getVector(),
164-
adaptor.getDynamicPosition()[0]);
195+
extractOp, dstType, adaptor.getVector(), sanitizedIndex);
196+
}
165197
return success();
166198
}
167199
};
@@ -266,13 +298,25 @@ struct VectorInsertOpConvert final
266298
}
267299

268300
if (std::optional<int64_t> id =
269-
getConstantIntValue(insertOp.getMixedPosition()[0]))
270-
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
271-
insertOp, adaptor.getSource(), adaptor.getDest(), id.value());
272-
else
301+
getConstantIntValue(insertOp.getMixedPosition()[0])) {
302+
// TODO: It would be better to apply the ub.poison folding for this case
303+
// unconditionally, and have a specific SPIR-V lowering for it,
304+
// rather than having to handle it here.
305+
if (id == vector::InsertOp::kPoisonIndex) {
306+
// Arbitrary choice of poison result, intended to stick out.
307+
Value zero = spirv::ConstantOp::getZero(insertOp.getDestVectorType(),
308+
insertOp.getLoc(), rewriter);
309+
rewriter.replaceOp(insertOp, zero);
310+
} else
311+
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
312+
insertOp, adaptor.getSource(), adaptor.getDest(), id.value());
313+
} else {
314+
Value sanitizedIndex = sanitizeDynamicIndex(
315+
rewriter, insertOp.getLoc(), adaptor.getDynamicPosition()[0],
316+
vector::InsertOp::kPoisonIndex);
273317
rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
274-
insertOp, insertOp.getDest(), adaptor.getSource(),
275-
adaptor.getDynamicPosition()[0]);
318+
insertOp, insertOp.getDest(), adaptor.getSource(), sanitizedIndex);
319+
}
276320
return success();
277321
}
278322
};

mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,8 @@ func.func @extract(%arg0 : vector<2xf32>) -> (vector<1xf32>, f32) {
176176
// -----
177177

178178
func.func @extract_poison_idx(%arg0 : vector<4xf32>) -> f32 {
179-
// expected-error@+1 {{index -1 out of bounds for 'vector<4xf32>'}}
179+
// CHECK: %[[ZERO:.+]] = spirv.Constant 0.000000e+00
180+
// CHECK: return %[[ZERO]]
180181
%0 = vector.extract %arg0[-1] : f32 from vector<4xf32>
181182
return %0: f32
182183
}
@@ -208,7 +209,11 @@ func.func @extract_size1_vector_dynamic(%arg0 : vector<1xf32>, %id : index) -> f
208209
// CHECK-LABEL: @extract_dynamic
209210
// CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ARG1:.*]]: index
210211
// CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i32
211-
// CHECK: spirv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32
212+
// CHECK: %[[POISON:.+]] = spirv.Constant -1 :
213+
// CHECK: %[[CMP:.+]] = spirv.IEqual %[[ID]], %[[POISON]]
214+
// CHECK: %[[ZERO:.+]] = spirv.Constant 0 :
215+
// CHECK: %[[SELECT:.+]] = spirv.Select %[[CMP]], %[[ZERO]], %[[ID]] :
216+
// CHECK: spirv.VectorExtractDynamic %[[V]][%[[SELECT]]] : vector<4xf32>, i32
212217
func.func @extract_dynamic(%arg0 : vector<4xf32>, %id : index) -> f32 {
213218
%0 = vector.extract %arg0[%id] : f32 from vector<4xf32>
214219
return %0: f32
@@ -264,8 +269,10 @@ func.func @insert(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
264269

265270
// -----
266271

272+
// CHECK-LABEL: @insert_poison_idx
273+
// CHECK: %[[ZERO:.+]] = spirv.Constant dense<0.000000e+00>
274+
// CHECK: return %[[ZERO]]
267275
func.func @insert_poison_idx(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
268-
// expected-error@+1 {{index -1 out of bounds for 'vector<4xf32>'}}
269276
%1 = vector.insert %arg1, %arg0[-1] : f32 into vector<4xf32>
270277
return %1: vector<4xf32>
271278
}
@@ -306,7 +313,11 @@ func.func @insert_size1_vector_dynamic(%arg0 : vector<1xf32>, %arg1: f32, %id :
306313
// CHECK-LABEL: @insert_dynamic
307314
// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[ARG2:.*]]: index
308315
// CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : index to i32
309-
// CHECK: spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[ID]]] : vector<4xf32>, i32
316+
// CHECK: %[[POISON:.+]] = spirv.Constant -1 :
317+
// CHECK: %[[CMP:.+]] = spirv.IEqual %[[ID]], %[[POISON]]
318+
// CHECK: %[[ZERO:.+]] = spirv.Constant 0 :
319+
// CHECK: %[[SELECT:.+]] = spirv.Select %[[CMP]], %[[ZERO]], %[[ID]] :
320+
// CHECK: spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[SELECT]]] : vector<4xf32>, i32
310321
func.func @insert_dynamic(%val: f32, %arg0 : vector<4xf32>, %id : index) -> vector<4xf32> {
311322
%0 = vector.insert %val, %arg0[%id] : f32 into vector<4xf32>
312323
return %0: vector<4xf32>

0 commit comments

Comments
 (0)