Skip to content

Commit 2e7ed78

Browse files
authored
[mlir][spirv] Add instruction OpGroupNonUniformRotateKHR (#133428)
Add an instruction under the extension SPV_KHR_subgroup_rotate. The specification for the extension is here: https://github.khronos.org/SPIRV-Registry/extensions/KHR/SPV_KHR_subgroup_rotate.html
1 parent 662d385 commit 2e7ed78

File tree

4 files changed

+167
-1
lines changed

4 files changed

+167
-1
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4489,6 +4489,7 @@ def SPIRV_OC_OpGroupNonUniformBitwiseXor : I32EnumAttrCase<"OpGroupNonUnifo
44894489
def SPIRV_OC_OpGroupNonUniformLogicalAnd : I32EnumAttrCase<"OpGroupNonUniformLogicalAnd", 362>;
44904490
def SPIRV_OC_OpGroupNonUniformLogicalOr : I32EnumAttrCase<"OpGroupNonUniformLogicalOr", 363>;
44914491
def SPIRV_OC_OpGroupNonUniformLogicalXor : I32EnumAttrCase<"OpGroupNonUniformLogicalXor", 364>;
4492+
def SPIRV_OC_OpGroupNonUniformRotateKHR : I32EnumAttrCase<"OpGroupNonUniformRotateKHR", 4431>;
44924493
def SPIRV_OC_OpSubgroupBallotKHR : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>;
44934494
def SPIRV_OC_OpSDot : I32EnumAttrCase<"OpSDot", 4450>;
44944495
def SPIRV_OC_OpUDot : I32EnumAttrCase<"OpUDot", 4451>;
@@ -4598,7 +4599,8 @@ def SPIRV_OpcodeAttr :
45984599
SPIRV_OC_OpGroupNonUniformFMax, SPIRV_OC_OpGroupNonUniformBitwiseAnd,
45994600
SPIRV_OC_OpGroupNonUniformBitwiseOr, SPIRV_OC_OpGroupNonUniformBitwiseXor,
46004601
SPIRV_OC_OpGroupNonUniformLogicalAnd, SPIRV_OC_OpGroupNonUniformLogicalOr,
4601-
SPIRV_OC_OpGroupNonUniformLogicalXor, SPIRV_OC_OpSubgroupBallotKHR,
4602+
SPIRV_OC_OpGroupNonUniformLogicalXor, SPIRV_OC_OpGroupNonUniformRotateKHR,
4603+
SPIRV_OC_OpSubgroupBallotKHR,
46024604
SPIRV_OC_OpSDot, SPIRV_OC_OpUDot, SPIRV_OC_OpSUDot, SPIRV_OC_OpSDotAccSat,
46034605
SPIRV_OC_OpUDotAccSat, SPIRV_OC_OpSUDotAccSat,
46044606
SPIRV_OC_OpTypeCooperativeMatrixKHR, SPIRV_OC_OpCooperativeMatrixLoadKHR,

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1361,4 +1361,78 @@ def SPIRV_GroupNonUniformBallotBitCountOp : SPIRV_Op<"GroupNonUniformBallotBitCo
13611361

13621362
// -----
13631363

1364+
def SPIRV_GroupNonUniformRotateKHROp : SPIRV_Op<"GroupNonUniformRotateKHR", [
1365+
Pure, AllTypesMatch<["value", "result"]>]> {
1366+
let summary = [{
1367+
Rotate values across invocations within a subgroup.
1368+
}];
1369+
1370+
let description = [{
1371+
Return the Value of the invocation whose id within the group is calculated
1372+
as follows:
1373+
1374+
LocalId = SubgroupLocalInvocationId if Execution is Subgroup or
1375+
LocalInvocationId if Execution is Workgroup
1376+
RotationGroupSize = ClusterSize when ClusterSize is present, otherwise
1377+
RotationGroupSize = SubgroupMaxSize if the Kernel capability is declared
1378+
and SubgroupSize if not.
1379+
Invocation ID = ( (LocalId + Delta) & (RotationGroupSize - 1) ) +
1380+
(LocalId & ~(RotationGroupSize - 1))
1381+
1382+
Result Type must be a scalar or vector of floating-point type, integer
1383+
type, or Boolean type.
1384+
1385+
Execution is a Scope. It must be either Workgroup or Subgroup.
1386+
1387+
The type of Value must be the same as Result Type.
1388+
1389+
Delta must be a scalar of integer type, whose Signedness operand is 0.
1390+
Delta must be dynamically uniform within Execution.
1391+
1392+
Delta is treated as unsigned and the resulting value is undefined if the
1393+
selected lane is inactive.
1394+
1395+
ClusterSize is the size of cluster to use. ClusterSize must be a scalar of
1396+
integer type, whose Signedness operand is 0. ClusterSize must come from a
1397+
constant instruction. Behavior is undefined unless ClusterSize is at least
1398+
1 and a power of 2. If ClusterSize is greater than the declared
1399+
SubGroupSize, executing this instruction results in undefined behavior.
1400+
1401+
<!-- End of AutoGen section -->
1402+
1403+
#### Example:
1404+
1405+
```mlir
1406+
%four = spirv.Constant 4 : i32
1407+
%0 = spirv.GroupNonUniformRotateKHR <Subgroup>, %value, %delta : f32, i32 -> f32
1408+
%1 = spirv.GroupNonUniformRotateKHR <Workgroup>, %value, %delta,
1409+
clustersize(%four) : f32, i32, i32 -> f32
1410+
```
1411+
}];
1412+
1413+
let availability = [
1414+
MinVersion<SPIRV_V_1_3>,
1415+
MaxVersion<SPIRV_V_1_6>,
1416+
Extension<[]>,
1417+
Capability<[SPIRV_C_GroupNonUniformRotateKHR]>
1418+
];
1419+
1420+
let arguments = (ins
1421+
SPIRV_ScopeAttr:$execution_scope,
1422+
AnyTypeOf<[SPIRV_ScalarOrVectorOf<SPIRV_Float>, SPIRV_ScalarOrVectorOf<SPIRV_Integer>, SPIRV_ScalarOrVectorOf<SPIRV_Bool>]>:$value,
1423+
SPIRV_SignlessOrUnsignedInt:$delta,
1424+
Optional<SPIRV_SignlessOrUnsignedInt>:$cluster_size
1425+
);
1426+
1427+
let results = (outs
1428+
AnyTypeOf<[SPIRV_ScalarOrVectorOf<SPIRV_Float>, SPIRV_ScalarOrVectorOf<SPIRV_Integer>, SPIRV_ScalarOrVectorOf<SPIRV_Bool>]>:$result
1429+
);
1430+
1431+
let assemblyFormat = [{
1432+
$execution_scope `,` $value `,` $delta (`,` `cluster_size` `(` $cluster_size^ `)`)? attr-dict `:` type($value) `,` type($delta) (`,` type($cluster_size)^)? `->` type(results)
1433+
}];
1434+
}
1435+
1436+
// -----
1437+
13641438
#endif // MLIR_DIALECT_SPIRV_IR_NON_UNIFORM_OPS

mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,29 @@ LogicalResult GroupNonUniformLogicalXorOp::verify() {
304304
return verifyGroupNonUniformArithmeticOp<GroupNonUniformLogicalXorOp>(*this);
305305
}
306306

307+
//===----------------------------------------------------------------------===//
308+
// spirv.GroupNonUniformRotateKHR
309+
//===----------------------------------------------------------------------===//
310+
311+
LogicalResult GroupNonUniformRotateKHROp::verify() {
312+
spirv::Scope scope = getExecutionScope();
313+
if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
314+
return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
315+
316+
if (Value clusterSizeVal = getClusterSize()) {
317+
mlir::Operation *defOp = clusterSizeVal.getDefiningOp();
318+
int32_t clusterSize = 0;
319+
320+
if (failed(extractValueFromConstOp(defOp, clusterSize)))
321+
return emitOpError("cluster size operand must come from a constant op");
322+
323+
if (!llvm::isPowerOf2_32(clusterSize))
324+
return emitOpError("cluster size operand must be a power of two");
325+
}
326+
327+
return success();
328+
}
329+
307330
//===----------------------------------------------------------------------===//
308331
// Group op verification
309332
//===----------------------------------------------------------------------===//

mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -604,3 +604,70 @@ func.func @group_non_uniform_logical_xor(%val: i32) -> i32 {
604604
%0 = spirv.GroupNonUniformLogicalXor <Workgroup> <Reduce> %val : i32 -> i32
605605
return %0: i32
606606
}
607+
608+
// -----
609+
610+
//===----------------------------------------------------------------------===//
611+
// spirv.GroupNonUniformRotateKHR
612+
//===----------------------------------------------------------------------===//
613+
614+
// CHECK-LABEL: @group_non_uniform_rotate_khr
615+
func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32) -> f32 {
616+
// CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR <Subgroup>, %{{.+}} : f32, i32 -> f32
617+
%0 = spirv.GroupNonUniformRotateKHR <Subgroup>, %val, %delta : f32, i32 -> f32
618+
return %0: f32
619+
}
620+
621+
// -----
622+
623+
// CHECK-LABEL: @group_non_uniform_rotate_khr
624+
func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32) -> f32 {
625+
// CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR <Workgroup>, %{{.+}} : f32, i32, i32 -> f32
626+
%four = spirv.Constant 4 : i32
627+
%0 = spirv.GroupNonUniformRotateKHR <Workgroup>, %val, %delta, cluster_size(%four) : f32, i32, i32 -> f32
628+
return %0: f32
629+
}
630+
631+
// -----
632+
633+
func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32) -> f32 {
634+
%four = spirv.Constant 4 : i32
635+
// expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
636+
%0 = spirv.GroupNonUniformRotateKHR <Device>, %val, %delta, cluster_size(%four) : f32, i32, i32 -> f32
637+
return %0: f32
638+
}
639+
640+
// -----
641+
642+
func.func @group_non_uniform_rotate_khr(%val: f32, %delta: si32) -> f32 {
643+
%four = spirv.Constant 4 : i32
644+
// expected-error @+1 {{op operand #1 must be 8/16/32/64-bit signless/unsigned integer, but got 'si32'}}
645+
%0 = spirv.GroupNonUniformRotateKHR <Subgroup>, %val, %delta, cluster_size(%four) : f32, si32, i32 -> f32
646+
return %0: f32
647+
}
648+
649+
// -----
650+
651+
func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32) -> f32 {
652+
%four = spirv.Constant 4 : si32
653+
// expected-error @+1 {{op operand #2 must be 8/16/32/64-bit signless/unsigned integer, but got 'si32'}}
654+
%0 = spirv.GroupNonUniformRotateKHR <Subgroup>, %val, %delta, cluster_size(%four) : f32, i32, si32 -> f32
655+
return %0: f32
656+
}
657+
658+
// -----
659+
660+
func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32, %four: i32) -> f32 {
661+
// expected-error @+1 {{cluster size operand must come from a constant op}}
662+
%0 = spirv.GroupNonUniformRotateKHR <Subgroup>, %val, %delta, cluster_size(%four) : f32, i32, i32 -> f32
663+
return %0: f32
664+
}
665+
666+
// -----
667+
668+
func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32) -> f32 {
669+
%five = spirv.Constant 5 : i32
670+
// expected-error @+1 {{cluster size operand must be a power of two}}
671+
%0 = spirv.GroupNonUniformRotateKHR <Subgroup>, %val, %delta, cluster_size(%five) : f32, i32, i32 -> f32
672+
return %0: f32
673+
}

0 commit comments

Comments
 (0)