diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -226,6 +226,8 @@ }]; let constructor = "mlir::createLowerToLLVMPass()"; let options = [ + Option<"useAlignedAlloc", "use-aligned-alloc", "bool", /*default=*/"false", + "Use aligned_alloc in place of malloc for heap allocations">, Option<"useBarePtrCallConv", "use-bare-ptr-memref-call-conv", "bool", /*default=*/"false", "Replace FuncOp's MemRef arguments with bare pointers to the MemRef " diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h @@ -20,8 +20,9 @@ /// Collect a set of patterns to convert memory-related operations from the /// Standard dialect to the LLVM dialect, excluding non-memory-related /// operations and FuncOp. -void populateStdToLLVMMemoryConversionPatters( - LLVMTypeConverter &converter, OwningRewritePatternList &patterns); +void populateStdToLLVMMemoryConversionPatterns( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns, + bool useAlignedAlloc); /// Collect a set of patterns to convert from the Standard dialect to the LLVM /// dialect, excluding the memory-related operations. @@ -40,13 +41,15 @@ /// LLVM. void populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter, OwningRewritePatternList &patterns, - bool emitCWrappers = false); + bool emitCWrappers = false, + bool useAlignedAlloc = false); /// Collect a set of patterns to convert from the Standard dialect to /// LLVM using the bare pointer calling convention for MemRef function /// arguments. void populateStdToLLVMBarePtrConversionPatterns( - LLVMTypeConverter &converter, OwningRewritePatternList &patterns); + LLVMTypeConverter &converter, OwningRewritePatternList &patterns, + bool useAlignedAlloc); /// Value to pass as bitwidth for the index type when the converter is expected /// to derive the bitwidth from the LLVM data layout. @@ -56,15 +59,18 @@ bool useBarePtrCallConv = false; bool emitCWrappers = false; unsigned indexBitwidth = kDeriveIndexBitwidthFromDataLayout; + /// Use aligned_alloc for heap allocations. + bool useAlignedAlloc = false; }; /// Creates a pass to convert the Standard dialect into the LLVMIR dialect. -/// stdlib malloc/free is used for allocating memrefs allocated with std.alloc, -/// while LLVM's alloca is used for those allocated with std.alloca. -std::unique_ptr> createLowerToLLVMPass( - const LowerToLLVMOptions &options = { - /*useBarePtrCallConv=*/false, /*emitCWrappers=*/false, - /*indexBitwidth=*/kDeriveIndexBitwidthFromDataLayout}); +/// stdlib malloc/free is used by default for allocating memrefs allocated with +/// std.alloc, while LLVM's alloca is used for those allocated with std.alloca. +std::unique_ptr> +createLowerToLLVMPass(const LowerToLLVMOptions &options = { + /*useBarePtrCallConv=*/false, /*emitCWrappers=*/false, + /*indexBitwidth=*/kDeriveIndexBitwidthFromDataLayout, + /*useAlignedAlloc=*/false}); } // namespace mlir diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -344,7 +344,6 @@ return success(); } -// TODO(mlir-team): improve/complete this when we have target data. static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { auto elementType = memRefType.getElementType(); diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -1248,8 +1248,10 @@ using ConvertOpToLLVMPattern::typeConverter; using ConvertOpToLLVMPattern::getVoidPtrType; - explicit AllocLikeOpLowering(LLVMTypeConverter &converter) - : ConvertOpToLLVMPattern(converter) {} + explicit AllocLikeOpLowering(LLVMTypeConverter &converter, + bool useAlignedAlloc = false) + : ConvertOpToLLVMPattern(converter), + useAlignedAlloc(useAlignedAlloc) {} LogicalResult match(Operation *op) const override { MemRefType memRefType = cast(op).getType(); @@ -1271,6 +1273,18 @@ return success(); } + // Returns bump = (alignment - (input % alignment))% alignment, which is the + // increment necessary to align `input` to `alignment` boundary. + // TODO: this can be made more efficient by just using a single addition + // and two bit shifts: (ptr + align - 1)/align, align is always power of 2. + Value createBumpToAlign(Location loc, OpBuilder b, Value input, + Value alignment) const { + Value modAlign = b.create(loc, input, alignment); + Value diff = b.create(loc, alignment, modAlign); + Value shift = b.create(loc, diff, alignment); + return shift; + } + /// Creates and populates the memref descriptor struct given all its fields. /// This method also performs any post allocation alignment needed for heap /// allocations when `accessAlignment` is non null. This is used with @@ -1292,12 +1306,7 @@ // offset = (align - (ptr % align))% align Value intVal = rewriter.create( loc, this->getIndexType(), allocatedBytePtr); - Value ptrModAlign = - rewriter.create(loc, intVal, accessAlignment); - Value subbed = - rewriter.create(loc, accessAlignment, ptrModAlign); - Value offset = - rewriter.create(loc, subbed, accessAlignment); + Value offset = createBumpToAlign(loc, rewriter, intVal, accessAlignment); Value aligned = rewriter.create( loc, allocatedBytePtr.getType(), allocatedBytePtr, offset); alignedBytePtr = rewriter.create( @@ -1388,38 +1397,90 @@ memRefType.getMemorySpace()); } + /// Returns the memref's element size in bytes. + // TODO: there are other places where this is used. Expose publicly? + static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { + auto elementType = memRefType.getElementType(); + + unsigned sizeInBits; + if (elementType.isIntOrFloat()) { + sizeInBits = elementType.getIntOrFloatBitWidth(); + } else { + auto vectorType = elementType.cast(); + sizeInBits = + vectorType.getElementTypeBitWidth() * vectorType.getNumElements(); + } + return llvm::divideCeil(sizeInBits, 8); + } + + /// Returns the alignment to be used for the allocation call itself. + /// aligned_alloc requires the allocation size to be a power of two, and the + /// allocation size to be a multiple of alignment, + Optional getAllocationAlignment(AllocOp allocOp) const { + // No alignment can be used for the 'malloc' call itself. + if (!useAlignedAlloc) + return None; + + if (allocOp.alignment()) + return allocOp.alignment().getValue().getSExtValue(); + + // Whenever we don't have alignment set, we will use an alignment + // consistent with the element type; since the allocation size has to be a + // power of two, we will bump to the next power of two if it already isn't. + auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType()); + return std::max(kMinAlignedAllocAlignment, + llvm::PowerOf2Ceil(eltSizeBytes)); + } + + /// Returns true if the memref size in bytes is known to be a multiple of + /// factor. + static bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor) { + uint64_t sizeDivisor = getMemRefEltSizeInBytes(type); + for (unsigned i = 0, e = type.getRank(); i < e; i++) { + if (type.isDynamic(type.getDimSize(i))) + continue; + sizeDivisor = sizeDivisor * type.getDimSize(i); + } + return sizeDivisor % factor == 0; + } + /// Allocates the underlying buffer using the right call. `allocatedBytePtr` /// is set to null for stack allocations. `accessAlignment` is set if /// alignment is neeeded post allocation (for eg. in conjunction with malloc). - /// TODO(bondhugula): next revision will support std lib func aligned_alloc. Value allocateBuffer(Location loc, Value cumulativeSize, Operation *op, MemRefType memRefType, Value one, Value &accessAlignment, Value &allocatedBytePtr, ConversionPatternRewriter &rewriter) const { auto elementPtrType = getElementPtrType(memRefType); - // Whether to use std lib function aligned_alloc that supports alignment. - Optional allocationAlignment = cast(op).alignment(); - // With alloca, one gets a pointer to the element type right away. - bool onStack = isa(op); - if (onStack) { + // For stack allocations. + if (auto allocaOp = dyn_cast(op)) { allocatedBytePtr = nullptr; accessAlignment = nullptr; return rewriter.create( loc, elementPtrType, cumulativeSize, - allocationAlignment ? allocationAlignment.getValue().getSExtValue() - : 0); + allocaOp.alignment() ? allocaOp.alignment().getValue().getSExtValue() + : 0); } - // Use malloc. Insert the malloc declaration if it is not already present. - auto allocFuncName = "malloc"; + // Heap allocations. AllocOp allocOp = cast(op); + + Optional allocationAlignment = getAllocationAlignment(allocOp); + // Whether to use std lib function aligned_alloc that supports alignment. + bool useAlignedAlloc = allocationAlignment.hasValue(); + + // Insert the malloc/aligned_alloc declaration if it is not already present. + auto allocFuncName = useAlignedAlloc ? "aligned_alloc" : "malloc"; auto module = allocOp.getParentOfType(); auto allocFunc = module.lookupSymbol(allocFuncName); if (!allocFunc) { OpBuilder moduleBuilder(op->getParentOfType().getBodyRegion()); SmallVector callArgTypes = {getIndexType()}; + // aligned_alloc(size_t alignment, size_t size) + if (useAlignedAlloc) + callArgTypes.push_back(getIndexType()); allocFunc = moduleBuilder.create( rewriter.getUnknownLoc(), allocFuncName, LLVM::LLVMType::getFunctionTy(getVoidPtrType(), callArgTypes, @@ -1429,16 +1490,33 @@ // Allocate the underlying buffer and store a pointer to it in the MemRef // descriptor. SmallVector callArgs; - // Adjust the allocation size to consider alignment. - if (allocOp.alignment()) { - accessAlignment = createIndexConstant( - rewriter, loc, allocOp.alignment().getValue().getSExtValue()); - cumulativeSize = rewriter.create( - loc, - rewriter.create(loc, cumulativeSize, accessAlignment), - one); + if (useAlignedAlloc) { + // Use aligned_alloc. + assert(allocationAlignment && "allocation alignment should be present"); + auto alignedAllocAlignmentValue = rewriter.create( + loc, typeConverter.convertType(rewriter.getIntegerType(64)), + rewriter.getI64IntegerAttr(allocationAlignment.getValue())); + // aligned_alloc requires size to be a multiple of alignment; we will pad + // the size to the next multiple if necessary. + if (!isMemRefSizeMultipleOf(memRefType, allocationAlignment.getValue())) { + Value bump = createBumpToAlign(loc, rewriter, cumulativeSize, + alignedAllocAlignmentValue); + cumulativeSize = + rewriter.create(loc, cumulativeSize, bump); + } + callArgs = {alignedAllocAlignmentValue, cumulativeSize}; + } else { + // Adjust the allocation size to consider alignment. + if (allocOp.alignment()) { + accessAlignment = createIndexConstant( + rewriter, loc, allocOp.alignment().getValue().getSExtValue()); + cumulativeSize = rewriter.create( + loc, + rewriter.create(loc, cumulativeSize, accessAlignment), + one); + } + callArgs.push_back(cumulativeSize); } - callArgs.push_back(cumulativeSize); auto allocFuncSymbol = rewriter.getSymbolRefAttr(allocFunc); allocatedBytePtr = rewriter .create(loc, getVoidPtrType(), @@ -1510,11 +1588,20 @@ // Return the final value of the descriptor. rewriter.replaceOp(op, {memRefDescriptor}); } + +protected: + /// Use aligned_alloc instead of malloc for all heap allocations. + bool useAlignedAlloc; + /// The minimum alignment to use with aligned_alloc (has to be a power of 2). + uint64_t kMinAlignedAllocAlignment = 16UL; }; struct AllocOpLowering : public AllocLikeOpLowering { - using Base::Base; + explicit AllocOpLowering(LLVMTypeConverter &converter, + bool useAlignedAlloc = false) + : AllocLikeOpLowering(converter, useAlignedAlloc) {} }; + struct AllocaOpLowering : public AllocLikeOpLowering { using Base::Base; }; @@ -2738,11 +2825,13 @@ // clang-format on } -void mlir::populateStdToLLVMMemoryConversionPatters( - LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { +void mlir::populateStdToLLVMMemoryConversionPatterns( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns, + bool useAlignedAlloc) { // clang-format off patterns.insert< AssumeAlignmentOpLowering, + DeallocOpLowering, DimOpLowering, LoadOpLowering, MemRefCastOpLowering, @@ -2750,8 +2839,8 @@ SubViewOpLowering, ViewOpLowering>(converter); patterns.insert< - AllocOpLowering, - DeallocOpLowering>(converter); + AllocOpLowering + >(converter, useAlignedAlloc); // clang-format on } @@ -2763,11 +2852,12 @@ void mlir::populateStdToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns, - bool emitCWrappers) { + bool emitCWrappers, bool useAlignedAlloc) { populateStdToLLVMDefaultFuncOpConversionPattern(converter, patterns, emitCWrappers); populateStdToLLVMNonMemoryConversionPatterns(converter, patterns); - populateStdToLLVMMemoryConversionPatters(converter, patterns); + populateStdToLLVMMemoryConversionPatterns(converter, patterns, + useAlignedAlloc); } static void populateStdToLLVMBarePtrFuncOpConversionPattern( @@ -2776,10 +2866,12 @@ } void mlir::populateStdToLLVMBarePtrConversionPatterns( - LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + LLVMTypeConverter &converter, OwningRewritePatternList &patterns, + bool useAlignedAlloc) { populateStdToLLVMBarePtrFuncOpConversionPattern(converter, patterns); populateStdToLLVMNonMemoryConversionPatterns(converter, patterns); - populateStdToLLVMMemoryConversionPatters(converter, patterns); + populateStdToLLVMMemoryConversionPatterns(converter, patterns, + useAlignedAlloc); } // Create an LLVM IR structure type if there is more than one result. @@ -2850,10 +2942,11 @@ struct LLVMLoweringPass : public ConvertStandardToLLVMBase { LLVMLoweringPass() = default; LLVMLoweringPass(bool useBarePtrCallConv, bool emitCWrappers, - unsigned indexBitwidth) { + unsigned indexBitwidth, bool useAlignedAlloc) { this->useBarePtrCallConv = useBarePtrCallConv; this->emitCWrappers = emitCWrappers; this->indexBitwidth = indexBitwidth; + this->useAlignedAlloc = useAlignedAlloc; } /// Run the dialect converter on the module. @@ -2876,10 +2969,11 @@ OwningRewritePatternList patterns; if (useBarePtrCallConv) - populateStdToLLVMBarePtrConversionPatterns(typeConverter, patterns); + populateStdToLLVMBarePtrConversionPatterns(typeConverter, patterns, + useAlignedAlloc); else populateStdToLLVMConversionPatterns(typeConverter, patterns, - emitCWrappers); + emitCWrappers, useAlignedAlloc); LLVMConversionTarget target(getContext()); if (failed(applyPartialConversion(m, target, patterns, &typeConverter))) @@ -2898,5 +2992,6 @@ std::unique_ptr> mlir::createLowerToLLVMPass(const LowerToLLVMOptions &options) { return std::make_unique( - options.useBarePtrCallConv, options.emitCWrappers, options.indexBitwidth); + options.useBarePtrCallConv, options.emitCWrappers, options.indexBitwidth, + options.useAlignedAlloc); } diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1068,7 +1068,7 @@ OpFoldResult DimOp::fold(ArrayRef operands) { // Constant fold dim when the size along the index referred to is a constant. auto opType = memrefOrTensor().getType(); - int64_t indexSize = -1; + int64_t indexSize = ShapedType::kDynamicSize; if (auto tensorType = opType.dyn_cast()) indexSize = tensorType.getShape()[getIndex()]; else if (auto memrefType = opType.dyn_cast()) diff --git a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt -convert-std-to-llvm %s | FileCheck %s +// RUN: mlir-opt -convert-std-to-llvm='use-aligned-alloc=1' %s | FileCheck %s --check-prefix=ALIGNED-ALLOC // CHECK-LABEL: func @check_strided_memref_arguments( // CHECK-COUNT-2: !llvm<"float*"> @@ -138,6 +139,52 @@ return } +// CHECK-LABEL: func @stdlib_aligned_alloc({{.*}}) -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> { +// ALIGNED-ALLOC-LABEL: func @stdlib_aligned_alloc({{.*}}) -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> { +func @stdlib_aligned_alloc(%N : index) -> memref<32x18xf32> { +// ALIGNED-ALLOC-NEXT: %[[sz1:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64 +// ALIGNED-ALLOC-NEXT: %[[sz2:.*]] = llvm.mlir.constant(18 : index) : !llvm.i64 +// ALIGNED-ALLOC-NEXT: %[[num_elems:.*]] = llvm.mul %0, %1 : !llvm.i64 +// ALIGNED-ALLOC-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*"> +// ALIGNED-ALLOC-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// ALIGNED-ALLOC-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> +// ALIGNED-ALLOC-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64 +// ALIGNED-ALLOC-NEXT: %[[bytes:.*]] = llvm.mul %[[num_elems]], %[[sizeof]] : !llvm.i64 +// ALIGNED-ALLOC-NEXT: %[[alignment:.*]] = llvm.mlir.constant(32 : i64) : !llvm.i64 +// ALIGNED-ALLOC-NEXT: %[[allocated:.*]] = llvm.call @aligned_alloc(%[[alignment]], %[[bytes]]) : (!llvm.i64, !llvm.i64) -> !llvm<"i8*"> +// ALIGNED-ALLOC-NEXT: llvm.bitcast %[[allocated]] : !llvm<"i8*"> to !llvm<"float*"> + %0 = alloc() {alignment = 32} : memref<32x18xf32> + // Do another alloc just to test that we have a unique declaration for + // aligned_alloc. + // ALIGNED-ALLOC: llvm.call @aligned_alloc + %1 = alloc() {alignment = 64} : memref<4096xf32> + + // Alignment is to element type boundaries (minimum 16 bytes). + // ALIGNED-ALLOC: %[[c32:.*]] = llvm.mlir.constant(32 : i64) : !llvm.i64 + // ALIGNED-ALLOC-NEXT: llvm.call @aligned_alloc(%[[c32]] + %2 = alloc() : memref<4096xvector<8xf32>> + // The minimum alignment is 16 bytes unless explicitly specified. + // ALIGNED-ALLOC: %[[c16:.*]] = llvm.mlir.constant(16 : i64) : !llvm.i64 + // ALIGNED-ALLOC-NEXT: llvm.call @aligned_alloc(%[[c16]], + %3 = alloc() : memref<4096xvector<2xf32>> + // ALIGNED-ALLOC: %[[c8:.*]] = llvm.mlir.constant(8 : i64) : !llvm.i64 + // ALIGNED-ALLOC-NEXT: llvm.call @aligned_alloc(%[[c8]], + %4 = alloc() {alignment = 8} : memref<1024xvector<4xf32>> + // Bump the memref allocation size if its size is not a multiple of alignment. + // ALIGNED-ALLOC: %[[c32:.*]] = llvm.mlir.constant(32 : i64) : !llvm.i64 + // ALIGNED-ALLOC-NEXT: llvm.urem + // ALIGNED-ALLOC-NEXT: llvm.sub + // ALIGNED-ALLOC-NEXT: llvm.urem + // ALIGNED-ALLOC-NEXT: %[[SIZE_ALIGNED:.*]] = llvm.add + // ALIGNED-ALLOC-NEXT: llvm.call @aligned_alloc(%[[c32]], %[[SIZE_ALIGNED]]) + %5 = alloc() {alignment = 32} : memref<100xf32> + // Bump alignment to the next power of two if it isn't. + // ALIGNED-ALLOC: %[[c128:.*]] = llvm.mlir.constant(128 : i64) : !llvm.i64 + // ALIGNED-ALLOC: llvm.call @aligned_alloc(%[[c128]] + %6 = alloc(%N) : memref> + return %0 : memref<32x18xf32> +} + // CHECK-LABEL: func @mixed_load( // CHECK-COUNT-2: !llvm<"float*">, // CHECK-COUNT-5: {{%[a-zA-Z0-9]*}}: !llvm.i64