diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -41,6 +41,7 @@ #include "mlir/Target/LLVMIR/ModuleTranslation.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/TypeSwitch.h" +#include namespace fir { #define GEN_PASS_DEF_FIRTOLLVMLOWERING @@ -3512,42 +3513,87 @@ } }; -/// Inlined complex division +static mlir::LogicalResult getDivc3(fir::DivcOp op, + mlir::ConversionPatternRewriter &rewriter, + std::string funcName, mlir::Type returnType, + llvm::SmallVector argType, + llvm::SmallVector args) { + auto module = op->getParentOfType(); + auto loc = op.getLoc(); + if (mlir::LLVM::LLVMFuncOp divideFunc = + module.lookupSymbol(funcName)) { + auto call = rewriter.create( + loc, returnType, mlir::SymbolRefAttr::get(divideFunc), args); + rewriter.replaceOp(op, call->getResults()); + return mlir::success(); + } + mlir::OpBuilder moduleBuilder( + op->getParentOfType().getBodyRegion()); + auto divideFunc = moduleBuilder.create( + rewriter.getUnknownLoc(), funcName, + mlir::LLVM::LLVMFunctionType::get(returnType, argType, + /*isVarArg=*/false)); + auto call = rewriter.create( + loc, returnType, mlir::SymbolRefAttr::get(divideFunc), args); + rewriter.replaceOp(op, call->getResults()); + return mlir::success(); +} + +/// complex division struct DivcOpConversion : public FIROpConversion { using FIROpConversion::FIROpConversion; mlir::LogicalResult matchAndRewrite(fir::DivcOp divc, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { - // TODO: Can we use a call to __divdc3 instead? - // Just generate inline code for now. // given: (x + iy) / (x' + iy') // result: ((xx'+yy')/d) + i((yx'-xy')/d) where d = x'x' + y'y' mlir::Value a = adaptor.getOperands()[0]; mlir::Value b = adaptor.getOperands()[1]; auto loc = divc.getLoc(); mlir::Type eleTy = convertType(getComplexEleTy(divc.getType())); - mlir::Type ty = convertType(divc.getType()); + llvm::SmallVector argTy = {eleTy, eleTy, eleTy, eleTy}; + mlir::Type firReturnTy = divc.getType(); + mlir::Type ty = convertType(firReturnTy); auto x0 = rewriter.create(loc, a, 0); auto y0 = rewriter.create(loc, a, 1); auto x1 = rewriter.create(loc, b, 0); auto y1 = rewriter.create(loc, b, 1); - auto xx = rewriter.create(loc, eleTy, x0, x1); - auto x1x1 = rewriter.create(loc, eleTy, x1, x1); - auto yx = rewriter.create(loc, eleTy, y0, x1); - auto xy = rewriter.create(loc, eleTy, x0, y1); - auto yy = rewriter.create(loc, eleTy, y0, y1); - auto y1y1 = rewriter.create(loc, eleTy, y1, y1); - auto d = rewriter.create(loc, eleTy, x1x1, y1y1); - auto rrn = rewriter.create(loc, eleTy, xx, yy); - auto rin = rewriter.create(loc, eleTy, yx, xy); - auto rr = rewriter.create(loc, eleTy, rrn, d); - auto ri = rewriter.create(loc, eleTy, rin, d); - auto ra = rewriter.create(loc, ty); - auto r1 = rewriter.create(loc, ra, rr, 0); - auto r0 = rewriter.create(loc, r1, ri, 1); - rewriter.replaceOp(divc, r0.getResult()); - return mlir::success(); + + fir::KindTy kind = (firReturnTy.dyn_cast()).getFKind(); + mlir::SmallVector args = {x0, y0, x1, y1}; + switch (kind) { + default: + llvm_unreachable("Unsupported complex type"); + case 4: + return getDivc3(divc, rewriter, "__divsc3", ty, argTy, args); + case 8: + return getDivc3(divc, rewriter, "__divdc3", ty, argTy, args); + case 10: + return getDivc3(divc, rewriter, "__divxc3", ty, argTy, args); + case 16: + return getDivc3(divc, rewriter, "__divtc3", ty, argTy, args); + case 3: + case 2: + // No library function for bfloat or half in compiler_rt, generate + // inline instead + auto xx = rewriter.create(loc, eleTy, x0, x1); + auto x1x1 = rewriter.create(loc, eleTy, x1, x1); + auto yx = rewriter.create(loc, eleTy, y0, x1); + auto xy = rewriter.create(loc, eleTy, x0, y1); + auto yy = rewriter.create(loc, eleTy, y0, y1); + auto y1y1 = rewriter.create(loc, eleTy, y1, y1); + auto d = rewriter.create(loc, eleTy, x1x1, y1y1); + auto rrn = rewriter.create(loc, eleTy, xx, yy); + auto rin = rewriter.create(loc, eleTy, yx, xy); + auto rr = rewriter.create(loc, eleTy, rrn, d); + auto ri = rewriter.create(loc, eleTy, rin, d); + auto ra = rewriter.create(loc, ty); + auto r1 = rewriter.create(loc, ra, rr, 0); + auto r0 = rewriter.create(loc, r1, ri, 1); + rewriter.replaceOp(divc, r0.getResult()); + return mlir::success(); + } } }; diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir --- a/flang/test/Fir/convert-to-llvm.fir +++ b/flang/test/Fir/convert-to-llvm.fir @@ -586,22 +586,42 @@ // CHECK: %[[Y0:.*]] = llvm.extractvalue %[[ARG0]][1] : !llvm.struct<(f128, f128)> // CHECK: %[[X1:.*]] = llvm.extractvalue %[[ARG1]][0] : !llvm.struct<(f128, f128)> // CHECK: %[[Y1:.*]] = llvm.extractvalue %[[ARG1]][1] : !llvm.struct<(f128, f128)> -// CHECK: %[[MUL_X0_X1:.*]] = llvm.fmul %[[X0]], %[[X1]] : f128 -// CHECK: %[[MUL_X1_X1:.*]] = llvm.fmul %[[X1]], %[[X1]] : f128 -// CHECK: %[[MUL_Y0_X1:.*]] = llvm.fmul %[[Y0]], %[[X1]] : f128 -// CHECK: %[[MUL_X0_Y1:.*]] = llvm.fmul %[[X0]], %[[Y1]] : f128 -// CHECK: %[[MUL_Y0_Y1:.*]] = llvm.fmul %[[Y0]], %[[Y1]] : f128 -// CHECK: %[[MUL_Y1_Y1:.*]] = llvm.fmul %[[Y1]], %[[Y1]] : f128 -// CHECK: %[[ADD_X1X1_Y1Y1:.*]] = llvm.fadd %[[MUL_X1_X1]], %[[MUL_Y1_Y1]] : f128 -// CHECK: %[[ADD_X0X1_Y0Y1:.*]] = llvm.fadd %[[MUL_X0_X1]], %[[MUL_Y0_Y1]] : f128 -// CHECK: %[[SUB_Y0X1_X0Y1:.*]] = llvm.fsub %[[MUL_Y0_X1]], %[[MUL_X0_Y1]] : f128 -// CHECK: %[[DIV0:.*]] = llvm.fdiv %[[ADD_X0X1_Y0Y1]], %[[ADD_X1X1_Y1Y1]] : f128 -// CHECK: %[[DIV1:.*]] = llvm.fdiv %[[SUB_Y0X1_X0Y1]], %[[ADD_X1X1_Y1Y1]] : f128 -// CHECK: %{{.*}} = llvm.mlir.undef : !llvm.struct<(f128, f128)> -// CHECK: %{{.*}} = llvm.insertvalue %[[DIV0]], %{{.*}}[0] : !llvm.struct<(f128, f128)> -// CHECK: %{{.*}} = llvm.insertvalue %[[DIV1]], %{{.*}}[1] : !llvm.struct<(f128, f128)> +// CHECK: %[[CALL:.*]] = llvm.call @__divtc3(%[[X0]], %[[Y0]], %[[X1]], %[[Y1]]) : (f128, f128, f128, f128) -> !llvm.struct<(f128, f128)> // CHECK: llvm.return %{{.*}} : !llvm.struct<(f128, f128)> +// ----- + +// Test FIR complex division inlines for KIND=3 + +func.func @fir_complex_div(%a: !fir.complex<3>, %b: !fir.complex<3>) -> !fir.complex<3> { + %c = fir.divc %a, %b : !fir.complex<3> + return %c : !fir.complex<3> +} + +// CHECK-LABEL: llvm.func @fir_complex_div( +// CHECK-SAME: %[[ARG0:.*]]: !llvm.struct<(bf16, bf16)>, +// CHECK-SAME: %[[ARG1:.*]]: !llvm.struct<(bf16, bf16)>) -> !llvm.struct<(bf16, bf16)> { +// CHECK: %[[X0:.*]] = llvm.extractvalue %[[ARG0]][0] : !llvm.struct<(bf16, bf16)> +// CHECK: %[[Y0:.*]] = llvm.extractvalue %[[ARG0]][1] : !llvm.struct<(bf16, bf16)> +// CHECK: %[[X1:.*]] = llvm.extractvalue %[[ARG1]][0] : !llvm.struct<(bf16, bf16)> +// CHECK: %[[Y1:.*]] = llvm.extractvalue %[[ARG1]][1] : !llvm.struct<(bf16, bf16)> +// CHECK: %[[MUL_X0_X1:.*]] = llvm.fmul %[[X0]], %[[X1]] : bf16 +// CHECK: %[[MUL_X1_X1:.*]] = llvm.fmul %[[X1]], %[[X1]] : bf16 +// CHECK: %[[MUL_Y0_X1:.*]] = llvm.fmul %[[Y0]], %[[X1]] : bf16 +// CHECK: %[[MUL_X0_Y1:.*]] = llvm.fmul %[[X0]], %[[Y1]] : bf16 +// CHECK: %[[MUL_Y0_Y1:.*]] = llvm.fmul %[[Y0]], %[[Y1]] : bf16 +// CHECK: %[[MUL_Y1_Y1:.*]] = llvm.fmul %[[Y1]], %[[Y1]] : bf16 +// CHECK: %[[ADD_X1X1_Y1Y1:.*]] = llvm.fadd %[[MUL_X1_X1]], %[[MUL_Y1_Y1]] : bf16 +// CHECK: %[[ADD_X0X1_Y0Y1:.*]] = llvm.fadd %[[MUL_X0_X1]], %[[MUL_Y0_Y1]] : bf16 +// CHECK: %[[SUB_Y0X1_X0Y1:.*]] = llvm.fsub %[[MUL_Y0_X1]], %[[MUL_X0_Y1]] : bf16 +// CHECK: %[[DIV0:.*]] = llvm.fdiv %[[ADD_X0X1_Y0Y1]], %[[ADD_X1X1_Y1Y1]] : bf16 +// CHECK: %[[DIV1:.*]] = llvm.fdiv %[[SUB_Y0X1_X0Y1]], %[[ADD_X1X1_Y1Y1]] : bf16 +// CHECK: %{{.*}} = llvm.mlir.undef : !llvm.struct<(bf16, bf16)> +// CHECK: %{{.*}} = llvm.insertvalue %[[DIV0]], %{{.*}}[0] : !llvm.struct<(bf16, bf16)> +// CHECK: %{{.*}} = llvm.insertvalue %[[DIV1]], %{{.*}}[1] : !llvm.struct<(bf16, bf16)> +// CHECK: llvm.return %{{.*}} : !llvm.struct<(bf16, bf16)> + + // ----- // Test FIR complex negation conversion