Index: flang/lib/Optimizer/CodeGen/CodeGen.cpp =================================================================== --- flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -13,6 +13,7 @@ #include "flang/Optimizer/CodeGen/CodeGen.h" #include "PassDetail.h" #include "flang/ISO_Fortran_binding.h" +#include "flang/Lower/Todo.h" // remove when TODO's are done #include "flang/Optimizer/Dialect/FIRAttr.h" #include "flang/Optimizer/Dialect/FIROps.h" #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" @@ -33,6 +34,10 @@ static constexpr unsigned kAttrPointer = CFI_attribute_pointer; static constexpr unsigned kAttrAllocatable = CFI_attribute_allocatable; +inline mlir::Type getVoidPtrType(mlir::MLIRContext *context) { + return mlir::LLVM::LLVMPointerType::get(mlir::IntegerType::get(context, 8)); +} + static mlir::LLVM::ConstantOp genConstantIndex(mlir::Location loc, mlir::Type ity, mlir::ConversionPatternRewriter &rewriter, @@ -60,6 +65,9 @@ mlir::Type convertType(mlir::Type ty) const { return lowerTy().convertType(ty); } + mlir::Type voidPtrTy() const { + return getVoidPtrType(&lowerTy().getContext()); + } mlir::LLVM::ConstantOp genConstantOffset(mlir::Location loc, @@ -212,7 +220,7 @@ using FIROpConversion::FIROpConversion; mlir::LogicalResult - matchAndRewrite(fir::AbsentOp absent, OpAdaptor, + matchAndRewrite(fir::AbsentOp absent, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { mlir::Type ty = convertType(absent.getType()); mlir::Location loc = absent.getLoc(); @@ -234,6 +242,18 @@ } }; +/// Helper function for generating the LLVM IR that computes the size +/// in bytes for a derived type. +mlir::Value computeDerivedTypeSize(mlir::Location loc, mlir::Type ptrTy, + mlir::Type idxTy, + mlir::ConversionPatternRewriter &rewriter) { + auto nullPtr = rewriter.create(loc, ptrTy); + mlir::Value one = genConstantIndex(loc, idxTy, rewriter, 1); + llvm::SmallVector args{nullPtr, one}; + auto gep = rewriter.create(loc, ptrTy, args); + return rewriter.create(loc, idxTy, gep); +} + // Lower `fir.address_of` operation to `llvm.address_of` operation. struct AddrOfOpConversion : public FIROpConversion { using FIROpConversion::FIROpConversion; @@ -533,6 +553,7 @@ return success(); } }; +} // namespace static mlir::Type getComplexEleTy(mlir::Type complex) { if (auto cc = complex.dyn_cast()) @@ -540,6 +561,7 @@ return complex.cast().getElementType(); } +namespace { /// Compare complex values /// /// Per 10.1, the only comparisons available are .EQ. (oeq) and .NE. (une). @@ -758,6 +780,102 @@ gentypedesc, "fir.fir.gentypedesc codegen is not implemented yet"); } }; +} // namespace + +/// Return the LLVMFuncOp corresponding to the standard malloc call. +static mlir::LLVM::LLVMFuncOp +getMalloc(fir::AllocMemOp op, mlir::ConversionPatternRewriter &rewriter) { + auto module = op->getParentOfType(); + if (auto mallocFunc = module.lookupSymbol("malloc")) + return mallocFunc; + mlir::OpBuilder moduleBuilder( + op->getParentOfType().getBodyRegion()); + auto indexType = mlir::IntegerType::get(op.getContext(), 64); + return moduleBuilder.create( + rewriter.getUnknownLoc(), "malloc", + mlir::LLVM::LLVMFunctionType::get(getVoidPtrType(op.getContext()), + indexType, + /*isVarArg=*/false)); +} + +namespace { +/// convert to `call` to the runtime to `malloc` memory +struct AllocMemOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::AllocMemOp heap, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::Type ty = convertType(heap.getType()); + mlir::LLVM::LLVMFuncOp mallocFunc = getMalloc(heap, rewriter); + auto loc = heap.getLoc(); + auto ity = lowerTy().indexType(); + if (auto recTy = + fir::unwrapSequenceType(heap.getType()).dyn_cast()) + if (recTy.getNumLenParams() != 0) + TODO(loc, "fir.allocmem of derived type with length parameters"); + mlir::Value size = genTypeSizeInBytes(loc, ity, rewriter, ty); + for (mlir::Value opnd : adaptor.getOperands()) + size = rewriter.create( + loc, ity, size, integerCast(loc, rewriter, ity, opnd)); + heap->setAttr("callee", mlir::SymbolRefAttr::get(mallocFunc)); + auto malloc = rewriter.create( + loc, ::getVoidPtrType(heap.getContext()), size, heap->getAttrs()); + rewriter.replaceOpWithNewOp(heap, ty, + malloc.getResult(0)); + return success(); + } + + // Compute the (allocation) size of the allocmem type in bytes. + mlir::Value genTypeSizeInBytes(mlir::Location loc, mlir::Type idxTy, + mlir::ConversionPatternRewriter &rewriter, + mlir::Type llTy) const { + // Use the primitive size, if available. + auto ptrTy = llTy.dyn_cast(); + if (auto size = + mlir::LLVM::getPrimitiveTypeSizeInBits(ptrTy.getElementType())) + return genConstantIndex(loc, idxTy, rewriter, size / 8); + + // Otherwise, generate the GEP trick in LLVM IR to compute the size. + return computeDerivedTypeSize(loc, ptrTy, idxTy, rewriter); + } +}; +} // namespace + +/// obtain the free() function +static mlir::LLVM::LLVMFuncOp +getFree(fir::FreeMemOp op, mlir::ConversionPatternRewriter &rewriter) { + auto module = op->getParentOfType(); + if (auto freeFunc = module.lookupSymbol("free")) + return freeFunc; + mlir::OpBuilder moduleBuilder(module.getBodyRegion()); + auto voidType = mlir::LLVM::LLVMVoidType::get(op.getContext()); + return moduleBuilder.create( + rewriter.getUnknownLoc(), "free", + mlir::LLVM::LLVMFunctionType::get(voidType, + getVoidPtrType(op.getContext()), + /*isVarArg=*/false)); +} + +namespace { +/// lower a freemem instruction into a call to free() +struct FreeMemOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::FreeMemOp freemem, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::LLVM::LLVMFuncOp freeFunc = getFree(freemem, rewriter); + auto loc = freemem.getLoc(); + auto bitcast = rewriter.create( + freemem.getLoc(), voidPtrTy(), adaptor.getOperands()[0]); + freemem->setAttr("callee", mlir::SymbolRefAttr::get(freeFunc)); + rewriter.create( + loc, mlir::TypeRange{}, mlir::ValueRange{bitcast}, freemem->getAttrs()); + rewriter.eraseOp(freemem); + return success(); + } +}; /// Lower `fir.has_value` operation to `llvm.return` operation. struct HasValueOpConversion : public FIROpConversion { @@ -853,11 +971,12 @@ return mlir::LLVM::Linkage::External; } }; +} // namespace -void genCondBrOp(mlir::Location loc, mlir::Value cmp, mlir::Block *dest, - Optional destOps, - mlir::ConversionPatternRewriter &rewriter, - mlir::Block *newBlock) { +static void genCondBrOp(mlir::Location loc, mlir::Value cmp, mlir::Block *dest, + Optional destOps, + mlir::ConversionPatternRewriter &rewriter, + mlir::Block *newBlock) { if (destOps.hasValue()) rewriter.create(loc, cmp, dest, destOps.getValue(), newBlock, mlir::ValueRange()); @@ -866,8 +985,8 @@ } template -void genBrOp(A caseOp, mlir::Block *dest, Optional destOps, - mlir::ConversionPatternRewriter &rewriter) { +static void genBrOp(A caseOp, mlir::Block *dest, Optional destOps, + mlir::ConversionPatternRewriter &rewriter) { if (destOps.hasValue()) rewriter.replaceOpWithNewOp(caseOp, destOps.getValue(), dest); @@ -875,9 +994,10 @@ rewriter.replaceOpWithNewOp(caseOp, llvm::None, dest); } -void genCaseLadderStep(mlir::Location loc, mlir::Value cmp, mlir::Block *dest, - Optional destOps, - mlir::ConversionPatternRewriter &rewriter) { +static void genCaseLadderStep(mlir::Location loc, mlir::Value cmp, + mlir::Block *dest, + Optional destOps, + mlir::ConversionPatternRewriter &rewriter) { auto *thisBlock = rewriter.getInsertionBlock(); auto *newBlock = createBlock(rewriter, dest); rewriter.setInsertionPointToEnd(thisBlock); @@ -885,6 +1005,7 @@ rewriter.setInsertionPointToEnd(newBlock); } +namespace { /// Conversion of `fir.select_case` /// /// The `fir.select_case` operation is converted to a if-then-else ladder. @@ -969,11 +1090,12 @@ return success(); } }; +} // namespace template -void selectMatchAndRewrite(fir::LLVMTypeConverter &lowering, OP select, - typename OP::Adaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) { +static void selectMatchAndRewrite(fir::LLVMTypeConverter &lowering, OP select, + typename OP::Adaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) { unsigned conds = select.getNumConditions(); auto cases = select.getCases().getValue(); mlir::Value selector = adaptor.selector(); @@ -1018,6 +1140,7 @@ /*branchWeights=*/ArrayRef()); } +namespace { /// conversion of fir::SelectOp to an if-then-else ladder struct SelectOpConversion : public FIROpConversion { using FIROpConversion::FIROpConversion; @@ -1150,6 +1273,7 @@ return success(); } }; +} // namespace // Code shared between insert_value and extract_value Ops. struct ValueOpCommon { @@ -1212,6 +1336,7 @@ } }; +namespace { /// Extract a subobject value from an ssa-value of aggregate type struct ExtractValueOpConversion : public FIROpAndTypeConversion, @@ -1326,6 +1451,7 @@ return success(); } }; +} // namespace // // Primitive operations on Complex types @@ -1333,9 +1459,10 @@ /// Generate inline code for complex addition/subtraction template -mlir::LLVM::InsertValueOp complexSum(OPTY sumop, mlir::ValueRange opnds, - mlir::ConversionPatternRewriter &rewriter, - fir::LLVMTypeConverter &lowering) { +static mlir::LLVM::InsertValueOp +complexSum(OPTY sumop, mlir::ValueRange opnds, + mlir::ConversionPatternRewriter &rewriter, + fir::LLVMTypeConverter &lowering) { mlir::Value a = opnds[0]; mlir::Value b = opnds[1]; auto loc = sumop.getLoc(); @@ -1355,6 +1482,7 @@ return rewriter.create(loc, ty, r1, ry, c1); } +namespace { struct AddcOpConversion : public FIROpConversion { using FIROpConversion::FIROpConversion; @@ -1594,10 +1722,11 @@ return success(); } }; +} // namespace /// Construct an `llvm.extractvalue` instruction. It will return value at /// element \p x from \p tuple. -mlir::LLVM::ExtractValueOp +static mlir::LLVM::ExtractValueOp genExtractValueWithIndex(mlir::Location loc, mlir::Value tuple, mlir::Type ty, mlir::ConversionPatternRewriter &rewriter, mlir::MLIRContext *ctx, int x) { @@ -1606,6 +1735,7 @@ return rewriter.create(loc, xty, tuple, cx); } +namespace { /// Convert `!fir.boxchar_len` to `!llvm.extractvalue` for the 2nd part of the /// boxchar. struct BoxCharLenOpConversion : public FIROpConversion { @@ -1681,21 +1811,22 @@ mlir::OwningRewritePatternList pattern(context); pattern.insert< AbsentOpConversion, AddcOpConversion, AddrOfOpConversion, - AllocaOpConversion, BoxAddrOpConversion, BoxCharLenOpConversion, - BoxDimsOpConversion, BoxEleSizeOpConversion, BoxIsAllocOpConversion, - BoxIsArrayOpConversion, BoxIsPtrOpConversion, BoxRankOpConversion, - BoxTypeDescOpConversion, CallOpConversion, CmpcOpConversion, - ConvertOpConversion, DispatchOpConversion, DispatchTableOpConversion, - DTEntryOpConversion, DivcOpConversion, EmboxCharOpConversion, - ExtractValueOpConversion, HasValueOpConversion, GenTypeDescOpConversion, - GlobalLenOpConversion, GlobalOpConversion, InsertOnRangeOpConversion, - InsertValueOpConversion, IsPresentOpConversion, LoadOpConversion, - NegcOpConversion, MulcOpConversion, SelectCaseOpConversion, - SelectOpConversion, SelectRankOpConversion, SelectTypeOpConversion, - ShapeOpConversion, ShapeShiftOpConversion, ShiftOpConversion, - SliceOpConversion, StoreOpConversion, StringLitOpConversion, - SubcOpConversion, UnboxCharOpConversion, UndefOpConversion, - UnreachableOpConversion, ZeroOpConversion>(typeConverter); + AllocaOpConversion, AllocMemOpConversion, BoxAddrOpConversion, + BoxCharLenOpConversion, BoxDimsOpConversion, BoxEleSizeOpConversion, + BoxIsAllocOpConversion, BoxIsArrayOpConversion, BoxIsPtrOpConversion, + BoxRankOpConversion, BoxTypeDescOpConversion, CallOpConversion, + CmpcOpConversion, ConvertOpConversion, DispatchOpConversion, + DispatchTableOpConversion, DTEntryOpConversion, DivcOpConversion, + EmboxCharOpConversion, ExtractValueOpConversion, FreeMemOpConversion, + HasValueOpConversion, GenTypeDescOpConversion, GlobalLenOpConversion, + GlobalOpConversion, InsertOnRangeOpConversion, InsertValueOpConversion, + IsPresentOpConversion, LoadOpConversion, NegcOpConversion, + MulcOpConversion, SelectCaseOpConversion, SelectOpConversion, + SelectRankOpConversion, SelectTypeOpConversion, ShapeOpConversion, + ShapeShiftOpConversion, ShiftOpConversion, SliceOpConversion, + StoreOpConversion, StringLitOpConversion, SubcOpConversion, + UnboxCharOpConversion, UndefOpConversion, UnreachableOpConversion, + ZeroOpConversion>(typeConverter); mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern); mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter, pattern); Index: flang/test/Fir/convert-to-llvm.fir =================================================================== --- flang/test/Fir/convert-to-llvm.fir +++ flang/test/Fir/convert-to-llvm.fir @@ -165,16 +165,56 @@ // ----- -// Verify that fir.unreachable is transformed to llvm.unreachable +// Verify that fir.allocmem is transformed to a call to malloc +// and that fir.freemem is transformed to a call to free +// Single item case -// CHECK: llvm.func @test_unreachable() { -// CHECK-NEXT: llvm.unreachable -// CHECK-NEXT: } +func @test_alloc_and_freemem_one() { + %z0 = fir.allocmem i32 + fir.freemem %z0 : !fir.heap + return +} + +// CHECK-LABEL: llvm.func @test_alloc_and_freemem_one() { +// CHECK-NEXT: [[N:%.*]] = llvm.mlir.constant(4 : i64) : i64 +// CHECK-NEXT: llvm.call @malloc([[N]]) +// CHECK: llvm.call @free(%{{.*}}) +// CHECK-NEXT: llvm.return + +// ----- +// Verify that fir.allocmem is transformed to a call to malloc +// and that fir.freemem is transformed to a call to free +// Several item case + +func @test_alloc_and_freemem_several() { + %z0 = fir.allocmem !fir.array<100xf32> + fir.freemem %z0 : !fir.heap> + return +} + +// CHECK-LABEL: llvm.func @test_alloc_and_freemem_several() { +// CHECK: [[NULL:%.*]] = llvm.mlir.null : !llvm.ptr> +// CHECK: [[ONE:%.*]] = llvm.mlir.constant(1 : i64) : i64 +// CHECK: [[PTR:%.*]] = llvm.getelementptr [[NULL]][{{.*}}] : (!llvm.ptr>, i64) -> !llvm.ptr> +// CHECK: [[N:%.*]] = llvm.ptrtoint [[PTR]] : !llvm.ptr> to i64 +// CHECK: [[MALLOC:%.*]] = llvm.call @malloc([[N]]) +// CHECK: [[B1:%.*]] = llvm.bitcast [[MALLOC]] : !llvm.ptr to !llvm.ptr> +// CHECK: [[B2:%.*]] = llvm.bitcast [[B1]] : !llvm.ptr> to !llvm.ptr +// CHECK: llvm.call @free([[B2]]) +// CHECK: llvm.return + +// ----- + +// Verify that fir.unreachable is transformed to llvm.unreachable func @test_unreachable() { fir.unreachable } +// CHECK: llvm.func @test_unreachable() { +// CHECK-NEXT: llvm.unreachable +// CHECK-NEXT: } + // ----- // Test `fir.select` operation conversion pattern.