Skip to content

Commit 2b3c505

Browse files
author
Sjoerd Meijer
committed
[Matrix] Intrinsic descriptions
This changes the matrix load/store intrinsic definitions to load/store from/to a pointer, and not from/to a pointer to a vector, as discussed in D83477. This also includes the recommit of "[Matrix] Tighten LangRef definitions and Verifier checks" which adds improved language reference descriptions of the matrix intrinsics and verifier checks. Differential Revision: https://reviews.llvm.org/D83785
1 parent be15284 commit 2b3c505

File tree

11 files changed

+324
-197
lines changed

11 files changed

+324
-197
lines changed

llvm/docs/LangRef.rst

Lines changed: 57 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -15525,6 +15525,7 @@ The argument to this intrinsic must be a vector of floating-point values.
1552515525

1552615526
Syntax:
1552715527
"""""""
15528+
This is an overloaded intrinsic.
1552815529

1552915530
::
1553015531

@@ -15549,17 +15550,20 @@ Matrix Intrinsics
1554915550
-----------------
1555015551

1555115552
Operations on matrixes requiring shape information (like number of rows/columns
15552-
or the memory layout) can be expressed using the matrix intrinsics. Matrixes are
15553-
embedded in a flat vector and the intrinsics take the dimensions as arguments.
15554-
Currently column-major layout is assumed. The intrinsics support both integer
15555-
and floating point matrixes.
15553+
or the memory layout) can be expressed using the matrix intrinsics. These
15554+
intrinsics require matrix dimensions to be passed as immediate arguments, and
15555+
matrixes are passed and returned as vectors. This means that for a ``R`` x
15556+
``C`` matrix, element ``i`` of column ``j`` is at index ``j * R + i`` in the
15557+
corresponding vector, with indices starting at 0. Currently column-major layout
15558+
is assumed. The intrinsics support both integer and floating point matrixes.
1555615559

1555715560

1555815561
'``llvm.matrix.transpose.*``' Intrinsic
15559-
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
15562+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1556015563

1556115564
Syntax:
1556215565
"""""""
15566+
This is an overloaded intrinsic.
1556315567

1556415568
::
1556515569

@@ -15568,21 +15572,24 @@ Syntax:
1556815572
Overview:
1556915573
"""""""""
1557015574

15571-
The '``llvm.matrix.transpose.*``' intrinsic treats %In as containing a matrix
15572-
with <Rows> rows and <Cols> columns and returns the transposed matrix embedded in
15573-
the result vector.
15575+
The '``llvm.matrix.transpose.*``' intrinsics treat %In as a <Rows> x <Cols> matrix
15576+
and return the transposed matrix in the result vector.
1557415577

1557515578
Arguments:
1557615579
""""""""""
1557715580

15578-
The <Rows> and <Cols> arguments must be constant integers. The vector argument
15579-
%In and the returned vector must have <Rows> * <Cols> elements.
15581+
First argument %In is vector that corresponds to a <Rows> x <Cols> matrix.
15582+
Thus, arguments <Rows> and <Cols> correspond to the number of rows and columns,
15583+
respectively, and must be positive, constant integers. The returned vector must
15584+
have <Rows> * <Cols> elements, and have the same float or integer element type
15585+
as %In.
1558015586

1558115587
'``llvm.matrix.multiply.*``' Intrinsic
15582-
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
15588+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1558315589

1558415590
Syntax:
1558515591
"""""""
15592+
This is an overloaded intrinsic.
1558615593

1558715594
::
1558815595

@@ -15591,25 +15598,27 @@ Syntax:
1559115598
Overview:
1559215599
"""""""""
1559315600

15594-
The '``llvm.matrix.multiply.*``' intrinsic treats %A as a matrix with <OuterRows>
15595-
rows and <Inner> columns, %B as a matrix with <Inner> rows and <OuterColumns>
15596-
columns and multiplies them. The result matrix is returned embedded in the
15597-
result vector.
15601+
The '``llvm.matrix.multiply.*``' intrinsics treat %A as a <OuterRows> x <Inner>
15602+
matrix, %B as a <Inner> x <OuterColumns> matrix, and multiplies them. The result
15603+
matrix is returned in the result vector.
1559815604

1559915605
Arguments:
1560015606
""""""""""
1560115607

