Skip to content

Commit 09d2ef8

Browse files
committed
Refine description and set 'valid' flag according to the resulting landID
1 parent f85e4a2 commit 09d2ef8

File tree

3 files changed

+106
-9
lines changed

3 files changed

+106
-9
lines changed

mlir/include/mlir/Dialect/GPU/IR/GPUOps.td

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1332,8 +1332,9 @@ def GPU_ShuffleOp : GPU_Op<
13321332
%3, %4 = gpu.shuffle down %0, %cst1, %width : f32
13331333
```
13341334

1335-
For lane `k`, returns the value from lane `(k + cst1)`. The resulting value
1336-
is undefined if the lane is out of bounds in the subgroup.
1335+
For lane `k`, returns the value from lane `(k + cst1)`. If `(k + cst1)` is
1336+
bigger than or equal to `width`, the value is unspecified and `valid` is
1337+
`false`.
13371338

13381339
`up` example:
13391340

@@ -1342,8 +1343,8 @@ def GPU_ShuffleOp : GPU_Op<
13421343
%5, %6 = gpu.shuffle up %0, %cst1, %width : f32
13431344
```
13441345

1345-
For lane `k`, returns the value from lane `(k - cst1)`. The resulting value
1346-
is undefined if the lane is out of bounds in the subgroup.
1346+
For lane `k`, returns the value from lane `(k - cst1)`. If `(k - cst1)` is
1347+
smaller than `0`, the value is unspecified and `valid` is `false`.
13471348

13481349
`idx` example:
13491350

mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp

Lines changed: 62 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,15 @@ LogicalResult GPUBarrierConversion::matchAndRewrite(
416416
return success();
417417
}
418418

419+
template <typename T>
420+
Value getDimOp(OpBuilder &builder, MLIRContext *ctx, Location loc,
421+
gpu::Dimension dimension) {
422+
Type indexType = IndexType::get(ctx);
423+
IntegerType i32Type = IntegerType::get(ctx, 32);
424+
Value dim = builder.create<T>(loc, indexType, dimension);
425+
return builder.create<arith::IndexCastOp>(loc, i32Type, dim);
426+
}
427+
419428
//===----------------------------------------------------------------------===//
420429
// Shuffle
421430
//===----------------------------------------------------------------------===//
@@ -436,8 +445,8 @@ LogicalResult GPUShuffleConversion::matchAndRewrite(
436445
shuffleOp, "shuffle width and target subgroup size mismatch");
437446

438447
Location loc = shuffleOp.getLoc();
439-
Value trueVal = spirv::ConstantOp::getOne(rewriter.getI1Type(),
440-
shuffleOp.getLoc(), rewriter);
448+
Value validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(),
449+
shuffleOp.getLoc(), rewriter);
441450
auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
442451
Value result;
443452

