diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -19,6 +19,7 @@ #include "llvm/ADT/APSInt.h" #include "llvm/ADT/SmallString.h" +#include "llvm/ADT/TypeSwitch.h" using namespace mlir; using namespace mlir::arith; @@ -1444,6 +1445,16 @@ return DenseElementsAttr::get(shapedType, boolAttr); } +static Optional getIntegerWidth(Type t) { + if (auto intType = t.dyn_cast()) { + return intType.getWidth(); + } + if (auto vectorIntType = t.dyn_cast()) { + return vectorIntType.getElementType().cast().getWidth(); + } + return llvm::None; +} + OpFoldResult arith::CmpIOp::fold(ArrayRef operands) { assert(operands.size() == 2 && "cmpi takes two operands"); @@ -1456,13 +1467,17 @@ if (matchPattern(getRhs(), m_Zero())) { if (auto extOp = getLhs().getDefiningOp()) { // extsi(%x : i1 -> iN) != 0 -> %x - if (extOp.getOperand().getType().cast().getWidth() == 1 && + Optional integerWidth = + getIntegerWidth(extOp.getOperand().getType()); + if (integerWidth && integerWidth.value() == 1 && getPredicate() == arith::CmpIPredicate::ne) return extOp.getOperand(); } if (auto extOp = getLhs().getDefiningOp()) { // extui(%x : i1 -> iN) != 0 -> %x - if (extOp.getOperand().getType().cast().getWidth() == 1 && + Optional integerWidth = + getIntegerWidth(extOp.getOperand().getType()); + if (integerWidth && integerWidth.value() == 1 && getPredicate() == arith::CmpIPredicate::ne) return extOp.getOperand(); } diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -162,7 +162,7 @@ // ----- -// CHECK-LABEL: @cmpOfExtSI +// CHECK-LABEL: @cmpOfExtSI( // CHECK-NEXT: return %arg0 func.func @cmpOfExtSI(%arg0: i1) -> i1 { %ext = arith.extsi %arg0 : i1 to i64 @@ -171,7 +171,7 @@ return %res : i1 } -// CHECK-LABEL: @cmpOfExtUI +// CHECK-LABEL: @cmpOfExtUI( // CHECK-NEXT: return %arg0 func.func @cmpOfExtUI(%arg0: i1) -> i1 { %ext = arith.extui %arg0 : i1 to i64 @@ -182,6 +182,26 @@ // ----- +// CHECK-LABEL: @cmpOfExtSIVector( +// CHECK-NEXT: return %arg0 +func.func @cmpOfExtSIVector(%arg0: vector<4xi1>) -> vector<4xi1> { + %ext = arith.extsi %arg0 : vector<4xi1> to vector<4xi64> + %c0 = arith.constant dense<0> : vector<4xi64> + %res = arith.cmpi ne, %ext, %c0 : vector<4xi64> + return %res : vector<4xi1> +} + +// CHECK-LABEL: @cmpOfExtUIVector( +// CHECK-NEXT: return %arg0 +func.func @cmpOfExtUIVector(%arg0: vector<4xi1>) -> vector<4xi1> { + %ext = arith.extui %arg0 : vector<4xi1> to vector<4xi64> + %c0 = arith.constant dense<0> : vector<4xi64> + %res = arith.cmpi ne, %ext, %c0 : vector<4xi64> + return %res : vector<4xi1> +} + +// ----- + // CHECK-LABEL: @extSIOfExtUI // CHECK: %[[res:.+]] = arith.extui %arg0 : i1 to i64 // CHECK: return %[[res]] @@ -1660,3 +1680,5 @@ %res = arith.xori %b, %c : i32 return %res : i32 } + +// -----