diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -2070,6 +2070,8 @@ let printer = [{ return printStandardCastOp(this->getOperation(), p); }]; + + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -2181,6 +2181,14 @@ return success(); } +OpFoldResult TruncateIOp::fold(ArrayRef operands) { + // trunci(zexti(a)) -> a + if (matchPattern(getOperand(), m_Op())) + return getOperand().getDefiningOp()->getOperand(0); + + return nullptr; +} + //===----------------------------------------------------------------------===// // UnsignedDivIOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -1059,3 +1059,53 @@ return %2 : tensor } +// ----- + +// CHECK-LABEL: func @fold_trunci +// CHECK-SAME: (%[[ARG0:[0-9a-z]*]]: i1) +func @fold_trunci(%arg0: i1) -> i1 attributes {} { + // CHECK-NEXT: return %[[ARG0]] : i1 + %0 = zexti %arg0 : i1 to i8 + %1 = trunci %0 : i8 to i1 + return %1 : i1 +} + +// ----- + +// CHECK-LABEL: func @fold_trunci_vector +// CHECK-SAME: (%[[ARG0:[0-9a-z]*]]: vector<4xi1>) +func @fold_trunci_vector(%arg0: vector<4xi1>) -> vector<4xi1> attributes {} { + // CHECK-NEXT: return %[[ARG0]] : vector<4xi1> + %0 = zexti %arg0 : vector<4xi1> to vector<4xi8> + %1 = trunci %0 : vector<4xi8> to vector<4xi1> + return %1 : vector<4xi1> +} + +// ----- + +// TODO Canonicalize this into: +// zexti %arg0 : i1 to i2 + +// CHECK-LABEL: func @do_not_fold_trunci +// CHECK-SAME: (%[[ARG0:[0-9a-z]*]]: i1) +func @do_not_fold_trunci(%arg0: i1) -> i2 attributes {} { + // CHECK-NEXT: zexti %[[ARG0]] : i1 to i8 + // CHECK-NEXT: %[[RES:[0-9a-z]*]] = trunci %{{.*}} : i8 to i2 + // CHECK-NEXT: return %[[RES]] : i2 + %0 = zexti %arg0 : i1 to i8 + %1 = trunci %0 : i8 to i2 + return %1 : i2 +} + +// ----- + +// CHECK-LABEL: func @do_not_fold_trunci_vector +// CHECK-SAME: (%[[ARG0:[0-9a-z]*]]: vector<4xi1>) +func @do_not_fold_trunci_vector(%arg0: vector<4xi1>) -> vector<4xi2> attributes {} { + // CHECK-NEXT: zexti %[[ARG0]] : vector<4xi1> to vector<4xi8> + // CHECK-NEXT: %[[RES:[0-9a-z]*]] = trunci %{{.*}} : vector<4xi8> to vector<4xi2> + // CHECK-NEXT: return %[[RES]] : vector<4xi2> + %0 = zexti %arg0 : vector<4xi1> to vector<4xi8> + %1 = trunci %0 : vector<4xi8> to vector<4xi2> + return %1 : vector<4xi2> +}