15602-
The <OuterRows>, <Inner> and <OuterColumns> arguments must be constant
15603-
integers. The vector argument %A must have <OuterRows> * <Inner> elements, %B
15604-
must have <Inner> * <OuterColumns> elements and the returned vector must have
15605-
<OuterRows> * <OuterColumns> elements.
15608+
The first vector argument %A corresponds to a matrix with <OuterRows> * <Inner>
15609+
elements, and the second argument %B to a matrix with <Inner> * <OuterColumns>
15610+
elements. Arguments <OuterRows>, <Inner> and <OuterColumns> must be positive,
15611+
constant integers. The returned vector must have <OuterRows> * <OuterColumns>
15612+
elements. Vectors %A, %B, and the returned vector all have the same float or
15613+
integer element type.
1560615614

1560715615

1560815616
'``llvm.matrix.column.major.load.*``' Intrinsic
1560915617
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1561015618

1561115619
Syntax:
1561215620
"""""""
15621+
This is an overloaded intrinsic.
1561315622

1561415623
::
1561515624

@@ -15619,22 +15628,26 @@ Syntax:
1561915628
Overview:
1562015629
"""""""""
1562115630

15622-
The '``llvm.matrix.column.major.load.*``' intrinsic loads a matrix with <Rows>
15623-
rows and <Cols> columns, using a stride of %Stride between columns. For two
15624-
consecutive columns A and B, %Stride refers to the distance (the number of
15625-
elements) between the start of column A and the start of column B. The result
15626-
matrix is returned embedded in the result vector. This allows for convenient
15627-
loading of sub matrixes. If <IsVolatile> is true, the intrinsic is considered
15628-
a :ref:`volatile memory access <volatile>`.
15629-
15630-
If the %Ptr argument is known to be aligned to some boundary, this can be
15631-
specified as an attribute on the argument.
15631+
The '``llvm.matrix.column.major.load.*``' intrinsics load a <Rows> x <Cols>
15632+
matrix using a stride of %Stride to compute the start address of the different
15633+
columns. This allows for convenient loading of sub matrixes. If <IsVolatile>
15634+
is true, the intrinsic is considered a :ref:`volatile memory access
15635+
<volatile>`. The result matrix is returned in the result vector. If the %Ptr
15636+
argument is known to be aligned to some boundary, this can be specified as an
15637+
attribute on the argument.
1563215638

1563315639
Arguments:
1563415640
""""""""""
1563515641

15636-
The <IsVolatile>, <Rows> and <Cols> arguments must be constant integers. The
15637-
returned vector must have <Rows> * <Cols> elements. %Stride must be >= <Rows>.
15642+
The first argument %Ptr is a pointer type to the returned vector type, and
15643+
correponds to the start address to load from. The second argument %Stride is a
15644+
postive, constant integer with %Stride ``>=`` <Rows>. %Stride is used to compute
15645+
the column memory addresses. I.e., for a column ``C``, its start memory
15646+
addresses is calculated with %Ptr + ``C`` * %Stride. The third Argument
15647+
<IsVolatile> is a boolean value. The fourth and fifth arguments, <Rows> and
15648+
<Cols>, correspond to the number of rows and columns, respectively, and must be
15649+
positive, constant integers. The returned vector must have <Rows> * <Cols>
15650+
elements.
1563815651

