diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -1056,13 +1056,21 @@ llvm_unreachable("unknown cmpi predicate kind"); } +static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) { + auto boolAttr = BoolAttr::get(ctx, value); + ShapedType shapedType = type.dyn_cast_or_null(); + if (!shapedType) + return boolAttr; + return DenseElementsAttr::get(shapedType, boolAttr); +} + OpFoldResult arith::CmpIOp::fold(ArrayRef operands) { assert(operands.size() == 2 && "cmpi takes two operands"); // cmpi(pred, x, x) if (getLhs() == getRhs()) { auto val = applyCmpPredicateToEqualOperands(getPredicate()); - return BoolAttr::get(getContext(), val); + return getBoolAttribute(getType(), getContext(), val); } auto lhs = operands.front().dyn_cast_or_null(); diff --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir --- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir +++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir @@ -22,6 +22,32 @@ : i1, i1, i1, i1, i1, i1, i1, i1, i1, i1 } +// Test case: Folding of comparisons with equal vector operands. +// CHECK-LABEL: @cmpi_equal_vector_operands +// CHECK-DAG: %[[T:.*]] = arith.constant dense +// CHECK-DAG: %[[F:.*]] = arith.constant dense +// CHECK: return %[[T]], %[[T]], %[[T]], %[[T]], %[[T]], +// CHECK-SAME: %[[F]], %[[F]], %[[F]], %[[F]], %[[F]] +func @cmpi_equal_vector_operands(%arg0: vector<1x8xi64>) + -> (vector<1x8xi1>, vector<1x8xi1>, vector<1x8xi1>, vector<1x8xi1>, + vector<1x8xi1>, vector<1x8xi1>, vector<1x8xi1>, vector<1x8xi1>, + vector<1x8xi1>, vector<1x8xi1>) { + %0 = arith.cmpi eq, %arg0, %arg0 : vector<1x8xi64> + %1 = arith.cmpi sle, %arg0, %arg0 : vector<1x8xi64> + %2 = arith.cmpi sge, %arg0, %arg0 : vector<1x8xi64> + %3 = arith.cmpi ule, %arg0, %arg0 : vector<1x8xi64> + %4 = arith.cmpi uge, %arg0, %arg0 : vector<1x8xi64> + %5 = arith.cmpi ne, %arg0, %arg0 : vector<1x8xi64> + %6 = arith.cmpi slt, %arg0, %arg0 : vector<1x8xi64> + %7 = arith.cmpi sgt, %arg0, %arg0 : vector<1x8xi64> + %8 = arith.cmpi ult, %arg0, %arg0 : vector<1x8xi64> + %9 = arith.cmpi ugt, %arg0, %arg0 : vector<1x8xi64> + return %0, %1, %2, %3, %4, %5, %6, %7, %8, %9 + : vector<1x8xi1>, vector<1x8xi1>, vector<1x8xi1>, vector<1x8xi1>, + vector<1x8xi1>, vector<1x8xi1>, vector<1x8xi1>, vector<1x8xi1>, + vector<1x8xi1>, vector<1x8xi1> +} + // ----- // CHECK-LABEL: @indexCastOfSignExtend