Skip to content

Commit dd5165a

Browse files
committed
[mlir] replace LLVM dialect float types with built-ins
Continue the convergence between LLVM dialect and built-in types by replacing the bfloat, half, float and double LLVM dialect types with their built-in counterparts. At the API level, this is a direct replacement. At the syntax level, we change the keywords to `bf16`, `f16`, `f32` and `f64`, respectively, to be compatible with the built-in type syntax. The old keywords can still be parsed but produce a deprecation warning and will be eventually removed. Depends On D94178 Reviewed By: mehdi_amini, silvas, antiagainst Differential Revision: https://reviews.llvm.org/D94179
1 parent 2e1bb79 commit dd5165a

File tree

73 files changed

+2493
-2558
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

73 files changed

+2493
-2558
lines changed

mlir/docs/ConversionToLLVMDialect.md

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@ Scalar types are converted to their LLVM counterparts if they exist. The
2525
following conversions are currently implemented:
2626

2727
- `i*` converts to `!llvm.i*`
28-
- `bf16` converts to `!llvm.bfloat`
29-
- `f16` converts to `!llvm.half`
30-
- `f32` converts to `!llvm.float`
31-
- `f64` converts to `!llvm.double`
28+
- `bf16` converts to `bf16`
29+
- `f16` converts to `f16`
30+
- `f32` converts to `f32`
31+
- `f64` converts to `f64`
3232

3333
### Index Type
3434

@@ -48,8 +48,8 @@ size with element type converted using these conversion rules. In the
4848
n-dimensional case, MLIR vectors are converted to (n-1)-dimensional array types
4949
of one-dimensional vectors.
5050

51-
For example, `vector<4 x f32>` converts to `!llvm.vec<4 x float>` and `vector<4
52-
x 8 x 16 x f32>` converts to `!llvm.array<4 x array<8 x vec<16 x float>>>`.
51+
For example, `vector<4 x f32>` converts to `!llvm.vec<4 x f32>` and `vector<4 x
52+
8 x 16 x f32>` converts to `!llvm.array<4 x array<8 x vec<16 x f32>>>`.
5353

5454
### Ranked Memref Types
5555

@@ -106,18 +106,18 @@ resulting in a struct containing two pointers + offset.
106106
Examples:
107107

108108
```mlir
109-
memref<f32> -> !llvm.struct<(ptr<float> , ptr<float>, i64)>
110-
memref<1 x f32> -> !llvm.struct<(ptr<float>, ptr<float>, i64,
109+
memref<f32> -> !llvm.struct<(ptr<f32> , ptr<f32>, i64)>
110+
memref<1 x f32> -> !llvm.struct<(ptr<f32>, ptr<f32>, i64,
111111
array<1 x 64>, array<1 x i64>)>
112-
memref<? x f32> -> !llvm.struct<(ptr<float>, ptr<float>, i64
112+
memref<? x f32> -> !llvm.struct<(ptr<f32>, ptr<f32>, i64
113113
array<1 x 64>, array<1 x i64>)>
114-
memref<10x42x42x43x123 x f32> -> !llvm.struct<(ptr<float>, ptr<float>, i64
114+
memref<10x42x42x43x123 x f32> -> !llvm.struct<(ptr<f32>, ptr<f32>, i64
115115
array<5 x 64>, array<5 x i64>)>
116-
memref<10x?x42x?x123 x f32> -> !llvm.struct<(ptr<float>, ptr<float>, i64
116+
memref<10x?x42x?x123 x f32> -> !llvm.struct<(ptr<f32>, ptr<f32>, i64
117117
array<5 x 64>, array<5 x i64>)>
118118
119119
// Memref types can have vectors as element types
120-
memref<1x? x vector<4xf32>> -> !llvm.struct<(ptr<vec<4 x float>>,
120+
memref<1x? x vector<4xf32>> -> !llvm.struct<(ptr<vec<4 x f32>>,
121121
ptr<vec<4 x float>>, i64,
122122
array<1 x i64>, array<1 x i64>)>
123123
```
@@ -132,11 +132,11 @@ attribute.
132132
Examples:
133133

