Skip to content

[LLVM][IR] Add textual shorthand for specifying constant vector splats. #74620

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions llvm/docs/LangRef.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4299,6 +4299,11 @@ constants and smaller complex constants.
"``< i32 42, i32 11, i32 74, i32 100 >``". Vector constants
must have :ref:`vector type <t_vector>`, and the number and types of
elements must match those specified by the type.

When creating a vector whose elements have the same constant value, the
preferred syntax is ``splat (<Ty> Val)``. For example: "``splat (i32 11)``".
These vector constants must have ::ref:`vector type <t_vector>` with an
element type that matches the ``splat`` operand.
**Zero initialization**
The string '``zeroinitializer``' can be used to zero initialize a
value to zero of *any* type, including scalar and
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/AsmParser/LLParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ namespace llvm {
t_Poison, // No value.
t_EmptyArray, // No value: []
t_Constant, // Value in ConstantVal.
t_ConstantSplat, // Value in ConstantVal.
t_InlineAsm, // Value in FTy/StrVal/StrVal2/UIntVal.
t_ConstantStruct, // Value in ConstantStructElts.
t_PackedConstantStruct // Value in ConstantStructElts.
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/AsmParser/LLToken.h
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ enum Kind {
kw_extractelement,
kw_insertelement,
kw_shufflevector,
kw_splat,
kw_extractvalue,
kw_insertvalue,
kw_blockaddress,
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/AsmParser/LLLexer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,7 @@ lltok::Kind LLLexer::LexIdentifier() {
KEYWORD(uinc_wrap);
KEYWORD(udec_wrap);

KEYWORD(splat);
KEYWORD(vscale);
KEYWORD(x);
KEYWORD(blockaddress);
Expand Down
27 changes: 27 additions & 0 deletions llvm/lib/AsmParser/LLParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3981,6 +3981,21 @@ bool LLParser::parseValID(ValID &ID, PerFunctionState *PFS, Type *ExpectedTy) {
return false;
}

case lltok::kw_splat: {
Lex.Lex();
if (parseToken(lltok::lparen, "expected '(' after vector splat"))
return true;
Constant *C;
if (parseGlobalTypeAndValue(C))
return true;
if (parseToken(lltok::rparen, "expected ')' at end of vector splat"))
return true;

ID.ConstantVal = C;
ID.Kind = ValID::t_ConstantSplat;
return false;
}

case lltok::kw_getelementptr:
case lltok::kw_shufflevector:
case lltok::kw_insertelement:
Expand Down Expand Up @@ -5824,6 +5839,17 @@ bool LLParser::convertValIDToValue(Type *Ty, ValID &ID, Value *&V,
"' but expected '" + getTypeString(Ty) + "'");
V = ID.ConstantVal;
return false;
case ValID::t_ConstantSplat:
if (!Ty->isVectorTy())
return error(ID.Loc, "vector constant must have vector type");
if (ID.ConstantVal->getType() != Ty->getScalarType())
return error(ID.Loc, "constant expression type mismatch: got type '" +
getTypeString(ID.ConstantVal->getType()) +
"' but expected '" +
getTypeString(Ty->getScalarType()) + "'");
V = ConstantVector::getSplat(cast<VectorType>(Ty)->getElementCount(),
ID.ConstantVal);
return false;
case ValID::t_ConstantStruct:
case ValID::t_PackedConstantStruct:
if (StructType *ST = dyn_cast<StructType>(Ty)) {
Expand Down Expand Up @@ -5861,6 +5887,7 @@ bool LLParser::parseConstantValue(Type *Ty, Constant *&C) {
case ValID::t_APFloat:
case ValID::t_Undef:
case ValID::t_Constant:
case ValID::t_ConstantSplat:
case ValID::t_ConstantStruct:
case ValID::t_PackedConstantStruct: {
Value *V;
Expand Down
40 changes: 40 additions & 0 deletions llvm/test/Assembler/constant-splat-diagnostics.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
; RUN: rm -rf %t && split-file %s %t

; RUN: not llvm-as < %t/not_a_constant.ll -o /dev/null 2>&1 | FileCheck -check-prefix=NOT_A_CONSTANT %s
; RUN: not llvm-as < %t/not_a_sclar.ll -o /dev/null 2>&1 | FileCheck -check-prefix=NOT_A_SCALAR %s
; RUN: not llvm-as < %t/not_a_vector.ll -o /dev/null 2>&1 | FileCheck -check-prefix=NOT_A_VECTOR %s
; RUN: not llvm-as < %t/wrong_explicit_type.ll -o /dev/null 2>&1 | FileCheck -check-prefix=WRONG_EXPLICIT_TYPE %s
; RUN: not llvm-as < %t/wrong_implicit_type.ll -o /dev/null 2>&1 | FileCheck -check-prefix=WRONG_IMPLICIT_TYPE %s

;--- not_a_constant.ll
; NOT_A_CONSTANT: error: expected instruction opcode
define <4 x i32> @not_a_constant(i32 %a) {
%splat = splat (i32 %a)
ret <vscale x 4 x i32> %splat
}

;--- not_a_sclar.ll
; NOT_A_SCALAR: error: constant expression type mismatch: got type '<1 x i32>' but expected 'i32'
define <4 x i32> @not_a_scalar() {
ret <4 x i32> splat (<1 x i32> <i32 7>)
}

;--- not_a_vector.ll
; NOT_A_VECTOR: error: vector constant must have vector type
define <4 x i32> @not_a_vector() {
ret i32 splat (i32 7)
}

;--- wrong_explicit_type.ll
; WRONG_EXPLICIT_TYPE: error: constant expression type mismatch: got type 'i8' but expected 'i32'
define <4 x i32> @wrong_explicit_type() {
ret <4 x i32> splat (i8 7)
}

;--- wrong_implicit_type.ll
; WRONG_IMPLICIT_TYPE: error: constant expression type mismatch: got type 'i8' but expected 'i32'
define void @wrong_implicit_type(<4 x i32> %a) {
%add = add <4 x i32> %a, splat (i8 7)
ret void
}

67 changes: 67 additions & 0 deletions llvm/test/Assembler/constant-splat.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
; RUN: llvm-as < %s | llvm-dis | llvm-as | llvm-dis | FileCheck %s

; NOTE: Tests the expansion of the "splat" shorthand method to create vector
; constants. Future work will change how "splat" is expanded, ultimately
; leading to a point where "splat" is emitted as the disassembly.

@my_global = external global i32

; CHECK: @constant.splat.i1 = constant <1 x i1> <i1 true>
@constant.splat.i1 = constant <1 x i1> splat (i1 true)

; CHECK: @constant.splat.i32 = constant <5 x i32> <i32 7, i32 7, i32 7, i32 7, i32 7>
@constant.splat.i32 = constant <5 x i32> splat (i32 7)

; CHECK: @constant.splat.i128 = constant <2 x i128> <i128 85070591730234615870450834276742070272, i128 85070591730234615870450834276742070272>
@constant.splat.i128 = constant <2 x i128> splat (i128 85070591730234615870450834276742070272)

; CHECK: @constant.splat.f16 = constant <4 x half> <half 0xHBC00, half 0xHBC00, half 0xHBC00, half 0xHBC00>
@constant.splat.f16 = constant <4 x half> splat (half 0xHBC00)

; CHECK: @constant.splat.f32 = constant <5 x float> <float -2.000000e+00, float -2.000000e+00, float -2.000000e+00, float -2.000000e+00, float -2.000000e+00>
@constant.splat.f32 = constant <5 x float> splat (float -2.000000e+00)

; CHECK: @constant.splat.f64 = constant <3 x double> <double -3.000000e+00, double -3.000000e+00, double -3.000000e+00>
@constant.splat.f64 = constant <3 x double> splat (double -3.000000e+00)

; CHECK: @constant.splat.128 = constant <2 x fp128> <fp128 0xL00000000000000018000000000000000, fp128 0xL00000000000000018000000000000000>
@constant.splat.128 = constant <2 x fp128> splat (fp128 0xL00000000000000018000000000000000)

; CHECK: @constant.splat.bf16 = constant <4 x bfloat> <bfloat 0xRC0A0, bfloat 0xRC0A0, bfloat 0xRC0A0, bfloat 0xRC0A0>
@constant.splat.bf16 = constant <4 x bfloat> splat (bfloat 0xRC0A0)

; CHECK: @constant.splat.x86_fp80 = constant <3 x x86_fp80> <x86_fp80 0xK4000C8F5C28F5C28F800, x86_fp80 0xK4000C8F5C28F5C28F800, x86_fp80 0xK4000C8F5C28F5C28F800>
@constant.splat.x86_fp80 = constant <3 x x86_fp80> splat (x86_fp80 0xK4000C8F5C28F5C28F800)

; CHECK: @constant.splat.ppc_fp128 = constant <1 x ppc_fp128> <ppc_fp128 0xM80000000000000000000000000000000>
@constant.splat.ppc_fp128 = constant <1 x ppc_fp128> splat (ppc_fp128 0xM80000000000000000000000000000000)

; CHECK: @constant.splat.global.ptr = constant <4 x ptr> <ptr @my_global, ptr @my_global, ptr @my_global, ptr @my_global>
@constant.splat.global.ptr = constant <4 x ptr> splat (ptr @my_global)

define void @add_fixed_lenth_vector_splat_i32(<4 x i32> %a) {
; CHECK: %add = add <4 x i32> %a, <i32 137, i32 137, i32 137, i32 137>
%add = add <4 x i32> %a, splat (i32 137)
ret void
}

define <4 x i32> @ret_fixed_lenth_vector_splat_i32() {
; CHECK: ret <4 x i32> <i32 56, i32 56, i32 56, i32 56>
ret <4 x i32> splat (i32 56)
}

define void @add_fixed_lenth_vector_splat_double(<vscale x 2 x double> %a) {
; CHECK: %add = fadd <vscale x 2 x double> %a, shufflevector (<vscale x 2 x double> insertelement (<vscale x 2 x double> poison, double 5.700000e+00, i64 0), <vscale x 2 x double> poison, <vscale x 2 x i32> zeroinitializer)
%add = fadd <vscale x 2 x double> %a, splat (double 5.700000e+00)
ret void
}

define <vscale x 4 x i32> @ret_scalable_vector_splat_i32() {
; CHECK: ret <vscale x 4 x i32> shufflevector (<vscale x 4 x i32> insertelement (<vscale x 4 x i32> poison, i32 78, i64 0), <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer)
ret <vscale x 4 x i32> splat (i32 78)
}

define <vscale x 4 x ptr> @ret_scalable_vector_ptr() {
; CHECK: ret <vscale x 4 x ptr> shufflevector (<vscale x 4 x ptr> insertelement (<vscale x 4 x ptr> poison, ptr @my_global, i64 0), <vscale x 4 x ptr> poison, <vscale x 4 x i32> zeroinitializer)
ret <vscale x 4 x ptr> splat (ptr @my_global)
}