diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -2183,7 +2183,9 @@ OpFoldResult TruncateIOp::fold(ArrayRef operands) { // trunci(zexti(a)) -> a - if (matchPattern(getOperand(), m_Op())) + // trunci(sexti(a)) -> a + if (matchPattern(getOperand(), m_Op()) || + matchPattern(getOperand(), m_Op())) return getOperand().getDefiningOp()->getOperand(0); return nullptr; diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -1109,3 +1109,14 @@ %1 = trunci %0 : vector<4xi8> to vector<4xi2> return %1 : vector<4xi2> } + +// ----- + +// CHECK-LABEL: func @fold_trunci_sexti +// CHECK-SAME: (%[[ARG0:[0-9a-z]*]]: i1) +func @fold_trunci_sexti(%arg0: i1) -> i1 attributes {} { + // CHECK-NEXT: return %[[ARG0]] : i1 + %0 = sexti %arg0 : i1 to i8 + %1 = trunci %0 : i8 to i1 + return %1 : i1 +}