diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -456,6 +456,32 @@ let assemblyFormat = "$result attr-dict `:` type($result)"; } +//===----------------------------------------------------------------------===// +// BitcastOp +//===----------------------------------------------------------------------===// + +def BitcastOp : ArithmeticCastOp<"bitcast"> { + let summary = "bitcast between values of equal bit width"; + let description = [{ + Bitcast an integer or floating point value to an integer or floating point + value of equal bit width. When operating on vectors, casts elementwise. + + Note that this implements a logical bitcast independent of target + endianness. This allows constant folding without target information and is + consitent with the bitcast constant folders in LLVM (see + https://github.com/llvm/llvm-project/blob/18c19414eb/llvm/lib/IR/ConstantFold.cpp#L168) + For targets where the source and target type have the same endianness (which + is the standard), this cast will also change no bits at runtime, but it may + still require an operation, for example if the machine has different + floating point and integer register files. For targets that have a different + endianness for the source and target types (e.g. float is big-endian and + integer is little-endian) a proper lowering would add operations to swap the + order of words in addition to the bitcast. + }]; + let hasFolder = 1; +} + + //===----------------------------------------------------------------------===// // BranchOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -23,6 +23,7 @@ #include "mlir/IR/Value.h" #include "mlir/Support/MathExtras.h" #include "mlir/Transforms/InliningUtils.h" +#include "llvm/ADT/APFloat.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/FormatVariadic.h" @@ -482,6 +483,62 @@ return success(); } +//===----------------------------------------------------------------------===// +// BitcastOp +//===----------------------------------------------------------------------===// + +bool BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { + assert(inputs.size() == 1 && outputs.size() == 1 && + "bitcast op expects one operand and result"); + Type a = inputs.front(), b = outputs.front(); + if (a.isSignlessIntOrFloat() && b.isSignlessIntOrFloat()) + return a.getIntOrFloatBitWidth() == b.getIntOrFloatBitWidth(); + return areVectorCastSimpleCompatible(a, b, areCastCompatible); +} + +OpFoldResult BitcastOp::fold(ArrayRef operands) { + assert(operands.size() == 1 && "bitcastop expects 1 operand"); + + // Bitcast of bitcast + auto *sourceOp = getOperand().getDefiningOp(); + if (auto sourceBitcast = dyn_cast_or_null(sourceOp)) { + setOperand(sourceBitcast.getOperand()); + return getResult(); + } + + auto operand = operands[0]; + if (!operand) + return {}; + + Type resType = getResult().getType(); + + if (auto denseAttr = operand.dyn_cast()) { + Type elType = getElementTypeOrSelf(resType); + return denseAttr.mapValues( + elType, [](const APFloat &f) { return f.bitcastToAPInt(); }); + } + if (auto denseAttr = operand.dyn_cast()) { + Type elType = getElementTypeOrSelf(resType); + // mapValues does its own bitcast to the target type. + return denseAttr.mapValues(elType, [](const APInt &i) { return i; }); + } + + APInt bits; + if (auto floatAttr = operand.dyn_cast()) + bits = floatAttr.getValue().bitcastToAPInt(); + else if (auto intAttr = operand.dyn_cast()) + bits = intAttr.getValue(); + else + return {}; + + if (resType.isa()) + return IntegerAttr::get(resType, bits); + if (auto resFloatType = resType.dyn_cast()) + return FloatAttr::get(resType, + APFloat(resFloatType.getFloatSemantics(), bits)); + return {}; +} + //===----------------------------------------------------------------------===// // BranchOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir --- a/mlir/test/Dialect/Standard/canonicalize.mlir +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -331,3 +331,102 @@ %res = select %arg0, %false, %true : i1 return %res : i1 } + +// ----- + +// CHECK-LABEL: @bitcastSameType( +// CHECK-SAME: %[[ARG:[a-zA-Z0-9_]*]] +func @bitcastSameType(%arg : f32) -> f32 { + // CHECK: return %[[ARG]] + %res = bitcast %arg : f32 to f32 + return %res : f32 +} + +// ----- + +// CHECK-LABEL: @bitcastConstantFPtoI( +func @bitcastConstantFPtoI() -> i32 { + // CHECK: %[[C0:.+]] = constant 0 : i32 + // CHECK: return %[[C0]] + %c0 = constant 0.0 : f32 + %res = bitcast %c0 : f32 to i32 + return %res : i32 +} + +// ----- + +// CHECK-LABEL: @bitcastConstantItoFP( +func @bitcastConstantItoFP() -> f32 { + // CHECK: %[[C0:.+]] = constant 0.0{{.*}} : f32 + // CHECK: return %[[C0]] + %c0 = constant 0 : i32 + %res = bitcast %c0 : i32 to f32 + return %res : f32 +} + +// ----- + +// CHECK-LABEL: @bitcastConstantFPtoFP( +func @bitcastConstantFPtoFP() -> f16 { + // CHECK: %[[C0:.+]] = constant 0.0{{.*}} : f16 + // CHECK: return %[[C0]] + %c0 = constant 0.0 : bf16 + %res = bitcast %c0 : bf16 to f16 + return %res : f16 +} + +// ----- + +// CHECK-LABEL: @bitcastConstantVecFPtoI( +func @bitcastConstantVecFPtoI() -> vector<3xf32> { + // CHECK: %[[C0:.+]] = constant dense<0.0{{.*}}> : vector<3xf32> + // CHECK: return %[[C0]] + %c0 = constant dense<0> : vector<3xi32> + %res = bitcast %c0 : vector<3xi32> to vector<3xf32> + return %res : vector<3xf32> +} + +// ----- + +// CHECK-LABEL: @bitcastConstantVecItoFP( +func @bitcastConstantVecItoFP() -> vector<3xi32> { + // CHECK: %[[C0:.+]] = constant dense<0> : vector<3xi32> + // CHECK: return %[[C0]] + %c0 = constant dense<0.0> : vector<3xf32> + %res = bitcast %c0 : vector<3xf32> to vector<3xi32> + return %res : vector<3xi32> +} + +// ----- + +// CHECK-LABEL: @bitcastConstantVecFPtoFP( +func @bitcastConstantVecFPtoFP() -> vector<3xbf16> { + // CHECK: %[[C0:.+]] = constant dense<0.0{{.*}}> : vector<3xbf16> + // CHECK: return %[[C0]] + %c0 = constant dense<0.0> : vector<3xf16> + %res = bitcast %c0 : vector<3xf16> to vector<3xbf16> + return %res : vector<3xbf16> +} + +// ----- + +// CHECK-LABEL: @bitcastBackAndForth( +// CHECK-SAME: %[[ARG:[a-zA-Z0-9_]*]] +func @bitcastBackAndForth(%arg : i32) -> i32 { + // CHECK: return %[[ARG]] + %f = bitcast %arg : i32 to f32 + %res = bitcast %f : f32 to i32 + return %res : i32 +} + +// ----- + +// CHECK-LABEL: @bitcastOfBitcast( +// CHECK-SAME: %[[ARG:[a-zA-Z0-9_]*]] +func @bitcastOfBitcast(%arg : i16) -> i16 { + // CHECK: return %[[ARG]] + %f = bitcast %arg : i16 to f16 + %bf = bitcast %f : f16 to bf16 + %res = bitcast %bf : bf16 to i16 + return %res : i16 +} diff --git a/mlir/test/Dialect/Standard/invalid.mlir b/mlir/test/Dialect/Standard/invalid.mlir --- a/mlir/test/Dialect/Standard/invalid.mlir +++ b/mlir/test/Dialect/Standard/invalid.mlir @@ -85,3 +85,11 @@ %0:2 = call @return_i32_f32() : () -> (f32, i32) return } + +// ----- + +func @bitcast_different_bit_widths(%arg : f16) -> f32 { + // expected-error@+1 {{are cast incompatible}} + %res = bitcast %arg : f16 to f32 + return %res : f32 +} diff --git a/mlir/test/Dialect/Standard/ops.mlir b/mlir/test/Dialect/Standard/ops.mlir --- a/mlir/test/Dialect/Standard/ops.mlir +++ b/mlir/test/Dialect/Standard/ops.mlir @@ -80,3 +80,9 @@ %result = constant [0.1 : f64, -1.0 : f64] : complex return %result : complex } + +// CHECK-LABEL: func @bitcast( +func @bitcast(%arg : f32) -> i32 { + %res = bitcast %arg : f32 to i32 + return %res : i32 +}