1563915652
The :ref:`align <attr_align>` parameter attribute can be provided
1564015653
for the %Ptr arguments.
@@ -15654,21 +15667,26 @@ Syntax:
1565415667
Overview:
1565515668
"""""""""
1565615669

15657-
The '``llvm.matrix.column.major.store.*``' intrinsic stores the matrix with
15658-
<Rows> rows and <Cols> columns embedded in %In, using a stride of %Stride
15659-
between columns. For two consecutive columns A and B, %Stride refers to the
15660-
distance (the number of elements) between the start of column A and the start
15661-
of column B. If <IsVolatile> is true, the intrinsic is considered a
15662-
:ref:`volatile memory access <volatile>`.
15670+
The '``llvm.matrix.column.major.store.*``' intrinsics store the <Rows> x <Cols>
15671+
matrix in %In to memory using a stride of %Stride between columns. If
15672+
<IsVolatile> is true, the intrinsic is considered a :ref:`volatile memory
15673+
access <volatile>`.
1566315674

1566415675
If the %Ptr argument is known to be aligned to some boundary, this can be
1566515676
specified as an attribute on the argument.
1566615677

1566715678
Arguments:
1566815679
""""""""""
1566915680

15670-
The <IsVolatile>, <Rows>, <Cols> arguments must be constant integers. The
15671-
vector argument %In must have <Rows> * <Cols> elements. %Stride must be >= <Rows>.
15681+
The first argument %In is a vector that corresponds to a <Rows> x <Cols> matrix
15682+
to be stored to memory. The second argument %Ptr is a pointer to the vector
15683+
type of %In, and is the start address of the matrix in memory. The third
15684+
argument %Stride is a positive, constant integer with %Stride ``>=`` <Rows>.
15685+
%Stride is used to compute the column memory addresses. I.e., for a column
15686+
``C``, its start memory addresses is calculated with %Ptr + ``C`` * %Stride.
15687+
The fourth argument <IsVolatile> is a boolean value. The arguments <Rows> and
15688+
<Cols> correspond to the number of rows and columns, respectively, and must be
15689+
positive, constant integers.
1567215690

1567315691
The :ref:`align <attr_align>` parameter attribute can be provided
1567415692
for the %Ptr arguments.

llvm/include/llvm/IR/Intrinsics.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1458,15 +1458,15 @@ def int_matrix_multiply
14581458

