Index: mlir/lib/Dialect/Arith/IR/ArithOps.cpp =================================================================== --- mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -1241,11 +1241,22 @@ OpFoldResult arith::TruncIOp::fold(ArrayRef operands) { assert(operands.size() == 1 && "unary operation takes one operand"); - // trunci(zexti(a)) -> a - // trunci(sexti(a)) -> a if (matchPattern(getOperand(), m_Op()) || - matchPattern(getOperand(), m_Op())) - return getOperand().getDefiningOp()->getOperand(0); + matchPattern(getOperand(), m_Op())) { + Value src = getOperand().getDefiningOp()->getOperand(0); + Type srcType = getElementTypeOrSelf(src.getType()); + Type dstType = getElementTypeOrSelf(getType()); + // trunci(zexti(a)) -> trunci(a) + // trunci(sexti(a)) -> trunci(a) + if (srcType.cast().getWidth() > + dstType.cast().getWidth()) { + setOperand(src); + return getResult(); + } + // trunci(zexti(a)) -> a + // trunci(sexti(a)) -> a + return src; + } // trunci(trunci(a)) -> trunci(a)) if (matchPattern(getOperand(), m_Op())) { Index: mlir/test/Dialect/Arith/canonicalize.mlir =================================================================== --- mlir/test/Dialect/Arith/canonicalize.mlir +++ mlir/test/Dialect/Arith/canonicalize.mlir @@ -429,6 +429,26 @@ return %trunci : i32 } +// CHECK-LABEL: @truncExtui2 +// CHECK: %[[ARG0:.+]]: i32 +// CHECK: %[[CST:.*]] = arith.trunci %[[ARG0:.+]] : i32 to i16 +// CHECK: return %[[CST:.*]] +func.func @truncExtui2(%arg0: i32) -> i16 { + %extui = arith.extui %arg0 : i32 to i64 + %trunci = arith.trunci %extui : i64 to i16 + return %trunci : i16 +} + +// CHECK-LABEL: @truncExtuiVector +// CHECK: %[[ARG0:.+]]: vector<2xi32> +// CHECK: %[[CST:.*]] = arith.trunci %[[ARG0:.+]] : vector<2xi32> to vector<2xi16> +// CHECK: return %[[CST:.*]] +func.func @truncExtuiVector(%arg0: vector<2xi32>) -> vector<2xi16> { + %extsi = arith.extui %arg0 : vector<2xi32> to vector<2xi64> + %trunci = arith.trunci %extsi : vector<2xi64> to vector<2xi16> + return %trunci : vector<2xi16> +} + // CHECK-LABEL: @truncExtsi // CHECK-NOT: trunci // CHECK: return %arg0 @@ -438,6 +458,26 @@ return %trunci : i32 } +// CHECK-LABEL: @truncExtsi2 +// CHECK: %[[ARG0:.+]]: i32 +// CHECK: %[[CST:.*]] = arith.trunci %[[ARG0:.+]] : i32 to i16 +// CHECK: return %[[CST:.*]] +func.func @truncExtsi2(%arg0: i32) -> i16 { + %extsi = arith.extsi %arg0 : i32 to i64 + %trunci = arith.trunci %extsi : i64 to i16 + return %trunci : i16 +} + +// CHECK-LABEL: @truncExtsiVector +// CHECK: %[[ARG0:.+]]: vector<2xi32> +// CHECK: %[[CST:.*]] = arith.trunci %[[ARG0:.+]] : vector<2xi32> to vector<2xi16> +// CHECK: return %[[CST:.*]] +func.func @truncExtsiVector(%arg0: vector<2xi32>) -> vector<2xi16> { + %extsi = arith.extsi %arg0 : vector<2xi32> to vector<2xi64> + %trunci = arith.trunci %extsi : vector<2xi64> to vector<2xi16> + return %trunci : vector<2xi16> +} + // CHECK-LABEL: @truncConstantSplat // CHECK: %[[cres:.+]] = arith.constant dense<-2> : vector<4xi8> // CHECK: return %[[cres]]