diff --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h --- a/mlir/include/mlir/Dialect/CommonFolders.h +++ b/mlir/include/mlir/Dialect/CommonFolders.h @@ -24,14 +24,16 @@ namespace mlir { /// Performs constant folding `calculate` with element-wise behavior on the two /// attributes in `operands` and returns the result if possible. +/// Uses `resultType` for the type of the returned attribute. template (ElementValueT, ElementValueT)>> Attribute constFoldBinaryOpConditional(ArrayRef operands, + Type resultType, const CalculationT &calculate) { assert(operands.size() == 2 && "binary op takes two operands"); - if (!operands[0] || !operands[1]) + if (!resultType || !operands[0] || !operands[1]) return {}; if (operands[0].isa() && operands[1].isa()) { @@ -45,7 +47,7 @@ if (!calRes) return {}; - return AttrElementT::get(lhs.getType(), *calRes); + return AttrElementT::get(resultType, *calRes); } if (operands[0].isa() && @@ -62,9 +64,10 @@ if (!elementResult) return {}; - return DenseElementsAttr::get(lhs.getType(), *elementResult); - } else if (operands[0].isa() && - operands[1].isa()) { + return DenseElementsAttr::get(resultType, *elementResult); + } + + if (operands[0].isa() && operands[1].isa()) { // Operands are ElementsAttr-derived; perform an element-wise fold by // expanding the values. auto lhs = operands[0].cast(); @@ -83,11 +86,53 @@ elementResults.push_back(*elementResult); } - return DenseElementsAttr::get(lhs.getType(), elementResults); + return DenseElementsAttr::get(resultType, elementResults); } return {}; } +/// Performs constant folding `calculate` with element-wise behavior on the two +/// attributes in `operands` and returns the result if possible. +/// Uses the operand element type for the element type of the returned +/// attribute. +template (ElementValueT, ElementValueT)>> +Attribute constFoldBinaryOpConditional(ArrayRef operands, + const CalculationT &calculate) { + assert(operands.size() == 2 && "binary op takes two operands"); + auto getResultType = [](Attribute attr) -> Type { + if (auto typed = attr.dyn_cast_or_null()) + return typed.getType(); + return {}; + }; + + Type lhsType = getResultType(operands[0]); + Type rhsType = getResultType(operands[1]); + if (!lhsType || !rhsType) + return {}; + if (lhsType != rhsType) + return {}; + + return constFoldBinaryOpConditional(operands, lhsType, + calculate); +} + +template > +Attribute constFoldBinaryOp(ArrayRef operands, Type resultType, + const CalculationT &calculate) { + return constFoldBinaryOpConditional( + operands, resultType, + [&](ElementValueT a, ElementValueT b) -> std::optional { + return calculate(a, b); + }); +} + template ()) + return RankedTensorType::get(tensorType.getShape(), i1Type); + if (type.isa()) + return UnrankedTensorType::get(i1Type); + if (auto vectorType = type.dyn_cast()) + return VectorType::get(vectorType.getShape(), i1Type, + vectorType.getNumScalableDims()); + return i1Type; +} + //===----------------------------------------------------------------------===// // ConstantOp //===----------------------------------------------------------------------===// @@ -276,41 +296,16 @@ // addui_extended(constant_a, constant_b) -> constant_sum, constant_carry // Let the `constFoldBinaryOp` utility attempt to fold the sum of both // operands. If that succeeds, calculate the overflow bit based on the sum - // and the first (constant) operand, `lhs`. Note that we cannot simply call - // `constFoldBinaryOp` again to calculate the overflow bit because the - // constructed attribute is of the same element type as both operands. + // and the first (constant) operand, `lhs`. if (Attribute sumAttr = constFoldBinaryOp( adaptor.getOperands(), [](APInt a, const APInt &b) { return std::move(a) + b; })) { - Attribute overflowAttr; - if (auto lhs = adaptor.getLhs().dyn_cast()) { - // Both arguments are scalars, calculate the scalar overflow value. - auto sum = sumAttr.cast(); - overflowAttr = IntegerAttr::get( - overflowTy, - calculateUnsignedOverflow(sum.getValue(), lhs.getValue())); - } else if (auto lhs = adaptor.getLhs().dyn_cast()) { - // Both arguments are splats, calculate the splat overflow value. - auto sum = sumAttr.cast(); - APInt overflow = calculateUnsignedOverflow(sum.getSplatValue(), - lhs.getSplatValue()); - overflowAttr = SplatElementsAttr::get(overflowTy, overflow); - } else if (auto lhs = adaptor.getLhs().dyn_cast()) { - // Othwerwise calculate element-wise overflow values. - auto sum = sumAttr.cast(); - const auto numElems = static_cast(sum.getNumElements()); - SmallVector overflowValues; - overflowValues.reserve(numElems); - - auto sumIt = sum.value_begin(); - auto lhsIt = lhs.value_begin(); - for (size_t i = 0, e = numElems; i != e; ++i, ++sumIt, ++lhsIt) - overflowValues.push_back(calculateUnsignedOverflow(*sumIt, *lhsIt)); - - overflowAttr = DenseElementsAttr::get(overflowTy, overflowValues); - } else { + Attribute overflowAttr = constFoldBinaryOp( + ArrayRef({sumAttr, adaptor.getLhs()}), + getI1SameShape(sumAttr.cast().getType()), + calculateUnsignedOverflow); + if (!overflowAttr) return failure(); - } results.push_back(sumAttr); results.push_back(overflowAttr); @@ -1534,23 +1529,6 @@ patterns.add(context); } -//===----------------------------------------------------------------------===// -// Helpers for compare ops -//===----------------------------------------------------------------------===// - -/// Return the type of the same shape (scalar, vector or tensor) containing i1. -static Type getI1SameShape(Type type) { - auto i1Type = IntegerType::get(type.getContext(), 1); - if (auto tensorType = type.dyn_cast()) - return RankedTensorType::get(tensorType.getShape(), i1Type); - if (type.isa()) - return UnrankedTensorType::get(i1Type); - if (auto vectorType = type.dyn_cast()) - return VectorType::get(vectorType.getShape(), i1Type, - vectorType.getNumScalableDims()); - return i1Type; -} - //===----------------------------------------------------------------------===// // CmpIOp //===----------------------------------------------------------------------===// @@ -1671,16 +1649,18 @@ llvm_unreachable("unknown cmpi predicate kind"); } - auto lhs = adaptor.getLhs().dyn_cast_or_null(); - if (!lhs) - return {}; - // We are moving constants to the right side; So if lhs is constant rhs is // guaranteed to be a constant. - auto rhs = adaptor.getRhs().cast(); + if (auto lhs = adaptor.getLhs().dyn_cast_or_null()) { + return constFoldBinaryOp( + adaptor.getOperands(), getI1SameShape(lhs.getType()), + [pred = getPredicate()](const APInt &lhs, const APInt &rhs) { + return APInt(1, + static_cast(applyCmpPredicate(pred, lhs, rhs))); + }); + } - auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); - return BoolAttr::get(getContext(), val); + return {}; } void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -322,6 +322,46 @@ return %res : i1 } +// CHECK-LABEL: @cmpIFoldEQ +// CHECK: %[[res:.+]] = arith.constant dense<[true, true, false]> : vector<3xi1> +// CHECK: return %[[res]] +func.func @cmpIFoldEQ() -> vector<3xi1> { + %lhs = arith.constant dense<[1, 2, 3]> : vector<3xi32> + %rhs = arith.constant dense<[1, 2, 4]> : vector<3xi32> + %res = arith.cmpi eq, %lhs, %rhs : vector<3xi32> + return %res : vector<3xi1> +} + +// CHECK-LABEL: @cmpIFoldNE +// CHECK: %[[res:.+]] = arith.constant dense<[false, false, true]> : vector<3xi1> +// CHECK: return %[[res]] +func.func @cmpIFoldNE() -> vector<3xi1> { + %lhs = arith.constant dense<[1, 2, 3]> : vector<3xi32> + %rhs = arith.constant dense<[1, 2, 4]> : vector<3xi32> + %res = arith.cmpi ne, %lhs, %rhs : vector<3xi32> + return %res : vector<3xi1> +} + +// CHECK-LABEL: @cmpIFoldSGE +// CHECK: %[[res:.+]] = arith.constant dense<[true, true, false]> : vector<3xi1> +// CHECK: return %[[res]] +func.func @cmpIFoldSGE() -> vector<3xi1> { + %lhs = arith.constant dense<2> : vector<3xi32> + %rhs = arith.constant dense<[1, 2, 4]> : vector<3xi32> + %res = arith.cmpi sge, %lhs, %rhs : vector<3xi32> + return %res : vector<3xi1> +} + +// CHECK-LABEL: @cmpIFoldULT +// CHECK: %[[res:.+]] = arith.constant dense : vector<3xi1> +// CHECK: return %[[res]] +func.func @cmpIFoldULT() -> vector<3xi1> { + %lhs = arith.constant dense<2> : vector<3xi32> + %rhs = arith.constant dense<1> : vector<3xi32> + %res = arith.cmpi ult, %lhs, %rhs : vector<3xi32> + return %res : vector<3xi1> +} + // ----- // CHECK-LABEL: @andOfExtSI diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -1070,13 +1070,13 @@ // CHECK: return %[[WHILE]]#0, %[[FUNC_ARG0]], %[[WHILE]]#1, %[[WHILE]]#2, %[[ZERO]] // CHECK-LABEL: @while_loop_invariant_argument_different_order -func.func @while_loop_invariant_argument_different_order() -> (tensor, tensor, tensor, tensor, tensor, tensor) { +func.func @while_loop_invariant_argument_different_order(%arg : tensor) -> (tensor, tensor, tensor, tensor, tensor, tensor) { %cst_0 = arith.constant dense<0> : tensor %cst_1 = arith.constant dense<1> : tensor %cst_42 = arith.constant dense<42> : tensor %0:6 = scf.while (%arg0 = %cst_0, %arg1 = %cst_1, %arg2 = %cst_1, %arg3 = %cst_1, %arg4 = %cst_0) : (tensor, tensor, tensor, tensor, tensor) -> (tensor, tensor, tensor, tensor, tensor, tensor) { - %1 = arith.cmpi slt, %arg0, %cst_42 : tensor + %1 = arith.cmpi slt, %arg0, %arg : tensor %2 = tensor.extract %1[] : tensor scf.condition(%2) %arg1, %arg0, %arg2, %arg0, %arg3, %arg4 : tensor, tensor, tensor, tensor, tensor, tensor } do { @@ -1087,11 +1087,11 @@ } return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : tensor, tensor, tensor, tensor, tensor, tensor } +// CHECK-SAME: (%[[ARG:.+]]: tensor) // CHECK: %[[ZERO:.*]] = arith.constant dense<0> // CHECK: %[[ONE:.*]] = arith.constant dense<1> -// CHECK: %[[CST42:.*]] = arith.constant dense<42> // CHECK: %[[WHILE:.*]]:2 = scf.while (%[[ARG1:.*]] = %[[ONE]], %[[ARG4:.*]] = %[[ZERO]]) -// CHECK: arith.cmpi slt, %[[ZERO]], %[[CST42]] +// CHECK: arith.cmpi sgt, %[[ARG]], %[[ZERO]] // CHECK: tensor.extract %{{.*}}[] // CHECK: scf.condition(%{{.*}}) %[[ARG1]], %[[ARG4]] // CHECK: } do {