Skip to content

[DirectX][SPIRV] Fix the lowering of dot4add #140315

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions clang/lib/CodeGen/CGHLSLBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -412,24 +412,28 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.dot");
}
case Builtin::BI__builtin_hlsl_dot4add_i8packed: {
Value *A = EmitScalarExpr(E->getArg(0));
Value *B = EmitScalarExpr(E->getArg(1));
Value *C = EmitScalarExpr(E->getArg(2));
Value *X = EmitScalarExpr(E->getArg(0));
Value *Y = EmitScalarExpr(E->getArg(1));
Value *Acc = EmitScalarExpr(E->getArg(2));

Intrinsic::ID ID = CGM.getHLSLRuntime().getDot4AddI8PackedIntrinsic();
// Note that the argument order disagrees between the builtin and the
// intrinsic here.
return Builder.CreateIntrinsic(
/*ReturnType=*/C->getType(), ID, ArrayRef<Value *>{A, B, C}, nullptr,
"hlsl.dot4add.i8packed");
/*ReturnType=*/Acc->getType(), ID, ArrayRef<Value *>{Acc, X, Y},
nullptr, "hlsl.dot4add.i8packed");
}
case Builtin::BI__builtin_hlsl_dot4add_u8packed: {
Value *A = EmitScalarExpr(E->getArg(0));
Value *B = EmitScalarExpr(E->getArg(1));
Value *C = EmitScalarExpr(E->getArg(2));
Value *X = EmitScalarExpr(E->getArg(0));
Value *Y = EmitScalarExpr(E->getArg(1));
Value *Acc = EmitScalarExpr(E->getArg(2));

Intrinsic::ID ID = CGM.getHLSLRuntime().getDot4AddU8PackedIntrinsic();
// Note that the argument order disagrees between the builtin and the
// intrinsic here.
return Builder.CreateIntrinsic(
/*ReturnType=*/C->getType(), ID, ArrayRef<Value *>{A, B, C}, nullptr,
"hlsl.dot4add.u8packed");
/*ReturnType=*/Acc->getType(), ID, ArrayRef<Value *>{Acc, X, Y},
nullptr, "hlsl.dot4add.u8packed");
}
case Builtin::BI__builtin_hlsl_elementwise_firstbithigh: {
Value *X = EmitScalarExpr(E->getArg(0));
Expand Down
29 changes: 18 additions & 11 deletions clang/test/CodeGenHLSL/builtins/dot4add_i8packed.hlsl
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
// RUN: %clang_cc1 -finclude-default-header -triple \
// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
// RUN: FileCheck %s -DTARGET=dx
// RUN: %clang_cc1 -finclude-default-header -triple \
// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
// RUN: FileCheck %s -DTARGET=spv
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.4-compute %s -emit-llvm -o - | FileCheck %s -DTARGET=dx
// RUN: %clang_cc1 -finclude-default-header -triple spirv-pc-vulkan-compute %s -emit-llvm -o - | FileCheck %s -DTARGET=spv

// Test basic lowering to runtime function call.

// CHECK-LABEL: test
int test(uint a, uint b, int c) {
// CHECK: %[[RET:.*]] = call [[TY:i32]] @llvm.[[TARGET]].dot4add.i8packed([[TY]] %[[#]], [[TY]] %[[#]], [[TY]] %[[#]])
// CHECK: ret [[TY]] %[[RET]]
return dot4add_i8packed(a, b, c);
int test(uint x, uint y, int acc) {
// CHECK: [[X_ADDR:%.*]] = alloca i32, align 4
// CHECK: [[Y_ADDR:%.*]] = alloca i32, align 4
// CHECK: [[ACC_ADDR:%.*]] = alloca i32, align 4
// CHECK: store i32 %x, ptr [[X_ADDR]], align 4
// CHECK: store i32 %y, ptr [[Y_ADDR]], align 4
// CHECK: store i32 %acc, ptr [[ACC_ADDR]], align 4
// CHECK: [[X0:%.*]] = load i32, ptr [[X_ADDR]], align 4
// CHECK: [[Y0:%.*]] = load i32, ptr [[Y_ADDR]], align 4
// CHECK: [[ACC0:%.*]] = load i32, ptr [[ACC_ADDR]], align 4
// CHECK: call i32 @llvm.[[TARGET]].dot4add.i8packed(i32 [[ACC0]], i32 [[X0]], i32 [[Y0]])
return dot4add_i8packed(x, y, acc);
}

// CHECK: declare [[TY]] @llvm.[[TARGET]].dot4add.i8packed([[TY]], [[TY]], [[TY]])
[numthreads(1,1,1)]
void main() {
test(0, 0, 0);
}
32 changes: 19 additions & 13 deletions clang/test/CodeGenHLSL/builtins/dot4add_u8packed.hlsl
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@

// RUN: %clang_cc1 -finclude-default-header -triple \
// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
// RUN: FileCheck %s -DTARGET=dx
// RUN: %clang_cc1 -finclude-default-header -triple \
// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
// RUN: FileCheck %s -DTARGET=spv
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.4-compute %s -emit-llvm -o - | FileCheck %s -DTARGET=dx
// RUN: %clang_cc1 -finclude-default-header -triple spirv-pc-vulkan-compute %s -emit-llvm -o - | FileCheck %s -DTARGET=spv

// Test basic lowering to runtime function call.

// CHECK-LABEL: define {{.*}}test
uint test(uint a, uint b, uint c) {
// CHECK: %[[RET:.*]] = call [[TY:i32]] @llvm.[[TARGET]].dot4add.u8packed([[TY]] %[[#]], [[TY]] %[[#]], [[TY]] %[[#]])
// CHECK: ret [[TY]] %[[RET]]
return dot4add_u8packed(a, b, c);
// CHECK-LABEL: test
int test(uint x, uint y, int acc) {
// CHECK: [[X_ADDR:%.*]] = alloca i32, align 4
// CHECK: [[Y_ADDR:%.*]] = alloca i32, align 4
// CHECK: [[ACC_ADDR:%.*]] = alloca i32, align 4
// CHECK: store i32 %x, ptr [[X_ADDR]], align 4
// CHECK: store i32 %y, ptr [[Y_ADDR]], align 4
// CHECK: store i32 %acc, ptr [[ACC_ADDR]], align 4
// CHECK: [[X0:%.*]] = load i32, ptr [[X_ADDR]], align 4
// CHECK: [[Y0:%.*]] = load i32, ptr [[Y_ADDR]], align 4
// CHECK: [[ACC0:%.*]] = load i32, ptr [[ACC_ADDR]], align 4
// CHECK: call i32 @llvm.[[TARGET]].dot4add.u8packed(i32 [[ACC0]], i32 [[X0]], i32 [[Y0]])
return dot4add_u8packed(x, y, acc);
}

// CHECK: declare [[TY]] @llvm.[[TARGET]].dot4add.u8packed([[TY]], [[TY]], [[TY]])
[numthreads(1,1,1)]
void main() {
test(0, 0, 0);
}
14 changes: 8 additions & 6 deletions llvm/lib/Target/DirectX/DXIL.td
Original file line number Diff line number Diff line change
Expand Up @@ -1119,19 +1119,21 @@ def Dot4AddI8Packed : DXILOp<163, dot4AddPacked> {
"accumulate to i32";
let intrinsics = [IntrinSelect<int_dx_dot4add_i8packed>];
let arguments = [Int32Ty, Int32Ty, Int32Ty];
let result = Int32Ty;
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
let result = OverloadTy;
let overloads = [Overloads<DXIL1_4, [Int32Ty]>];
let stages = [Stages<DXIL1_4, [all_stages]>];
let attributes = [Attributes<DXIL1_4, [ReadNone]>];
}

def Dot4AddU8Packed : DXILOp<164, dot4AddPacked> {
let Doc = "unsigned dot product of 4 x i8 vectors packed into i32, with "
"accumulate to i32";
let intrinsics = [IntrinSelect<int_dx_dot4add_u8packed>];
let arguments = [Int32Ty, Int32Ty, Int32Ty];
let result = Int32Ty;
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
let result = OverloadTy;
let overloads = [Overloads<DXIL1_4, [Int32Ty]>];
let stages = [Stages<DXIL1_4, [all_stages]>];
let attributes = [Attributes<DXIL1_4, [ReadNone]>];
}

def AnnotateHandle : DXILOp<216, annotateHandle> {
Expand Down
20 changes: 13 additions & 7 deletions llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2021,20 +2021,24 @@ bool SPIRVInstructionSelector::selectDot4AddPacked(Register ResVReg,
assert(I.getOperand(4).isReg());
MachineBasicBlock &BB = *I.getParent();

Register Acc = I.getOperand(2).getReg();
Register X = I.getOperand(3).getReg();
Register Y = I.getOperand(4).getReg();

auto DotOp = Signed ? SPIRV::OpSDot : SPIRV::OpUDot;
Register Dot = MRI->createVirtualRegister(GR.getRegClass(ResType));
bool Result = BuildMI(BB, I, I.getDebugLoc(), TII.get(DotOp))
.addDef(Dot)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(I.getOperand(2).getReg())
.addUse(I.getOperand(3).getReg())
.addUse(X)
.addUse(Y)
.constrainAllUses(TII, TRI, RBI);

return Result && BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpIAddS))
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(Dot)
.addUse(I.getOperand(4).getReg())
.addUse(Acc)
.constrainAllUses(TII, TRI, RBI);
}

Expand All @@ -2052,8 +2056,10 @@ bool SPIRVInstructionSelector::selectDot4AddPackedExpansion(

bool Result = true;

// Acc = C
Register Acc = I.getOperand(4).getReg();
Register Acc = I.getOperand(2).getReg();
Register X = I.getOperand(3).getReg();
Register Y = I.getOperand(4).getReg();

SPIRVType *EltType = GR.getOrCreateSPIRVIntegerType(8, I, TII);
auto ExtractOp =
Signed ? SPIRV::OpBitFieldSExtract : SPIRV::OpBitFieldUExtract;
Expand All @@ -2067,7 +2073,7 @@ bool SPIRVInstructionSelector::selectDot4AddPackedExpansion(
BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp))
.addDef(AElt)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(I.getOperand(2).getReg())
.addUse(X)
.addUse(GR.getOrCreateConstInt(i * 8, I, EltType, TII, ZeroAsNull))
.addUse(GR.getOrCreateConstInt(8, I, EltType, TII, ZeroAsNull))
.constrainAllUses(TII, TRI, RBI);
Expand All @@ -2078,7 +2084,7 @@ bool SPIRVInstructionSelector::selectDot4AddPackedExpansion(
BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp))
.addDef(BElt)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(I.getOperand(3).getReg())
.addUse(Y)
.addUse(GR.getOrCreateConstInt(i * 8, I, EltType, TII, ZeroAsNull))
.addUse(GR.getOrCreateConstInt(8, I, EltType, TII, ZeroAsNull))
.constrainAllUses(TII, TRI, RBI);
Expand Down
8 changes: 4 additions & 4 deletions llvm/test/CodeGen/DirectX/dot4add_i8packed.ll
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s | FileCheck %s
; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.4-compute %s | FileCheck %s

define void @main(i32 %a, i32 %b, i32 %c) {
define void @main(i32 %acc, i32 %x, i32 %y) {
entry:
; CHECK: call i32 @dx.op.dot4AddPacked(i32 163, i32 %a, i32 %b, i32 %c) #[[#ATTR:]]
%0 = call i32 @llvm.dx.dot4add.i8packed(i32 %a, i32 %b, i32 %c)
; CHECK: call i32 @dx.op.dot4AddPacked.i32(i32 163, i32 %acc, i32 %x, i32 %y) #[[#ATTR:]]
%0 = call i32 @llvm.dx.dot4add.i8packed(i32 %acc, i32 %x, i32 %y)
ret void
}

Expand Down
10 changes: 10 additions & 0 deletions llvm/test/CodeGen/DirectX/dot4add_i8packed_error.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s 2>&1 | FileCheck %s

; CHECK: in function f
; CHECK-SAME: Cannot create Dot4AddI8Packed operation: No valid overloads for DXIL version 1.3

define void @f(i32 %acc, i32 %x, i32 %y) {
entry:
%0 = call i32 @llvm.dx.dot4add.i8packed(i32 %acc, i32 %x, i32 %y)
ret void
}
8 changes: 4 additions & 4 deletions llvm/test/CodeGen/DirectX/dot4add_u8packed.ll
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s | FileCheck %s
; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.4-compute %s | FileCheck %s

define void @main(i32 %a, i32 %b, i32 %c) {
define void @main(i32 %acc, i32 %x, i32 %y) {
entry:
; CHECK: call i32 @dx.op.dot4AddPacked(i32 164, i32 %a, i32 %b, i32 %c) #[[#ATTR:]]
%0 = call i32 @llvm.dx.dot4add.u8packed(i32 %a, i32 %b, i32 %c)
; CHECK: call i32 @dx.op.dot4AddPacked.i32(i32 164, i32 %acc, i32 %x, i32 %y) #[[#ATTR:]]
%0 = call i32 @llvm.dx.dot4add.u8packed(i32 %acc, i32 %x, i32 %y)
ret void
}

Expand Down
10 changes: 10 additions & 0 deletions llvm/test/CodeGen/DirectX/dot4add_u8packed_error.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s 2>&1 | FileCheck %s

; CHECK: in function f
; CHECK-SAME: Cannot create Dot4AddU8Packed operation: No valid overloads for DXIL version 1.3

define void @f(i32 %acc, i32 %x, i32 %y) {
entry:
%0 = call i32 @llvm.dx.dot4add.u8packed(i32 %acc, i32 %x, i32 %y)
ret void
}
40 changes: 20 additions & 20 deletions llvm/test/CodeGen/SPIRV/hlsl-intrinsics/dot4add_i8packed.ll
Original file line number Diff line number Diff line change
Expand Up @@ -17,49 +17,49 @@
; CHECK-EXP-DAG: %[[#twentyfour:]] = OpConstant %[[#int_8]] 24

; CHECK-LABEL: Begin function test_dot
define noundef i32 @test_dot(i32 noundef %a, i32 noundef %b, i32 noundef %c) {
define noundef i32 @test_dot(i32 noundef %acc, i32 noundef %x, i32 noundef %y) {
entry:
; CHECK: %[[#A:]] = OpFunctionParameter %[[#int_32]]
; CHECK: %[[#B:]] = OpFunctionParameter %[[#int_32]]
; CHECK: %[[#C:]] = OpFunctionParameter %[[#int_32]]
; CHECK: %[[#ACC:]] = OpFunctionParameter %[[#int_32]]
; CHECK: %[[#X:]] = OpFunctionParameter %[[#int_32]]
; CHECK: %[[#Y:]] = OpFunctionParameter %[[#int_32]]

; Test that we use the dot product op when capabilities allow

; CHECK-DOT: %[[#DOT:]] = OpSDot %[[#int_32]] %[[#A]] %[[#B]]
; CHECK-DOT: %[[#RES:]] = OpIAdd %[[#int_32]] %[[#DOT]] %[[#C]]
; CHECK-DOT: %[[#DOT:]] = OpSDot %[[#int_32]] %[[#X]] %[[#Y]]
; CHECK-DOT: %[[#RES:]] = OpIAdd %[[#int_32]] %[[#DOT]] %[[#ACC]]

; Test expansion is used when spirv dot product capabilities aren't available:

; First element of the packed vector
; CHECK-EXP: %[[#A0:]] = OpBitFieldSExtract %[[#int_32]] %[[#A]] %[[#zero]] %[[#eight]]
; CHECK-EXP: %[[#B0:]] = OpBitFieldSExtract %[[#int_32]] %[[#B]] %[[#zero]] %[[#eight]]
; CHECK-EXP: %[[#MUL0:]] = OpIMul %[[#int_32]] %[[#A0]] %[[#B0]]
; CHECK-EXP: %[[#X0:]] = OpBitFieldSExtract %[[#int_32]] %[[#X]] %[[#zero]] %[[#eight]]
; CHECK-EXP: %[[#Y0:]] = OpBitFieldSExtract %[[#int_32]] %[[#Y]] %[[#zero]] %[[#eight]]
; CHECK-EXP: %[[#MUL0:]] = OpIMul %[[#int_32]] %[[#X0]] %[[#Y0]]
; CHECK-EXP: %[[#MASK0:]] = OpBitFieldSExtract %[[#int_32]] %[[#MUL0]] %[[#zero]] %[[#eight]]
; CHECK-EXP: %[[#ACC0:]] = OpIAdd %[[#int_32]] %[[#C]] %[[#MASK0]]
; CHECK-EXP: %[[#ACC0:]] = OpIAdd %[[#int_32]] %[[#ACC]] %[[#MASK0]]

; Second element of the packed vector
; CHECK-EXP: %[[#A1:]] = OpBitFieldSExtract %[[#int_32]] %[[#A]] %[[#eight]] %[[#eight]]
; CHECK-EXP: %[[#B1:]] = OpBitFieldSExtract %[[#int_32]] %[[#B]] %[[#eight]] %[[#eight]]
; CHECK-EXP: %[[#MUL1:]] = OpIMul %[[#int_32]] %[[#A1]] %[[#B1]]
; CHECK-EXP: %[[#X1:]] = OpBitFieldSExtract %[[#int_32]] %[[#X]] %[[#eight]] %[[#eight]]
; CHECK-EXP: %[[#Y1:]] = OpBitFieldSExtract %[[#int_32]] %[[#Y]] %[[#eight]] %[[#eight]]
; CHECK-EXP: %[[#MUL1:]] = OpIMul %[[#int_32]] %[[#X1]] %[[#Y1]]
; CHECK-EXP: %[[#MASK1:]] = OpBitFieldSExtract %[[#int_32]] %[[#MUL1]] %[[#zero]] %[[#eight]]
; CHECK-EXP: %[[#ACC1:]] = OpIAdd %[[#int_32]] %[[#ACC0]] %[[#MASK1]]

; Third element of the packed vector
; CHECK-EXP: %[[#A2:]] = OpBitFieldSExtract %[[#int_32]] %[[#A]] %[[#sixteen]] %[[#eight]]
; CHECK-EXP: %[[#B2:]] = OpBitFieldSExtract %[[#int_32]] %[[#B]] %[[#sixteen]] %[[#eight]]
; CHECK-EXP: %[[#MUL2:]] = OpIMul %[[#int_32]] %[[#A2]] %[[#B2]]
; CHECK-EXP: %[[#X2:]] = OpBitFieldSExtract %[[#int_32]] %[[#X]] %[[#sixteen]] %[[#eight]]
; CHECK-EXP: %[[#Y2:]] = OpBitFieldSExtract %[[#int_32]] %[[#Y]] %[[#sixteen]] %[[#eight]]
; CHECK-EXP: %[[#MUL2:]] = OpIMul %[[#int_32]] %[[#X2]] %[[#Y2]]
; CHECK-EXP: %[[#MASK2:]] = OpBitFieldSExtract %[[#int_32]] %[[#MUL2]] %[[#zero]] %[[#eight]]
; CHECK-EXP: %[[#ACC2:]] = OpIAdd %[[#int_32]] %[[#ACC1]] %[[#MASK2]]

; Fourth element of the packed vector
; CHECK-EXP: %[[#A3:]] = OpBitFieldSExtract %[[#int_32]] %[[#A]] %[[#twentyfour]] %[[#eight]]
; CHECK-EXP: %[[#B3:]] = OpBitFieldSExtract %[[#int_32]] %[[#B]] %[[#twentyfour]] %[[#eight]]
; CHECK-EXP: %[[#MUL3:]] = OpIMul %[[#int_32]] %[[#A3]] %[[#B3]]
; CHECK-EXP: %[[#X3:]] = OpBitFieldSExtract %[[#int_32]] %[[#X]] %[[#twentyfour]] %[[#eight]]
; CHECK-EXP: %[[#Y3:]] = OpBitFieldSExtract %[[#int_32]] %[[#Y]] %[[#twentyfour]] %[[#eight]]
; CHECK-EXP: %[[#MUL3:]] = OpIMul %[[#int_32]] %[[#X3]] %[[#Y3]]
; CHECK-EXP: %[[#MASK3:]] = OpBitFieldSExtract %[[#int_32]] %[[#MUL3]] %[[#zero]] %[[#eight]]

; CHECK-EXP: %[[#RES:]] = OpIAdd %[[#int_32]] %[[#ACC2]] %[[#MASK3]]
; CHECK: OpReturnValue %[[#RES]]
%spv.dot = call i32 @llvm.spv.dot4add.i8packed(i32 %a, i32 %b, i32 %c)
%spv.dot = call i32 @llvm.spv.dot4add.i8packed(i32 %acc, i32 %x, i32 %y)

ret i32 %spv.dot
}
40 changes: 20 additions & 20 deletions llvm/test/CodeGen/SPIRV/hlsl-intrinsics/dot4add_u8packed.ll
Original file line number Diff line number Diff line change
Expand Up @@ -17,49 +17,49 @@
; CHECK-EXP-DAG: %[[#twentyfour:]] = OpConstant %[[#int_8]] 24

; CHECK-LABEL: Begin function test_dot
define noundef i32 @test_dot(i32 noundef %a, i32 noundef %b, i32 noundef %c) {
define noundef i32 @test_dot(i32 noundef %acc, i32 noundef %x, i32 noundef %y) {
entry:
; CHECK: %[[#A:]] = OpFunctionParameter %[[#int_32]]
; CHECK: %[[#B:]] = OpFunctionParameter %[[#int_32]]
; CHECK: %[[#C:]] = OpFunctionParameter %[[#int_32]]
; CHECK: %[[#ACC:]] = OpFunctionParameter %[[#int_32]]
; CHECK: %[[#X:]] = OpFunctionParameter %[[#int_32]]
; CHECK: %[[#Y:]] = OpFunctionParameter %[[#int_32]]

; Test that we use the dot product op when capabilities allow

; CHECK-DOT: %[[#DOT:]] = OpUDot %[[#int_32]] %[[#A]] %[[#B]]
; CHECK-DOT: %[[#RES:]] = OpIAdd %[[#int_32]] %[[#DOT]] %[[#C]]
; CHECK-DOT: %[[#DOT:]] = OpUDot %[[#int_32]] %[[#X]] %[[#Y]]
; CHECK-DOT: %[[#RES:]] = OpIAdd %[[#int_32]] %[[#DOT]] %[[#ACC]]

; Test expansion is used when spirv dot product capabilities aren't available:

; First element of the packed vector
; CHECK-EXP: %[[#A0:]] = OpBitFieldUExtract %[[#int_32]] %[[#A]] %[[#zero]] %[[#eight]]
; CHECK-EXP: %[[#B0:]] = OpBitFieldUExtract %[[#int_32]] %[[#B]] %[[#zero]] %[[#eight]]
; CHECK-EXP: %[[#MUL0:]] = OpIMul %[[#int_32]] %[[#A0]] %[[#B0]]
; CHECK-EXP: %[[#X0:]] = OpBitFieldUExtract %[[#int_32]] %[[#X]] %[[#zero]] %[[#eight]]
; CHECK-EXP: %[[#Y0:]] = OpBitFieldUExtract %[[#int_32]] %[[#Y]] %[[#zero]] %[[#eight]]
; CHECK-EXP: %[[#MUL0:]] = OpIMul %[[#int_32]] %[[#X0]] %[[#Y0]]
; CHECK-EXP: %[[#MASK0:]] = OpBitFieldUExtract %[[#int_32]] %[[#MUL0]] %[[#zero]] %[[#eight]]
; CHECK-EXP: %[[#ACC0:]] = OpIAdd %[[#int_32]] %[[#C]] %[[#MASK0]]
; CHECK-EXP: %[[#ACC0:]] = OpIAdd %[[#int_32]] %[[#ACC]] %[[#MASK0]]

; Second element of the packed vector
; CHECK-EXP: %[[#A1:]] = OpBitFieldUExtract %[[#int_32]] %[[#A]] %[[#eight]] %[[#eight]]
; CHECK-EXP: %[[#B1:]] = OpBitFieldUExtract %[[#int_32]] %[[#B]] %[[#eight]] %[[#eight]]
; CHECK-EXP: %[[#MUL1:]] = OpIMul %[[#int_32]] %[[#A1]] %[[#B1]]
; CHECK-EXP: %[[#X1:]] = OpBitFieldUExtract %[[#int_32]] %[[#X]] %[[#eight]] %[[#eight]]
; CHECK-EXP: %[[#Y1:]] = OpBitFieldUExtract %[[#int_32]] %[[#Y]] %[[#eight]] %[[#eight]]
; CHECK-EXP: %[[#MUL1:]] = OpIMul %[[#int_32]] %[[#X1]] %[[#Y1]]
; CHECK-EXP: %[[#MASK1:]] = OpBitFieldUExtract %[[#int_32]] %[[#MUL1]] %[[#zero]] %[[#eight]]
; CHECK-EXP: %[[#ACC1:]] = OpIAdd %[[#int_32]] %[[#ACC0]] %[[#MASK1]]

; Third element of the packed vector
; CHECK-EXP: %[[#A2:]] = OpBitFieldUExtract %[[#int_32]] %[[#A]] %[[#sixteen]] %[[#eight]]
; CHECK-EXP: %[[#B2:]] = OpBitFieldUExtract %[[#int_32]] %[[#B]] %[[#sixteen]] %[[#eight]]
; CHECK-EXP: %[[#MUL2:]] = OpIMul %[[#int_32]] %[[#A2]] %[[#B2]]
; CHECK-EXP: %[[#X2:]] = OpBitFieldUExtract %[[#int_32]] %[[#X]] %[[#sixteen]] %[[#eight]]
; CHECK-EXP: %[[#Y2:]] = OpBitFieldUExtract %[[#int_32]] %[[#Y]] %[[#sixteen]] %[[#eight]]
; CHECK-EXP: %[[#MUL2:]] = OpIMul %[[#int_32]] %[[#X2]] %[[#Y2]]
; CHECK-EXP: %[[#MASK2:]] = OpBitFieldUExtract %[[#int_32]] %[[#MUL2]] %[[#zero]] %[[#eight]]
; CHECK-EXP: %[[#ACC2:]] = OpIAdd %[[#int_32]] %[[#ACC1]] %[[#MASK2]]

; Fourth element of the packed vector
; CHECK-EXP: %[[#A3:]] = OpBitFieldUExtract %[[#int_32]] %[[#A]] %[[#twentyfour]] %[[#eight]]
; CHECK-EXP: %[[#B3:]] = OpBitFieldUExtract %[[#int_32]] %[[#B]] %[[#twentyfour]] %[[#eight]]
; CHECK-EXP: %[[#MUL3:]] = OpIMul %[[#int_32]] %[[#A3]] %[[#B3]]
; CHECK-EXP: %[[#X3:]] = OpBitFieldUExtract %[[#int_32]] %[[#X]] %[[#twentyfour]] %[[#eight]]
; CHECK-EXP: %[[#Y3:]] = OpBitFieldUExtract %[[#int_32]] %[[#Y]] %[[#twentyfour]] %[[#eight]]
; CHECK-EXP: %[[#MUL3:]] = OpIMul %[[#int_32]] %[[#X3]] %[[#Y3]]
; CHECK-EXP: %[[#MASK3:]] = OpBitFieldUExtract %[[#int_32]] %[[#MUL3]] %[[#zero]] %[[#eight]]

; CHECK-EXP: %[[#RES:]] = OpIAdd %[[#int_32]] %[[#ACC2]] %[[#MASK3]]
; CHECK: OpReturnValue %[[#RES]]
%spv.dot = call i32 @llvm.spv.dot4add.u8packed(i32 %a, i32 %b, i32 %c)
%spv.dot = call i32 @llvm.spv.dot4add.u8packed(i32 %acc, i32 %x, i32 %y)

ret i32 %spv.dot
}
Loading