diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.td b/mlir/include/mlir/Dialect/StandardOps/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/Ops.td @@ -758,7 +758,7 @@ static bool areCastCompatible(Type a, Type b); }]; - let hasFolder = 0; + let hasFolder = 1; } def FPExtOp : CastOp<"fpext">, Arguments<(ins AnyType:$in)> { diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -1586,6 +1586,14 @@ (a.isa() && b.isIndex()); } +OpFoldResult IndexCastOp::fold(ArrayRef cstOperands) { + // Fold IndexCast(IndexCast(x)) -> x + auto cast = dyn_cast_or_null(getOperand().getDefiningOp()); + if (cast && cast.getOperand().getType() == getType()) + return cast.getOperand(); + return {}; +} + //===----------------------------------------------------------------------===// // LoadOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -878,3 +878,12 @@ // CHECK: return %[[C7]], %[[C11]] return %7, %8 : index, index } + +// CHECK-LABEL: func @index_cast +// CHECK-SAME: %[[ARG_0:arg[0-9]+]]: i16 +func @index_cast(%arg0: i16) -> (i16) { + %11 = index_cast %arg0 : i16 to index + %12 = index_cast %11 : index to i16 + // CHECK: return %[[ARG_0]] : i16 + return %12 : i16 +}