diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -455,6 +455,19 @@ let assemblyFormat = "$result attr-dict `:` type($result)"; } +//===----------------------------------------------------------------------===// +// BitcastOp +//===----------------------------------------------------------------------===// + +def BitcastOp : ArithmeticCastOp<"bitcast">, Arguments<(ins AnyType:$in)> { + let summary = "bitcast between values of equal bit width"; + let description = [{ + Bitcast an integer or floating point value to an integer or floating point + value of equal bit width. At runtime this is always a noop cast. + }]; +} + + //===----------------------------------------------------------------------===// // BranchOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -482,6 +482,19 @@ return success(); } +//===----------------------------------------------------------------------===// +// BitcastOp +//===----------------------------------------------------------------------===// + +bool BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { + if (inputs.size() != 1 || outputs.size() != 1) + return false; + Type a = inputs.front(), b = outputs.front(); + if (a.isSignlessIntOrFloat() && b.isSignlessIntOrFloat()) + return a.getIntOrFloatBitWidth() == b.getIntOrFloatBitWidth(); + return areVectorCastSimpleCompatible(a, b, areCastCompatible); +} + //===----------------------------------------------------------------------===// // BranchOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir --- a/mlir/test/Dialect/Standard/canonicalize.mlir +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -331,3 +331,13 @@ %res = select %arg0, %false, %true : i1 return %res : i1 } + +// ----- + +// CHECK-LABEL: @sameTypeBitcast( +// CHECK-SAME: %[[ARG:[a-zA-Z0-9_]*]] +func @sameTypeBitcast(%arg : f32) -> f32 { + // CHECK: return %[[ARG]] + %res = bitcast %arg : f32 to f32 + return %res : f32 +} diff --git a/mlir/test/Dialect/Standard/invalid.mlir b/mlir/test/Dialect/Standard/invalid.mlir --- a/mlir/test/Dialect/Standard/invalid.mlir +++ b/mlir/test/Dialect/Standard/invalid.mlir @@ -69,3 +69,11 @@ %0 = constant [1.0 : f32, -1.0 : f64] : complex return } + +// ----- + +func @bitcast_different_bit_widths(%arg : f16) -> f32 { + // expected-error@+1 {{are cast incompatible}} + %res = bitcast %arg : f16 to f32 + return %res : f32 +} diff --git a/mlir/test/Dialect/Standard/ops.mlir b/mlir/test/Dialect/Standard/ops.mlir --- a/mlir/test/Dialect/Standard/ops.mlir +++ b/mlir/test/Dialect/Standard/ops.mlir @@ -80,3 +80,9 @@ %result = constant [0.1 : f64, -1.0 : f64] : complex return %result : complex } + +// CHECK-LABEL: func @bitcast( +func @bitcast(%arg : f32) -> i32 { + %res = bitcast %arg : f32 to i32 + return %res : i32 +}