diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -1591,6 +1591,13 @@ auto cast = dyn_cast_or_null(getOperand().getDefiningOp()); if (cast && cast.getOperand().getType() == getType()) return cast.getOperand(); + + // Fold IndexCast(constant) -> constant + // A little hack because we go through int. Otherwise, the size + // of the constant might need to change. + if (auto value = cstOperands[0].dyn_cast_or_null()) + return IntegerAttr::get(getType(), value.getInt()); + return {}; } diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -887,3 +887,15 @@ // CHECK: return %[[ARG_0]] : i16 return %12 : i16 } + +// CHECK-LABEL: func @index_cast_fold +func @index_cast_fold() -> (i16, index) { + %c4 = constant 4 : index + %1 = index_cast %c4 : index to i16 + %c4_i16 = constant 4 : i16 + %2 = index_cast %c4_i16 : i16 to index + // CHECK: %[[C4_I16:.*]] = constant 4 : i16 + // CHECK: %[[C4:.*]] = constant 4 : index + // CHECK: return %[[C4_I16]], %[[C4]] : i16, index + return %1, %2 : i16, index +}