diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td @@ -89,6 +89,7 @@ $operand attr-dict `:` type($operand) `to` type($result) }]; let hasCanonicalizer = 1; + let hasFolder = 1; } // ----- diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp @@ -116,6 +116,13 @@ // spirv.BitcastOp //===----------------------------------------------------------------------===// +OpFoldResult spirv::BitcastOp::fold(ArrayRef /*operands*/) { + if (getType() == getOperand().getType()) + return getOperand(); + + return Value{}; +} + void spirv::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir @@ -86,6 +86,29 @@ // ----- +// CHECK-LABEL: @convert_bitcast_roundtip +// CHECK-SAME: %[[ARG:.+]]: i64 +func.func @convert_bitcast_roundtip(%arg0 : i64) -> i64 { + // CHECK: spirv.ReturnValue %[[ARG]] + %0 = spirv.Bitcast %arg0 : i64 to f64 + %1 = spirv.Bitcast %0 : f64 to i64 + spirv.ReturnValue %1 : i64 +} + +// ----- + +// CHECK-LABEL: @convert_bitcast_chained_roundtip +// CHECK-SAME: %[[ARG:.+]]: i64 +func.func @convert_bitcast_chained_roundtip(%arg0 : i64) -> i64 { + // CHECK: spirv.ReturnValue %[[ARG]] + %0 = spirv.Bitcast %arg0 : i64 to f64 + %1 = spirv.Bitcast %0 : f64 to vector<2xi32> + %2 = spirv.Bitcast %1 : vector<2xi32> to i64 + spirv.ReturnValue %2 : i64 +} + +// ----- + //===----------------------------------------------------------------------===// // spirv.CompositeExtract //===----------------------------------------------------------------------===//