@@ -2818,6 +2818,88 @@ class SPIRVExpectINTELInstBase : public SPIRVInstTemplateBase {
2818
2818
_SPIRV_OP_INTERNAL (ExpectINTEL, true , 5 )
2819
2819
#undef _SPIRV_OP_INTERNAL
2820
2820
2821
+ class SPIRVDotKHRBase : public SPIRVInstTemplateBase {
2822
+ protected:
2823
+ SPIRVCapVec getRequiredCapability () const override {
2824
+ // Both vector operands must have the same type, so analyzing the
2825
+ // first operand will suffice.
2826
+ SPIRVCapabilityKind ArgCap = getRequiredCapabilityForOperand (Ops[0 ]);
2827
+ return getVec (ArgCap, CapabilityDotProductKHR);
2828
+ }
2829
+
2830
+ llvm::Optional<ExtensionID> getRequiredExtension () const override {
2831
+ return ExtensionID::SPV_KHR_integer_dot_product;
2832
+ }
2833
+
2834
+ void validate () const override {
2835
+ SPIRVInstruction::validate ();
2836
+ SPIRVId Vec1 = Ops[0 ];
2837
+ SPIRVId Vec2 = Ops[1 ];
2838
+ (void )Vec1;
2839
+ (void )Vec2;
2840
+
2841
+ assert (getValueType (Vec1) == getValueType (Vec2) &&
2842
+ " Input vectors must have the same type" );
2843
+ assert (getType ()->isTypeInt () && " Result type must be an integer type" );
2844
+ assert (!getType ()->isTypeVector () && " Result type must be scalar" );
2845
+ }
2846
+
2847
+ private:
2848
+ bool isAccSat () const {
2849
+ return (OpCode == OpSDotAccSatKHR || OpCode == OpUDotAccSatKHR ||
2850
+ OpCode == OpSUDotAccSatKHR);
2851
+ }
2852
+
2853
+ Optional<PackedVectorFormat> getPackedVectorFormat () const {
2854
+ size_t PackFmtIdx = 2 ;
2855
+ if (isAccSat ()) {
2856
+ // AccSat instructions have an additional Accumulator operand.
2857
+ PackFmtIdx++;
2858
+ }
2859
+
2860
+ if (PackFmtIdx == Ops.size () - 1 )
2861
+ return static_cast <PackedVectorFormat>(Ops[PackFmtIdx]);
2862
+
2863
+ return None;
2864
+ }
2865
+
2866
+ SPIRVCapabilityKind getRequiredCapabilityForOperand (SPIRVId ArgId) const {
2867
+ const SPIRVType *T = getValueType (ArgId);
2868
+ if (auto PackFmt = getPackedVectorFormat ()) {
2869
+ switch (*PackFmt) {
2870
+ case PackedVectorFormatPackedVectorFormat4x8BitKHR:
2871
+ assert (!T->isTypeVector () && T->isTypeInt () && T->getBitWidth () == 32 &&
2872
+ " Type does not match pack format" );
2873
+ return CapabilityDotProductInput4x8BitPackedKHR;
2874
+ case PackedVectorFormatMax:
2875
+ break ;
2876
+ }
2877
+ llvm_unreachable (" Unknown Packed Vector Format" );
2878
+ }
2879
+
2880
+ if (T->isTypeVector ()) {
2881
+ const SPIRVType *EltT = T->getVectorComponentType ();
2882
+ if (T->getVectorComponentCount () == 4 && EltT->isTypeInt () &&
2883
+ EltT->getBitWidth () == 8 )
2884
+ return CapabilityDotProductInput4x8BitKHR;
2885
+ if (EltT->isTypeInt ())
2886
+ return CapabilityDotProductInputAllKHR;
2887
+ }
2888
+
2889
+ llvm_unreachable (" No mapping for argument type to capability." );
2890
+ }
2891
+ };
2892
+
2893
+ #define _SPIRV_OP (x, ...) \
2894
+ typedef SPIRVInstTemplate<SPIRVDotKHRBase, Op##x, __VA_ARGS__> SPIRV##x;
2895
+ _SPIRV_OP (SDotKHR, true , 5 , true , 2 )
2896
+ _SPIRV_OP(UDotKHR, true , 5 , true , 2 )
2897
+ _SPIRV_OP(SUDotKHR, true , 5 , true , 2 )
2898
+ _SPIRV_OP(SDotAccSatKHR, true , 6 , true , 3 )
2899
+ _SPIRV_OP(UDotAccSatKHR, true , 6 , true , 3 )
2900
+ _SPIRV_OP(SUDotAccSatKHR, true , 6 , true , 3 )
2901
+ #undef _SPIRV_OP
2902
+
2821
2903
class SPIRVSubgroupShuffleINTELInstBase : public SPIRVInstTemplateBase {
2822
2904
protected:
2823
2905
SPIRVCapVec getRequiredCapability () const override {
0 commit comments