Index: mlir/lib/Dialect/Arith/IR/ArithOps.cpp =================================================================== --- mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -1241,11 +1241,27 @@ OpFoldResult arith::TruncIOp::fold(ArrayRef operands) { assert(operands.size() == 1 && "unary operation takes one operand"); - // trunci(zexti(a)) -> a - // trunci(sexti(a)) -> a if (matchPattern(getOperand(), m_Op()) || - matchPattern(getOperand(), m_Op())) + matchPattern(getOperand(), m_Op())) { + Type srcType; + Type dstType = getElementTypeOrSelf(getType()); + if (arith::ExtUIOp zextOp = getIn().getDefiningOp()) + srcType = getElementTypeOrSelf(zextOp.getIn().getType()); + else + srcType = getElementTypeOrSelf( + getIn().getDefiningOp().getIn().getType()); + + // trunci(zexti(a)) -> trunci(a) + // trunci(sexti(a)) -> trunci(a) + if (srcType.cast().getWidth() > + dstType.cast().getWidth()) { + setOperand(getOperand().getDefiningOp()->getOperand(0)); + return getResult(); + } + // trunci(zexti(a)) -> a + // trunci(sexti(a)) -> a return getOperand().getDefiningOp()->getOperand(0); + } // trunci(trunci(a)) -> trunci(a)) if (matchPattern(getOperand(), m_Op())) { Index: mlir/test/Dialect/Arith/canonicalize.mlir =================================================================== --- mlir/test/Dialect/Arith/canonicalize.mlir +++ mlir/test/Dialect/Arith/canonicalize.mlir @@ -429,6 +429,16 @@ return %trunci : i32 } +// CHECK-LABEL: @truncExtui2 +// CHECK: %[[ARG0:.+]]: i32 +// CHECK: %[[CST:.*]] = arith.trunci %[[ARG0:.+]] : i32 to i16 +// CHECK: return %[[CST:.*]] +func.func @truncExtui2(%arg0: i32) -> i16 { + %extui = arith.extui %arg0 : i32 to i64 + %trunci = arith.trunci %extui : i64 to i16 + return %trunci : i16 +} + // CHECK-LABEL: @truncExtsi // CHECK-NOT: trunci // CHECK: return %arg0 @@ -438,6 +448,16 @@ return %trunci : i32 } +// CHECK-LABEL: @truncExtsi2 +// CHECK: %[[ARG0:.+]]: i32 +// CHECK: %[[CST:.*]] = arith.trunci %[[ARG0:.+]] : i32 to i16 +// CHECK: return %[[CST:.*]] +func.func @truncExtsi2(%arg0: i32) -> i16 { + %extsi = arith.extsi %arg0 : i32 to i64 + %trunci = arith.trunci %extsi : i64 to i16 + return %trunci : i16 +} + // CHECK-LABEL: @truncConstantSplat // CHECK: %[[cres:.+]] = arith.constant dense<-2> : vector<4xi8> // CHECK: return %[[cres]]