diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp --- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp +++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp @@ -23,11 +23,15 @@ #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/Dialect/X86Vector/X86VectorDialect.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" using namespace mlir; using namespace mlir::math; @@ -279,6 +283,65 @@ } } // namespace +//----------------------------------------------------------------------------// +// Helper function/pattern to insert casts for reusing F32 bit expansion. +//----------------------------------------------------------------------------// + +template +LogicalResult insertCasts(Operation *op, PatternRewriter &rewriter) { + // Conservatively only allow where the operand and result types are exactly 1. + Type origType = op->getResultTypes().front(); + for (Type t : llvm::drop_begin(op->getResultTypes())) + if (origType != t) + return rewriter.notifyMatchFailure(op, "required all types to match"); + for (Type t : op->getOperandTypes()) + if (origType != t) + return rewriter.notifyMatchFailure(op, "required all types to match"); + + // Skip if already F32 or larger than 32 bits. + if (getElementTypeOrSelf(origType).isF32() || + getElementTypeOrSelf(origType).getIntOrFloatBitWidth() > 32) + return failure(); + + // Create F32 equivalent type. + Type newType; + if (auto shaped = origType.dyn_cast()) { + newType = shaped.clone(rewriter.getF32Type()); + } else if (origType.isa()) { + newType = rewriter.getF32Type(); + } else { + return rewriter.notifyMatchFailure(op, + "unable to find F32 equivalent type"); + } + + Location loc = op->getLoc(); + SmallVector operands; + for (auto operand : op->getOperands()) + operands.push_back(rewriter.create(loc, newType, operand)); + auto result = rewriter.create(loc, newType, operands); + rewriter.replaceOpWithNewOp(op, origType, result); + return success(); +} + +namespace { +// Pattern to cast to F32 to reuse F32 expansion as fallback for single-result +// op. +// TODO: Consider revising to avoid adding multiple casts for a subgraph that is +// all in lower precision. Currently this is only fallback support and performs +// simplistic casting. +template +struct ReuseF32Expansion : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(T op, PatternRewriter &rewriter) const final { + static_assert( + T::template hasTrait(), + "requires same operands and result types"); + return insertCasts(op, rewriter); + } +}; +} // namespace + //----------------------------------------------------------------------------// // AtanOp approximation. //----------------------------------------------------------------------------// @@ -1209,6 +1272,7 @@ patterns.add, SinAndCosApproximation, SinAndCosApproximation>( patterns.getContext()); diff --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir --- a/mlir/test/Dialect/Math/polynomial-approximation.mlir +++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir @@ -542,7 +542,9 @@ // CHECK-DAG: %[[N3:.+]] = arith.constant -0.0106783099 // CHECK-DAG: %[[N4:.+]] = arith.constant 1.00209987 // CHECK-DAG: %[[HALF_PI:.+]] = arith.constant 1.57079637 -// CHECK-DAG: %[[RATIO:.+]] = arith.divf %arg0, %arg1 +// CHECK-DAG: %[[ARG0:.+]] = arith.extf %arg0 : f16 to f32 +// CHECK-DAG: %[[ARG1:.+]] = arith.extf %arg1 : f16 to f32 +// CHECK-DAG: %[[RATIO:.+]] = arith.divf %[[ARG0]], %[[ARG1]] // CHECK-DAG: %[[ABS:.+]] = math.abs %[[RATIO]] // CHECK-DAG: %[[DIV:.+]] = arith.divf %cst, %[[ABS]] // CHECK-DAG: %[[CMP:.+]] = arith.cmpf olt, %[[ABS]], %[[DIV]] @@ -562,30 +564,31 @@ // CHECK-DAG: %[[SUB_PI:.+]] = arith.subf %[[ATAN]], %[[PI]] // CHECK-DAG: %[[CMP_ATAN:.+]] = arith.cmpf ogt, %[[ATAN]], %[[ZERO]] // CHECK-DAG: %[[ATAN_ADJUST:.+]] = arith.select %[[CMP_ATAN]], %[[SUB_PI]], %[[ADD_PI]] -// CHECK-DAG: %[[X_NEG:.+]] = arith.cmpf ogt, %arg1, %[[ZERO]] +// CHECK-DAG: %[[X_NEG:.+]] = arith.cmpf ogt, %[[ARG1]], %[[ZERO]] // CHECK-DAG: %[[ATAN_EST:.+]] = arith.select %[[X_NEG]], %[[ATAN]], %[[ATAN_ADJUST]] // Handle PI / 2 edge case: -// CHECK-DAG: %[[X_ZERO:.+]] = arith.cmpf oeq, %arg1, %[[ZERO]] -// CHECK-DAG: %[[Y_POS:.+]] = arith.cmpf ogt, %arg0, %[[ZERO]] +// CHECK-DAG: %[[X_ZERO:.+]] = arith.cmpf oeq, %[[ARG1]], %[[ZERO]] +// CHECK-DAG: %[[Y_POS:.+]] = arith.cmpf ogt, %[[ARG0]], %[[ZERO]] // CHECK-DAG: %[[IS_HALF_PI:.+]] = arith.andi %[[X_ZERO]], %[[Y_POS]] // CHECK-DAG: %[[EDGE1:.+]] = arith.select %[[IS_HALF_PI]], %[[HALF_PI]], %[[ATAN_EST]] // Handle -PI / 2 edge case: // CHECK-DAG: %[[NEG_HALF_PI:.+]] = arith.constant -1.57079637 -// CHECK-DAG: %[[Y_NEG:.+]] = arith.cmpf olt, %arg0, %[[ZERO]] +// CHECK-DAG: %[[Y_NEG:.+]] = arith.cmpf olt, %[[ARG0]], %[[ZERO]] // CHECK-DAG: %[[IS_NEG_HALF_PI:.+]] = arith.andi %[[X_ZERO]], %[[Y_NEG]] // CHECK-DAG: %[[EDGE2:.+]] = arith.select %[[IS_NEG_HALF_PI]], %[[NEG_HALF_PI]], %[[EDGE1]] // Handle Nan edgecase: -// CHECK-DAG: %[[Y_ZERO:.+]] = arith.cmpf oeq, %arg0, %[[ZERO]] +// CHECK-DAG: %[[Y_ZERO:.+]] = arith.cmpf oeq, %[[ARG0]], %[[ZERO]] // CHECK-DAG: %[[X_Y_ZERO:.+]] = arith.andi %[[X_ZERO]], %[[Y_ZERO]] // CHECK-DAG: %[[NAN:.+]] = arith.constant 0x7FC00000 // CHECK-DAG: %[[EDGE3:.+]] = arith.select %[[X_Y_ZERO]], %[[NAN]], %[[EDGE2]] -// CHECK: return %[[EDGE3]] +// CHECK: %[[RET:.+]] = arith.truncf %[[EDGE3]] +// CHECK: return %[[RET]] -func @atan2_scalar(%arg0: f32, %arg1: f32) -> f32 { - %0 = math.atan2 %arg0, %arg1 : f32 - return %0 : f32 +func @atan2_scalar(%arg0: f16, %arg1: f16) -> f16 { + %0 = math.atan2 %arg0, %arg1 : f16 + return %0 : f16 }