Index: mlir/lib/Dialect/Arith/IR/ArithOps.cpp =================================================================== --- mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -1243,15 +1243,35 @@ //===----------------------------------------------------------------------===// // TruncIOp //===----------------------------------------------------------------------===// +template +static std::pair getOpOfBitWith(Op1 op1, Op2 op2) { + Type op1Type = getElementTypeOrSelf(op1.getIn().getType()); + Type op2Type = getElementTypeOrSelf(op2.getType()); + unsigned op1BitWith = op1Type.cast().getWidth(); + unsigned op2BitWith = op2Type.cast().getWidth(); + return std::make_pair(op1BitWith, op2BitWith); +} OpFoldResult arith::TruncIOp::fold(ArrayRef operands) { - assert(operands.size() == 1 && "unary operation takes one operand"); - - // trunci(zexti(a)) -> a - // trunci(sexti(a)) -> a if (matchPattern(getOperand(), m_Op()) || - matchPattern(getOperand(), m_Op())) + matchPattern(getOperand(), m_Op())) { + // trunci(zexti(a)) -> trunci(a) + // trunci(sexti(a)) -> trunci(a) + std::pair bitWith; + if (arith::ExtUIOp zextOp = getIn().getDefiningOp()) { + bitWith = getOpOfBitWith(zextOp, *this); + } else { + bitWith = getOpOfBitWith( + getIn().getDefiningOp(), *this); + } + if (bitWith.first > bitWith.second) { + setOperand(getOperand().getDefiningOp()->getOperand(0)); + return getResult(); + } + // trunci(zexti(a)) -> a + // trunci(sexti(a)) -> a return getOperand().getDefiningOp()->getOperand(0); + } // trunci(trunci(a)) -> trunci(a)) if (matchPattern(getOperand(), m_Op())) { Index: mlir/test/Dialect/Arith/canonicalize.mlir =================================================================== --- mlir/test/Dialect/Arith/canonicalize.mlir +++ mlir/test/Dialect/Arith/canonicalize.mlir @@ -429,6 +429,16 @@ return %trunci : i32 } +// CHECK-LABEL: @truncExtui +// CHECK: %[[ARG0:.+]]: i32 +// CHECK: %[[CST:.*]] = arith.trunci %[[ARG0:.+]] : i32 to i16 +// CHECK: return %[[CST:.*]] +func.func @truncExtuiToTrunci(%arg0: i32) -> i16 { + %extui = arith.extui %arg0 : i32 to i64 + %trunci = arith.trunci %extui : i64 to i16 + return %trunci : i16 +} + // CHECK-LABEL: @truncExtsi // CHECK-NOT: trunci // CHECK: return %arg0 @@ -438,6 +448,16 @@ return %trunci : i32 } +// CHECK-LABEL: @truncExtsiToTrunc +// CHECK: %[[ARG0:.+]]: i32 +// CHECK: %[[CST:.*]] = arith.trunci %[[ARG0:.+]] : i32 to i16 +// CHECK: return %[[CST:.*]] +func.func @truncExtsiToTrunc(%arg0: i32) -> i16 { + %extsi = arith.extsi %arg0 : i32 to i64 + %trunci = arith.trunci %extsi : i64 to i16 + return %trunci : i16 +} + // CHECK-LABEL: @truncConstantSplat // CHECK: %[[cres:.+]] = arith.constant dense<-2> : vector<4xi8> // CHECK: return %[[cres]]