14591459
def int_matrix_column_major_load
14601460
: Intrinsic<[llvm_anyvector_ty],
1461-
[LLVMAnyPointerType<LLVMMatchType<0>>, llvm_i64_ty, llvm_i1_ty,
1461+
[LLVMPointerToElt<0>, llvm_i64_ty, llvm_i1_ty,
14621462
llvm_i32_ty, llvm_i32_ty],
14631463
[IntrNoSync, IntrWillReturn, IntrArgMemOnly, IntrReadMem,
14641464
NoCapture<ArgIndex<0>>, ImmArg<ArgIndex<2>>, ImmArg<ArgIndex<3>>,
14651465
ImmArg<ArgIndex<4>>]>;
14661466

14671467
def int_matrix_column_major_store
14681468
: Intrinsic<[],
1469-
[llvm_anyvector_ty, LLVMAnyPointerType<LLVMMatchType<0>>,
1469+
[llvm_anyvector_ty, LLVMPointerToElt<0>,
14701470
llvm_i64_ty, llvm_i1_ty, llvm_i32_ty, llvm_i32_ty],
14711471
[IntrNoSync, IntrWillReturn, IntrArgMemOnly, IntrWriteMem,
14721472
WriteOnly<ArgIndex<1>>, NoCapture<ArgIndex<1>>,

llvm/lib/IR/Verifier.cpp

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5017,36 +5017,73 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
50175017
case Intrinsic::matrix_transpose:
50185018
case Intrinsic::matrix_column_major_load:
50195019
case Intrinsic::matrix_column_major_store: {
5020+
Function *IF = Call.getCalledFunction();
5021+
ConstantInt *Stride = nullptr;
50205022
ConstantInt *NumRows;
50215023
ConstantInt *NumColumns;
5022-
VectorType *TypeToCheck;
5024+
VectorType *ResultTy;
5025+
Type *Op0ElemTy = nullptr;
5026+
Type *Op1ElemTy = nullptr;
50235027
switch (ID) {
50245028
case Intrinsic::matrix_multiply:
50255029
NumRows = cast<ConstantInt>(Call.getArgOperand(2));
50265030
NumColumns = cast<ConstantInt>(Call.getArgOperand(4));
5027-
TypeToCheck = cast<VectorType>(Call.getType());
5031+
ResultTy = cast<VectorType>(Call.getType());
5032+
Op0ElemTy =
5033+
cast<VectorType>(Call.getArgOperand(0)->getType())->getElementType();
5034+
Op1ElemTy =
5035+
cast<VectorType>(Call.getArgOperand(1)->getType())->getElementType();
50285036
break;
50295037
case Intrinsic::matrix_transpose:
50305038
NumRows = cast<ConstantInt>(Call.getArgOperand(1));
50315039
NumColumns = cast<ConstantInt>(Call.getArgOperand(2));
5032-
TypeToCheck = cast<VectorType>(Call.getType());
5040+
ResultTy = cast<VectorType>(Call.getType());
5041+
Op0ElemTy =
5042+
cast<VectorType>(Call.getArgOperand(0)->getType())->getElementType();
50335043
break;
50345044
case Intrinsic::matrix_column_major_load:
5045+
Stride = dyn_cast<ConstantInt>(Call.getArgOperand(1));
50355046
NumRows = cast<ConstantInt>(Call.getArgOperand(3));
50365047
NumColumns = cast<ConstantInt>(Call.getArgOperand(4));
5037-
TypeToCheck = cast<VectorType>(Call.getType());
5048+
ResultTy = cast<VectorType>(Call.getType());
5049+
Op0ElemTy =
5050+
cast<PointerType>(Call.getArgOperand(0)->getType())->getElementType();
50385051
break;
50395052
case Intrinsic::matrix_column_major_store:
5053+
Stride = dyn_cast<ConstantInt>(Call.getArgOperand(2));
50405054
NumRows = cast<ConstantInt>(Call.getArgOperand(4));
50415055
NumColumns = cast<ConstantInt>(Call.getArgOperand(5));
5042-
TypeToCheck = cast<VectorType>(Call.getArgOperand(0)->getType());
5056+
ResultTy = cast<VectorType>(Call.getArgOperand(0)->getType());
5057+
Op0ElemTy =
5058+
cast<VectorType>(Call.getArgOperand(0)->getType())->getElementType();
5059+
Op1ElemTy =
5060+
cast<PointerType>(Call.getArgOperand(1)->getType())->getElementType();
50435061
break;
50445062
default:
50455063
llvm_unreachable("unexpected intrinsic");
50465064
}
5047-
Assert(TypeToCheck->getNumElements() ==
5065+
5066+
Assert(ResultTy->getElementType()->isIntegerTy() ||
5067+
ResultTy->getElementType()->isFloatingPointTy(),
5068+
"Result type must be an integer or floating-point type!", IF);
5069+
5070+
Assert(ResultTy->getElementType() == Op0ElemTy,
5071+
"Vector element type mismatch of the result and first operand "
5072+
"vector!", IF);
5073+
5074+
if (Op1ElemTy)
5075+
Assert(ResultTy->getElementType() == Op1ElemTy,
5076+
"Vector element type mismatch of the result and second operand "
5077+
"vector!", IF);
5078+
5079+
Assert(ResultTy->getNumElements() ==
50485080
NumRows->getZExtValue() * NumColumns->getZExtValue(),
5049-
"result of a matrix operation does not fit in the returned vector");
5081+
"Result of a matrix operation does not fit in the returned vector!");
5082+
5083+
if (Stride)
5084+
Assert(Stride->getZExtValue() >= NumRows->getZExtValue(),
5085+
"Stride must be greater or equal than the number of rows!", IF);
5086+
50505087
break;
50515088
}
50525089
};

0 commit comments

Comments
 (0)