@@ -450,17 +459,65 @@ LogicalResult GPUShuffleConversion::matchAndRewrite(
450459
result = rewriter.create<spirv::GroupNonUniformShuffleOp>(
451460
loc, scope, adaptor.getValue(), adaptor.getOffset());
452461
break;
453-
case gpu::ShuffleMode::DOWN:
462+
case gpu::ShuffleMode::DOWN: {
454463
result = rewriter.create<spirv::GroupNonUniformShuffleDownOp>(
455464
loc, scope, adaptor.getValue(), adaptor.getOffset());
465+
466+
MLIRContext *ctx = shuffleOp.getContext();
467+
Value dimX =
468+
getDimOp<gpu::BlockDimOp>(rewriter, ctx, loc, gpu::Dimension::x);
469+
Value dimY =
470+
getDimOp<gpu::BlockDimOp>(rewriter, ctx, loc, gpu::Dimension::y);
471+
Value tidX =
472+
getDimOp<gpu::ThreadIdOp>(rewriter, ctx, loc, gpu::Dimension::x);
473+
Value tidY =
474+
getDimOp<gpu::ThreadIdOp>(rewriter, ctx, loc, gpu::Dimension::y);
475+
Value tidZ =
476+
getDimOp<gpu::ThreadIdOp>(rewriter, ctx, loc, gpu::Dimension::z);
477+
auto i32Type = rewriter.getIntegerType(32);
478+
Value tmp1 = rewriter.create<arith::MulIOp>(loc, i32Type, tidZ, dimY);
479+
Value tmp2 = rewriter.create<arith::AddIOp>(loc, i32Type, tmp1, tidY);
480+
Value tmp3 = rewriter.create<arith::MulIOp>(loc, i32Type, tmp2, dimX);
481+
Value landId = rewriter.create<arith::AddIOp>(loc, i32Type, tmp3, tidX);
482+
483+
Value resultLandId =
484+
rewriter.create<arith::AddIOp>(loc, landId, adaptor.getOffset());
485+
validVal = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
486+
resultLandId, adaptor.getWidth());
456487
break;
457-
case gpu::ShuffleMode::UP:
488+
}
489+
case gpu::ShuffleMode::UP: {
458490
result = rewriter.create<spirv::GroupNonUniformShuffleUpOp>(
459491
loc, scope, adaptor.getValue(), adaptor.getOffset());
492+
493+
MLIRContext *ctx = shuffleOp.getContext();
494+
Value dimX =
495+
getDimOp<gpu::BlockDimOp>(rewriter, ctx, loc, gpu::Dimension::x);
496+
Value dimY =
497+
getDimOp<gpu::BlockDimOp>(rewriter, ctx, loc, gpu::Dimension::y);
498+
Value tidX =
499+
getDimOp<gpu::ThreadIdOp>(rewriter, ctx, loc, gpu::Dimension::x);
500+
Value tidY =
501+
getDimOp<gpu::ThreadIdOp>(rewriter, ctx, loc, gpu::Dimension::y);
502+
Value tidZ =
503+
getDimOp<gpu::ThreadIdOp>(rewriter, ctx, loc, gpu::Dimension::z);
504+
auto i32Type = rewriter.getIntegerType(32);
505+
Value tmp1 = rewriter.create<arith::MulIOp>(loc, i32Type, tidZ, dimY);
506+
Value tmp2 = rewriter.create<arith::AddIOp>(loc, i32Type, tmp1, tidY);
507+
Value tmp3 = rewriter.create<arith::MulIOp>(loc, i32Type, tmp2, dimX);
508+
Value landId = rewriter.create<arith::AddIOp>(loc, i32Type, tmp3, tidX);
509+
510+
Value resultLandId =
511+
rewriter.create<arith::SubIOp>(loc, landId, adaptor.getOffset());
512+
validVal = rewriter.create<arith::CmpIOp>(
513+
loc, arith::CmpIPredicate::sge, resultLandId,
514+
rewriter.create<arith::ConstantOp>(
515+
loc, i32Type, rewriter.getIntegerAttr(i32Type, 0)));
460516
break;
461517
}
518+
}
462519

463-
rewriter.replaceOp(shuffleOp, {result, trueVal});
520+
rewriter.replaceOp(shuffleOp, {result, validVal});
464521
return success();
465522
}
466523

