diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -1592,8 +1592,13 @@ if (cast && cast.getOperand().getType() == getType()) { return cast.getOperand(); } + auto value = cstOperands[0].dyn_cast_or_null(); + if (value) { + // A little hack because we go through int. Otherwise, the size + // of the constant might need to change. + return IntegerAttr::get(getType(), value.getInt()); + } return {}; - } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -887,3 +887,15 @@ // CHECK: return %[[ARG_0]] : i16 return %12 : i16 } + +// CHECK-LABEL: func @index_cast_fold +func @index_cast_fold() -> (i16, index) { + %c4 = constant 4 : index + %1 = index_cast %c4 : index to i16 + %c4_i16 = constant 4 : i16 + %2 = index_cast %c4_i16 : i16 to index + // CHECK: %[[C4_I16:.*]] = constant 4 : i16 + // CHECK: %[[C4:.*]] = constant 4 : index + // CHECK: return %[[C4_I16]], %[[C4]] : i16, index + return %1, %2 : i16, index +}