Skip to content

Commit ca6baf1

Browse files
committed
[MLIR][std] Introduce bitcast operation
This patch introduces a bitcast operation to the standard dialect. RFC: https://llvm.discourse.group/t/rfc-introduce-a-bitcast-op/3774 Reviewed By: silvas Differential Revision: https://reviews.llvm.org/D105376
1 parent 276be84 commit ca6baf1

File tree

5 files changed

+196
-0
lines changed

5 files changed

+196
-0
lines changed

mlir/include/mlir/Dialect/StandardOps/IR/Ops.td

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,32 @@ def AtomicYieldOp : Std_Op<"atomic_yield", [
456456
let assemblyFormat = "$result attr-dict `:` type($result)";
457457
}
458458

459+
//===----------------------------------------------------------------------===//
460+
// BitcastOp
461+
//===----------------------------------------------------------------------===//
462+
463+
def BitcastOp : ArithmeticCastOp<"bitcast"> {
464+
let summary = "bitcast between values of equal bit width";
465+
let description = [{
466+
Bitcast an integer or floating point value to an integer or floating point
467+
value of equal bit width. When operating on vectors, casts elementwise.
468+
469+
Note that this implements a logical bitcast independent of target
470+
endianness. This allows constant folding without target information and is
471+
consitent with the bitcast constant folders in LLVM (see
472+
https://github.com/llvm/llvm-project/blob/18c19414eb/llvm/lib/IR/ConstantFold.cpp#L168)
473+
For targets where the source and target type have the same endianness (which
474+
is the standard), this cast will also change no bits at runtime, but it may
475+
still require an operation, for example if the machine has different
476+
floating point and integer register files. For targets that have a different
477+
endianness for the source and target types (e.g. float is big-endian and
478+
integer is little-endian) a proper lowering would add operations to swap the
479+
order of words in addition to the bitcast.
480+
}];
481+
let hasFolder = 1;
482+
}
483+
484+
459485
//===----------------------------------------------------------------------===//
460486
// BranchOp
461487
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/StandardOps/IR/Ops.cpp

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "mlir/IR/Value.h"
2424
#include "mlir/Support/MathExtras.h"
2525
#include "mlir/Transforms/InliningUtils.h"
26+
#include "llvm/ADT/APFloat.h"
2627
#include "llvm/ADT/STLExtras.h"
2728
#include "llvm/ADT/StringSwitch.h"
2829
#include "llvm/Support/FormatVariadic.h"
@@ -482,6 +483,62 @@ static LogicalResult verify(AtomicYieldOp op) {
482483
return success();
483484
}
484485

486+
//===----------------------------------------------------------------------===//
487+
// BitcastOp
488+
//===----------------------------------------------------------------------===//
489+
490+
bool BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
491+
assert(inputs.size() == 1 && outputs.size() == 1 &&
492+
"bitcast op expects one operand and result");
493+
Type a = inputs.front(), b = outputs.front();
494+
if (a.isSignlessIntOrFloat() && b.isSignlessIntOrFloat())
495+
return a.getIntOrFloatBitWidth() == b.getIntOrFloatBitWidth();
496+
return areVectorCastSimpleCompatible(a, b, areCastCompatible);
497+
}
498+
499+
OpFoldResult BitcastOp::fold(ArrayRef<Attribute> operands) {
500+
assert(operands.size() == 1 && "bitcastop expects 1 operand");
501+
502+
// Bitcast of bitcast
503+
auto *sourceOp = getOperand().getDefiningOp();
504+
if (auto sourceBitcast = dyn_cast_or_null<BitcastOp>(sourceOp)) {
505+
setOperand(sourceBitcast.getOperand());
506+
return getResult();
507+
}
508+
509+
auto operand = operands[0];
510+
if (!operand)
511+
return {};
512+
513+
Type resType = getResult().getType();
514+
515+
if (auto denseAttr = operand.dyn_cast<DenseFPElementsAttr>()) {
516+
Type elType = getElementTypeOrSelf(resType);
517+
return denseAttr.mapValues(
518+
elType, [](const APFloat &f) { return f.bitcastToAPInt(); });
519+
}
520+
if (auto denseAttr = operand.dyn_cast<DenseIntElementsAttr>()) {
521+
Type elType = getElementTypeOrSelf(resType);
522+
// mapValues does its own bitcast to the target type.
523+
return denseAttr.mapValues(elType, [](const APInt &i) { return i; });
524+
}
525+
526+
APInt bits;
527+
if (auto floatAttr = operand.dyn_cast<FloatAttr>())
528+
bits = floatAttr.getValue().bitcastToAPInt();
529+
else if (auto intAttr = operand.dyn_cast<IntegerAttr>())
530+
bits = intAttr.getValue();
531+
else
532+
return {};
533+
534+
if (resType.isa<IntegerType>())
535+
return IntegerAttr::get(resType, bits);
536+
if (auto resFloatType = resType.dyn_cast<FloatType>())
537+
return FloatAttr::get(resType,
538+
APFloat(resFloatType.getFloatSemantics(), bits));
539+
return {};
540+
}
541+
485542
//===----------------------------------------------------------------------===//
486543
// BranchOp
487544
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Standard/canonicalize.mlir

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,3 +331,102 @@ func @selToNot(%arg0: i1) -> i1 {
331331
%res = select %arg0, %false, %true : i1
332332
return %res : i1
333333
}
334+
335+
// -----
336+
337+
// CHECK-LABEL: @bitcastSameType(
338+
// CHECK-SAME: %[[ARG:[a-zA-Z0-9_]*]]
339+
func @bitcastSameType(%arg : f32) -> f32 {
340+
// CHECK: return %[[ARG]]
341+
%res = bitcast %arg : f32 to f32
342+
return %res : f32
343+
}
344+
345+
// -----
346+
347+
// CHECK-LABEL: @bitcastConstantFPtoI(
348+
func @bitcastConstantFPtoI() -> i32 {
349+
// CHECK: %[[C0:.+]] = constant 0 : i32
350+
// CHECK: return %[[C0]]
351+
%c0 = constant 0.0 : f32
352+
%res = bitcast %c0 : f32 to i32
353+
return %res : i32
354+
}
355+
356+
// -----
357+
358+
// CHECK-LABEL: @bitcastConstantItoFP(
359+
func @bitcastConstantItoFP() -> f32 {
360+
// CHECK: %[[C0:.+]] = constant 0.0{{.*}} : f32
361+
// CHECK: return %[[C0]]
362+
%c0 = constant 0 : i32
363+
%res = bitcast %c0 : i32 to f32
364+
return %res : f32
365+
}
366+
367+
// -----
368+
369+
// CHECK-LABEL: @bitcastConstantFPtoFP(
370+
func @bitcastConstantFPtoFP() -> f16 {
371+
// CHECK: %[[C0:.+]] = constant 0.0{{.*}} : f16
372+
// CHECK: return %[[C0]]
373+
%c0 = constant 0.0 : bf16
374+
%res = bitcast %c0 : bf16 to f16
375+
return %res : f16
376+
}
377+
378+
// -----
379+
380+
// CHECK-LABEL: @bitcastConstantVecFPtoI(
381+
func @bitcastConstantVecFPtoI() -> vector<3xf32> {
382+
// CHECK: %[[C0:.+]] = constant dense<0.0{{.*}}> : vector<3xf32>
383+
// CHECK: return %[[C0]]
384+
%c0 = constant dense<0> : vector<3xi32>
385+
%res = bitcast %c0 : vector<3xi32> to vector<3xf32>
386+
return %res : vector<3xf32>
387+
}
388+
389+
// -----
390+
391+
// CHECK-LABEL: @bitcastConstantVecItoFP(
392+
func @bitcastConstantVecItoFP() -> vector<3xi32> {
393+
// CHECK: %[[C0:.+]] = constant dense<0> : vector<3xi32>
394+
// CHECK: return %[[C0]]
395+
%c0 = constant dense<0.0> : vector<3xf32>
396+
%res = bitcast %c0 : vector<3xf32> to vector<3xi32>
397+
return %res : vector<3xi32>
398+
}
399+
400+
// -----
401+
402+
// CHECK-LABEL: @bitcastConstantVecFPtoFP(
403+
func @bitcastConstantVecFPtoFP() -> vector<3xbf16> {
404+
// CHECK: %[[C0:.+]] = constant dense<0.0{{.*}}> : vector<3xbf16>
405+
// CHECK: return %[[C0]]
406+
%c0 = constant dense<0.0> : vector<3xf16>
407+
%res = bitcast %c0 : vector<3xf16> to vector<3xbf16>
408+
return %res : vector<3xbf16>
409+
}
410+
411+
// -----
412+
413+
// CHECK-LABEL: @bitcastBackAndForth(
414+
// CHECK-SAME: %[[ARG:[a-zA-Z0-9_]*]]
415+
func @bitcastBackAndForth(%arg : i32) -> i32 {
416+
// CHECK: return %[[ARG]]
417+
%f = bitcast %arg : i32 to f32
418+
%res = bitcast %f : f32 to i32
419+
return %res : i32
420+
}
421+
422+
// -----
423+
424+
// CHECK-LABEL: @bitcastOfBitcast(
425+
// CHECK-SAME: %[[ARG:[a-zA-Z0-9_]*]]
426+
func @bitcastOfBitcast(%arg : i16) -> i16 {
427+
// CHECK: return %[[ARG]]
428+
%f = bitcast %arg : i16 to f16
429+
%bf = bitcast %f : f16 to bf16
430+
%res = bitcast %bf : bf16 to i16
431+
return %res : i16
432+
}

mlir/test/Dialect/Standard/invalid.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,3 +85,11 @@ func @call() {
8585
%0:2 = call @return_i32_f32() : () -> (f32, i32)
8686
return
8787
}
88+
89+
// -----
90+
91+
func @bitcast_different_bit_widths(%arg : f16) -> f32 {
92+
// expected-error@+1 {{are cast incompatible}}
93+
%res = bitcast %arg : f16 to f32
94+
return %res : f32
95+
}

mlir/test/Dialect/Standard/ops.mlir

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,9 @@ func @constant_complex_f64() -> complex<f64> {
8080
%result = constant [0.1 : f64, -1.0 : f64] : complex<f64>
8181
return %result : complex<f64>
8282
}
83+
84+
// CHECK-LABEL: func @bitcast(
85+
func @bitcast(%arg : f32) -> i32 {
86+
%res = bitcast %arg : f32 to i32
87+
return %res : i32
88+
}

0 commit comments

Comments
 (0)