mlir/test/Conversion/GPUToSPIRV/shuffle.mlir

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,25 @@ gpu.module @kernels {
9494
// CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
9595
// CHECK: %{{.+}} = spirv.Constant true
9696
// CHECK: %{{.+}} = spirv.GroupNonUniformShuffleDown <Subgroup> %[[VAL]], %[[OFFSET]] : f32, i32
97+
98+
// CHECK: %[[BLOCK_SIZE_X:.+]] = spirv.Constant 16 : i32
99+
// CHECK: %[[BLOCK_SIZE_Y:.+]] = spirv.Constant 1 : i32
100+
// CHECK: %__builtin__LocalInvocationId___addr = spirv.mlir.addressof @__builtin__LocalInvocationId__ : !spirv.ptr<vector<3xi32>, Input>
101+
// CHECK: %[[WORKGROUP:.+]] = spirv.Load "Input" %__builtin__LocalInvocationId___addr : vector<3xi32>
102+
// CHECK: %[[THREAD_X:.+]] = spirv.CompositeExtract %[[WORKGROUP]][0 : i32] : vector<3xi32>
103+
// CHECK: %__builtin__LocalInvocationId___addr_1 = spirv.mlir.addressof @__builtin__LocalInvocationId__ : !spirv.ptr<vector<3xi32>, Input>
104+
// CHECK: %[[WORKGROUP_1:.+]] = spirv.Load "Input" %__builtin__LocalInvocationId___addr_1 : vector<3xi32>
105+
// CHECK: %[[THREAD_Y:.+]] = spirv.CompositeExtract %[[WORKGROUP_1]][1 : i32] : vector<3xi32>
106+
// CHECK: %__builtin__LocalInvocationId___addr_2 = spirv.mlir.addressof @__builtin__LocalInvocationId__ : !spirv.ptr<vector<3xi32>, Input>
107+
// CHECK: %[[WORKGROUP_2:.+]] = spirv.Load "Input" %__builtin__LocalInvocationId___addr_2 : vector<3xi32>
108+
// CHECK: %[[THREAD_Z:.+]] = spirv.CompositeExtract %[[WORKGROUP_2]][2 : i32] : vector<3xi32>
109+
// CHECK: %[[S0:.+]] = spirv.IMul %[[THREAD_Z]], %[[BLOCK_SIZE_Y]] : i32
110+
// CHECK: %[[S1:.+]] = spirv.IAdd %[[S0]], %[[THREAD_Y]] : i32
111+
// CHECK: %[[S2:.+]] = spirv.IMul %[[S1]], %[[BLOCK_SIZE_X]] : i32
112+
// CHECK: %[[LANE_ID:.+]] = spirv.IAdd %[[S2]], %[[THREAD_X]] : i32
113+
// CHECK: %[[VAL_LANE_ID:.+]] = spirv.IAdd %[[LANE_ID]], %[[OFFSET]] : i32
114+
// CHECK: %[[VALID:.+]] = spirv.ULessThan %[[VAL_LANE_ID]], %[[WIDTH]] : i32
115+
97116
%result, %valid = gpu.shuffle down %val, %offset, %width : f32
98117
gpu.return
99118
}
@@ -122,6 +141,26 @@ gpu.module @kernels {
122141
// CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
123142
// CHECK: %{{.+}} = spirv.Constant true
124143
// CHECK: %{{.+}} = spirv.GroupNonUniformShuffleUp <Subgroup> %[[VAL]], %[[OFFSET]] : f32, i32
144+
145+
// CHECK: %[[BLOCK_SIZE_X:.+]] = spirv.Constant 16 : i32
146+
// CHECK: %[[BLOCK_SIZE_Y:.+]] = spirv.Constant 1 : i32
147+
// CHECK: %__builtin__LocalInvocationId___addr = spirv.mlir.addressof @__builtin__LocalInvocationId__ : !spirv.ptr<vector<3xi32>, Input>
148+
// CHECK: %[[WORKGROUP:.+]] = spirv.Load "Input" %__builtin__LocalInvocationId___addr : vector<3xi32>
149+
// CHECK: %[[THREAD_X:.+]] = spirv.CompositeExtract %[[WORKGROUP]][0 : i32] : vector<3xi32>
150+
// CHECK: %__builtin__LocalInvocationId___addr_1 = spirv.mlir.addressof @__builtin__LocalInvocationId__ : !spirv.ptr<vector<3xi32>, Input>
151+
// CHECK: %[[WORKGROUP_1:.+]] = spirv.Load "Input" %__builtin__LocalInvocationId___addr_1 : vector<3xi32>
152+
// CHECK: %[[THREAD_Y:.+]] = spirv.CompositeExtract %[[WORKGROUP_1]][1 : i32] : vector<3xi32>
153+
// CHECK: %__builtin__LocalInvocationId___addr_2 = spirv.mlir.addressof @__builtin__LocalInvocationId__ : !spirv.ptr<vector<3xi32>, Input>
154+
// CHECK: %[[WORKGROUP_2:.+]] = spirv.Load "Input" %__builtin__LocalInvocationId___addr_2 : vector<3xi32>
155+
// CHECK: %[[THREAD_Z:.+]] = spirv.CompositeExtract %[[WORKGROUP_2]][2 : i32] : vector<3xi32>
156+
// CHECK: %[[S0:.+]] = spirv.IMul %[[THREAD_Z]], %[[BLOCK_SIZE_Y]] : i32
157+
// CHECK: %[[S1:.+]] = spirv.IAdd %[[S0]], %[[THREAD_Y]] : i32
158+
// CHECK: %[[S2:.+]] = spirv.IMul %[[S1]], %[[BLOCK_SIZE_X]] : i32
159+
// CHECK: %[[LANE_ID:.+]] = spirv.IAdd %[[S2]], %[[THREAD_X]] : i32
160+
// CHECK: %[[VAL_LANE_ID:.+]] = spirv.ISub %[[LANE_ID]], %[[OFFSET]] : i32
161+
// CHECK: %[[CST0:.+]] = spirv.Constant 0 : i32
162+
// CHECK: %[[VALID:.+]] = spirv.SGreaterThanEqual %[[VAL_LANE_ID]], %[[CST0]] : i32
163+
125164
%result, %valid = gpu.shuffle up %val, %offset, %width : f32
126165
gpu.return
127166
}

0 commit comments

Comments
 (0)