diff --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h --- a/mlir/include/mlir/Dialect/CommonFolders.h +++ b/mlir/include/mlir/Dialect/CommonFolders.h @@ -23,12 +23,12 @@ namespace mlir { /// Performs constant folding `calculate` with element-wise behavior on the two /// attributes in `operands` and returns the result if possible. -template > -Attribute constFoldBinaryOp(ArrayRef operands, - const CalculationT &calculate) { +template < + class AttrElementT, class ElementValueT = typename AttrElementT::ValueType, + class CalculationT = + function_ref(ElementValueT, ElementValueT)>> +Attribute constFoldBinaryOpConditional(ArrayRef operands, + const CalculationT &calculate) { assert(operands.size() == 2 && "binary op takes two operands"); if (!operands[0] || !operands[1]) return {}; @@ -39,9 +39,14 @@ auto lhs = operands[0].cast(); auto rhs = operands[1].cast(); - return AttrElementT::get(lhs.getType(), - calculate(lhs.getValue(), rhs.getValue())); + auto calRes = calculate(lhs.getValue(), rhs.getValue()); + + if (!calRes) + return {}; + + return AttrElementT::get(lhs.getType(), *calRes); } + if (operands[0].isa() && operands[1].isa()) { // Both operands are splats so we can avoid expanding the values out and @@ -51,7 +56,10 @@ auto elementResult = calculate(lhs.getSplatValue(), rhs.getSplatValue()); - return DenseElementsAttr::get(lhs.getType(), elementResult); + if (!elementResult) + return {}; + + return DenseElementsAttr::get(lhs.getType(), *elementResult); } else if (operands[0].isa() && operands[1].isa()) { // Operands are ElementsAttr-derived; perform an element-wise fold by @@ -63,13 +71,31 @@ auto rhsIt = rhs.value_begin(); SmallVector elementResults; elementResults.reserve(lhs.getNumElements()); - for (size_t i = 0, e = lhs.getNumElements(); i < e; ++i, ++lhsIt, ++rhsIt) - elementResults.push_back(calculate(*lhsIt, *rhsIt)); + for (size_t i = 0, e = lhs.getNumElements(); i < e; ++i, ++lhsIt, ++rhsIt) { + auto elementResult = calculate(*lhsIt, *rhsIt); + if (!elementResult) + return {}; + elementResults.push_back(*elementResult); + } + return DenseElementsAttr::get(lhs.getType(), elementResults); } return {}; } +template > +Attribute constFoldBinaryOp(ArrayRef operands, + const CalculationT &calculate) { + return constFoldBinaryOpConditional( + operands, + [&](ElementValueT a, ElementValueT b) -> Optional { + return calculate(a, b); + }); +} + /// Performs constant folding `calculate` with element-wise behavior on the one /// attributes in `operands` and returns the result if possible. template operands) { - auto ft = getType().dyn_cast(); - if (!ft) - return {}; - - APFloat vals[2]{APFloat(ft.getFloatSemantics()), - APFloat(ft.getFloatSemantics())}; - for (int i = 0; i < 2; ++i) { - if (!operands[i]) - return {}; - - auto attr = operands[i].dyn_cast(); - if (!attr) - return {}; - - vals[i] = attr.getValue(); - } - - if (ft.getWidth() == 64) - return FloatAttr::get( - getType(), pow(vals[0].convertToDouble(), vals[1].convertToDouble())); - - if (ft.getWidth() == 32) - return FloatAttr::get( - getType(), powf(vals[0].convertToFloat(), vals[1].convertToFloat())); - - return {}; + return constFoldBinaryOpConditional( + operands, [](const APFloat &a, const APFloat &b) -> Optional { + if (a.getSizeInBits(a.getSemantics()) == 64 && + b.getSizeInBits(b.getSemantics()) == 64) + return APFloat(pow(a.convertToDouble(), b.convertToDouble())); + + if (a.getSizeInBits(a.getSemantics()) == 32 && + b.getSizeInBits(b.getSemantics()) == 32) + return APFloat(powf(a.convertToFloat(), b.convertToFloat())); + + return {}; + }); } OpFoldResult math::SqrtOp::fold(ArrayRef operands) { diff --git a/mlir/test/Dialect/Math/canonicalize.mlir b/mlir/test/Dialect/Math/canonicalize.mlir --- a/mlir/test/Dialect/Math/canonicalize.mlir +++ b/mlir/test/Dialect/Math/canonicalize.mlir @@ -83,6 +83,16 @@ return %r : f32 } +// CHECK-LABEL: @powf_fold_vec +// CHECK: %[[cst:.+]] = arith.constant dense<[1.000000e+00, 4.000000e+00, 9.000000e+00, 1.600000e+01]> : vector<4xf32> +// CHECK: return %[[cst]] +func.func @powf_fold_vec() -> (vector<4xf32>) { + %v1 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : vector<4xf32> + %v2 = arith.constant dense<[2.0, 2.0, 2.0, 2.0]> : vector<4xf32> + %0 = math.powf %v1, %v2 : vector<4xf32> + return %0 : vector<4xf32> +} + // CHECK-LABEL: @sqrt_fold // CHECK: %[[cst:.+]] = arith.constant 2.000000e+00 : f32 // CHECK: return %[[cst]]