Index: flang/lib/Optimizer/CodeGen/CodeGen.cpp =================================================================== --- flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -13,6 +13,7 @@ #include "flang/Optimizer/CodeGen/CodeGen.h" #include "PassDetail.h" #include "flang/ISO_Fortran_binding.h" +#include "flang/Lower/Todo.h" // remove when TODO's are done #include "flang/Optimizer/Dialect/FIRAttr.h" #include "flang/Optimizer/Dialect/FIROps.h" #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" @@ -33,6 +34,10 @@ static constexpr unsigned kAttrPointer = CFI_attribute_pointer; static constexpr unsigned kAttrAllocatable = CFI_attribute_allocatable; +inline mlir::Type getVoidPtrType(mlir::MLIRContext *context) { + return mlir::LLVM::LLVMPointerType::get(mlir::IntegerType::get(context, 8)); +} + static mlir::LLVM::ConstantOp genConstantIndex(mlir::Location loc, mlir::Type ity, mlir::ConversionPatternRewriter &rewriter, @@ -60,6 +65,30 @@ mlir::Type convertType(mlir::Type ty) const { return lowerTy().convertType(ty); } + mlir::Type voidPtrTy() const { + return getVoidPtrType(&lowerTy().getContext()); + } + + /// Perform an extension or truncation as needed on an integer value. Lowering + /// to the specific target may involve some sign-extending or truncation of + /// values, particularly to fit them from abstract box types to the + /// appropriate reified structures. + mlir::Value integerCast(mlir::Location loc, + mlir::ConversionPatternRewriter &rewriter, + mlir::Type ty, mlir::Value val) const { + auto valTy = val.getType(); + // If the value was not yet lowered, lower its type so that it can + // be used in getPrimitiveTypeSizeInBits. + if (!valTy.isa()) + valTy = convertType(valTy); + auto toSize = mlir::LLVM::getPrimitiveTypeSizeInBits(ty); + auto fromSize = mlir::LLVM::getPrimitiveTypeSizeInBits(valTy); + if (toSize < fromSize) + return rewriter.create(loc, ty, val); + if (toSize > fromSize) + return rewriter.create(loc, ty, val); + return val; + } mlir::LLVM::ConstantOp genConstantOffset(mlir::Location loc, @@ -161,27 +190,6 @@ return rewriter.create(loc, ty, base, cv); } - /// Perform an extension or truncation as needed on an integer value. Lowering - /// to the specific target may involve some sign-extending or truncation of - /// values, particularly to fit them from abstract box types to the - /// appropriate reified structures. - mlir::Value integerCast(mlir::Location loc, - mlir::ConversionPatternRewriter &rewriter, - mlir::Type ty, mlir::Value val) const { - auto valTy = val.getType(); - // If the value was not yet lowered, lower its type so that it can - // be used in getPrimitiveTypeSizeInBits. - if (!valTy.isa()) - valTy = convertType(valTy); - auto toSize = mlir::LLVM::getPrimitiveTypeSizeInBits(ty); - auto fromSize = mlir::LLVM::getPrimitiveTypeSizeInBits(valTy); - if (toSize < fromSize) - return rewriter.create(loc, ty, val); - if (toSize > fromSize) - return rewriter.create(loc, ty, val); - return val; - } - fir::LLVMTypeConverter &lowerTy() const { return *static_cast(this->getTypeConverter()); } @@ -212,7 +220,7 @@ using FIROpConversion::FIROpConversion; mlir::LogicalResult - matchAndRewrite(fir::AbsentOp absent, OpAdaptor, + matchAndRewrite(fir::AbsentOp absent, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { mlir::Type ty = convertType(absent.getType()); mlir::Location loc = absent.getLoc(); @@ -234,6 +242,20 @@ } }; +namespace { +/// Helper function for generating the LLVM IR that computes the size +/// in bytes for a derived type. +mlir::Value computeDerivedTypeSize(mlir::Location loc, mlir::Type ptrTy, + mlir::Type idxTy, + mlir::ConversionPatternRewriter &rewriter) { + auto nullPtr = rewriter.create(loc, ptrTy); + mlir::Value one = genConstantIndex(loc, idxTy, rewriter, 1); + llvm::SmallVector args{nullPtr, one}; + auto gep = rewriter.create(loc, ptrTy, args); + return rewriter.create(loc, idxTy, gep); +} +} // namespace + // Lower `fir.address_of` operation to `llvm.address_of` operation. struct AddrOfOpConversion : public FIROpConversion { using FIROpConversion::FIROpConversion; @@ -759,6 +781,102 @@ } }; +/// Return the LLVMFuncOp corresponding to the standard malloc call. +static mlir::LLVM::LLVMFuncOp +getMalloc(fir::AllocMemOp op, mlir::ConversionPatternRewriter &rewriter) { + auto module = op->getParentOfType(); + if (auto mallocFunc = module.lookupSymbol("malloc")) + return mallocFunc; + mlir::OpBuilder moduleBuilder( + op->getParentOfType().getBodyRegion()); + auto indexType = mlir::IntegerType::get(op.getContext(), 64); + return moduleBuilder.create( + rewriter.getUnknownLoc(), "malloc", + mlir::LLVM::LLVMFunctionType::get(getVoidPtrType(op.getContext()), + indexType, + /*isVarArg=*/false)); +} + +namespace { +/// convert to `call` to the runtime to `malloc` memory +struct AllocMemOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::AllocMemOp heap, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + auto ty = convertType(heap.getType()); + auto mallocFunc = getMalloc(heap, rewriter); + auto loc = heap.getLoc(); + auto ity = lowerTy().indexType(); + if (auto recTy = + fir::unwrapSequenceType(heap.getType()).dyn_cast()) + if (recTy.getNumLenParams() != 0) + TODO(loc, "fir.allocmem of derived type with length parameters"); + auto size = genTypeSizeInBytes(loc, ity, rewriter, ty); + for (auto opnd : adaptor.getOperands()) + size = rewriter.create( + loc, ity, size, integerCast(loc, rewriter, ity, opnd)); + heap->setAttr("callee", mlir::SymbolRefAttr::get(mallocFunc)); + auto malloc = rewriter.create( + loc, ::getVoidPtrType(heap.getContext()), size, heap->getAttrs()); + rewriter.replaceOpWithNewOp(heap, ty, + malloc.getResult(0)); + return success(); + } + + // Compute the (allocation) size of the allocmem type in bytes. + mlir::Value genTypeSizeInBytes(mlir::Location loc, mlir::Type idxTy, + mlir::ConversionPatternRewriter &rewriter, + mlir::Type llTy) const { + // Use the primitive size, if available. + auto ptrTy = llTy.dyn_cast(); + if (auto size = + mlir::LLVM::getPrimitiveTypeSizeInBits(ptrTy.getElementType())) + return genConstantIndex(loc, idxTy, rewriter, size / 8); + + // Otherwise, generate the GEP trick in LLVM IR to compute the size. + return computeDerivedTypeSize(loc, ptrTy, idxTy, rewriter); + } +}; +} // namespace + +/// obtain the free() function +static mlir::LLVM::LLVMFuncOp +getFree(fir::FreeMemOp op, mlir::ConversionPatternRewriter &rewriter) { + auto module = op->getParentOfType(); + if (auto freeFunc = module.lookupSymbol("free")) + return freeFunc; + mlir::OpBuilder moduleBuilder(module.getBodyRegion()); + auto voidType = mlir::LLVM::LLVMVoidType::get(op.getContext()); + return moduleBuilder.create( + rewriter.getUnknownLoc(), "free", + mlir::LLVM::LLVMFunctionType::get(voidType, + getVoidPtrType(op.getContext()), + /*isVarArg=*/false)); +} + +namespace { +/// lower a freemem instruction into a call to free() +struct FreeMemOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::FreeMemOp freemem, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + auto freeFunc = getFree(freemem, rewriter); + auto loc = freemem.getLoc(); + auto bitcast = rewriter.create( + freemem.getLoc(), voidPtrTy(), adaptor.getOperands()[0]); + freemem->setAttr("callee", mlir::SymbolRefAttr::get(freeFunc)); + rewriter.create( + loc, mlir::TypeRange{}, mlir::ValueRange{bitcast}, freemem->getAttrs()); + rewriter.eraseOp(freemem); + return success(); + } +}; +} // namespace + /// Lower `fir.has_value` operation to `llvm.return` operation. struct HasValueOpConversion : public FIROpConversion { using FIROpConversion::FIROpConversion; @@ -1681,21 +1799,22 @@ mlir::OwningRewritePatternList pattern(context); pattern.insert< AbsentOpConversion, AddcOpConversion, AddrOfOpConversion, - AllocaOpConversion, BoxAddrOpConversion, BoxCharLenOpConversion, - BoxDimsOpConversion, BoxEleSizeOpConversion, BoxIsAllocOpConversion, - BoxIsArrayOpConversion, BoxIsPtrOpConversion, BoxRankOpConversion, - BoxTypeDescOpConversion, CallOpConversion, CmpcOpConversion, - ConvertOpConversion, DispatchOpConversion, DispatchTableOpConversion, - DTEntryOpConversion, DivcOpConversion, EmboxCharOpConversion, - ExtractValueOpConversion, HasValueOpConversion, GenTypeDescOpConversion, - GlobalLenOpConversion, GlobalOpConversion, InsertOnRangeOpConversion, - InsertValueOpConversion, IsPresentOpConversion, LoadOpConversion, - NegcOpConversion, MulcOpConversion, SelectCaseOpConversion, - SelectOpConversion, SelectRankOpConversion, SelectTypeOpConversion, - ShapeOpConversion, ShapeShiftOpConversion, ShiftOpConversion, - SliceOpConversion, StoreOpConversion, StringLitOpConversion, - SubcOpConversion, UnboxCharOpConversion, UndefOpConversion, - UnreachableOpConversion, ZeroOpConversion>(typeConverter); + AllocaOpConversion, AllocMemOpConversion, BoxAddrOpConversion, + BoxCharLenOpConversion, BoxDimsOpConversion, BoxEleSizeOpConversion, + BoxIsAllocOpConversion, BoxIsArrayOpConversion, BoxIsPtrOpConversion, + BoxRankOpConversion, BoxTypeDescOpConversion, CallOpConversion, + CmpcOpConversion, ConvertOpConversion, DispatchOpConversion, + DispatchTableOpConversion, DTEntryOpConversion, DivcOpConversion, + EmboxCharOpConversion, ExtractValueOpConversion, FreeMemOpConversion, + HasValueOpConversion, GenTypeDescOpConversion, GlobalLenOpConversion, + GlobalOpConversion, InsertOnRangeOpConversion, InsertValueOpConversion, + IsPresentOpConversion, LoadOpConversion, NegcOpConversion, + MulcOpConversion, SelectCaseOpConversion, SelectOpConversion, + SelectRankOpConversion, SelectTypeOpConversion, ShapeOpConversion, + ShapeShiftOpConversion, ShiftOpConversion, SliceOpConversion, + StoreOpConversion, StringLitOpConversion, SubcOpConversion, + UnboxCharOpConversion, UndefOpConversion, UnreachableOpConversion, + ZeroOpConversion>(typeConverter); mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern); mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter, pattern); Index: flang/test/Fir/convert-to-llvm.fir =================================================================== --- flang/test/Fir/convert-to-llvm.fir +++ flang/test/Fir/convert-to-llvm.fir @@ -165,6 +165,46 @@ // ----- +// Verify that fir.allocmem is transformed to a call to malloc +// and that fir.freemem is transformed to a call to free +// Single item case + +// CHECK: llvm.func @test_alloc_and_freemem_one() { +// CHECK-NEXT: [[N:%.*]] = llvm.mlir.constant(4 : i64) : i64 +// CHECK-NEXT: llvm.call @malloc([[N]]) +// CHECK: llvm.call @free(%{{.*}}) +// CHECK-NEXT: llvm.return + +func @test_alloc_and_freemem_one() { + %z0 = fir.allocmem i32 + fir.freemem %z0 : !fir.heap + return +} + +// ----- +// Verify that fir.allocmem is transformed to a call to malloc +// and that fir.freemem is transformed to a call to free +// Several item case + +// CHECK: llvm.func @test_alloc_and_freemem_several() { +// CHECK: [[NULL:%.*]] = llvm.mlir.null : !llvm.ptr> +// CHECK: [[ONE:%.*]] = llvm.mlir.constant(1 : i64) : i64 +// CHECK: [[PTR:%.*]] = llvm.getelementptr [[NULL]][{{.*}}] : (!llvm.ptr>, i64) -> !llvm.ptr> +// CHECK: [[N:%.*]] = llvm.ptrtoint [[PTR]] : !llvm.ptr> to i64 +// CHECK: [[MALLOC:%.*]] = llvm.call @malloc([[N]]) +// CHECK: [[B1:%.*]] = llvm.bitcast [[MALLOC]] : !llvm.ptr to !llvm.ptr> +// CHECK: [[B2:%.*]] = llvm.bitcast [[B1]] : !llvm.ptr> to !llvm.ptr +// CHECK: llvm.call @free([[B2]]) +// CHECK: llvm.return + +func @test_alloc_and_freemem_several() { + %z0 = fir.allocmem !fir.array<100xf32> + fir.freemem %z0 : !fir.heap> + return +} + +// ----- + // Verify that fir.unreachable is transformed to llvm.unreachable // CHECK: llvm.func @test_unreachable() {