diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -225,6 +225,9 @@ }]; let constructor = "mlir::createLowerToLLVMPass()"; let options = [ + Option<"alignedAlloc", "aligned-alloc", "bool", /*default=*/"false", + "Use aligned_alloc in place of malloc whenever alignment is set on " + "std.alloc">, Option<"useBarePtrCallConv", "use-bare-ptr-memref-call-conv", "bool", /*default=*/"false", "Replace FuncOp's MemRef arguments with bare pointers to the MemRef " diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h @@ -20,8 +20,9 @@ /// Collect a set of patterns to convert memory-related operations from the /// Standard dialect to the LLVM dialect, excluding non-memory-related /// operations and FuncOp. -void populateStdToLLVMMemoryConversionPatters( - LLVMTypeConverter &converter, OwningRewritePatternList &patterns); +void populateStdToLLVMMemoryConversionPatterns( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns, + bool alignedAlloc); /// Collect a set of patterns to convert from the Standard dialect to the LLVM /// dialect, excluding the memory-related operations. @@ -37,16 +38,20 @@ bool emitCWrappers = false); /// Collect a set of default patterns to convert from the Standard dialect to -/// LLVM. +/// LLVM. If `alignedAlloc` is set, aligned_alloc is used in place of malloc +/// whenever an alignment is specified. void populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter, OwningRewritePatternList &patterns, - bool emitCWrappers = false); + bool emitCWrappers = false, + bool alignedAlloc = false); /// Collect a set of patterns to convert from the Standard dialect to /// LLVM using the bare pointer calling convention for MemRef function -/// arguments. +/// arguments. If `alignedAlloc` is set, aligned_alloc is used in place of +/// malloc whenever an alignment is specified. void populateStdToLLVMBarePtrConversionPatterns( - LLVMTypeConverter &converter, OwningRewritePatternList &patterns); + LLVMTypeConverter &converter, OwningRewritePatternList &patterns, + bool alignedAlloc); /// Value to pass as bitwidth for the index type when the converter is expected /// to derive the bitwidth from the LLVM data layout. @@ -56,15 +61,17 @@ bool useBarePtrCallConv = false; bool emitCWrappers = false; unsigned indexBitwidth = kDeriveIndexBitwidthFromDataLayout; + bool alignedAlloc = false; }; /// Creates a pass to convert the Standard dialect into the LLVMIR dialect. -/// stdlib malloc/free is used for allocating memrefs allocated with std.alloc, -/// while LLVM's alloca is used for those allocated with std.alloca. -std::unique_ptr> createLowerToLLVMPass( - const LowerToLLVMOptions &options = { - /*useBarePtrCallConv=*/false, /*emitCWrappers=*/false, - /*indexBitwidth=*/kDeriveIndexBitwidthFromDataLayout}); +/// stdlib malloc/free is used by default for allocating memrefs allocated with +/// std.alloc, while LLVM's alloca is used for those allocated with std.alloca. +std::unique_ptr> +createLowerToLLVMPass(const LowerToLLVMOptions &options = { + /*useBarePtrCallConv=*/false, /*emitCWrappers=*/false, + /*indexBitwidth=*/kDeriveIndexBitwidthFromDataLayout, + /*alignedAlloc=*/false}); } // namespace mlir diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -1248,8 +1248,10 @@ using ConvertOpToLLVMPattern::typeConverter; using ConvertOpToLLVMPattern::getVoidPtrType; - explicit AllocLikeOpLowering(LLVMTypeConverter &converter) - : ConvertOpToLLVMPattern(converter) {} + explicit AllocLikeOpLowering(LLVMTypeConverter &converter, + bool alignedAlloc = false) + : ConvertOpToLLVMPattern(converter), + alignedAlloc(alignedAlloc) {} LogicalResult match(Operation *op) const override { MemRefType memRefType = cast(op).getType(); @@ -1388,38 +1390,78 @@ memRefType.getMemorySpace()); } + /// Returns the alignment to be used for the allocation call itself. If this + /// returns a non-null value, we'd have to use aligned_alloc. + Optional getAllocationAlignment(AllocOp allocOp) const { + auto elementType = allocOp.getType().getElementType(); + + // For heap allocations. + if (alignedAlloc && allocOp.alignment()) + return allocOp.alignment().getValue().getSExtValue(); + + // The alloc op might have an alignment attribute, but there is no alignment + // to be used for the malloc call itself. + if (allocOp.alignment()) + return None; + + // Whenever we don't have alignment set, we will still use aligned_alloc + // if the alignment needed is more than what malloc can provide. + uint64_t constEltSizeBytes = 0; + auto isMallocAlignmentSufficient = [&]() { + if (auto vectorType = elementType.template dyn_cast()) + constEltSizeBytes = + vectorType.getNumElements() * + llvm::divideCeil(vectorType.getElementTypeBitWidth(), 8); + else + constEltSizeBytes = + llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8); + + // Use aligned_alloc if elt_size > malloc's alignment. + return (constEltSizeBytes > kMallocAlignment); + }; + if (isMallocAlignmentSufficient()) + return std::max(constEltSizeBytes, + static_cast(kMallocAlignment)); + return None; + } + /// Allocates the underlying buffer using the right call. `allocatedBytePtr` /// is set to null for stack allocations. `accessAlignment` is set if /// alignment is neeeded post allocation (for eg. in conjunction with malloc). - /// TODO(bondhugula): next revision will support std lib func aligned_alloc. Value allocateBuffer(Location loc, Value cumulativeSize, Operation *op, MemRefType memRefType, Value one, Value &accessAlignment, Value &allocatedBytePtr, ConversionPatternRewriter &rewriter) const { auto elementPtrType = getElementPtrType(memRefType); - // Whether to use std lib function aligned_alloc that supports alignment. - Optional allocationAlignment = cast(op).alignment(); - // With alloca, one gets a pointer to the element type right away. - bool onStack = isa(op); - if (onStack) { + // For stack allocations. + if (auto allocaOp = dyn_cast(op)) { allocatedBytePtr = nullptr; accessAlignment = nullptr; return rewriter.create( loc, elementPtrType, cumulativeSize, - allocationAlignment ? allocationAlignment.getValue().getSExtValue() - : 0); + allocaOp.alignment() ? allocaOp.alignment().getValue().getSExtValue() + : 0); } - // Use malloc. Insert the malloc declaration if it is not already present. - auto allocFuncName = "malloc"; + // Heap allocations. AllocOp allocOp = cast(op); + + Optional allocationAlignment = getAllocationAlignment(allocOp); + // Whether to use std lib function aligned_alloc that supports alignment. + bool useAlignedAlloc = allocationAlignment.hasValue(); + + // Insert the malloc/aligned_alloc declaration if it is not already present. + auto allocFuncName = useAlignedAlloc ? "aligned_alloc" : "malloc"; auto module = allocOp.getParentOfType(); auto allocFunc = module.lookupSymbol(allocFuncName); if (!allocFunc) { OpBuilder moduleBuilder(op->getParentOfType().getBodyRegion()); SmallVector callArgTypes = {getIndexType()}; + // aligned_alloc(size_t alignment, size_t size) + if (useAlignedAlloc) + callArgTypes.push_back(getIndexType()); allocFunc = moduleBuilder.create( rewriter.getUnknownLoc(), allocFuncName, LLVM::LLVMType::getFunctionTy(getVoidPtrType(), callArgTypes, @@ -1429,16 +1471,25 @@ // Allocate the underlying buffer and store a pointer to it in the MemRef // descriptor. SmallVector callArgs; - // Adjust the allocation size to consider alignment. - if (allocOp.alignment()) { - accessAlignment = createIndexConstant( - rewriter, loc, allocOp.alignment().getValue().getSExtValue()); - cumulativeSize = rewriter.create( - loc, - rewriter.create(loc, cumulativeSize, accessAlignment), - one); + if (useAlignedAlloc) { + assert(allocationAlignment && "allocation alignment should be present"); + // Use aligned_alloc: we don't need to set 'align'. + auto alignedAllocAlignmentValue = rewriter.create( + loc, typeConverter.convertType(rewriter.getIntegerType(64)), + rewriter.getI64IntegerAttr(allocationAlignment.getValue())); + callArgs = {alignedAllocAlignmentValue, cumulativeSize}; + } else { + // Adjust the allocation size to consider alignment. + if (allocOp.alignment()) { + accessAlignment = createIndexConstant( + rewriter, loc, allocOp.alignment().getValue().getSExtValue()); + cumulativeSize = rewriter.create( + loc, + rewriter.create(loc, cumulativeSize, accessAlignment), + one); + } + callArgs.push_back(cumulativeSize); } - callArgs.push_back(cumulativeSize); auto allocFuncSymbol = rewriter.getSymbolRefAttr(allocFunc); allocatedBytePtr = rewriter .create(loc, getVoidPtrType(), @@ -1510,11 +1561,19 @@ // Return the final value of the descriptor. rewriter.replaceOp(op, {memRefDescriptor}); } + +protected: + // This is the alignment malloc typically provides. + constexpr static unsigned kMallocAlignment = 16; + bool alignedAlloc; }; struct AllocOpLowering : public AllocLikeOpLowering { - using Base::Base; + explicit AllocOpLowering(LLVMTypeConverter &converter, + bool alignedAlloc = false) + : AllocLikeOpLowering(converter, alignedAlloc) {} }; + struct AllocaOpLowering : public AllocLikeOpLowering { using Base::Base; }; @@ -2738,11 +2797,13 @@ // clang-format on } -void mlir::populateStdToLLVMMemoryConversionPatters( - LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { +void mlir::populateStdToLLVMMemoryConversionPatterns( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns, + bool alignedAlloc) { // clang-format off patterns.insert< AssumeAlignmentOpLowering, + DeallocOpLowering, DimOpLowering, LoadOpLowering, MemRefCastOpLowering, @@ -2750,8 +2811,8 @@ SubViewOpLowering, ViewOpLowering>(converter); patterns.insert< - AllocOpLowering, - DeallocOpLowering>(converter); + AllocOpLowering + >(converter, alignedAlloc); // clang-format on } @@ -2763,11 +2824,11 @@ void mlir::populateStdToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns, - bool emitCWrappers) { + bool emitCWrappers, bool alignedAlloc) { populateStdToLLVMDefaultFuncOpConversionPattern(converter, patterns, emitCWrappers); populateStdToLLVMNonMemoryConversionPatterns(converter, patterns); - populateStdToLLVMMemoryConversionPatters(converter, patterns); + populateStdToLLVMMemoryConversionPatterns(converter, patterns, alignedAlloc); } static void populateStdToLLVMBarePtrFuncOpConversionPattern( @@ -2776,10 +2837,11 @@ } void mlir::populateStdToLLVMBarePtrConversionPatterns( - LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + LLVMTypeConverter &converter, OwningRewritePatternList &patterns, + bool alignedAlloc) { populateStdToLLVMBarePtrFuncOpConversionPattern(converter, patterns); populateStdToLLVMNonMemoryConversionPatterns(converter, patterns); - populateStdToLLVMMemoryConversionPatters(converter, patterns); + populateStdToLLVMMemoryConversionPatterns(converter, patterns, alignedAlloc); } // Create an LLVM IR structure type if there is more than one result. @@ -2854,10 +2916,11 @@ /// Creates an LLVM lowering pass. LLVMLoweringPass(bool useBarePtrCallConv, bool emitCWrappers, - unsigned indexBitwidth) { + unsigned indexBitwidth, bool alignedAlloc) { this->useBarePtrCallConv = useBarePtrCallConv; this->emitCWrappers = emitCWrappers; this->indexBitwidth = indexBitwidth; + this->alignedAlloc = alignedAlloc; } explicit LLVMLoweringPass() {} LLVMLoweringPass(const LLVMLoweringPass &pass) {} @@ -2882,10 +2945,11 @@ OwningRewritePatternList patterns; if (useBarePtrCallConv) - populateStdToLLVMBarePtrConversionPatterns(typeConverter, patterns); + populateStdToLLVMBarePtrConversionPatterns(typeConverter, patterns, + alignedAlloc); else populateStdToLLVMConversionPatterns(typeConverter, patterns, - emitCWrappers); + emitCWrappers, alignedAlloc); LLVMConversionTarget target(getContext()); if (failed(applyPartialConversion(m, target, patterns, &typeConverter))) @@ -2904,5 +2968,6 @@ std::unique_ptr> mlir::createLowerToLLVMPass(const LowerToLLVMOptions &options) { return std::make_unique( - options.useBarePtrCallConv, options.emitCWrappers, options.indexBitwidth); + options.useBarePtrCallConv, options.emitCWrappers, options.indexBitwidth, + options.alignedAlloc); } diff --git a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt -convert-std-to-llvm %s | FileCheck %s +// RUN: mlir-opt -convert-std-to-llvm='aligned-alloc=1' %s | FileCheck %s --check-prefix=ALIGNED-ALLOC // CHECK-LABEL: func @check_strided_memref_arguments( // CHECK-COUNT-2: !llvm<"float*"> @@ -138,6 +139,39 @@ return } +// CHECK-LABEL: func @stdlib_aligned_alloc() -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> { +// ALIGNED-ALLOC-LABEL: func @stdlib_aligned_alloc() -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> { +func @stdlib_aligned_alloc() -> memref<32x18xf32> { +// ALIGNED-ALLOC-NEXT: %[[sz1:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64 +// ALIGNED-ALLOC-NEXT: %[[sz2:.*]] = llvm.mlir.constant(18 : index) : !llvm.i64 +// ALIGNED-ALLOC-NEXT: %[[num_elems:.*]] = llvm.mul %0, %1 : !llvm.i64 +// ALIGNED-ALLOC-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*"> +// ALIGNED-ALLOC-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// ALIGNED-ALLOC-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> +// ALIGNED-ALLOC-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64 +// ALIGNED-ALLOC-NEXT: %[[bytes:.*]] = llvm.mul %[[num_elems]], %[[sizeof]] : !llvm.i64 +// ALIGNED-ALLOC-NEXT: %[[alignment:.*]] = llvm.mlir.constant(32 : i64) : !llvm.i64 +// ALIGNED-ALLOC-NEXT: %[[allocated:.*]] = llvm.call @aligned_alloc(%[[alignment]], %[[bytes]]) : (!llvm.i64, !llvm.i64) -> !llvm<"i8*"> +// ALIGNED-ALLOC-NEXT: llvm.bitcast %[[allocated]] : !llvm<"i8*"> to !llvm<"float*"> + %0 = alloc() {alignment = 32} : memref<32x18xf32> + // Do another alloc just to test that we have a unique declaration for + // aligned_alloc. + // ALIGNED-ALLOC: llvm.call @aligned_alloc + %1 = alloc() {alignment = 64} : memref<4096xf32> + + // Memrefs with elements types larger than 16 bytes are allocated via + // aligned_alloc even without an alignment attribute. + // ALIGNED-ALLOC: llvm.call @aligned_alloc + %2 = alloc() : memref<4096xvector<8xf32>> + // While those with elt type <= 16 bytes still use malloc unless the alignment + // attribute is specified. + // ALIGNED-ALLOC: llvm.call @malloc + %3 = alloc() : memref<4096xvector<2xf32>> + // ALIGNED-ALLOC: llvm.call @aligned_alloc + %4 = alloc() {alignment = 16} : memref<4096xvector<2xf32>> + return %0 : memref<32x18xf32> +} + // CHECK-LABEL: func @mixed_load( // CHECK-COUNT-2: !llvm<"float*">, // CHECK-COUNT-5: {{%[a-zA-Z0-9]*}}: !llvm.i64