diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -225,8 +225,9 @@ }]; let constructor = "mlir::createLowerToLLVMPass()"; let options = [ - Option<"useAlloca", "use-alloca", "bool", /*default=*/"false", - "Use `alloca` instead of `call @malloc` for converting std.alloc">, + Option<"alignedAlloc", "aligned-alloc", "bool", /*default=*/"false", + "Use aligned_alloc in place of malloc whenever alignment is set on " + "std.alloc">, Option<"useBarePtrCallConv", "use-bare-ptr-memref-call-conv", "bool", /*default=*/"false", "Replace FuncOp's MemRef arguments with bare pointers to the MemRef " diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h @@ -20,8 +20,9 @@ /// Collect a set of patterns to convert memory-related operations from the /// Standard dialect to the LLVM dialect, excluding non-memory-related /// operations and FuncOp. -void populateStdToLLVMMemoryConversionPatters( - LLVMTypeConverter &converter, OwningRewritePatternList &patterns); +void populateStdToLLVMMemoryConversionPatterns( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns, + bool alignedAlloc); /// Collect a set of patterns to convert from the Standard dialect to the LLVM /// dialect, excluding the memory-related operations. @@ -37,16 +38,20 @@ bool emitCWrappers = false); /// Collect a set of default patterns to convert from the Standard dialect to -/// LLVM. +/// LLVM. If `alignedAlloc` is set, aligned_alloc is used in place of malloc +/// whenever an alignment is specified. void populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter, OwningRewritePatternList &patterns, + bool alignedAlloc = false, bool emitCWrappers = false); /// Collect a set of patterns to convert from the Standard dialect to /// LLVM using the bare pointer calling convention for MemRef function -/// arguments. +/// arguments. If `alignedAlloc` is set, aligned_alloc is used in place of +/// malloc whenever an alignment is specified. void populateStdToLLVMBarePtrConversionPatterns( - LLVMTypeConverter &converter, OwningRewritePatternList &patterns); + LLVMTypeConverter &converter, OwningRewritePatternList &patterns, + bool alignedAlloc); /// Value to pass as bitwidth for the index type when the converter is expected /// to derive the bitwidth from the LLVM data layout. @@ -54,8 +59,12 @@ /// Creates a pass to convert the Standard dialect into the LLVMIR dialect. /// By default stdlib malloc/free are used for allocating MemRef payloads. +/// Specifying `aligned-alloc-true` uses aligned_alloc instead of malloc to +/// perform memref allocations whenever an alignment is +/// specified. std::unique_ptr> createLowerToLLVMPass( - bool useBarePtrCallConv = false, bool emitCWrappers = false, + bool alignedAlloc = false, bool useBarePtrCallConv = false, + bool emitCWrappers = false, unsigned indexBitwidth = kDeriveIndexBitwidthFromDataLayout); } // namespace mlir diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -1248,8 +1248,10 @@ using ConvertOpToLLVMPattern::typeConverter; using ConvertOpToLLVMPattern::getVoidPtrType; - explicit AllocLikeOpLowering(LLVMTypeConverter &converter) - : ConvertOpToLLVMPattern(converter) {} + explicit AllocLikeOpLowering(LLVMTypeConverter &converter, + bool alignedAlloc = false) + : ConvertOpToLLVMPattern(converter), + alignedAlloc(alignedAlloc) {} LogicalResult match(Operation *op) const override { MemRefType memRefType = cast(op).getType(); @@ -1388,10 +1390,59 @@ memRefType.getMemorySpace()); } + /// Returns the alignment to be used for the allocation itself. + /// `useAlignedAlloc` is set to true if aligned_alloc should be used instead + /// of malloc. + Optional getAllocationAlignment(Operation *op, + bool &useAlignedAlloc) const { + auto elementType = cast(op).getType().getElementType(); + + useAlignedAlloc = false; + + // For stack allocations. + if (auto allocaOp = dyn_cast(op)) { + auto alignment = cast(op).alignment(); + return alignment ? Optional(alignment.getValue().getSExtValue()) + : None; + } + + // For heap allocations. + AllocOp allocOp = cast(op); + if (alignedAlloc && allocOp.alignment()) { + useAlignedAlloc = true; + return allocOp.alignment().getValue().getSExtValue(); + } + + // The alloc op might have an alignment attribute, but there is no alignment + // to be used for the malloc call itself. + if (allocOp.alignment()) + return None; + + // Whenever we don't have alignment set, we will still use aligned_alloc + // if the alignment needed is more than what malloc can provide. + uint64_t constEltSizeBytes = 0; + auto isMallocAlignmentSufficient = [&]() { + if (auto vectorType = elementType.template dyn_cast()) + constEltSizeBytes = + vectorType.getNumElements() * + llvm::divideCeil(vectorType.getElementTypeBitWidth(), 8); + else + constEltSizeBytes = + llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8); + + // Use aligned_alloc if elt_size > malloc's alignment. + return (constEltSizeBytes > kMallocAlignment); + }; + useAlignedAlloc = isMallocAlignmentSufficient(); + if (useAlignedAlloc) + return std::max(constEltSizeBytes, + static_cast(kMallocAlignment)); + return None; + } + /// Allocates the underlying buffer using the right call. `allocatedBytePtr` /// is set to null for stack allocations. `accessAlignment` is set if /// alignment is neeeded post allocation (for eg. in conjunction with malloc). - /// TODO(bondhugula): next revision will support std lib func aligned_alloc. Value allocateBuffer(Location loc, Value cumulativeSize, Operation *op, MemRefType memRefType, Value one, Value &accessAlignment, Value &allocatedBytePtr, @@ -1399,7 +1450,9 @@ auto elementPtrType = getElementPtrType(memRefType); // Whether to use std lib function aligned_alloc that supports alignment. - Optional allocationAlignment = cast(op).alignment(); + bool useAlignedAlloc; + Optional allocationAlignment = + getAllocationAlignment(op, useAlignedAlloc); // With alloca, one gets a pointer to the element type right away. bool onStack = isa(op); @@ -1408,18 +1461,22 @@ accessAlignment = nullptr; return rewriter.create( loc, elementPtrType, cumulativeSize, - allocationAlignment ? allocationAlignment.getValue().getSExtValue() - : 0); + allocationAlignment ? allocationAlignment.getValue() : 0); } - // Use malloc. Insert the malloc declaration if it is not already present. - auto allocFuncName = "malloc"; + // Use malloc if useAlignedAlloc is false, otherwise use aligned_alloc; + + // Insert the malloc/aligned_alloc declaration if it is not already present. + auto allocFuncName = useAlignedAlloc ? "aligned_alloc" : "malloc"; AllocOp allocOp = cast(op); auto module = allocOp.getParentOfType(); auto allocFunc = module.lookupSymbol(allocFuncName); if (!allocFunc) { OpBuilder moduleBuilder(op->getParentOfType().getBodyRegion()); SmallVector callArgTypes = {getIndexType()}; + // aligned_alloc(size_t alignment, size_t size) + if (useAlignedAlloc) + callArgTypes.push_back(getIndexType()); allocFunc = moduleBuilder.create( rewriter.getUnknownLoc(), allocFuncName, LLVM::LLVMType::getFunctionTy(getVoidPtrType(), callArgTypes, @@ -1429,16 +1486,25 @@ // Allocate the underlying buffer and store a pointer to it in the MemRef // descriptor. SmallVector callArgs; - // Adjust the allocation size to consider alignment. - if (allocOp.alignment()) { - accessAlignment = createIndexConstant( - rewriter, loc, allocOp.alignment().getValue().getSExtValue()); - cumulativeSize = rewriter.create( - loc, - rewriter.create(loc, cumulativeSize, accessAlignment), - one); + if (useAlignedAlloc) { + assert(allocationAlignment && "allocation alignment should be present"); + // Use aligned_alloc: we don't need to set 'align'. + auto alignedAllocAlignmentValue = rewriter.create( + loc, typeConverter.convertType(rewriter.getIntegerType(64)), + rewriter.getI64IntegerAttr(allocationAlignment.getValue())); + callArgs = {alignedAllocAlignmentValue, cumulativeSize}; + } else { + // Adjust the allocation size to consider alignment. + if (allocOp.alignment()) { + accessAlignment = createIndexConstant( + rewriter, loc, allocOp.alignment().getValue().getSExtValue()); + cumulativeSize = rewriter.create( + loc, + rewriter.create(loc, cumulativeSize, accessAlignment), + one); + } + callArgs.push_back(cumulativeSize); } - callArgs.push_back(cumulativeSize); auto allocFuncSymbol = rewriter.getSymbolRefAttr(allocFunc); allocatedBytePtr = rewriter .create(loc, getVoidPtrType(), @@ -1510,11 +1576,19 @@ // Return the final value of the descriptor. rewriter.replaceOp(op, {memRefDescriptor}); } + +protected: + // This is the alignment malloc typically provides. + constexpr static unsigned kMallocAlignment = 16; + bool alignedAlloc; }; struct AllocOpLowering : public AllocLikeOpLowering { - using Base::Base; + explicit AllocOpLowering(LLVMTypeConverter &converter, + bool alignedAlloc = false) + : AllocLikeOpLowering(converter, alignedAlloc) {} }; + struct AllocaOpLowering : public AllocLikeOpLowering { using Base::Base; }; @@ -2738,11 +2812,13 @@ // clang-format on } -void mlir::populateStdToLLVMMemoryConversionPatters( - LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { +void mlir::populateStdToLLVMMemoryConversionPatterns( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns, + bool alignedAlloc) { // clang-format off patterns.insert< AssumeAlignmentOpLowering, + DeallocOpLowering, DimOpLowering, LoadOpLowering, MemRefCastOpLowering, @@ -2750,8 +2826,8 @@ SubViewOpLowering, ViewOpLowering>(converter); patterns.insert< - AllocOpLowering, - DeallocOpLowering>(converter); + AllocOpLowering + >(converter, alignedAlloc); // clang-format on } @@ -2763,11 +2839,11 @@ void mlir::populateStdToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns, - bool emitCWrappers) { + bool alignedAlloc, bool emitCWrappers) { populateStdToLLVMDefaultFuncOpConversionPattern(converter, patterns, emitCWrappers); populateStdToLLVMNonMemoryConversionPatterns(converter, patterns); - populateStdToLLVMMemoryConversionPatters(converter, patterns); + populateStdToLLVMMemoryConversionPatterns(converter, patterns, alignedAlloc); } static void populateStdToLLVMBarePtrFuncOpConversionPattern( @@ -2776,10 +2852,11 @@ } void mlir::populateStdToLLVMBarePtrConversionPatterns( - LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + LLVMTypeConverter &converter, OwningRewritePatternList &patterns, + bool alignedAlloc) { populateStdToLLVMBarePtrFuncOpConversionPattern(converter, patterns); populateStdToLLVMNonMemoryConversionPatterns(converter, patterns); - populateStdToLLVMMemoryConversionPatters(converter, patterns); + populateStdToLLVMMemoryConversionPatterns(converter, patterns, alignedAlloc); } // Create an LLVM IR structure type if there is more than one result. @@ -2853,8 +2930,9 @@ #include "mlir/Conversion/Passes.h.inc" /// Creates an LLVM lowering pass. - LLVMLoweringPass(bool useBarePtrCallConv, bool emitCWrappers, - unsigned indexBitwidth) { + LLVMLoweringPass(bool alignedAlloc, bool useBarePtrCallConv, + bool emitCWrappers, unsigned indexBitwidth) { + this->alignedAlloc = alignedAlloc; this->useBarePtrCallConv = useBarePtrCallConv; this->emitCWrappers = emitCWrappers; this->indexBitwidth = indexBitwidth; @@ -2882,9 +2960,10 @@ OwningRewritePatternList patterns; if (useBarePtrCallConv) - populateStdToLLVMBarePtrConversionPatterns(typeConverter, patterns); + populateStdToLLVMBarePtrConversionPatterns(typeConverter, patterns, + alignedAlloc); else - populateStdToLLVMConversionPatterns(typeConverter, patterns, + populateStdToLLVMConversionPatterns(typeConverter, patterns, alignedAlloc, emitCWrappers); LLVMConversionTarget target(getContext()); @@ -2902,8 +2981,8 @@ } std::unique_ptr> -mlir::createLowerToLLVMPass(bool useBarePtrCallConv, bool emitCWrappers, - unsigned indexBitwidth) { - return std::make_unique(useBarePtrCallConv, emitCWrappers, - indexBitwidth); +mlir::createLowerToLLVMPass(bool alignedAlloc, bool useBarePtrCallConv, + bool emitCWrappers, unsigned indexBitwidth) { + return std::make_unique(alignedAlloc, useBarePtrCallConv, + emitCWrappers, indexBitwidth); } diff --git a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt -convert-std-to-llvm %s | FileCheck %s +// RUN: mlir-opt -convert-std-to-llvm='aligned-alloc=1' %s | FileCheck %s --check-prefix=ALIGNED-ALLOC // CHECK-LABEL: func @check_strided_memref_arguments( // CHECK-COUNT-2: !llvm<"float*"> @@ -138,6 +139,39 @@ return } +// CHECK-LABEL: func @stdlib_aligned_alloc() -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> { +// ALIGNED-ALLOC-LABEL: func @stdlib_aligned_alloc() -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> { +func @stdlib_aligned_alloc() -> memref<32x18xf32> { +// ALIGNED-ALLOC-NEXT: %[[sz1:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64 +// ALIGNED-ALLOC-NEXT: %[[sz2:.*]] = llvm.mlir.constant(18 : index) : !llvm.i64 +// ALIGNED-ALLOC-NEXT: %[[num_elems:.*]] = llvm.mul %0, %1 : !llvm.i64 +// ALIGNED-ALLOC-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*"> +// ALIGNED-ALLOC-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// ALIGNED-ALLOC-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> +// ALIGNED-ALLOC-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64 +// ALIGNED-ALLOC-NEXT: %[[bytes:.*]] = llvm.mul %[[num_elems]], %[[sizeof]] : !llvm.i64 +// ALIGNED-ALLOC-NEXT: %[[alignment:.*]] = llvm.mlir.constant(32 : i64) : !llvm.i64 +// ALIGNED-ALLOC-NEXT: %[[allocated:.*]] = llvm.call @aligned_alloc(%[[alignment]], %[[bytes]]) : (!llvm.i64, !llvm.i64) -> !llvm<"i8*"> +// ALIGNED-ALLOC-NEXT: llvm.bitcast %[[allocated]] : !llvm<"i8*"> to !llvm<"float*"> + %0 = alloc() {alignment = 32} : memref<32x18xf32> + // Do another alloc just to test that we have a unique declaration for + // aligned_alloc. + // ALIGNED-ALLOC: llvm.call @aligned_alloc + %1 = alloc() {alignment = 64} : memref<4096xf32> + + // Memrefs with elements types larger than 16 bytes are allocated via + // aligned_alloc even without an alignment attribute. + // ALIGNED-ALLOC: llvm.call @aligned_alloc + %2 = alloc() : memref<4096xvector<8xf32>> + // While those with elt type <= 16 bytes still use malloc unless the alignment + // attribute is specified. + // ALIGNED-ALLOC: llvm.call @malloc + %3 = alloc() : memref<4096xvector<2xf32>> + // ALIGNED-ALLOC: llvm.call @aligned_alloc + %4 = alloc() {alignment = 16} : memref<4096xvector<2xf32>> + return %0 : memref<32x18xf32> +} + // CHECK-LABEL: func @mixed_load( // CHECK-COUNT-2: !llvm<"float*">, // CHECK-COUNT-5: {{%[a-zA-Z0-9]*}}: !llvm.i64