@@ -137,6 +137,26 @@ struct VectorBroadcastConvert final
137
137
}
138
138
};
139
139
140
+ // SPIR-V does not have a concept of a poison index for certain instructions,
141
+ // which creates a UB hazard when lowering from otherwise equivalent Vector
142
+ // dialect instructions, because this index will be considered out-of-bounds.
143
+ // To avoid this, this function implements a dynamic sanitization, arbitrarily
144
+ // choosing to replace the poison index with index 0 (always in-bounds).
145
+ static Value sanitizeDynamicIndex (ConversionPatternRewriter &rewriter,
146
+ Location loc, Value dynamicIndex,
147
+ int64_t kPoisonIndex ) {
148
+ Value poisonIndex = rewriter.create <spirv::ConstantOp>(
149
+ loc, dynamicIndex.getType (),
150
+ rewriter.getIntegerAttr (dynamicIndex.getType (), kPoisonIndex ));
151
+ Value cmpResult =
152
+ rewriter.create <spirv::IEqualOp>(loc, dynamicIndex, poisonIndex);
153
+ Value sanitizedIndex = rewriter.create <spirv::SelectOp>(
154
+ loc, cmpResult,
155
+ spirv::ConstantOp::getZero (dynamicIndex.getType (), loc, rewriter),
156
+ dynamicIndex);
157
+ return sanitizedIndex;
158
+ }
159
+
140
160
struct VectorExtractOpConvert final
141
161
: public OpConversionPattern<vector::ExtractOp> {
142
162
using OpConversionPattern::OpConversionPattern;
@@ -154,14 +174,26 @@ struct VectorExtractOpConvert final
154
174
}
155
175
156
176
if (std::optional<int64_t > id =
157
- getConstantIntValue (extractOp.getMixedPosition ()[0 ]))
158
- rewriter.replaceOpWithNewOp <spirv::CompositeExtractOp>(
159
- extractOp, dstType, adaptor.getVector (),
160
- rewriter.getI32ArrayAttr (id.value ()));
161
- else
177
+ getConstantIntValue (extractOp.getMixedPosition ()[0 ])) {
178
+ // TODO: It would be better to apply the ub.poison folding for this case
179
+ // unconditionally, and have a specific SPIR-V lowering for it,
180
+ // rather than having to handle it here.
181
+ if (id == vector::ExtractOp::kPoisonIndex ) {
182
+ // Arbitrary choice of poison result, intended to stick out.
183
+ Value zero =
184
+ spirv::ConstantOp::getZero (dstType, extractOp.getLoc (), rewriter);
185
+ rewriter.replaceOp (extractOp, zero);
186
+ } else
187
+ rewriter.replaceOpWithNewOp <spirv::CompositeExtractOp>(
188
+ extractOp, dstType, adaptor.getVector (),
189
+ rewriter.getI32ArrayAttr (id.value ()));
190
+ } else {
191
+ Value sanitizedIndex = sanitizeDynamicIndex (
192
+ rewriter, extractOp.getLoc (), adaptor.getDynamicPosition ()[0 ],
193
+ vector::ExtractOp::kPoisonIndex );
162
194
rewriter.replaceOpWithNewOp <spirv::VectorExtractDynamicOp>(
163
- extractOp, dstType, adaptor.getVector (),
164
- adaptor. getDynamicPosition ()[ 0 ]);
195
+ extractOp, dstType, adaptor.getVector (), sanitizedIndex);
196
+ }
165
197
return success ();
166
198
}
167
199
};
@@ -266,13 +298,25 @@ struct VectorInsertOpConvert final
266
298
}
267
299
268
300
if (std::optional<int64_t > id =
269
- getConstantIntValue (insertOp.getMixedPosition ()[0 ]))
270
- rewriter.replaceOpWithNewOp <spirv::CompositeInsertOp>(
271
- insertOp, adaptor.getSource (), adaptor.getDest (), id.value ());
272
- else
301
+ getConstantIntValue (insertOp.getMixedPosition ()[0 ])) {
302
+ // TODO: It would be better to apply the ub.poison folding for this case
303
+ // unconditionally, and have a specific SPIR-V lowering for it,
304
+ // rather than having to handle it here.
305
+ if (id == vector::InsertOp::kPoisonIndex ) {
306
+ // Arbitrary choice of poison result, intended to stick out.
307
+ Value zero = spirv::ConstantOp::getZero (insertOp.getDestVectorType (),
308
+ insertOp.getLoc (), rewriter);
309
+ rewriter.replaceOp (insertOp, zero);
310
+ } else
311
+ rewriter.replaceOpWithNewOp <spirv::CompositeInsertOp>(
312
+ insertOp, adaptor.getSource (), adaptor.getDest (), id.value ());
313
+ } else {
314
+ Value sanitizedIndex = sanitizeDynamicIndex (
315
+ rewriter, insertOp.getLoc (), adaptor.getDynamicPosition ()[0 ],
316
+ vector::InsertOp::kPoisonIndex );
273
317
rewriter.replaceOpWithNewOp <spirv::VectorInsertDynamicOp>(
274
- insertOp, insertOp.getDest (), adaptor.getSource (),
275
- adaptor. getDynamicPosition ()[ 0 ]);
318
+ insertOp, insertOp.getDest (), adaptor.getSource (), sanitizedIndex);
319
+ }
276
320
return success ();
277
321
}
278
322
};
0 commit comments