diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h @@ -16,6 +16,7 @@ void populateExpandCtlzPattern(RewritePatternSet &patterns); void populateExpandTanPattern(RewritePatternSet &patterns); void populateExpandTanhPattern(RewritePatternSet &patterns); +void populateExpandFmaFPattern(RewritePatternSet &patterns); void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns); diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp --- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp +++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp @@ -66,6 +66,18 @@ return success(); } +static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) { + ImplicitLocOpBuilder b(op->getLoc(), rewriter); + Value operandA = op.getOperand(0); + Value operandB = op.getOperand(1); + Value operandC = op.getOperand(2); + Type type = op.getType(); + Value mult = b.create(type, operandA, operandB); + Value add = b.create(type, mult, operandC); + rewriter.replaceOp(op, add); + return success(); +} + static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op, PatternRewriter &rewriter) { auto operand = op.getOperand(); @@ -126,3 +138,7 @@ void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) { patterns.add(convertTanhOp); } + +void mlir::populateExpandFmaFPattern(RewritePatternSet &patterns) { + patterns.add(convertFmaFOp); +} diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir --- a/mlir/test/Dialect/Math/expand-math.mlir +++ b/mlir/test/Dialect/Math/expand-math.mlir @@ -53,3 +53,16 @@ // CHECK: return %[[WHILE]]#1 return %res : i32 } + +// ----- + +// CHECK-LABEL: func @fmaf_func +// CHECK-SAME: ([[ARG0:%.+]]: f64, [[ARG1:%.+]]: f64, [[ARG2:%.+]]: f64) -> f64 +func.func @fmaf_func(%a: f64, %b: f64, %c: f64) -> f64 { + // CHECK-NOT: [[ZERO:%.+]] = math.fma [[ARG0:%.+]], [[ARG1:%.+]], [[ARG2:%.+]] + // CHECK-NEXT: [[ZERO:%.+]] = arith.mulf [[ARG0:%.+]], [[ARG1:%.+]]: f64 + // CHECK-NEXT: [[ONE:%.+]] = arith.addf [[ZERO:%.+]], [[ARG2:%.+]] : f64 + // CHECK-NEXT: return [[ONE:%.+]] : f64 + %ret = math.fma %a, %b, %c : f64 + return %ret : f64 +} diff --git a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp --- a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp +++ b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp @@ -37,6 +37,7 @@ populateExpandCtlzPattern(patterns); populateExpandTanPattern(patterns); populateExpandTanhPattern(patterns); + populateExpandFmaFPattern(patterns); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); }