@@ -137,6 +137,33 @@ struct VectorBroadcastConvert final
137
137
}
138
138
};
139
139
140
+ // SPIR-V does not have a concept of a poison index for certain instructions,
141
+ // which creates a UB hazard when lowering from otherwise equivalent Vector
142
+ // dialect instructions, because this index will be considered out-of-bounds.
143
+ // To avoid this, this function implements a dynamic sanitization that returns
144
+ // some arbitrary safe index. For power-of-two vector sizes, this uses a bitmask
145
+ // (presumably more efficient), and otherwise index 0 (always in-bounds).
146
+ static Value sanitizeDynamicIndex (ConversionPatternRewriter &rewriter,
147
+ Location loc, Value dynamicIndex,
148
+ int64_t kPoisonIndex , unsigned vectorSize) {
149
+ if (llvm::isPowerOf2_32 (vectorSize)) {
150
+ Value inBoundsMask = rewriter.create <spirv::ConstantOp>(
151
+ loc, dynamicIndex.getType (),
152
+ rewriter.getIntegerAttr (dynamicIndex.getType (), vectorSize - 1 ));
153
+ return rewriter.create <spirv::BitwiseAndOp>(loc, dynamicIndex,
154
+ inBoundsMask);
155
+ }
156
+ Value poisonIndex = rewriter.create <spirv::ConstantOp>(
157
+ loc, dynamicIndex.getType (),
158
+ rewriter.getIntegerAttr (dynamicIndex.getType (), kPoisonIndex ));
159
+ Value cmpResult =
160
+ rewriter.create <spirv::IEqualOp>(loc, dynamicIndex, poisonIndex);
161
+ return rewriter.create <spirv::SelectOp>(
162
+ loc, cmpResult,
163
+ spirv::ConstantOp::getZero (dynamicIndex.getType (), loc, rewriter),
164
+ dynamicIndex);
165
+ }
166
+
140
167
struct VectorExtractOpConvert final
141
168
: public OpConversionPattern<vector::ExtractOp> {
142
169
using OpConversionPattern::OpConversionPattern;
@@ -154,14 +181,26 @@ struct VectorExtractOpConvert final
154
181
}
155
182
156
183
if (std::optional<int64_t > id =
157
- getConstantIntValue (extractOp.getMixedPosition ()[0 ]))
158
- rewriter.replaceOpWithNewOp <spirv::CompositeExtractOp>(
159
- extractOp, dstType, adaptor.getVector (),
160
- rewriter.getI32ArrayAttr (id.value ()));
161
- else
184
+ getConstantIntValue (extractOp.getMixedPosition ()[0 ])) {
185
+ // TODO: ExtractOp::fold() already can fold a static poison index to
186
+ // ub.poison; remove this once ub.poison can be converted to SPIR-V.
187
+ if (id == vector::ExtractOp::kPoisonIndex ) {
188
+ // Arbitrary choice of poison result, intended to stick out.
189
+ Value zero =
190
+ spirv::ConstantOp::getZero (dstType, extractOp.getLoc (), rewriter);
191
+ rewriter.replaceOp (extractOp, zero);
192
+ } else
193
+ rewriter.replaceOpWithNewOp <spirv::CompositeExtractOp>(
194
+ extractOp, dstType, adaptor.getVector (),
195
+ rewriter.getI32ArrayAttr (id.value ()));
196
+ } else {
197
+ Value sanitizedIndex = sanitizeDynamicIndex (
198
+ rewriter, extractOp.getLoc (), adaptor.getDynamicPosition ()[0 ],
199
+ vector::ExtractOp::kPoisonIndex ,
200
+ extractOp.getSourceVectorType ().getNumElements ());
162
201
rewriter.replaceOpWithNewOp <spirv::VectorExtractDynamicOp>(
163
- extractOp, dstType, adaptor.getVector (),
164
- adaptor. getDynamicPosition ()[ 0 ]);
202
+ extractOp, dstType, adaptor.getVector (), sanitizedIndex);
203
+ }
165
204
return success ();
166
205
}
167
206
};
@@ -266,13 +305,25 @@ struct VectorInsertOpConvert final
266
305
}
267
306
268
307
if (std::optional<int64_t > id =
269
- getConstantIntValue (insertOp.getMixedPosition ()[0 ]))
270
- rewriter.replaceOpWithNewOp <spirv::CompositeInsertOp>(
271
- insertOp, adaptor.getSource (), adaptor.getDest (), id.value ());
272
- else
308
+ getConstantIntValue (insertOp.getMixedPosition ()[0 ])) {
309
+ // TODO: ExtractOp::fold() already can fold a static poison index to
310
+ // ub.poison; remove this once ub.poison can be converted to SPIR-V.
311
+ if (id == vector::InsertOp::kPoisonIndex ) {
312
+ // Arbitrary choice of poison result, intended to stick out.
313
+ Value zero = spirv::ConstantOp::getZero (insertOp.getDestVectorType (),
314
+ insertOp.getLoc (), rewriter);
315
+ rewriter.replaceOp (insertOp, zero);
316
+ } else
317
+ rewriter.replaceOpWithNewOp <spirv::CompositeInsertOp>(
318
+ insertOp, adaptor.getSource (), adaptor.getDest (), id.value ());
319
+ } else {
320
+ Value sanitizedIndex = sanitizeDynamicIndex (
321
+ rewriter, insertOp.getLoc (), adaptor.getDynamicPosition ()[0 ],
322
+ vector::InsertOp::kPoisonIndex ,
323
+ insertOp.getDestVectorType ().getNumElements ());
273
324
rewriter.replaceOpWithNewOp <spirv::VectorInsertDynamicOp>(
274
- insertOp, insertOp.getDest (), adaptor.getSource (),
275
- adaptor. getDynamicPosition ()[ 0 ]);
325
+ insertOp, insertOp.getDest (), adaptor.getSource (), sanitizedIndex);
326
+ }
276
327
return success ();
277
328
}
278
329
};
0 commit comments