134134
```mlir
135-
memref<f32> -> !llvm.ptr<float>
136-
memref<10x42 x f32> -> !llvm.ptr<float>
135+
memref<f32> -> !llvm.ptr<f32>
136+
memref<10x42 x f32> -> !llvm.ptr<f32>
137137
138138
// Memrefs with vector types are also supported.
139-
memref<10x42 x vector<4xf32>> -> !llvm.ptr<vec<4 x float>>
139+
memref<10x42 x vector<4xf32>> -> !llvm.ptr<vec<4 x f32>>
140140
```
141141

142142
### Unranked Memref types
@@ -196,12 +196,12 @@ Examples:
196196
// Binary function with one result:
197197
(i32, f32) -> (i64)
198198
// has its arguments handled separately
199-
!llvm.func<i64 (i32, float)>
199+
!llvm.func<i64 (i32, f32)>
200200
201201
// Binary function with two results:
202202
(i32, f32) -> (i64, f64)
203203
// has its result aggregated into a structure type.
204-
!llvm.func<struct<(i64, double)> (i32, float)>
204+
!llvm.func<struct<(i64, f64)> (i32, f32)>
205205
```
206206

207207
#### Functions as Function Arguments or Results
@@ -249,19 +249,19 @@ Examples:
249249
// A memref descriptor appearing as function argument:
250250
(memref<f32>) -> ()
251251
// gets converted into a list of individual scalar components of a descriptor.
252-
!llvm.func<void (ptr<float>, ptr<float>, i64)>
252+
!llvm.func<void (ptr<f32>, ptr<f32>, i64)>
253253
254254
// The list of arguments is linearized and one can freely mix memref and other
255255
// types in this list:
256256
(memref<f32>, f32) -> ()
257257
// which gets converted into a flat list.
258-
!llvm.func<void (ptr<float>, ptr<float>, i64, float)>
258+
!llvm.func<void (ptr<f32>, ptr<f32>, i64, f32)>
259259
260260
// For nD ranked memref descriptors:
261261
(memref<?x?xf32>) -> ()
262262
// the converted signature will contain 2n+1 `index`-typed integer arguments,
263263
// offset, n sizes and n strides, per memref argument type.
264-
!llvm.func<void (ptr<float>, ptr<float>, i64, i64, i64, i64, i64)>
264+
!llvm.func<void (ptr<f32>, ptr<f32>, i64, i64, i64, i64, i64)>
265265
266266
// Same rules apply to unranked descriptors:
267267
(memref<*xf32>) -> ()
@@ -271,12 +271,12 @@ Examples:
271271
// However, returning a memref from a function is not affected:
272272
() -> (memref<?xf32>)
273273
// gets converted to a function returning a descriptor structure.
274-
!llvm.func<struct<(ptr<float>, ptr<float>, i64, array<1xi64>, array<1xi64>)> ()>
274+
!llvm.func<struct<(ptr<f32>, ptr<f32>, i64, array<1xi64>, array<1xi64>)> ()>
275275
276276
// If multiple memref-typed results are returned:
277277
() -> (memref<f32>, memref<f64>)
278278
// their descriptor structures are additionally packed into another structure,
279279
// potentially with other non-memref typed results.
280-
!llvm.func<struct<(struct<(ptr<float>, ptr<float>, i64)>,
280+
!llvm.func<struct<(struct<(ptr<f32>, ptr<f32>, i64)>,
281281
struct<(ptr<double>, ptr<double>, i64)>)> ()>
282282
```

