diff --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h --- a/mlir/include/mlir/Dialect/CommonFolders.h +++ b/mlir/include/mlir/Dialect/CommonFolders.h @@ -98,11 +98,11 @@ /// Performs constant folding `calculate` with element-wise behavior on the one /// attributes in `operands` and returns the result if possible. -template > -Attribute constFoldUnaryOp(ArrayRef operands, - const CalculationT &&calculate) { +template < + class AttrElementT, class ElementValueT = typename AttrElementT::ValueType, + class CalculationT = function_ref(ElementValueT)>> +Attribute constFoldUnaryOpConditional(ArrayRef operands, + const CalculationT &&calculate) { assert(operands.size() == 1 && "unary op takes one operands"); if (!operands[0]) return {}; @@ -110,7 +110,10 @@ if (operands[0].isa()) { auto op = operands[0].cast(); - return AttrElementT::get(op.getType(), calculate(op.getValue())); + auto res = calculate(op.getValue()); + if (!res) + return {}; + return AttrElementT::get(op.getType(), *res); } if (operands[0].isa()) { // Both operands are splats so we can avoid expanding the values out and @@ -118,7 +121,9 @@ auto op = operands[0].cast(); auto elementResult = calculate(op.getSplatValue()); - return DenseElementsAttr::get(op.getType(), elementResult); + if (!elementResult) + return {}; + return DenseElementsAttr::get(op.getType(), *elementResult); } else if (operands[0].isa()) { // Operands are ElementsAttr-derived; perform an element-wise fold by // expanding the values. @@ -127,13 +132,27 @@ auto opIt = op.value_begin(); SmallVector elementResults; elementResults.reserve(op.getNumElements()); - for (size_t i = 0, e = op.getNumElements(); i < e; ++i, ++opIt) - elementResults.push_back(calculate(*opIt)); + for (size_t i = 0, e = op.getNumElements(); i < e; ++i, ++opIt) { + auto elementResult = calculate(*opIt); + if (!elementResult) + return {}; + elementResults.push_back(*elementResult); + } return DenseElementsAttr::get(op.getType(), elementResults); } return {}; } +template > +Attribute constFoldUnaryOp(ArrayRef operands, + const CalculationT &&calculate) { + return constFoldUnaryOpConditional( + operands, + [&](ElementValueT a) -> Optional { return calculate(a); }); +} + template < class AttrElementT, class TargetAttrElementT, class ElementValueT = typename AttrElementT::ValueType, diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp --- a/mlir/lib/Dialect/Math/IR/MathOps.cpp +++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp @@ -92,28 +92,19 @@ //===----------------------------------------------------------------------===// OpFoldResult math::Log2Op::fold(ArrayRef operands) { - auto constOperand = operands.front(); - if (!constOperand) - return {}; - - auto attr = constOperand.dyn_cast(); - if (!attr) - return {}; + return constFoldUnaryOpConditional( + operands, [](const APFloat &a) -> Optional { + if (a.isNegative()) + return {}; - auto ft = getType().cast(); + if (a.getSizeInBits(a.getSemantics()) == 64) + return APFloat(log2(a.convertToDouble())); - APFloat apf = attr.getValue(); + if (a.getSizeInBits(a.getSemantics()) == 32) + return APFloat(log2f(a.convertToFloat())); - if (apf.isNegative()) - return {}; - - if (ft.getWidth() == 64) - return FloatAttr::get(getType(), log2(apf.convertToDouble())); - - if (ft.getWidth() == 32) - return FloatAttr::get(getType(), log2f(apf.convertToFloat())); - - return {}; + return {}; + }); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Math/canonicalize.mlir b/mlir/test/Dialect/Math/canonicalize.mlir --- a/mlir/test/Dialect/Math/canonicalize.mlir +++ b/mlir/test/Dialect/Math/canonicalize.mlir @@ -74,6 +74,15 @@ return %r : f64 } +// CHECK-LABEL: @log2_fold_vec +// CHECK: %[[cst:.+]] = arith.constant dense<[0.000000e+00, 1.000000e+00, 1.58496249, 2.000000e+00]> : vector<4xf32> +// CHECK: return %[[cst]] +func.func @log2_fold_vec() -> (vector<4xf32>) { + %v1 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : vector<4xf32> + %0 = math.log2 %v1 : vector<4xf32> + return %0 : vector<4xf32> +} + // CHECK-LABEL: @powf_fold // CHECK: %[[cst:.+]] = arith.constant 4.000000e+00 : f32 // CHECK: return %[[cst]] diff --git a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir --- a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir +++ b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir @@ -80,7 +80,7 @@ %1 = math.log2 %0 : f32 vector.print %1 : f32 - // CHECK: -2, -0.415037, 0, 0.321928 + // CHECK: -2, -0.415038, 0, 0.321928 %2 = arith.constant dense<[0.25, 0.75, 1.0, 1.25]> : vector<4xf32> %3 = math.log2 %2 : vector<4xf32> vector.print %3 : vector<4xf32>