diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h @@ -21,8 +21,7 @@ /// Standard dialect to the LLVM dialect, excluding non-memory-related /// operations and FuncOp. void populateStdToLLVMMemoryConversionPatters( - LLVMTypeConverter &converter, OwningRewritePatternList &patterns, - bool useAlloca); + LLVMTypeConverter &converter, OwningRewritePatternList &patterns); /// Collect a set of patterns to convert from the Standard dialect to the LLVM /// dialect, excluding the memory-related operations. @@ -38,20 +37,16 @@ bool emitCWrappers = false); /// Collect a set of default patterns to convert from the Standard dialect to -/// LLVM. If `useAlloca` is set, the patterns for AllocOp and DeallocOp will -/// generate `llvm.alloca` instead of calls to "malloc". +/// LLVM. void populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter, OwningRewritePatternList &patterns, - bool useAlloca = false, bool emitCWrappers = false); /// Collect a set of patterns to convert from the Standard dialect to /// LLVM using the bare pointer calling convention for MemRef function -/// arguments. If `useAlloca` is set, the patterns for AllocOp and DeallocOp -/// will generate `llvm.alloca` instead of calls to "malloc". +/// arguments. void populateStdToLLVMBarePtrConversionPatterns( - LLVMTypeConverter &converter, OwningRewritePatternList &patterns, - bool useAlloca = false); + LLVMTypeConverter &converter, OwningRewritePatternList &patterns); /// Value to pass as bitwidth for the index type when the converter is expected /// to derive the bitwidth from the LLVM data layout. @@ -59,11 +54,8 @@ /// Creates a pass to convert the Standard dialect into the LLVMIR dialect. /// By default stdlib malloc/free are used for allocating MemRef payloads. -/// Specifying `useAlloca-true` emits stack allocations instead. In the future -/// this may become an enum when we have concrete uses for other options. std::unique_ptr> createLowerToLLVMPass( - bool useAlloca = false, bool useBarePtrCallConv = false, - bool emitCWrappers = false, + bool useBarePtrCallConv = false, bool emitCWrappers = false, unsigned indexBitwidth = kDeriveIndexBitwidthFromDataLayout); } // namespace mlir diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -124,6 +124,56 @@ ArithmeticOp, Arguments<(ins FloatLike:$lhs, FloatLike:$rhs)>; +// Base class for memref allocating ops: alloca and alloc. +// +// %0 = alloclike(%m)[%s] : memref<8x?xf32, (d0, d1)[s0] -> ((d0 + s0), d1)> +// +class AllocLikeOp traits = []> : + Std_Op { + + let arguments = (ins Variadic:$value, + Confined, [IntMinValue<0>]>:$alignment); + let results = (outs AnyMemRef); + + let builders = [OpBuilder< + "Builder *builder, OperationState &result, MemRefType memrefType", [{ + result.types.push_back(memrefType); + }]>, + OpBuilder< + "Builder *builder, OperationState &result, MemRefType memrefType, " # + "ValueRange operands, IntegerAttr alignment = IntegerAttr()", [{ + result.addOperands(operands); + result.types.push_back(memrefType); + if (alignment) + result.addAttribute(getAlignmentAttrName(), alignment); + }]>]; + + let extraClassDeclaration = [{ + static StringRef getAlignmentAttrName() { return "alignment"; } + + MemRefType getType() { return getResult().getType().cast(); } + + /// Returns the number of symbolic operands (the ones in square brackets), + /// which bind to the symbols of the memref's layout map. + unsigned getNumSymbolicOperands() { + return getNumOperands() - getType().getNumDynamicDims(); + } + + /// Returns the symbolic operands (the ones in square brackets), which bind + /// to the symbols of the memref's layout map. + operand_range getSymbolicOperands() { + return {operand_begin() + getType().getNumDynamicDims(), operand_end()}; + } + + /// Returns the dynamic sizes for this alloc operation if specified. + operand_range getDynamicSizes() { return getOperands(); } + }]; + + let parser = [{ return ::parseAllocLikeOp(parser, result); }]; + + let hasCanonicalizer = 1; +} + //===----------------------------------------------------------------------===// // AbsFOp //===----------------------------------------------------------------------===// @@ -225,7 +275,7 @@ // AllocOp //===----------------------------------------------------------------------===// -def AllocOp : Std_Op<"alloc"> { +def AllocOp : AllocLikeOp<"alloc"> { let summary = "memory allocation operation"; let description = [{ The `alloc` operation allocates a region of memory, as specified by its @@ -234,7 +284,7 @@ Example: ```mlir - %0 = alloc() : memref<8x64xf32, (d0, d1) -> (d0, d1), 1> + %0 = alloc() : memref<8x64xf32, 1> ``` The optional list of dimension operands are bound to the dynamic dimensions @@ -242,7 +292,7 @@ bound to the second dimension of the memref (which is dynamic). ```mlir - %0 = alloc(%d) : memref<8x?xf32, (d0, d1) -> (d0, d1), 1> + %0 = alloc(%d) : memref<8x?xf32, 1> ``` The optional list of symbol operands are bound to the symbols of the @@ -250,7 +300,8 @@ the symbol 's0' in the affine map specified in the allocs memref type. ```mlir - %0 = alloc()[%s] : memref<8x64xf32, (d0, d1)[s0] -> ((d0 + s0), d1), 1> + %0 = alloc()[%s] : memref<8x64xf32, + affine_map<(d0, d1)[s0] -> ((d0 + s0), %d1)>, 1> ``` This operation returns a single ssa value of memref type, which can be used @@ -262,49 +313,49 @@ ```mlir %0 = alloc()[%s] {alignment = 8} : - memref<8x64xf32, (d0, d1)[s0] -> ((d0 + s0), d1), 1> + memref<8x64xf32, affine_map<(d0, d1)[s0] -> ((d0 + s0), d1)>, 1> ``` }]; +} - let arguments = (ins Variadic:$value, - Confined, [IntMinValue<0>]>:$alignment); - let results = (outs AnyMemRef); +//===----------------------------------------------------------------------===// +// AllocaOp +//===----------------------------------------------------------------------===// - let builders = [OpBuilder< - "Builder *builder, OperationState &result, MemRefType memrefType", [{ - result.types.push_back(memrefType); - }]>, - OpBuilder< - "Builder *builder, OperationState &result, MemRefType memrefType, " # - "ArrayRef operands, IntegerAttr alignment = IntegerAttr()", [{ - result.addOperands(operands); - result.types.push_back(memrefType); - if (alignment) - result.addAttribute(getAlignmentAttrName(), alignment); - }]>]; +def AllocaOp : AllocLikeOp<"alloca"> { + let summary = "stack memory allocation operation"; + let description = [{ + The "alloca" operation allocates memory on the stack, to be automatically + released when the stack frame is discarded. The amount of memory allocated + is specified by its memref and additional operands. For example: - let extraClassDeclaration = [{ - static StringRef getAlignmentAttrName() { return "alignment"; } + ```mlir + %0 = alloca() : memref<8x64xf32> + ``` - MemRefType getType() { return getResult().getType().cast(); } + The optional list of dimension operands are bound to the dynamic dimensions + specified in its memref type. In the example below, the SSA value '%d' is + bound to the second dimension of the memref (which is dynamic). - /// Returns the number of symbolic operands (the ones in square brackets), - /// which bind to the symbols of the memref's layout map. - unsigned getNumSymbolicOperands() { - return getNumOperands() - getType().getNumDynamicDims(); - } + ```mlir + %0 = alloca(%d) : memref<8x?xf32> + ``` - /// Returns the symbolic operands (the ones in square brackets), which bind - /// to the symbols of the memref's layout map. - operand_range getSymbolicOperands() { - return {operand_begin() + getType().getNumDynamicDims(), operand_end()}; - } + The optional list of symbol operands are bound to the symbols of the + memref's affine map. In the example below, the SSA value '%s' is bound to + the symbol 's0' in the affine map specified in the allocs memref type. - /// Returns the dynamic sizes for this alloc operation if specified. - operand_range getDynamicSizes() { return getOperands(); } - }]; + ```mlir + %0 = alloca()[%s] : memref<8x64xf32, + affine_map<(d0, d1)[s0] -> ((d0 + s0), d1)>> + ``` - let hasCanonicalizer = 1; + This operation returns a single SSA value of memref type, which can be used + by subsequent load and store operations. An optional alignment attribute, if + specified, guarantees alignment at least to that boundary. If not specified, + an alignment on any convenient boundary compatible with the type will be + chosen. + }]; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -11,8 +11,8 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/ADT/TypeSwitch.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" @@ -1238,32 +1238,27 @@ [](AffineMap map) { return map.isIdentity(); }); } -// An `alloc` is converted into a definition of a memref descriptor value and -// a call to `malloc` to allocate the underlying data buffer. The memref -// descriptor is of the LLVM structure type where: -// 1. the first element is a pointer to the allocated (typed) data buffer, -// 2. the second element is a pointer to the (typed) payload, aligned to the -// specified alignment, -// 3. the remaining elements serve to store all the sizes and strides of the -// memref using LLVM-converted `index` type. -// -// Alignment is obtained by allocating `alignment - 1` more bytes than requested -// and shifting the aligned pointer relative to the allocated memory. If -// alignment is unspecified, the two pointers are equal. -struct AllocOpLowering : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; +/// Lowering for AllocOp or AllocaOp. +template +struct AllocLikeOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + using Base = AllocLikeOpLowering; + using ConvertOpToLLVMPattern::createIndexConstant; + using ConvertOpToLLVMPattern::getIndexType; + using ConvertOpToLLVMPattern::typeConverter; + using ConvertOpToLLVMPattern::getVoidPtrType; - explicit AllocOpLowering(LLVMTypeConverter &converter, bool useAlloca = false) - : ConvertOpToLLVMPattern(converter), useAlloca(useAlloca) {} + explicit AllocLikeOpLowering(LLVMTypeConverter &converter) + : ConvertOpToLLVMPattern(converter) {} LogicalResult match(Operation *op) const override { - MemRefType type = cast(op).getType(); - if (isSupportedMemRefType(type)) + MemRefType memRefType = cast(op).getType(); + if (isSupportedMemRefType(memRefType)) return success(); int64_t offset; SmallVector strides; - auto successStrides = getStridesAndOffset(type, strides, offset); + auto successStrides = getStridesAndOffset(memRefType, strides, offset); if (failed(successStrides)) return failure(); @@ -1276,138 +1271,48 @@ return success(); } - void rewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - auto loc = op->getLoc(); - auto allocOp = cast(op); - MemRefType type = allocOp.getType(); - - // Get actual sizes of the memref as values: static sizes are constant - // values and dynamic sizes are passed to 'alloc' as operands. In case of - // zero-dimensional memref, assume a scalar (size 1). - SmallVector sizes; - sizes.reserve(type.getRank()); - unsigned i = 0; - for (int64_t s : type.getShape()) - sizes.push_back(s == -1 ? operands[i++] - : createIndexConstant(rewriter, loc, s)); - if (sizes.empty()) - sizes.push_back(createIndexConstant(rewriter, loc, 1)); - - // Compute the total number of memref elements. - Value cumulativeSize = sizes.front(); - for (unsigned i = 1, e = sizes.size(); i < e; ++i) - cumulativeSize = rewriter.create( - loc, getIndexType(), ArrayRef{cumulativeSize, sizes[i]}); - - // Compute the size of an individual element. This emits the MLIR equivalent - // of the following sizeof(...) implementation in LLVM IR: - // %0 = getelementptr %elementType* null, %indexType 1 - // %1 = ptrtoint %elementType* %0 to %indexType - // which is a common pattern of getting the size of a type in bytes. - auto elementType = type.getElementType(); - auto convertedPtrType = typeConverter.convertType(elementType) - .cast() - .getPointerTo(); - auto nullPtr = rewriter.create(loc, convertedPtrType); - auto one = createIndexConstant(rewriter, loc, 1); - auto gep = rewriter.create(loc, convertedPtrType, - ArrayRef{nullPtr, one}); - auto elementSize = - rewriter.create(loc, getIndexType(), gep); - cumulativeSize = rewriter.create( - loc, getIndexType(), ArrayRef{cumulativeSize, elementSize}); - - // Allocate the underlying buffer and store a pointer to it in the MemRef - // descriptor. - Value allocated = nullptr; - int alignment = 0; - Value alignmentValue = nullptr; - if (auto alignAttr = allocOp.alignment()) - alignment = alignAttr.getValue().getSExtValue(); - - if (useAlloca) { - allocated = rewriter.create(loc, getVoidPtrType(), - cumulativeSize, alignment); - } else { - // Insert the `malloc` declaration if it is not already present. - auto module = op->getParentOfType(); - auto mallocFunc = module.lookupSymbol("malloc"); - if (!mallocFunc) { - OpBuilder moduleBuilder( - op->getParentOfType().getBodyRegion()); - mallocFunc = moduleBuilder.create( - rewriter.getUnknownLoc(), "malloc", - LLVM::LLVMType::getFunctionTy(getVoidPtrType(), getIndexType(), - /*isVarArg=*/false)); - } - if (alignment != 0) { - alignmentValue = createIndexConstant(rewriter, loc, alignment); - cumulativeSize = rewriter.create( - loc, - rewriter.create(loc, cumulativeSize, alignmentValue), - one); - } - allocated = rewriter - .create( - loc, getVoidPtrType(), - rewriter.getSymbolRefAttr(mallocFunc), cumulativeSize) - .getResult(0); - } - - auto structElementType = typeConverter.convertType(elementType); - auto elementPtrType = structElementType.cast().getPointerTo( - type.getMemorySpace()); - Value bitcastAllocated = rewriter.create( - loc, elementPtrType, ArrayRef(allocated)); - - int64_t offset; - SmallVector strides; - auto successStrides = getStridesAndOffset(type, strides, offset); - assert(succeeded(successStrides) && "unexpected non-strided memref"); - (void)successStrides; - assert(offset != MemRefType::getDynamicStrideOrOffset() && - "unexpected dynamic offset"); - - // 0-D memref corner case: they have size 1 ... - assert(((type.getRank() == 0 && strides.empty() && sizes.size() == 1) || - (strides.size() == sizes.size())) && - "unexpected number of strides"); - - // Create the MemRef descriptor. - auto structType = typeConverter.convertType(type); + /// Creates and populates the memref descriptor struct given all its fields. + /// This method also performs the any post allocation alignment needed for + /// heap allocations. + MemRefDescriptor createMemRefDescriptor( + Location loc, ConversionPatternRewriter &rewriter, MemRefType memRefType, + Value allocatedTypePtr, Value allocatedBytePtr, Value accessAlignment, + uint64_t offset, ArrayRef strides, ArrayRef sizes) const { + auto elementPtrType = getElementPtrType(memRefType); + auto structType = typeConverter.convertType(memRefType); auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType); + // Field 1: Allocated pointer, used for malloc/free. - memRefDescriptor.setAllocatedPtr(rewriter, loc, bitcastAllocated); + memRefDescriptor.setAllocatedPtr(rewriter, loc, allocatedTypePtr); // Field 2: Actual aligned pointer to payload. - Value bitcastAligned = bitcastAllocated; - if (!useAlloca && alignment != 0) { - assert(alignmentValue); + Value alignedBytePtr = allocatedTypePtr; + if (accessAlignment) { // offset = (align - (ptr % align))% align Value intVal = rewriter.create( - loc, this->getIndexType(), allocated); + loc, this->getIndexType(), allocatedBytePtr); Value ptrModAlign = - rewriter.create(loc, intVal, alignmentValue); + rewriter.create(loc, intVal, accessAlignment); Value subbed = - rewriter.create(loc, alignmentValue, ptrModAlign); - Value offset = rewriter.create(loc, subbed, alignmentValue); - Value aligned = rewriter.create(loc, allocated.getType(), - allocated, offset); - bitcastAligned = rewriter.create( + rewriter.create(loc, accessAlignment, ptrModAlign); + Value offset = + rewriter.create(loc, subbed, accessAlignment); + Value aligned = rewriter.create( + loc, allocatedBytePtr.getType(), allocatedBytePtr, offset); + alignedBytePtr = rewriter.create( loc, elementPtrType, ArrayRef(aligned)); } - memRefDescriptor.setAlignedPtr(rewriter, loc, bitcastAligned); + memRefDescriptor.setAlignedPtr(rewriter, loc, alignedBytePtr); // Field 3: Offset in aligned pointer. memRefDescriptor.setOffset(rewriter, loc, createIndexConstant(rewriter, loc, offset)); - if (type.getRank() == 0) + if (memRefType.getRank() == 0) // No size/stride descriptor in memref, return the descriptor value. - return rewriter.replaceOp(op, {memRefDescriptor}); + return memRefDescriptor; - // Fields 4 and 5: Sizes and strides of the strided MemRef. + // Fields 4 and 5: sizes and strides of the strided MemRef. // Store all sizes in the descriptor. Only dynamic sizes are passed in as // operands to AllocOp. Value runningStride = nullptr; @@ -1433,12 +1338,184 @@ memRefDescriptor.setSize(rewriter, loc, index, indexedSize.value()); memRefDescriptor.setStride(rewriter, loc, index, strideValues[index]); } + return memRefDescriptor; + } + + /// Determines sizes to be used in the memref descriptor. + void getSizes(Location loc, MemRefType memRefType, ArrayRef operands, + ConversionPatternRewriter &rewriter, + SmallVectorImpl &sizes, Value &cumulativeSize, + Value &one) const { + sizes.reserve(memRefType.getRank()); + unsigned i = 0; + for (int64_t s : memRefType.getShape()) + sizes.push_back(s == -1 ? operands[i++] + : createIndexConstant(rewriter, loc, s)); + if (sizes.empty()) + sizes.push_back(createIndexConstant(rewriter, loc, 1)); + + // Compute the total number of memref elements. + cumulativeSize = sizes.front(); + for (unsigned i = 1, e = sizes.size(); i < e; ++i) + cumulativeSize = rewriter.create( + loc, getIndexType(), ArrayRef{cumulativeSize, sizes[i]}); + + // Compute the size of an individual element. This emits the MLIR equivalent + // of the following sizeof(...) implementation in LLVM IR: + // %0 = getelementptr %elementType* null, %indexType 1 + // %1 = ptrtoint %elementType* %0 to %indexType + // which is a common pattern of getting the size of a type in bytes. + auto elementType = memRefType.getElementType(); + auto convertedPtrType = typeConverter.convertType(elementType) + .template cast() + .getPointerTo(); + auto nullPtr = rewriter.create(loc, convertedPtrType); + one = createIndexConstant(rewriter, loc, 1); + auto gep = rewriter.create(loc, convertedPtrType, + ArrayRef{nullPtr, one}); + auto elementSize = + rewriter.create(loc, getIndexType(), gep); + cumulativeSize = rewriter.create( + loc, getIndexType(), ArrayRef{cumulativeSize, elementSize}); + } + + /// Returns the type of a pointer to an element of the memref. + Type getElementPtrType(MemRefType memRefType) const { + auto elementType = memRefType.getElementType(); + auto structElementType = typeConverter.convertType(elementType); + return structElementType.template cast().getPointerTo( + memRefType.getMemorySpace()); + } + + /// Allocates the underlying buffer using the right call. `allocatedBytePtr` + /// is set to null for stack allocations. `accessAlignment` is set if + /// alignment is neeeded post allocation (for eg. in conjunction with malloc). + /// TODO(bondhugula): next revision will support std lib func aligned_alloc. + Value allocateBuffer(Location loc, Value cumulativeSize, Operation *op, + MemRefType memRefType, Value one, Value &accessAlignment, + Value &allocatedBytePtr, + ConversionPatternRewriter &rewriter) const { + auto elementPtrType = getElementPtrType(memRefType); + + // Whether to use std lib function aligned_alloc that supports alignment. + Optional allocationAlignment = cast(op).alignment(); + + // With alloca, one gets a pointer to the element type right away. + bool onStack = isa(op); + if (onStack) { + allocatedBytePtr = nullptr; + accessAlignment = nullptr; + return rewriter.create( + loc, elementPtrType, cumulativeSize, + allocationAlignment ? allocationAlignment.getValue().getSExtValue() + : 0); + } + + // Use malloc. Insert the malloc declaration if it is not already present. + auto allocFuncName = "malloc"; + AllocOp allocOp = cast(op); + auto module = allocOp.getParentOfType(); + auto allocFunc = module.lookupSymbol(allocFuncName); + if (!allocFunc) { + OpBuilder moduleBuilder(op->getParentOfType().getBodyRegion()); + SmallVector callArgTypes = {getIndexType()}; + allocFunc = moduleBuilder.create( + rewriter.getUnknownLoc(), allocFuncName, + LLVM::LLVMType::getFunctionTy(getVoidPtrType(), callArgTypes, + /*isVarArg=*/false)); + } + + // Allocate the underlying buffer and store a pointer to it in the MemRef + // descriptor. + SmallVector callArgs; + // Adjust the allocation size to consider alignment. + if (allocOp.alignment()) { + accessAlignment = createIndexConstant( + rewriter, loc, allocOp.alignment().getValue().getSExtValue()); + cumulativeSize = rewriter.create( + loc, + rewriter.create(loc, cumulativeSize, accessAlignment), + one); + } + callArgs.push_back(cumulativeSize); + auto allocFuncSymbol = rewriter.getSymbolRefAttr(allocFunc); + allocatedBytePtr = rewriter + .create(loc, getVoidPtrType(), + allocFuncSymbol, callArgs) + .getResult(0); + // For heap allocations, the allocated pointer is a cast of the byte pointer + // to the type pointer. + return rewriter.create(loc, elementPtrType, + allocatedBytePtr); + } + + // An `alloc` is converted into a definition of a memref descriptor value and + // a call to `malloc` to allocate the underlying data buffer. The memref + // descriptor is of the LLVM structure type where: + // 1. the first element is a pointer to the allocated (typed) data buffer, + // 2. the second element is a pointer to the (typed) payload, aligned to the + // specified alignment, + // 3. the remaining elements serve to store all the sizes and strides of the + // memref using LLVM-converted `index` type. + // + // Alignment is performed by allocating `alignment - 1` more bytes than + // requested and shifting the aligned pointer relative to the allocated + // memory. If alignment is unspecified, the two pointers are equal. + + // An `alloca` is converted into a definition of a memref descriptor value and + // an llvm.alloca to allocate the underlying data buffer. + void rewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + MemRefType memRefType = cast(op).getType(); + auto loc = op->getLoc(); + + // Get actual sizes of the memref as values: static sizes are constant + // values and dynamic sizes are passed to 'alloc' as operands. In case of + // zero-dimensional memref, assume a scalar (size 1). + SmallVector sizes; + Value cumulativeSize, one; + getSizes(loc, memRefType, operands, rewriter, sizes, cumulativeSize, one); + + // Allocate the underlying buffer. + // Value holding the alignment that has to be performed post allocation + // (in conjunction with allocators that do not support alignment, eg. + // malloc); nullptr if no such adjustment needs to be performed. + Value accessAlignment; + // Byte pointer to the allocated buffer. + Value allocatedBytePtr; + Value allocatedTypePtr = + allocateBuffer(loc, cumulativeSize, op, memRefType, one, + accessAlignment, allocatedBytePtr, rewriter); + + int64_t offset; + SmallVector strides; + auto successStrides = getStridesAndOffset(memRefType, strides, offset); + (void)successStrides; + assert(succeeded(successStrides) && "unexpected non-strided memref"); + assert(offset != MemRefType::getDynamicStrideOrOffset() && + "unexpected dynamic offset"); + + // 0-D memref corner case: they have size 1. + assert( + ((memRefType.getRank() == 0 && strides.empty() && sizes.size() == 1) || + (strides.size() == sizes.size())) && + "unexpected number of strides"); + + // Create the MemRef descriptor. + auto memRefDescriptor = createMemRefDescriptor( + loc, rewriter, memRefType, allocatedTypePtr, allocatedBytePtr, + accessAlignment, offset, strides, sizes); // Return the final value of the descriptor. rewriter.replaceOp(op, {memRefDescriptor}); } +}; - bool useAlloca; +struct AllocOpLowering : public AllocLikeOpLowering { + using Base::Base; +}; +struct AllocaOpLowering : public AllocLikeOpLowering { + using Base::Base; }; // A CallOp automatically promotes MemRefType to a sequence of alloca/store and @@ -1517,16 +1594,12 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - explicit DeallocOpLowering(LLVMTypeConverter &converter, - bool useAlloca = false) - : ConvertOpToLLVMPattern(converter), useAlloca(useAlloca) {} + explicit DeallocOpLowering(LLVMTypeConverter &converter) + : ConvertOpToLLVMPattern(converter) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - if (useAlloca) - return rewriter.eraseOp(op), success(); - assert(operands.size() == 1 && "dealloc takes one operand"); OperandAdaptor transformed(operands); @@ -1549,8 +1622,6 @@ op, ArrayRef(), rewriter.getSymbolRefAttr(freeFunc), casted); return success(); } - - bool useAlloca; }; // A `rsqrt` is converted into `1 / sqrt`. @@ -2613,6 +2684,7 @@ AbsFOpLowering, AddFOpLowering, AddIOpLowering, + AllocaOpLowering, AndOpLowering, AtomicCmpXchgOpLowering, AtomicRMWOpLowering, @@ -2666,8 +2738,7 @@ } void mlir::populateStdToLLVMMemoryConversionPatters( - LLVMTypeConverter &converter, OwningRewritePatternList &patterns, - bool useAlloca) { + LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { // clang-format off patterns.insert< AssumeAlignmentOpLowering, @@ -2679,7 +2750,7 @@ ViewOpLowering>(converter); patterns.insert< AllocOpLowering, - DeallocOpLowering>(converter, useAlloca); + DeallocOpLowering>(converter); // clang-format on } @@ -2691,11 +2762,11 @@ void mlir::populateStdToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns, - bool useAlloca, bool emitCWrappers) { + bool emitCWrappers) { populateStdToLLVMDefaultFuncOpConversionPattern(converter, patterns, emitCWrappers); populateStdToLLVMNonMemoryConversionPatterns(converter, patterns); - populateStdToLLVMMemoryConversionPatters(converter, patterns, useAlloca); + populateStdToLLVMMemoryConversionPatters(converter, patterns); } static void populateStdToLLVMBarePtrFuncOpConversionPattern( @@ -2704,11 +2775,10 @@ } void mlir::populateStdToLLVMBarePtrConversionPatterns( - LLVMTypeConverter &converter, OwningRewritePatternList &patterns, - bool useAlloca) { + LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { populateStdToLLVMBarePtrFuncOpConversionPattern(converter, patterns); populateStdToLLVMNonMemoryConversionPatterns(converter, patterns); - populateStdToLLVMMemoryConversionPatters(converter, patterns, useAlloca); + populateStdToLLVMMemoryConversionPatters(converter, patterns); } // Create an LLVM IR structure type if there is more than one result. @@ -2782,9 +2852,8 @@ #include "mlir/Conversion/Passes.h.inc" /// Creates an LLVM lowering pass. - LLVMLoweringPass(bool useAlloca, bool useBarePtrCallConv, bool emitCWrappers, + LLVMLoweringPass(bool useBarePtrCallConv, bool emitCWrappers, unsigned indexBitwidth) { - this->useAlloca = useAlloca; this->useBarePtrCallConv = useBarePtrCallConv; this->emitCWrappers = emitCWrappers; this->indexBitwidth = indexBitwidth; @@ -2812,10 +2881,9 @@ OwningRewritePatternList patterns; if (useBarePtrCallConv) - populateStdToLLVMBarePtrConversionPatterns(typeConverter, patterns, - useAlloca); + populateStdToLLVMBarePtrConversionPatterns(typeConverter, patterns); else - populateStdToLLVMConversionPatterns(typeConverter, patterns, useAlloca, + populateStdToLLVMConversionPatterns(typeConverter, patterns, emitCWrappers); LLVMConversionTarget target(getContext()); @@ -2833,8 +2901,8 @@ } std::unique_ptr> -mlir::createLowerToLLVMPass(bool useAlloca, bool useBarePtrCallConv, - bool emitCWrappers, unsigned indexBitwidth) { - return std::make_unique(useAlloca, useBarePtrCallConv, - emitCWrappers, indexBitwidth); +mlir::createLowerToLLVMPass(bool useBarePtrCallConv, bool emitCWrappers, + unsigned indexBitwidth) { + return std::make_unique(useBarePtrCallConv, emitCWrappers, + indexBitwidth); } diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -242,11 +242,14 @@ } //===----------------------------------------------------------------------===// -// AllocOp +// AllocOp / AllocaOp //===----------------------------------------------------------------------===// -static void print(OpAsmPrinter &p, AllocOp op) { - p << "alloc"; +template +static void printAllocLikeOp(OpAsmPrinter &p, AllocLikeOp op, StringRef name) { + static_assert(llvm::is_one_of::value, + "applies to only alloc or alloca"); + p << name; // Print dynamic dimension operands. MemRefType type = op.getType(); @@ -256,7 +259,16 @@ p << " : " << type; } -static ParseResult parseAllocOp(OpAsmParser &parser, OperationState &result) { +static void print(OpAsmPrinter &p, AllocOp op) { + printAllocLikeOp(p, op, "alloc"); +} + +static void print(OpAsmPrinter &p, AllocaOp op) { + printAllocLikeOp(p, op, "alloca"); +} + +static ParseResult parseAllocLikeOp(OpAsmParser &parser, + OperationState &result) { MemRefType type; // Parse the dimension operands and optional symbol operands, followed by a @@ -281,8 +293,12 @@ return success(); } -static LogicalResult verify(AllocOp op) { - auto memRefType = op.getResult().getType().dyn_cast(); +template +static LogicalResult verify(AllocLikeOp op) { + static_assert(std::is_same::value || + std::is_same::value, + "applies to only alloc or alloca"); + auto memRefType = op.getResult().getType().template dyn_cast(); if (!memRefType) return op.emitOpError("result must be a memref"); @@ -309,11 +325,12 @@ } namespace { -/// Fold constant dimensions into an alloc operation. -struct SimplifyAllocConst : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +/// Fold constant dimensions into an alloc like operation. +template +struct SimplifyAllocConst : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AllocOp alloc, + LogicalResult matchAndRewrite(AllocLikeOp alloc, PatternRewriter &rewriter) const override { // Check to see if any dimensions operands are constants. If so, we can // substitute and drop them. @@ -357,8 +374,8 @@ newMemRefType.getNumDynamicDims()); // Create and insert the alloc op for the new memref. - auto newAlloc = rewriter.create(alloc.getLoc(), newMemRefType, - newOperands, IntegerAttr()); + auto newAlloc = rewriter.create(alloc.getLoc(), newMemRefType, + newOperands, IntegerAttr()); // Insert a cast so we have the same type as the old alloc. auto resultCast = rewriter.create(alloc.getLoc(), newAlloc, alloc.getType()); @@ -386,7 +403,12 @@ void AllocOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); + results.insert, SimplifyDeadAlloc>(context); +} + +void AllocaOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert>(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir @@ -93,6 +93,42 @@ return %0 : memref } +// ----- + +// CHECK-LABEL: func @dynamic_alloca +// CHECK: %[[M:.*]]: !llvm.i64, %[[N:.*]]: !llvm.i64) -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> { +func @dynamic_alloca(%arg0: index, %arg1: index) -> memref { +// CHECK: %[[num_elems:.*]] = llvm.mul %[[M]], %[[N]] : !llvm.i64 +// CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*"> +// CHECK-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> +// CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64 +// CHECK-NEXT: %[[sz_bytes:.*]] = llvm.mul %[[num_elems]], %[[sizeof]] : !llvm.i64 +// CHECK-NEXT: %[[allocated:.*]] = llvm.alloca %[[sz_bytes]] x !llvm.float : (!llvm.i64) -> !llvm<"float*"> +// CHECK-NEXT: llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK-NEXT: llvm.insertvalue %[[allocated]], %{{.*}}[0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK-NEXT: llvm.insertvalue %[[allocated]], %{{.*}}[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK-NEXT: llvm.insertvalue %[[off]], %{{.*}}[2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK-NEXT: %[[st0:.*]] = llvm.mul %{{.*}}, %[[N]] : !llvm.i64 +// CHECK-NEXT: llvm.insertvalue %[[M]], %{{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK-NEXT: llvm.insertvalue %[[st0]], %{{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK-NEXT: llvm.insertvalue %[[N]], %{{.*}}[3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK-NEXT: llvm.insertvalue %[[st1]], %{{.*}}[4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + %0 = alloca(%arg0, %arg1) : memref + +// Test with explicitly specified alignment. llvm.alloca takes care of the +// alignment. The same pointer is thus used for allocation and aligned +// accesses. +// CHECK: %[[alloca_aligned:.*]] = llvm.alloca %{{.*}} x !llvm.float {alignment = 32 : i64} : (!llvm.i64) -> !llvm<"float*"> +// CHECK: %[[desc:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK: %[[desc1:.*]] = llvm.insertvalue %[[alloca_aligned]], %[[desc]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK: llvm.insertvalue %[[alloca_aligned]], %[[desc1]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + alloca(%arg0, %arg1) {alignment = 32} : memref + return %0 : memref +} + // CHECK-LABEL: func @dynamic_dealloc func @dynamic_dealloc(%arg0: memref) { // CHECK: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> diff --git a/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir @@ -1,5 +1,4 @@ // RUN: mlir-opt -convert-std-to-llvm %s | FileCheck %s -// RUN: mlir-opt -convert-std-to-llvm='use-alloca=1' %s | FileCheck %s --check-prefix=ALLOCA // RUN: mlir-opt -convert-std-to-llvm='use-bare-ptr-memref-call-conv=1' -split-input-file %s | FileCheck %s --check-prefix=BAREPTR // BAREPTR-LABEL: func @check_noalias @@ -67,7 +66,6 @@ // ----- // CHECK-LABEL: func @zero_d_alloc() -> !llvm<"{ float*, float*, i64 }"> { -// ALLOCA-LABEL: func @zero_d_alloc() -> !llvm<"{ float*, float*, i64 }"> { // BAREPTR-LABEL: func @zero_d_alloc() -> !llvm<"{ float*, float*, i64 }"> { func @zero_d_alloc() -> memref { // CHECK-NEXT: llvm.mlir.constant(1 : index) : !llvm.i64 @@ -84,10 +82,6 @@ // CHECK-NEXT: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // CHECK-NEXT: llvm.insertvalue %[[c0]], %{{.*}}[2] : !llvm<"{ float*, float*, i64 }"> -// ALLOCA-NOT: malloc -// ALLOCA: alloca -// ALLOCA-NOT: malloc - // BAREPTR-NEXT: llvm.mlir.constant(1 : index) : !llvm.i64 // BAREPTR-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*"> // BAREPTR-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 @@ -207,6 +201,32 @@ // ----- +// CHECK-LABEL: func @static_alloca() -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> { +func @static_alloca() -> memref<32x18xf32> { +// CHECK-NEXT: %[[sz1:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64 +// CHECK-NEXT: %[[sz2:.*]] = llvm.mlir.constant(18 : index) : !llvm.i64 +// CHECK-NEXT: %[[num_elems:.*]] = llvm.mul %0, %1 : !llvm.i64 +// CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*"> +// CHECK-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> +// CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64 +// CHECK-NEXT: %[[bytes:.*]] = llvm.mul %[[num_elems]], %[[sizeof]] : !llvm.i64 +// CHECK-NEXT: %[[allocated:.*]] = llvm.alloca %[[bytes]] x !llvm.float : (!llvm.i64) -> !llvm<"float*"> + %0 = alloca() : memref<32x18xf32> + + // Test with explicitly specified alignment. llvm.alloca takes care of the + // alignment. The same pointer is thus used for allocation and aligned + // accesses. + // CHECK: %[[alloca_aligned:.*]] = llvm.alloca %{{.*}} x !llvm.float {alignment = 32 : i64} : (!llvm.i64) -> !llvm<"float*"> + // CHECK: %[[desc:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: %[[desc1:.*]] = llvm.insertvalue %[[alloca_aligned]], %[[desc]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: llvm.insertvalue %[[alloca_aligned]], %[[desc1]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + alloca() {alignment = 32} : memref<32x18xf32> + return %0 : memref<32x18xf32> +} + +// ----- + // CHECK-LABEL: func @static_dealloc // BAREPTR-LABEL: func @static_dealloc(%{{.*}}: !llvm<"float*">) { func @static_dealloc(%static: memref<10x8xf32>) { diff --git a/mlir/test/IR/memory-ops.mlir b/mlir/test/IR/memory-ops.mlir --- a/mlir/test/IR/memory-ops.mlir +++ b/mlir/test/IR/memory-ops.mlir @@ -33,6 +33,35 @@ return } +// CHECK-LABEL: func @alloca() { +func @alloca() { +^bb0: + // Test simple alloc. + // CHECK: %0 = alloca() : memref<1024x64xf32, 1> + %0 = alloca() : memref<1024x64xf32, affine_map<(d0, d1) -> (d0, d1)>, 1> + + %c0 = "std.constant"() {value = 0: index} : () -> index + %c1 = "std.constant"() {value = 1: index} : () -> index + + // Test alloca with dynamic dimensions. + // CHECK: %1 = alloca(%c0, %c1) : memref + %1 = alloca(%c0, %c1) : memref (d0, d1)>, 1> + + // Test alloca with no dynamic dimensions and one symbol. + // CHECK: %2 = alloca()[%c0] : memref<2x4xf32, #map0, 1> + %2 = alloca()[%c0] : memref<2x4xf32, affine_map<(d0, d1)[s0] -> ((d0 + s0), d1)>, 1> + + // Test alloca with dynamic dimensions and one symbol. + // CHECK: %3 = alloca(%c1)[%c0] : memref<2x?xf32, #map0, 1> + %3 = alloca(%c1)[%c0] : memref<2x?xf32, affine_map<(d0, d1)[s0] -> (d0 + s0, d1)>, 1> + + // Alloca with no mappings, but with alignment. + // CHECK: %4 = alloca() {alignment = 64 : i64} : memref<2xi32> + %4 = alloca() {alignment = 64} : memref<2 x i32> + + return +} + // CHECK-LABEL: func @dealloc() { func @dealloc() { ^bb0: diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -392,15 +392,21 @@ %N = constant 1024 : index %K = constant 512 : index - // CHECK-NEXT: %0 = alloc(%arg0) : memref + // CHECK-NEXT: alloc(%arg0) : memref %a = alloc(%L, %N) : memref - // CHECK-NEXT: %1 = alloc(%arg1) : memref<4x1024x8x512x?xf32> + // CHECK-NEXT: alloc(%arg1) : memref<4x1024x8x512x?xf32> %b = alloc(%N, %K, %M) : memref<4 x ? x 8 x ? x ? x f32> - // CHECK-NEXT: %2 = alloc() : memref<512x1024xi32> + // CHECK-NEXT: alloc() : memref<512x1024xi32> %c = alloc(%K, %N) : memref + // CHECK: alloc() : memref<9x9xf32> + %d = alloc(%nine, %nine) : memref + + // CHECK: alloca(%arg1) : memref<4x1024x8x512x?xf32> + %e = alloca(%N, %K, %M) : memref<4 x ? x 8 x ? x ? x f32> + // CHECK: affine.for affine.for %i = 0 to %L { // CHECK-NEXT: affine.for @@ -412,9 +418,6 @@ } } - // CHECK: alloc() : memref<9x9xf32> - %d = alloc(%nine, %nine) : memref - return %c, %d : memref, memref }