mlir/docs/Dialects/LLVM.md

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ Examples:
115115
```mlir
116116
// Create an undefined value of structure type with a 32-bit integer followed
117117
// by a float.
118-
%0 = llvm.mlir.undef : !llvm.struct<(i32, float)>
118+
%0 = llvm.mlir.undef : !llvm.struct<(i32, f32)>
119119
120120
// Null pointer to i8.
121121
%1 = llvm.mlir.null : !llvm.ptr<i8>
@@ -127,7 +127,7 @@ Examples:
127127
%3 = llvm.mlir.constant(42 : i32) : i32
128128
129129
// Splat dense vector constant.
130-
%3 = llvm.mlir.constant(dense<1.0> : vector<4xf32>) : !llvm.vec<4 x float>
130+
%3 = llvm.mlir.constant(dense<1.0> : vector<4xf32>) : !llvm.vec<4 x f32>
131131
```
132132

133133
Note that constants use built-in types within the initializer definition: MLIR
@@ -214,14 +214,6 @@ containing an 8-bit and a 32-bit integer.
214214

215215
The following non-parametric types are supported.
216216

217-
- `!llvm.bfloat` (`LLVMBFloatType`) - 16-bit “brain” floating-point value
218-
(7-bit significand).
219-
- `!llvm.half` (`LLVMHalfType`) - 16-bit floating-point value as per
220-
IEEE-754-2008.
221-
- `!llvm.float` (`LLVMFloatType`) - 32-bit floating-point value as per
222-
IEEE-754-2008.
223-
- `!llvm.double` (`LLVMDoubleType`) - 64-bit floating-point value as per
224-
IEEE-754-2008.
225217
- `!llvm.fp128` (`LLVMFP128Type`) - 128-bit floating-point value as per
226218
IEEE-754-2008.
227219
- `!llvm.x86_fp80` (`LLVMX86FP80Type`) - 80-bit floating-point value (x87).
@@ -322,7 +314,7 @@ For example,
322314

323315
```mlir
324316
!llvm.func<void ()> // a function with no arguments;
325-
!llvm.func<i32 (float, i32)> // a function with two arguments and a result;
317+
!llvm.func<i32 (f32, i32)> // a function with two arguments and a result;
326318
!llvm.func<void (i32, ...)> // a variadic function with at least one argument.
327319
```
328320

