diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -41,6 +41,7 @@ #include "mlir/Target/LLVMIR/ModuleTranslation.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/TypeSwitch.h" +#include namespace fir { #define GEN_PASS_DEF_FIRTOLLVMLOWERING @@ -3584,42 +3585,87 @@ } }; -/// Inlined complex division +static mlir::LogicalResult getDivc3(fir::DivcOp op, + mlir::ConversionPatternRewriter &rewriter, + std::string funcName, mlir::Type returnType, + llvm::SmallVector argType, + llvm::SmallVector args) { + auto module = op->getParentOfType(); + auto loc = op.getLoc(); + if (mlir::LLVM::LLVMFuncOp divideFunc = + module.lookupSymbol(funcName)) { + auto call = rewriter.create( + loc, returnType, mlir::SymbolRefAttr::get(divideFunc), args); + rewriter.replaceOp(op, call->getResults()); + return mlir::success(); + } + mlir::OpBuilder moduleBuilder( + op->getParentOfType().getBodyRegion()); + auto divideFunc = moduleBuilder.create( + rewriter.getUnknownLoc(), funcName, + mlir::LLVM::LLVMFunctionType::get(returnType, argType, + /*isVarArg=*/false)); + auto call = rewriter.create( + loc, returnType, mlir::SymbolRefAttr::get(divideFunc), args); + rewriter.replaceOp(op, call->getResults()); + return mlir::success(); +} + +/// complex division struct DivcOpConversion : public FIROpConversion { using FIROpConversion::FIROpConversion; mlir::LogicalResult matchAndRewrite(fir::DivcOp divc, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { - // TODO: Can we use a call to __divdc3 instead? - // Just generate inline code for now. // given: (x + iy) / (x' + iy') // result: ((xx'+yy')/d) + i((yx'-xy')/d) where d = x'x' + y'y' mlir::Value a = adaptor.getOperands()[0]; mlir::Value b = adaptor.getOperands()[1]; auto loc = divc.getLoc(); mlir::Type eleTy = convertType(getComplexEleTy(divc.getType())); - mlir::Type ty = convertType(divc.getType()); + llvm::SmallVector argTy = {eleTy, eleTy, eleTy, eleTy}; + mlir::Type firReturnTy = divc.getType(); + mlir::Type ty = convertType(firReturnTy); auto x0 = rewriter.create(loc, a, 0); auto y0 = rewriter.create(loc, a, 1); auto x1 = rewriter.create(loc, b, 0); auto y1 = rewriter.create(loc, b, 1); - auto xx = rewriter.create(loc, eleTy, x0, x1); - auto x1x1 = rewriter.create(loc, eleTy, x1, x1); - auto yx = rewriter.create(loc, eleTy, y0, x1); - auto xy = rewriter.create(loc, eleTy, x0, y1); - auto yy = rewriter.create(loc, eleTy, y0, y1); - auto y1y1 = rewriter.create(loc, eleTy, y1, y1); - auto d = rewriter.create(loc, eleTy, x1x1, y1y1); - auto rrn = rewriter.create(loc, eleTy, xx, yy); - auto rin = rewriter.create(loc, eleTy, yx, xy); - auto rr = rewriter.create(loc, eleTy, rrn, d); - auto ri = rewriter.create(loc, eleTy, rin, d); - auto ra = rewriter.create(loc, ty); - auto r1 = rewriter.create(loc, ra, rr, 0); - auto r0 = rewriter.create(loc, r1, ri, 1); - rewriter.replaceOp(divc, r0.getResult()); - return mlir::success(); + + fir::KindTy kind = (firReturnTy.dyn_cast()).getFKind(); + mlir::SmallVector args = {x0, y0, x1, y1}; + switch (kind) { + default: + llvm_unreachable("Unsupported complex type"); + case 2: + return getDivc3(divc, rewriter, "__divhc3", ty, argTy, args); + case 4: + return getDivc3(divc, rewriter, "__divsc3", ty, argTy, args); + case 8: + return getDivc3(divc, rewriter, "__divdc3", ty, argTy, args); + case 10: + return getDivc3(divc, rewriter, "__divxc3", ty, argTy, args); + case 16: + return getDivc3(divc, rewriter, "__divtc3", ty, argTy, args); + case 3: + // No library function for bfloat - use inline version + auto xx = rewriter.create(loc, eleTy, x0, x1); + auto x1x1 = rewriter.create(loc, eleTy, x1, x1); + auto yx = rewriter.create(loc, eleTy, y0, x1); + auto xy = rewriter.create(loc, eleTy, x0, y1); + auto yy = rewriter.create(loc, eleTy, y0, y1); + auto y1y1 = rewriter.create(loc, eleTy, y1, y1); + auto d = rewriter.create(loc, eleTy, x1x1, y1y1); + auto rrn = rewriter.create(loc, eleTy, xx, yy); + auto rin = rewriter.create(loc, eleTy, yx, xy); + auto rr = rewriter.create(loc, eleTy, rrn, d); + auto ri = rewriter.create(loc, eleTy, rin, d); + auto ra = rewriter.create(loc, ty); + auto r1 = rewriter.create(loc, ra, rr, 0); + auto r0 = rewriter.create(loc, r1, ri, 1); + rewriter.replaceOp(divc, r0.getResult()); + return mlir::success(); + } } };