diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -1056,13 +1056,21 @@ llvm_unreachable("unknown cmpi predicate kind"); } +static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) { + auto boolAttr = BoolAttr::get(ctx, value); + ShapedType shapedType = type.dyn_cast_or_null(); + if (!shapedType) + return boolAttr; + return DenseElementsAttr::get(shapedType, boolAttr); +} + OpFoldResult arith::CmpIOp::fold(ArrayRef operands) { assert(operands.size() == 2 && "cmpi takes two operands"); // cmpi(pred, x, x) if (getLhs() == getRhs()) { auto val = applyCmpPredicateToEqualOperands(getPredicate()); - return BoolAttr::get(getContext(), val); + return getBoolAttribute(getType(), getContext(), val); } auto lhs = operands.front().dyn_cast_or_null(); diff --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir --- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir +++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir @@ -4,10 +4,11 @@ // CHECK-LABEL: @cmpi_equal_operands // CHECK-DAG: %[[T:.*]] = arith.constant true // CHECK-DAG: %[[F:.*]] = arith.constant false +// CHECK-DAG: %[[F2:.*]] = arith.constant dense // CHECK: return %[[T]], %[[T]], %[[T]], %[[T]], %[[T]], -// CHECK-SAME: %[[F]], %[[F]], %[[F]], %[[F]], %[[F]] -func @cmpi_equal_operands(%arg0: i64) - -> (i1, i1, i1, i1, i1, i1, i1, i1, i1, i1) { +// CHECK-SAME: %[[F]], %[[F]], %[[F]], %[[F]], %[[F2]] +func @cmpi_equal_operands(%arg0: i64, %arg1: vector<1x8xi64>) + -> (i1, i1, i1, i1, i1, i1, i1, i1, i1, vector<1x8xi1>) { %0 = arith.cmpi eq, %arg0, %arg0 : i64 %1 = arith.cmpi sle, %arg0, %arg0 : i64 %2 = arith.cmpi sge, %arg0, %arg0 : i64 @@ -17,9 +18,9 @@ %6 = arith.cmpi slt, %arg0, %arg0 : i64 %7 = arith.cmpi sgt, %arg0, %arg0 : i64 %8 = arith.cmpi ult, %arg0, %arg0 : i64 - %9 = arith.cmpi ugt, %arg0, %arg0 : i64 + %9 = arith.cmpi ugt, %arg1, %arg1 : vector<1x8xi64> return %0, %1, %2, %3, %4, %5, %6, %7, %8, %9 - : i1, i1, i1, i1, i1, i1, i1, i1, i1, i1 + : i1, i1, i1, i1, i1, i1, i1, i1, i1, vector<1x8xi1> } // -----