Index: mlir/lib/Dialect/Arith/IR/ArithOps.cpp =================================================================== --- mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -1245,13 +1245,29 @@ //===----------------------------------------------------------------------===// OpFoldResult arith::TruncIOp::fold(ArrayRef operands) { - assert(operands.size() == 1 && "unary operation takes one operand"); - - // trunci(zexti(a)) -> a - // trunci(sexti(a)) -> a if (matchPattern(getOperand(), m_Op()) || - matchPattern(getOperand(), m_Op())) + matchPattern(getOperand(), m_Op())) { + // trunci(zexti(a)) -> trunci(a) + // trunci(sexti(a)) -> trunci(a) + int64_t srcBitWith; + int64_t dstBitWith; + if (arith::ExtUIOp zextOp = getIn().getDefiningOp()) { + srcBitWith = getScalarOrElementWidth(zextOp.getIn().getType()); + dstBitWith = getScalarOrElementWidth(getType()); + } else { + srcBitWith = getScalarOrElementWidth( + getIn().getDefiningOp().getIn().getType()); + dstBitWith = getScalarOrElementWidth(getType()); + } + if (srcBitWith > dstBitWith) { + setOperand(getOperand().getDefiningOp()->getOperand(0)); + return getResult(); + } + + // trunci(zexti(a)) -> a + // trunci(sexti(a)) -> a return getOperand().getDefiningOp()->getOperand(0); + } // trunci(trunci(a)) -> trunci(a)) if (matchPattern(getOperand(), m_Op())) { Index: mlir/test/Dialect/Arith/canonicalize.mlir =================================================================== --- mlir/test/Dialect/Arith/canonicalize.mlir +++ mlir/test/Dialect/Arith/canonicalize.mlir @@ -429,6 +429,16 @@ return %trunci : i32 } +// CHECK-LABEL: @truncExtuiToTrunci +// CHECK: %[[ARG0:.+]]: i32 +// CHECK: %[[CST:.*]] = arith.trunci %[[ARG0:.+]] : i32 to i16 +// CHECK: return %[[CST:.*]] +func.func @truncExtuiToTrunci(%arg0: i32) -> i16 { + %extui = arith.extui %arg0 : i32 to i64 + %trunci = arith.trunci %extui : i64 to i16 + return %trunci : i16 +} + // CHECK-LABEL: @truncExtsi // CHECK-NOT: trunci // CHECK: return %arg0 @@ -438,6 +448,16 @@ return %trunci : i32 } +// CHECK-LABEL: @truncExtsiToTrunci +// CHECK: %[[ARG0:.+]]: i32 +// CHECK: %[[CST:.*]] = arith.trunci %[[ARG0:.+]] : i32 to i16 +// CHECK: return %[[CST:.*]] +func.func @truncExtsiToTrunci(%arg0: i32) -> i16 { + %extsi = arith.extsi %arg0 : i32 to i64 + %trunci = arith.trunci %extsi : i64 to i16 + return %trunci : i16 +} + // CHECK-LABEL: @truncConstantSplat // CHECK: %[[cres:.+]] = arith.constant dense<-2> : vector<4xi8> // CHECK: return %[[cres]]