@@ -39,17 +39,17 @@ namespace {
39
39
// ===----------------------------------------------------------------------===//
40
40
// Helpers
41
41
// ===----------------------------------------------------------------------===//
42
- Attribute getScalarOrSplatAttr (Type type, int64_t value) {
42
+ static Attribute getScalarOrSplatAttr (Type type, int64_t value) {
43
43
APInt sizedValue (getElementTypeOrSelf (type).getIntOrFloatBitWidth (), value);
44
44
if (auto intTy = dyn_cast<IntegerType>(type))
45
45
return IntegerAttr::get (intTy, sizedValue);
46
46
47
47
return SplatElementsAttr::get (cast<ShapedType>(type), sizedValue);
48
48
}
49
49
50
- Value lowerExtendedMultiplication (Operation *mulOp, PatternRewriter &rewriter ,
51
- Value lhs , Value rhs ,
52
- bool signExtendArguments) {
50
+ static Value lowerExtendedMultiplication (Operation *mulOp,
51
+ PatternRewriter &rewriter , Value lhs ,
52
+ Value rhs, bool signExtendArguments) {
53
53
Location loc = mulOp->getLoc ();
54
54
Type argTy = lhs.getType ();
55
55
// Emulate 64-bit multiplication by splitting each input element of type i32
@@ -203,15 +203,39 @@ struct ExpandAddCarryPattern final : OpRewritePattern<IAddCarryOp> {
203
203
}
204
204
};
205
205
206
+ struct ExpandIsInfPattern final : OpRewritePattern<IsInfOp> {
207
+ using OpRewritePattern::OpRewritePattern;
208
+
209
+ LogicalResult matchAndRewrite (IsInfOp op,
210
+ PatternRewriter &rewriter) const override {
211
+ // We assume values to be finite and turn `IsInf` info `false`.
212
+ rewriter.replaceOpWithNewOp <spirv::ConstantOp>(
213
+ op, op.getType (), getScalarOrSplatAttr (op.getType (), 0 ));
214
+ return success ();
215
+ }
216
+ };
217
+
218
+ struct ExpandIsNanPattern final : OpRewritePattern<IsNanOp> {
219
+ using OpRewritePattern::OpRewritePattern;
220
+
221
+ LogicalResult matchAndRewrite (IsNanOp op,
222
+ PatternRewriter &rewriter) const override {
223
+ // We assume values to be finite and turn `IsNan` info `false`.
224
+ rewriter.replaceOpWithNewOp <spirv::ConstantOp>(
225
+ op, op.getType (), getScalarOrSplatAttr (op.getType (), 0 ));
226
+ return success ();
227
+ }
228
+ };
229
+
206
230
// ===----------------------------------------------------------------------===//
207
231
// Passes
208
232
// ===----------------------------------------------------------------------===//
209
- class WebGPUPreparePass
210
- : public impl::SPIRVWebGPUPreparePassBase<WebGPUPreparePass> {
211
- public:
233
+ struct WebGPUPreparePass final
234
+ : impl::SPIRVWebGPUPreparePassBase<WebGPUPreparePass> {
212
235
void runOnOperation () override {
213
236
RewritePatternSet patterns (&getContext ());
214
237
populateSPIRVExpandExtendedMultiplicationPatterns (patterns);
238
+ populateSPIRVExpandNonFiniteArithmeticPatterns (patterns);
215
239
216
240
if (failed (
217
241
applyPatternsAndFoldGreedily (getOperation (), std::move (patterns))))
@@ -227,12 +251,16 @@ void populateSPIRVExpandExtendedMultiplicationPatterns(
227
251
RewritePatternSet &patterns) {
228
252
// WGSL currently does not support extended multiplication ops, see:
229
253
// https://github.com/gpuweb/gpuweb/issues/1565.
230
- patterns.add <
231
- // clang-format off
232
- ExpandSMulExtendedPattern,
233
- ExpandUMulExtendedPattern,
234
- ExpandAddCarryPattern
235
- >(patterns.getContext ());
254
+ patterns.add <ExpandSMulExtendedPattern, ExpandUMulExtendedPattern,
255
+ ExpandAddCarryPattern>(patterns.getContext ());
236
256
}
257
+
258
+ void populateSPIRVExpandNonFiniteArithmeticPatterns (
259
+ RewritePatternSet &patterns) {
260
+ // WGSL currently does not support `isInf` and `isNan`, see:
261
+ // https://github.com/gpuweb/gpuweb/pull/2311.
262
+ patterns.add <ExpandIsInfPattern, ExpandIsNanPattern>(patterns.getContext ());
263
+ }
264
+
237
265
} // namespace spirv
238
266
} // namespace mlir
0 commit comments