diff --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h --- a/mlir/include/mlir/Dialect/CommonFolders.h +++ b/mlir/include/mlir/Dialect/CommonFolders.h @@ -70,6 +70,66 @@ return {}; } +template < + class AttrElementT, class ElementValueT = typename AttrElementT::ValueType, + class CalculationT = + function_ref(ElementValueT, ElementValueT)>> +Attribute constFoldBinaryOpConditional(ArrayRef operands, + const CalculationT &calculate) { + assert(operands.size() == 2 && "binary op takes two operands"); + if (!operands[0] || !operands[1]) + return {}; + if (operands[0].getType() != operands[1].getType()) + return {}; + + if (operands[0].isa() && operands[1].isa()) { + auto lhs = operands[0].cast(); + auto rhs = operands[1].cast(); + + auto calRes = calculate(lhs.getValue(), rhs.getValue()); + + if (!calRes) + return {}; + + return AttrElementT::get(lhs.getType(), *calRes); + } + + if (operands[0].isa() && + operands[1].isa()) { + // Both operands are splats so we can avoid expanding the values out and + // just fold based on the splat value. + auto lhs = operands[0].cast(); + auto rhs = operands[1].cast(); + + auto elementResult = calculate(lhs.getSplatValue(), + rhs.getSplatValue()); + if (!elementResult) + return {}; + + return DenseElementsAttr::get(lhs.getType(), *elementResult); + } else if (operands[0].isa() && + operands[1].isa()) { + // Operands are ElementsAttr-derived; perform an element-wise fold by + // expanding the values. + auto lhs = operands[0].cast(); + auto rhs = operands[1].cast(); + + auto lhsIt = lhs.value_begin(); + auto rhsIt = rhs.value_begin(); + SmallVector elementResults; + elementResults.reserve(lhs.getNumElements()); + for (size_t i = 0, e = lhs.getNumElements(); i < e; ++i, ++lhsIt, ++rhsIt) { + auto elementResult = calculate(*lhsIt, *rhsIt); + if (!elementResult) + return {}; + elementResults.push_back(*elementResult); + } + + return DenseElementsAttr::get(lhs.getType(), elementResults); + } + return {}; +} + /// Performs constant folding `calculate` with element-wise behavior on the one /// attributes in `operands` and returns the result if possible. template operands) { - auto ft = getType().dyn_cast(); - if (!ft) - return {}; - - APFloat vals[2]{APFloat(ft.getFloatSemantics()), - APFloat(ft.getFloatSemantics())}; - for (int i = 0; i < 2; ++i) { - if (!operands[i]) - return {}; - - auto attr = operands[i].dyn_cast(); - if (!attr) - return {}; - - vals[i] = attr.getValue(); - } - - if (ft.getWidth() == 64) - return FloatAttr::get( - getType(), pow(vals[0].convertToDouble(), vals[1].convertToDouble())); - - if (ft.getWidth() == 32) - return FloatAttr::get( - getType(), powf(vals[0].convertToFloat(), vals[1].convertToFloat())); - - return {}; + return constFoldBinaryOpConditional( + operands, [](const APFloat &a, const APFloat &b) -> Optional { + if (a.getSizeInBits(a.getSemantics()) == 64 && + b.getSizeInBits(b.getSemantics()) == 64) + return APFloat(pow(a.convertToDouble(), b.convertToDouble())); + + if (a.getSizeInBits(a.getSemantics()) == 32 && + b.getSizeInBits(b.getSemantics()) == 32) + return APFloat(powf(a.convertToFloat(), b.convertToFloat())); + + return {}; + }); } OpFoldResult math::SqrtOp::fold(ArrayRef operands) { diff --git a/mlir/test/Dialect/Math/canonicalize.mlir b/mlir/test/Dialect/Math/canonicalize.mlir --- a/mlir/test/Dialect/Math/canonicalize.mlir +++ b/mlir/test/Dialect/Math/canonicalize.mlir @@ -83,6 +83,16 @@ return %r : f32 } +// CHECK-LABEL: @powf_fold_vec +// CHECK: %[[cst:.+]] = arith.constant dense<[1.000000e+00, 4.000000e+00, 9.000000e+00, 1.600000e+01]> : vector<4xf32> +// CHECK: return %[[cst]] +func.func @powf_fold_vec() -> (vector<4xf32>) { + %v1 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : vector<4xf32> + %v2 = arith.constant dense<[2.0, 2.0, 2.0, 2.0]> : vector<4xf32> + %0 = math.powf %v1, %v2 : vector<4xf32> + return %0 : vector<4xf32> +} + // CHECK-LABEL: @sqrt_fold // CHECK: %[[cst:.+]] = arith.constant 2.000000e+00 : f32 // CHECK: return %[[cst]]