diff --git a/mlir/include/mlir/Dialect/StandardOps/CMakeLists.txt b/mlir/include/mlir/Dialect/StandardOps/CMakeLists.txt --- a/mlir/include/mlir/Dialect/StandardOps/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/StandardOps/CMakeLists.txt @@ -1,4 +1,5 @@ set(LLVM_TARGET_DEFINITIONS Ops.td) +mlir_tablegen(Ops.inc -gen-rewriters) mlir_tablegen(Ops.h.inc -gen-op-decls) mlir_tablegen(Ops.cpp.inc -gen-op-defs) mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.td b/mlir/include/mlir/Dialect/StandardOps/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/Ops.td @@ -753,8 +753,14 @@ static bool areCastCompatible(Type a, Type b); }]; + let hasCanonicalizer = 1; let hasFolder = 0; } +def TypesAreIdentical : Constraint>; +def CombineIndexCastPattern : + Pat<(IndexCastOp:$out (IndexCastOp $in)), + (replaceWithValue $in), + [(TypesAreIdentical $in, $out)]>; def FPExtOp : CastOp<"fpext">, Arguments<(ins AnyType:$in)> { let summary = "cast from floating-point to wider floating-point"; diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -30,6 +30,10 @@ // Pull in all enum type definitions and utility function declarations. #include "mlir/Dialect/StandardOps/OpsEnums.cpp.inc" +namespace mlir { +#include "mlir/Dialect/StandardOps/Ops.inc" +} + using namespace mlir; //===----------------------------------------------------------------------===// @@ -1627,6 +1631,12 @@ (a.isa() && b.isIndex()); } +void IndexCastOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + /// indexCast(indexCast x) -> x + results.insert(context); +} + //===----------------------------------------------------------------------===// // LoadOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -878,3 +878,11 @@ // CHECK: return %[[C7]], %[[C11]] return %7, %8 : index, index } + +// CHECK-LABEL: func @index_cast +func @index_cast(%arg0: i16) -> (i16) { + %11 = index_cast %arg0 : i16 to index + %12 = index_cast %11 : index to i16 + // CHECK: return %arg0 : i16 + return %12 : i16 +}