Skip to content

Commit 673604a

Browse files
authored
[mlir][vector] Update docs for scalable vectors (#101842)
Adds a few notes on scalable vectors in the docs for the Vector dialect. This is mostly "repeating" things from LLVM's LangRef. Additionally: * Adds a few basic tests with scalable vectors (those should've been added long time ago), * Updates a comment in "TypeConverter.cpp" (the current comment is out-of-date), * Includes small formatting edits in Vector.md. **NOTE** Depends on #101813 - only review the top commit
1 parent a98a0dc commit 673604a

File tree

4 files changed

+54
-25
lines changed

4 files changed

+54
-25
lines changed

mlir/docs/Dialects/Vector.md

Lines changed: 41 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,30 @@ following top-down rewrites and conversions:
7474
### LLVM level
7575

7676
On CPU, the `n-D` `vector` type currently lowers to `!llvm<array<vector>>`.
77-
More concretely, `vector<4x8x128xf32>` lowers to `!llvm<[4 x [ 8 x [ 128 x
78-
float ]]]>`. There are tradeoffs involved related to how one can access
79-
subvectors and how one uses `llvm.extractelement`, `llvm.insertelement` and
80-
`llvm.shufflevector`. The section on [LLVM Lowering
81-
Tradeoffs](#llvm-lowering-tradeoffs) offers a deeper dive into the current
82-
design choices and tradeoffs.
77+
More concretely,
78+
* `vector<4x8x128xf32>` lowers to `!llvm<[4 x [ 8 x < 128
79+
x float >]]>` (fixed-width vector), and
80+
* `vector<4x8x[128]xf32>` lowers to `!llvm<[4 x [ 8 x < vscale x 128
81+
x float >]]>` (scalable vector).
82+
83+
There are tradeoffs involved related to how one can access subvectors and how
84+
one uses `llvm.extractelement`, `llvm.insertelement` and `llvm.shufflevector`.
85+
The section on [LLVM Lowering Tradeoffs](#llvm-lowering-tradeoffs) offers a
86+
deeper dive into the current design choices and tradeoffs.
87+
88+
Note, while LLVM supports arrarys of scalable vectors, these are required to be
89+
fixed-width arrays of 1-D scalable vectors. This means scalable vectors with a
90+
non-trailing scalable dimension (e.g. `vector<4x[8]x128xf32`) are not
91+
convertible to LLVM.
92+
93+
Finally, MLIR takes the same view on scalable Vectors as LLVM (c.f. (Vector
94+
Type)[https://llvm.org/docs/LangRef.html#vector-type]):
95+
> For scalable vectors, the total number of elements is a constant multiple
96+
> (called vscale) of the specified number of elements; vscale is a positive
97+
> integer that is unknown at compile time and the same hardware-dependent
98+
> constant for all scalable vectors at run time. The size of a specific
99+
> scalable vector type is thus constant within IR, even if the exact size in
100+
> bytes cannot be determined until run time.
83101
84102
### Hardware Vector Ops
85103

@@ -269,11 +287,6 @@ proposal for now, this assumes LLVM only has built-in support for 1-D vector.
269287
The relationship with the LLVM Matrix proposal is discussed at the end of this
270288
document.
271289

272-
MLIR does not currently support dynamic vector sizes (i.e. SVE style) so the
273-
discussion is limited to static rank and static vector sizes (e.g.
274-
`vector<4x8x16x32xf32>`). This section discusses operations on vectors in LLVM
275-
and MLIR.
276-
277290
LLVM instructions are prefixed by the `llvm.` dialect prefix (e.g.
278291
`llvm.insertvalue`). Such ops operate exclusively on 1-D vectors and aggregates
279292
following the [LLVM LangRef](https://llvm.org/docs/LangRef.html). MLIR
@@ -287,10 +300,11 @@ Consider a vector of rank n with static sizes `{s_0, ... s_{n-1}}` (i.e. an MLIR
287300
`vector<s_0x...s_{n-1}xf32>`). Lowering such an `n-D` MLIR vector type to an
288301
LLVM descriptor can be done by either:
289302

290-
1. Flattening to a `1-D` vector: `!llvm<"(s_0*...*s_{n-1})xfloat">` in the MLIR
303+
1. Nested aggregate type of `1-D` vector:
304+
`!llvm."[s_0x[s_1x[...<s_{n-1}xf32>]]]">` in the MLIR LLVM dialect (current
305+
lowering in MLIR).
306+
2. Flattening to a `1-D` vector: `!llvm<"(s_0*...*s_{n-1})xfloat">` in the MLIR
291307
LLVM dialect.
292-
2. Nested aggregate type of `1-D` vector:
293-
`!llvm."[s_0x[s_1x[...<s_{n-1}xf32>]]]">` in the MLIR LLVM dialect.
294308
3. A mix of both.
295309

296310
There are multiple tradeoffs involved in choosing one or the other that we
@@ -303,9 +317,11 @@ vector<4x8x16x32xf32> to vector<4x4096xf32>` operation, that flattens the most
303317

304318
The first constraint was already mentioned: LLVM only supports `1-D` `vector`
305319
types natively. Additional constraints are related to the difference in LLVM
306-
between vector and aggregate types: `“Aggregate Types are a subset of derived
307-
types that can contain multiple member types. Arrays and structs are aggregate
308-
types. Vectors are not considered to be aggregate types.”.`
320+
between vector and
321+
[aggregate types](https://llvm.org/docs/LangRef.html#aggregate-types):
322+
> Aggregate Types are a subset of derived types that can contain multiple
323+
> member types. Arrays and structs are aggregate types. Vectors are not
324+
> considered to be aggregate types.
309325
310326
This distinction is also reflected in some of the operations. For `1-D` vectors,
311327
the operations `llvm.extractelement`, `llvm.insertelement`, and
@@ -314,12 +330,15 @@ vectors with `n>1`, and thus aggregate types at LLVM level, the more restrictive
314330
operations `llvm.extractvalue` and `llvm.insertvalue` apply, which only accept
315331
static indices. There is no direct shuffling support for aggregate types.
316332

317-
The next sentence illustrates a recurrent tradeoff, also found in MLIR, between
333+
The next sentence (cf. LangRef [structure
334+
type](https://llvm.org/docs/LangRef.html#structure-type)) illustrates a
335+
recurrent tradeoff, also found in MLIR, between
318336
“value types” (subject to SSA use-def chains) and “memory types” (subject to
319-
aliasing and side-effects): `“Structures in memory are accessed using ‘load’ and
320-
‘store’ by getting a pointer to a field with the llvm.getelementptr instruction.
321-
Structures in registers are accessed using the llvm.extractvalue and
322-
llvm.insertvalue instructions.”`
337+
aliasing and side-effects):
338+
> Structures in memory are accessed using ‘load’ and ‘store’ by getting a
339+
> pointer to a field with the llvm.getelementptr instruction. Structures in
340+
> registers are accessed using the llvm.extractvalue and llvm.insertvalue
341+
> instructions.
323342
324343
When transposing this to MLIR, `llvm.getelementptr` works on pointers to `n-D`
325344
vectors in memory. For `n-D`, vectors values that live in registers we can use

mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -509,8 +509,8 @@ Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) const {
509509
/// * 1-D `vector<axT>` remains as is while,
510510
/// * n>1 `vector<ax...xkxT>` convert via an (n-1)-D array type to
511511
/// `!llvm.array<ax...array<jxvector<kxT>>>`.
512-
/// Returns failure for n-D scalable vector types as LLVM does not support
513-
/// arrays of scalable vectors.
512+
/// As LLVM supports arrays of scalable vectors, this method will also convert
513+
/// n-D scalable vectors provided that only the trailing dim is scalable.
514514
FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
515515
auto elementType = convertType(type.getElementType());
516516
if (!elementType)
@@ -521,7 +521,9 @@ FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
521521
type.getScalableDims().back());
522522
assert(LLVM::isCompatibleVectorType(vectorType) &&
523523
"expected vector type compatible with the LLVM dialect");
524-
// Only the trailing dimension can be scalable.
524+
// For n-D vector types for which a _non-trailing_ dim is scalable,
525+
// return a failure. Supporting such cases would require LLVM
526+
// to support something akin "scalable arrays" of vectors.
525527
if (llvm::is_contained(type.getScalableDims().drop_back(), true))
526528
return failure();
527529
auto shape = type.getShape();

mlir/test/Dialect/LLVMIR/types.mlir

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,10 @@ func.func @array() {
9191
"some.op"() : () -> !llvm.array<10 x ptr<4>>
9292
// CHECK: !llvm.array<10 x array<4 x f32>>
9393
"some.op"() : () -> !llvm.array<10 x array<4 x f32>>
94+
// CHECK: !llvm.array<10 x array<4 x vector<8xf32>>>
95+
"some.op"() : () -> !llvm.array<10 x array<4 x vector<8xf32>>>
96+
// CHECK: !llvm.array<10 x array<4 x vector<[8]xf32>>>
97+
"some.op"() : () -> !llvm.array<10 x array<4 x vector<[8]xf32>>>
9498
return
9599
}
96100

mlir/test/Target/LLVMIR/llvmir-types.mlir

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,10 @@ llvm.func @return_a8_float() -> !llvm.array<8 x f32>
9999
llvm.func @return_a10_p_4() -> !llvm.array<10 x ptr<4>>
100100
// CHECK: declare [10 x [4 x float]] @return_a10_a4_float()
101101
llvm.func @return_a10_a4_float() -> !llvm.array<10 x array<4 x f32>>
102+
// CHECK: declare [10 x [4 x <4 x float>]] @return_a10_a4_v4_float()
103+
llvm.func @return_a10_a4_v4_float() -> !llvm.array<10 x array<4 x vector<4xf32>>>
104+
// CHECK: declare [10 x [4 x <vscale x 4 x float>]] @return_a10_a4_sv4_float()
105+
llvm.func @return_a10_a4_sv4_float() -> !llvm.array<10 x array<4 x vector<[4]xf32>>>
102106

103107
//
104108
// Literal structures.

0 commit comments

Comments
 (0)