Index: flang/lib/Optimizer/CodeGen/CodeGen.cpp =================================================================== --- flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -37,6 +37,10 @@ static constexpr unsigned kAttrPointer = CFI_attribute_pointer; static constexpr unsigned kAttrAllocatable = CFI_attribute_allocatable; +static mlir::Type getVoidPtrType(mlir::MLIRContext *context) { + return mlir::LLVM::LLVMPointerType::get(mlir::IntegerType::get(context, 8)); +} + static mlir::LLVM::ConstantOp genConstantIndex(mlir::Location loc, mlir::Type ity, mlir::ConversionPatternRewriter &rewriter, @@ -64,6 +68,9 @@ mlir::Type convertType(mlir::Type ty) const { return lowerTy().convertType(ty); } + mlir::Type voidPtrTy() const { + return getVoidPtrType(&lowerTy().getContext()); + } mlir::LLVM::ConstantOp genI32Constant(mlir::Location loc, mlir::ConversionPatternRewriter &rewriter, @@ -246,7 +253,7 @@ using FIROpConversion::FIROpConversion; mlir::LogicalResult - matchAndRewrite(fir::AbsentOp absent, OpAdaptor, + matchAndRewrite(fir::AbsentOp absent, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { mlir::Type ty = convertType(absent.getType()); mlir::Location loc = absent.getLoc(); @@ -267,7 +274,21 @@ return success(); } }; +} // namespace + +/// Helper function for generating the LLVM IR that computes the size +/// in bytes for a derived type. +static mlir::Value +computeDerivedTypeSize(mlir::Location loc, mlir::Type ptrTy, mlir::Type idxTy, + mlir::ConversionPatternRewriter &rewriter) { + auto nullPtr = rewriter.create(loc, ptrTy); + mlir::Value one = genConstantIndex(loc, idxTy, rewriter, 1); + llvm::SmallVector args{nullPtr, one}; + auto gep = rewriter.create(loc, ptrTy, args); + return rewriter.create(loc, idxTy, gep); +} +namespace { // Lower `fir.address_of` operation to `llvm.address_of` operation. struct AddrOfOpConversion : public FIROpConversion { using FIROpConversion::FIROpConversion; @@ -581,6 +602,7 @@ return success(); } }; +} // namespace static mlir::Type getComplexEleTy(mlir::Type complex) { if (auto cc = complex.dyn_cast()) @@ -588,6 +610,7 @@ return complex.cast().getElementType(); } +namespace { /// Compare complex values /// /// Per 10.1, the only comparisons available are .EQ. (oeq) and .NE. (une). @@ -838,6 +861,103 @@ gentypedesc, "fir.fir.gentypedesc codegen is not implemented yet"); } }; +} // namespace + +/// Return the LLVMFuncOp corresponding to the standard malloc call. +static mlir::LLVM::LLVMFuncOp +getMalloc(fir::AllocMemOp op, mlir::ConversionPatternRewriter &rewriter) { + auto module = op->getParentOfType(); + if (auto mallocFunc = module.lookupSymbol("malloc")) + return mallocFunc; + mlir::OpBuilder moduleBuilder( + op->getParentOfType().getBodyRegion()); + auto indexType = mlir::IntegerType::get(op.getContext(), 64); + return moduleBuilder.create( + rewriter.getUnknownLoc(), "malloc", + mlir::LLVM::LLVMFunctionType::get(getVoidPtrType(op.getContext()), + indexType, + /*isVarArg=*/false)); +} + +namespace { +/// convert to `call` to the runtime to `malloc` memory +struct AllocMemOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::AllocMemOp heap, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::Type ty = convertType(heap.getType()); + mlir::LLVM::LLVMFuncOp mallocFunc = getMalloc(heap, rewriter); + auto loc = heap.getLoc(); + auto ity = lowerTy().indexType(); + if (auto recTy = + fir::unwrapSequenceType(heap.getType()).dyn_cast()) + if (recTy.getNumLenParams() != 0) + return rewriter.notifyMatchFailure( + loc, "fir.allocmem of derived type with length parameters"); + mlir::Value size = genTypeSizeInBytes(loc, ity, rewriter, ty); + for (mlir::Value opnd : adaptor.getOperands()) + size = rewriter.create( + loc, ity, size, integerCast(loc, rewriter, ity, opnd)); + heap->setAttr("callee", mlir::SymbolRefAttr::get(mallocFunc)); + auto malloc = rewriter.create( + loc, ::getVoidPtrType(heap.getContext()), size, heap->getAttrs()); + rewriter.replaceOpWithNewOp(heap, ty, + malloc.getResult(0)); + return success(); + } + + // Compute the (allocation) size of the allocmem type in bytes. + mlir::Value genTypeSizeInBytes(mlir::Location loc, mlir::Type idxTy, + mlir::ConversionPatternRewriter &rewriter, + mlir::Type llTy) const { + // Use the primitive size, if available. + auto ptrTy = llTy.dyn_cast(); + if (auto size = + mlir::LLVM::getPrimitiveTypeSizeInBits(ptrTy.getElementType())) + return genConstantIndex(loc, idxTy, rewriter, size / 8); + + // Otherwise, generate the GEP trick in LLVM IR to compute the size. + return computeDerivedTypeSize(loc, ptrTy, idxTy, rewriter); + } +}; +} // namespace + +/// obtain the free() function +static mlir::LLVM::LLVMFuncOp +getFree(fir::FreeMemOp op, mlir::ConversionPatternRewriter &rewriter) { + auto module = op->getParentOfType(); + if (auto freeFunc = module.lookupSymbol("free")) + return freeFunc; + mlir::OpBuilder moduleBuilder(module.getBodyRegion()); + auto voidType = mlir::LLVM::LLVMVoidType::get(op.getContext()); + return moduleBuilder.create( + rewriter.getUnknownLoc(), "free", + mlir::LLVM::LLVMFunctionType::get(voidType, + getVoidPtrType(op.getContext()), + /*isVarArg=*/false)); +} + +namespace { +/// lower a freemem instruction into a call to free() +struct FreeMemOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::FreeMemOp freemem, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::LLVM::LLVMFuncOp freeFunc = getFree(freemem, rewriter); + auto loc = freemem.getLoc(); + auto bitcast = rewriter.create( + freemem.getLoc(), voidPtrTy(), adaptor.getOperands()[0]); + freemem->setAttr("callee", mlir::SymbolRefAttr::get(freeFunc)); + rewriter.create( + loc, mlir::TypeRange{}, mlir::ValueRange{bitcast}, freemem->getAttrs()); + rewriter.eraseOp(freemem); + return success(); + } +}; /// Convert `fir.end` struct FirEndOpConversion : public FIROpConversion { @@ -945,11 +1065,12 @@ return mlir::LLVM::Linkage::External; } }; +} // namespace -void genCondBrOp(mlir::Location loc, mlir::Value cmp, mlir::Block *dest, - Optional destOps, - mlir::ConversionPatternRewriter &rewriter, - mlir::Block *newBlock) { +static void genCondBrOp(mlir::Location loc, mlir::Value cmp, mlir::Block *dest, + Optional destOps, + mlir::ConversionPatternRewriter &rewriter, + mlir::Block *newBlock) { if (destOps.hasValue()) rewriter.create(loc, cmp, dest, destOps.getValue(), newBlock, mlir::ValueRange()); @@ -958,8 +1079,8 @@ } template -void genBrOp(A caseOp, mlir::Block *dest, Optional destOps, - mlir::ConversionPatternRewriter &rewriter) { +static void genBrOp(A caseOp, mlir::Block *dest, Optional destOps, + mlir::ConversionPatternRewriter &rewriter) { if (destOps.hasValue()) rewriter.replaceOpWithNewOp(caseOp, destOps.getValue(), dest); @@ -967,9 +1088,10 @@ rewriter.replaceOpWithNewOp(caseOp, llvm::None, dest); } -void genCaseLadderStep(mlir::Location loc, mlir::Value cmp, mlir::Block *dest, - Optional destOps, - mlir::ConversionPatternRewriter &rewriter) { +static void genCaseLadderStep(mlir::Location loc, mlir::Value cmp, + mlir::Block *dest, + Optional destOps, + mlir::ConversionPatternRewriter &rewriter) { auto *thisBlock = rewriter.getInsertionBlock(); auto *newBlock = createBlock(rewriter, dest); rewriter.setInsertionPointToEnd(thisBlock); @@ -977,6 +1099,7 @@ rewriter.setInsertionPointToEnd(newBlock); } +namespace { /// Conversion of `fir.select_case` /// /// The `fir.select_case` operation is converted to a if-then-else ladder. @@ -1061,11 +1184,12 @@ return success(); } }; +} // namespace template -void selectMatchAndRewrite(fir::LLVMTypeConverter &lowering, OP select, - typename OP::Adaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) { +static void selectMatchAndRewrite(fir::LLVMTypeConverter &lowering, OP select, + typename OP::Adaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) { unsigned conds = select.getNumConditions(); auto cases = select.getCases().getValue(); mlir::Value selector = adaptor.selector(); @@ -1110,6 +1234,7 @@ /*branchWeights=*/ArrayRef()); } +namespace { /// conversion of fir::SelectOp to an if-then-else ladder struct SelectOpConversion : public FIROpConversion { using FIROpConversion::FIROpConversion; @@ -1256,6 +1381,7 @@ return success(); } }; +} // namespace /// Common base class for embox to descriptor conversion. template @@ -1582,7 +1708,6 @@ } }; - // Code shared between insert_value and extract_value Ops. struct ValueOpCommon { // Translate the arguments pertaining to any multidimensional array to @@ -1644,6 +1769,7 @@ } }; +namespace { /// Extract a subobject value from an ssa-value of aggregate type struct ExtractValueOpConversion : public FIROpAndTypeConversion, @@ -1758,6 +1884,7 @@ return success(); } }; +} // namespace // // Primitive operations on Complex types @@ -1765,9 +1892,10 @@ /// Generate inline code for complex addition/subtraction template -mlir::LLVM::InsertValueOp complexSum(OPTY sumop, mlir::ValueRange opnds, - mlir::ConversionPatternRewriter &rewriter, - fir::LLVMTypeConverter &lowering) { +static mlir::LLVM::InsertValueOp +complexSum(OPTY sumop, mlir::ValueRange opnds, + mlir::ConversionPatternRewriter &rewriter, + fir::LLVMTypeConverter &lowering) { mlir::Value a = opnds[0]; mlir::Value b = opnds[1]; auto loc = sumop.getLoc(); @@ -1787,6 +1915,7 @@ return rewriter.create(loc, ty, r1, ry, c1); } +namespace { struct AddcOpConversion : public FIROpConversion { using FIROpConversion::FIROpConversion; @@ -2026,10 +2155,11 @@ return success(); } }; +} // namespace /// Construct an `llvm.extractvalue` instruction. It will return value at /// element \p x from \p tuple. -mlir::LLVM::ExtractValueOp +static mlir::LLVM::ExtractValueOp genExtractValueWithIndex(mlir::Location loc, mlir::Value tuple, mlir::Type ty, mlir::ConversionPatternRewriter &rewriter, mlir::MLIRContext *ctx, int x) { @@ -2038,6 +2168,7 @@ return rewriter.create(loc, xty, tuple, cx); } +namespace { /// Convert `!fir.boxchar_len` to `!llvm.extractvalue` for the 2nd part of the /// boxchar. struct BoxCharLenOpConversion : public FIROpConversion { @@ -2169,24 +2300,25 @@ mlir::OwningRewritePatternList pattern(context); pattern.insert< AbsentOpConversion, AddcOpConversion, AddrOfOpConversion, - AllocaOpConversion, BoxAddrOpConversion, BoxCharLenOpConversion, - BoxDimsOpConversion, BoxEleSizeOpConversion, BoxIsAllocOpConversion, - BoxIsArrayOpConversion, BoxIsPtrOpConversion, BoxProcHostOpConversion, - BoxRankOpConversion, BoxTypeDescOpConversion, CallOpConversion, - CmpcOpConversion, ConstcOpConversion, ConvertOpConversion, - DispatchOpConversion, DispatchTableOpConversion, DTEntryOpConversion, - DivcOpConversion, EmboxOpConversion, EmboxCharOpConversion, - EmboxProcOpConversion, ExtractValueOpConversion, FieldIndexOpConversion, - FirEndOpConversion, HasValueOpConversion, GenTypeDescOpConversion, - GlobalLenOpConversion, GlobalOpConversion, InsertOnRangeOpConversion, - InsertValueOpConversion, IsPresentOpConversion, LoadOpConversion, - NegcOpConversion, NoReassocOpConversion, MulcOpConversion, - SelectCaseOpConversion, SelectOpConversion, SelectRankOpConversion, - SelectTypeOpConversion, ShapeOpConversion, ShapeShiftOpConversion, - ShiftOpConversion, SliceOpConversion, StoreOpConversion, - StringLitOpConversion, SubcOpConversion, UnboxCharOpConversion, - UnboxProcOpConversion, UndefOpConversion, UnreachableOpConversion, - ZeroOpConversion>(typeConverter); + AllocaOpConversion, AllocMemOpConversion, BoxAddrOpConversion, + BoxCharLenOpConversion, BoxDimsOpConversion, BoxEleSizeOpConversion, + BoxIsAllocOpConversion, BoxIsArrayOpConversion, BoxIsPtrOpConversion, + BoxProcHostOpConversion, BoxRankOpConversion, BoxTypeDescOpConversion, + CallOpConversion, CmpcOpConversion, ConstcOpConversion, + ConvertOpConversion, DispatchOpConversion, DispatchTableOpConversion, + DTEntryOpConversion, DivcOpConversion, EmboxOpConversion, + EmboxCharOpConversion, EmboxProcOpConversion, ExtractValueOpConversion, + FieldIndexOpConversion, FirEndOpConversion, FreeMemOpConversion, + HasValueOpConversion, GenTypeDescOpConversion, GlobalLenOpConversion, + GlobalOpConversion, InsertOnRangeOpConversion, InsertValueOpConversion, + IsPresentOpConversion, LoadOpConversion, NegcOpConversion, + NoReassocOpConversion, MulcOpConversion, SelectCaseOpConversion, + SelectOpConversion, SelectRankOpConversion, SelectTypeOpConversion, + ShapeOpConversion, ShapeShiftOpConversion, ShiftOpConversion, + SliceOpConversion, StoreOpConversion, StringLitOpConversion, + SubcOpConversion, UnboxCharOpConversion, UnboxProcOpConversion, + UndefOpConversion, UnreachableOpConversion, ZeroOpConversion>( + typeConverter); mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern); mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter, pattern); Index: flang/test/Fir/convert-to-llvm.fir =================================================================== --- flang/test/Fir/convert-to-llvm.fir +++ flang/test/Fir/convert-to-llvm.fir @@ -165,16 +165,56 @@ // ----- -// Verify that fir.unreachable is transformed to llvm.unreachable +// Verify that fir.allocmem is transformed to a call to malloc +// and that fir.freemem is transformed to a call to free +// Single item case -// CHECK: llvm.func @test_unreachable() { -// CHECK-NEXT: llvm.unreachable -// CHECK-NEXT: } +func @test_alloc_and_freemem_one() { + %z0 = fir.allocmem i32 + fir.freemem %z0 : !fir.heap + return +} + +// CHECK-LABEL: llvm.func @test_alloc_and_freemem_one() { +// CHECK-NEXT: [[N:%.*]] = llvm.mlir.constant(4 : i64) : i64 +// CHECK-NEXT: llvm.call @malloc([[N]]) +// CHECK: llvm.call @free(%{{.*}}) +// CHECK-NEXT: llvm.return + +// ----- +// Verify that fir.allocmem is transformed to a call to malloc +// and that fir.freemem is transformed to a call to free +// Several item case + +func @test_alloc_and_freemem_several() { + %z0 = fir.allocmem !fir.array<100xf32> + fir.freemem %z0 : !fir.heap> + return +} + +// CHECK-LABEL: llvm.func @test_alloc_and_freemem_several() { +// CHECK: [[NULL:%.*]] = llvm.mlir.null : !llvm.ptr> +// CHECK: [[ONE:%.*]] = llvm.mlir.constant(1 : i64) : i64 +// CHECK: [[PTR:%.*]] = llvm.getelementptr [[NULL]][{{.*}}] : (!llvm.ptr>, i64) -> !llvm.ptr> +// CHECK: [[N:%.*]] = llvm.ptrtoint [[PTR]] : !llvm.ptr> to i64 +// CHECK: [[MALLOC:%.*]] = llvm.call @malloc([[N]]) +// CHECK: [[B1:%.*]] = llvm.bitcast [[MALLOC]] : !llvm.ptr to !llvm.ptr> +// CHECK: [[B2:%.*]] = llvm.bitcast [[B1]] : !llvm.ptr> to !llvm.ptr +// CHECK: llvm.call @free([[B2]]) +// CHECK: llvm.return + +// ----- + +// Verify that fir.unreachable is transformed to llvm.unreachable func @test_unreachable() { fir.unreachable } +// CHECK: llvm.func @test_unreachable() { +// CHECK-NEXT: llvm.unreachable +// CHECK-NEXT: } + // ----- // Test `fir.select` operation conversion pattern.