diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.td b/mlir/include/mlir/Dialect/StandardOps/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/Ops.td @@ -754,7 +754,7 @@ }]; let hasCanonicalizer = 1; - let hasFolder = 0; + let hasFolder = 1; } def TypesAreIdentical : Constraint>; def CombineIndexCastPattern : diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -1637,6 +1637,16 @@ results.insert(context); } +OpFoldResult IndexCastOp::fold(ArrayRef operands) { + auto a = operands[0].dyn_cast_or_null(); + if(a) { + // A little hack because we go through int. Otherwise, the size + // of the constant might need to change. + return IntegerAttr::get(getType(), a.getInt()); + } + return IntegerAttr(); +} + //===----------------------------------------------------------------------===// // LoadOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -886,3 +886,13 @@ // CHECK: return %arg0 : i16 return %12 : i16 } + +// CHECK-LABEL: func @index_cast_fold +func @index_cast_fold() -> (i16, index) { + %c4 = constant 4 : index + %1 = index_cast %c4 : index to i16 + %c4_i16 = constant 4 : i16 + %2 = index_cast %c4_i16 : i16 to index + // CHECK: return %c4_i16, %c4 : i16, index + return %1, %2 : i16, index +}