mlir/docs/Dialects/Linalg.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -429,11 +429,11 @@ func @example(%arg0: !llvm<"float*">, ...) {
429429
430430
llvm.func @pointwise_add(%arg0: !llvm<"float*">, ...) attributes {llvm.emit_c_interface} {
431431
...
432-
llvm.call @_mlir_ciface_pointwise_add(%9, %19, %29) : (!llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }
432+
llvm.call @_mlir_ciface_pointwise_add(%9, %19, %29) : (!llvm."{ float*, float*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ f32*, f32*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }
433433
*">) -> ()
434434
llvm.return
435435
}
436-
llvm.func @_mlir_ciface_pointwise_add(!llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">) attributes {llvm.emit_c_interface}
436+
llvm.func @_mlir_ciface_pointwise_add(!llvm."{ float*, float*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ f32*, f32*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ f32*, f32*, i64, [2 x i64], [2 x i64] }*">) attributes {llvm.emit_c_interface}
437437
```
438438

439439
##### Convention For External Library Interoperability

mlir/docs/Dialects/Vector.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -264,11 +264,11 @@ Consider a vector of rank n with static sizes `{s_0, ... s_{n-1}}` (i.e. an
264264
MLIR `vector<s_0x...s_{n-1}xf32>`). Lowering such an `n-D` MLIR vector type to
265265
an LLVM descriptor can be done by either:
266266

267-
1. Flattening to a `1-D` vector: `!llvm<"(s_0*...*s_{n-1})xfloat">` in the
268-
MLIR LLVM dialect.
269-
2. Nested aggregate type of `1-D` vector:
270-
`!llvm<"[s_0x[s_1x[...<s_{n-1}xfloat>]]]">` in the MLIR LLVM dialect.
271-
3. A mix of both.
267+
1. Flattening to a `1-D` vector: `!llvm<"(s_0*...*s_{n-1})xfloat">` in the MLIR
268+
LLVM dialect.
269+
2. Nested aggregate type of `1-D` vector:
270+
`!llvm."[s_0x[s_1x[...<s_{n-1}xf32>]]]">` in the MLIR LLVM dialect.
271+
3. A mix of both.
272272

273273
There are multiple tradeoffs involved in choosing one or the other that we
274274
discuss. It is important to note that “a mix of both” immediately reduces to

mlir/docs/LLVMDialectMemRefConvention.md

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,11 @@ func @foo(%arg0: memref<?xf32>) -> () {
8282
8383
// Gets converted to the following
8484
// (using type alias for brevity):
85-
!llvm.memref_1d = type !llvm.struct<(ptr<float>, ptr<float>, i64,
85+
!llvm.memref_1d = type !llvm.struct<(ptr<f32>, ptr<f32>, i64,
8686
array<1xi64>, array<1xi64>)>
8787
88-
llvm.func @foo(%arg0: !llvm.ptr<float>, // Allocated pointer.
89-
%arg1: !llvm.ptr<float>, // Aligned pointer.
88+
llvm.func @foo(%arg0: !llvm.ptr<f32>, // Allocated pointer.
89+
%arg1: !llvm.ptr<f32>, // Aligned pointer.
9090
%arg2: i64, // Offset.
9191
%arg3: i64, // Size in dim 0.
9292
%arg4: i64) { // Stride in dim 0.
@@ -113,7 +113,7 @@ func @bar() {
113113
114114
// Gets converted to the following
115115
// (using type alias for brevity):
116-
!llvm.memref_1d = type !llvm.struct<(ptr<float>, ptr<float>, i64,
116+
!llvm.memref_1d = type !llvm.struct<(ptr<f32>, ptr<f32>, i64,
117117
array<1xi64>, array<1xi64>)>
118118
119119
llvm.func @bar() {
@@ -264,11 +264,11 @@ func @qux(%arg0: memref<?x?xf32>)
264264
265265
// Gets converted into the following
266266
// (using type alias for brevity):
267-
!llvm.memref_2d = type !llvm.struct<(ptr<float>, ptr<float>, i64,
267+
!llvm.memref_2d = type !llvm.struct<(ptr<f32>, ptr<f32>, i64,
268268
array<2xi64>, array<2xi64>)>
269269
270270
// Function with unpacked arguments.
271-
llvm.func @qux(%arg0: !llvm.ptr<float>, %arg1: !llvm.ptr<float>,
271+
llvm.func @qux(%arg0: !llvm.ptr<f32>, %arg1: !llvm.ptr<f32>,
272272
%arg2: i64, %arg3: i64, %arg4: i64,
273273
%arg5: i64, %arg6: i64) {
274274
// Populate memref descriptor (as per calling convention).
@@ -284,22 +284,22 @@ llvm.func @qux(%arg0: !llvm.ptr<float>, %arg1: !llvm.ptr<float>,
284284
// Store the descriptor in a stack-allocated space.
285285
%8 = llvm.mlir.constant(1 : index) : i64
286286
%9 = llvm.alloca %8 x !llvm.memref_2d
287-
: (i64) -> !llvm.ptr<struct<(ptr<float>, ptr<float>, i64,
287+
: (i64) -> !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64,
288288
array<2xi64>, array<2xi64>)>>
289-
llvm.store %7, %9 : !llvm.ptr<struct<(ptr<float>, ptr<float>, i64,
289+
llvm.store %7, %9 : !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64,
290290
array<2xi64>, array<2xi64>)>>
291291
292292
// Call the interface function.
293293
llvm.call @_mlir_ciface_qux(%9)
294-
: (!llvm.ptr<struct<(ptr<float>, ptr<float>, i64,
294+
: (!llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64,
295295
array<2xi64>, array<2xi64>)>>) -> ()
296296
297297
// The stored descriptor will be freed on return.
298298
llvm.return
299299
}
300300
301301
// Interface function.
302-
llvm.func @_mlir_ciface_qux(!llvm.ptr<struct<(ptr<float>, ptr<float>, i64,
302+
llvm.func @_mlir_ciface_qux(!llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64,
303303
array<2xi64>, array<2xi64>)>>)
304304
```
305305

@@ -310,13 +310,13 @@ func @foo(%arg0: memref<?x?xf32>) {
310310
311311
// Gets converted into the following
312312
// (using type alias for brevity):
313-
!llvm.memref_2d = type !llvm.struct<(ptr<float>, ptr<float>, i64,
313+
!llvm.memref_2d = type !llvm.struct<(ptr<f32>, ptr<f32>, i64,
314314
array<2xi64>, array<2xi64>)>
315-
!llvm.memref_2d_ptr = type !llvm.ptr<struct<(ptr<float>, ptr<float>, i64,
315+
!llvm.memref_2d_ptr = type !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64,
316316
array<2xi64>, array<2xi64>)>>
317317
318318
// Function with unpacked arguments.
319-
llvm.func @foo(%arg0: !llvm.ptr<float>, %arg1: !llvm.ptr<float>,
319+
llvm.func @foo(%arg0: !llvm.ptr<f32>, %arg1: !llvm.ptr<f32>,
320320
%arg2: i64, %arg3: i64, %arg4: i64,
321321
%arg5: i64, %arg6: i64) {
322322
llvm.return
@@ -336,7 +336,7 @@ llvm.func @_mlir_ciface_foo(%arg0: !llvm.memref_2d_ptr) {
336336
%6 = llvm.extractvalue %0[4, 0] : !llvm.memref_2d
337337
%7 = llvm.extractvalue %0[4, 1] : !llvm.memref_2d
338338
llvm.call @foo(%1, %2, %3, %4, %5, %6, %7)
339-
: (!llvm.ptr<float>, !llvm.ptr<float>, i64, i64, i64,
339+
: (!llvm.ptr<f32>, !llvm.ptr<f32>, i64, i64, i64,
340340
i64, i64) -> ()
341341
llvm.return
342342
}
@@ -395,7 +395,7 @@ is transformed into the equivalent of the following code:
395395
// Compute the linearized index from strides.
396396
// When strides or, in absence of explicit strides, the corresponding sizes are
397397
// dynamic, extract the stride value from the descriptor.
398-
%stride1 = llvm.extractvalue[4, 0] : !llvm.struct<(ptr<float>, ptr<float>, i64,
398+
%stride1 = llvm.extractvalue[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
399399
array<4xi64>, array<4xi64>)>
400400
%addr1 = muli %stride1, %1 : i64
401401
@@ -415,21 +415,21 @@ is transformed into the equivalent of the following code:
415415
416416
// If the linear offset is known to be zero, it can also be omitted. If it is
417417
// dynamic, it is extracted from the descriptor.
418-
%offset = llvm.extractvalue[2] : !llvm.struct<(ptr<float>, ptr<float>, i64,
418+
%offset = llvm.extractvalue[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
419419
array<4xi64>, array<4xi64>)>
420420
%addr7 = addi %addr6, %offset : i64
421421
422422
// All accesses are based on the aligned pointer.
423-
%aligned = llvm.extractvalue[1] : !llvm.struct<(ptr<float>, ptr<float>, i64,
423+
%aligned = llvm.extractvalue[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
424424
array<4xi64>, array<4xi64>)>
425425
426426
// Get the address of the data pointer.
427427
%ptr = llvm.getelementptr %aligned[%addr8]
428-
: !llvm.struct<(ptr<float>, ptr<float>, i64, array<4xi64>, array<4xi64>)>
429-
-> !llvm.ptr<float>
428+
: !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<4xi64>, array<4xi64>)>
429+
-> !llvm.ptr<f32>
430430
431431
// Perform the actual load.
432-
%0 = llvm.load %ptr : !llvm.ptr<float>
432+
%0 = llvm.load %ptr : !llvm.ptr<f32>
433433
```
434434

435435
For stores, the address computation code is identical and only the actual store

0 commit comments

Comments
 (0)