diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td @@ -88,7 +88,7 @@ let assemblyFormat = [{ $operand attr-dict `:` type($operand) `to` type($result) }]; - let hasCanonicalizer = 1; + let hasFolder = 1; } // ----- diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp @@ -116,9 +116,23 @@ // spirv.BitcastOp //===----------------------------------------------------------------------===// -void spirv::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add(context); +OpFoldResult spirv::BitcastOp::fold(ArrayRef /*operands*/) { + Value arg = getOperand(); + if (getType() == arg.getType()) + return arg; + + // Look through nested bitcasts. + if (auto bitcast = arg.getDefiningOp()) { + Value nestedArg = bitcast.getOperand(); + if (nestedArg.getType() == getType()) + return nestedArg; + + getOperandMutable().assign(nestedArg); + return getResult(); + } + + // TODO(kuhar): Consider constant-folding the operand attribute. + return getResult(); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.td b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.td --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.td +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.td @@ -13,13 +13,6 @@ include "mlir/IR/PatternBase.td" include "mlir/Dialect/SPIRV/IR/SPIRVOps.td" -//===----------------------------------------------------------------------===// -// spirv.Bitcast -//===----------------------------------------------------------------------===// - -def ConvertChainedBitcast : Pat<(SPIRV_BitcastOp (SPIRV_BitcastOp $operand)), - (SPIRV_BitcastOp $operand)>; - //===----------------------------------------------------------------------===// // spirv.LogicalNot //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir @@ -86,6 +86,30 @@ // ----- +// CHECK-LABEL: @convert_bitcast_roundtip +// CHECK-SAME: %[[ARG:.+]]: i64 +func.func @convert_bitcast_roundtip(%arg0 : i64) -> i64 { + // CHECK: spirv.ReturnValue %[[ARG]] + %0 = spirv.Bitcast %arg0 : i64 to f64 + %1 = spirv.Bitcast %0 : f64 to i64 + spirv.ReturnValue %1 : i64 +} + +// ----- + +// CHECK-LABEL: @convert_bitcast_chained_roundtip +// CHECK-SAME: %[[ARG:.+]]: i64 +func.func @convert_bitcast_chained_roundtip(%arg0 : i64) -> i64 { + // CHECK: spirv.ReturnValue %[[ARG]] + %0 = spirv.Bitcast %arg0 : i64 to f64 + %1 = spirv.Bitcast %0 : f64 to vector<2xi32> + %2 = spirv.Bitcast %1 : vector<2xi32> to vector<2xf32> + %3 = spirv.Bitcast %2 : vector<2xf32> to i64 + spirv.ReturnValue %3 : i64 +} + +// ----- + //===----------------------------------------------------------------------===// // spirv.CompositeExtract //===----------------------------------------------------------------------===//