diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -606,6 +606,59 @@ using ConvertToLLVMPattern::matchAndRewrite; }; +/// Lowering for AllocOp and AllocaOp. +struct AllocLikeOpLLVMLowering : public ConvertToLLVMPattern { + using ConvertToLLVMPattern::createIndexConstant; + using ConvertToLLVMPattern::getIndexType; + using ConvertToLLVMPattern::getVoidPtrType; + + explicit AllocLikeOpLLVMLowering(StringRef opName, + LLVMTypeConverter &converter) + : ConvertToLLVMPattern(opName, &converter.getContext(), converter) {} + +protected: + // Returns 'input' aligned up to 'alignment'. Computes + // bumped = input + alignement - 1 + // aligned = bumped - bumped % alignment + static Value createAligned(ConversionPatternRewriter &rewriter, Location loc, + Value input, Value alignment); + + /// Allocates the underlying buffer. Returns the allocated pointer and the + /// aligned pointer. + virtual std::tuple + allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, + Value sizeBytes, Operation *op) const = 0; + +private: + static MemRefType getMemRefResultType(Operation *op) { + return op->getResult(0).getType().cast(); + } + + LogicalResult match(Operation *op) const override { + MemRefType memRefType = getMemRefResultType(op); + return success(isConvertibleAndHasIdentityMaps(memRefType)); + } + + // An `alloc` is converted into a definition of a memref descriptor value and + // a call to `malloc` to allocate the underlying data buffer. The memref + // descriptor is of the LLVM structure type where: + // 1. the first element is a pointer to the allocated (typed) data buffer, + // 2. the second element is a pointer to the (typed) payload, aligned to the + // specified alignment, + // 3. the remaining elements serve to store all the sizes and strides of the + // memref using LLVM-converted `index` type. + // + // Alignment is performed by allocating `alignment` more bytes than + // requested and shifting the aligned pointer relative to the allocated + // memory. Note: `alignment - ` would actually be + // sufficient. If alignment is unspecified, the two pointers are equal. + + // An `alloca` is converted into a definition of a memref descriptor value and + // an llvm.alloca to allocate the underlying data buffer. + void rewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; + namespace LLVM { namespace detail { /// Replaces the given operation "op" with a new operation of type "targetOp" diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -1831,92 +1831,10 @@ } }; -/// Lowering for AllocOp and AllocaOp. -struct AllocLikeOpLowering : public ConvertToLLVMPattern { - using ConvertToLLVMPattern::createIndexConstant; - using ConvertToLLVMPattern::getIndexType; - using ConvertToLLVMPattern::getVoidPtrType; - - explicit AllocLikeOpLowering(StringRef opName, LLVMTypeConverter &converter) - : ConvertToLLVMPattern(opName, &converter.getContext(), converter) {} - -protected: - // Returns 'input' aligned up to 'alignment'. Computes - // bumped = input + alignement - 1 - // aligned = bumped - bumped % alignment - static Value createAligned(ConversionPatternRewriter &rewriter, Location loc, - Value input, Value alignment) { - Value one = createIndexAttrConstant(rewriter, loc, alignment.getType(), 1); - Value bump = rewriter.create(loc, alignment, one); - Value bumped = rewriter.create(loc, input, bump); - Value mod = rewriter.create(loc, bumped, alignment); - return rewriter.create(loc, bumped, mod); - } - - /// Allocates the underlying buffer. Returns the allocated pointer and the - /// aligned pointer. - virtual std::tuple - allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, - Value sizeBytes, Operation *op) const = 0; - -private: - static MemRefType getMemRefResultType(Operation *op) { - return op->getResult(0).getType().cast(); - } - - LogicalResult match(Operation *op) const override { - MemRefType memRefType = getMemRefResultType(op); - return success(isConvertibleAndHasIdentityMaps(memRefType)); - } - - // An `alloc` is converted into a definition of a memref descriptor value and - // a call to `malloc` to allocate the underlying data buffer. The memref - // descriptor is of the LLVM structure type where: - // 1. the first element is a pointer to the allocated (typed) data buffer, - // 2. the second element is a pointer to the (typed) payload, aligned to the - // specified alignment, - // 3. the remaining elements serve to store all the sizes and strides of the - // memref using LLVM-converted `index` type. - // - // Alignment is performed by allocating `alignment` more bytes than - // requested and shifting the aligned pointer relative to the allocated - // memory. Note: `alignment - ` would actually be - // sufficient. If alignment is unspecified, the two pointers are equal. - - // An `alloca` is converted into a definition of a memref descriptor value and - // an llvm.alloca to allocate the underlying data buffer. - void rewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - MemRefType memRefType = getMemRefResultType(op); - auto loc = op->getLoc(); - - // Get actual sizes of the memref as values: static sizes are constant - // values and dynamic sizes are passed to 'alloc' as operands. In case of - // zero-dimensional memref, assume a scalar (size 1). - SmallVector sizes; - SmallVector strides; - Value sizeBytes; - this->getMemRefDescriptorSizes(loc, memRefType, operands, rewriter, sizes, - strides, sizeBytes); - - // Allocate the underlying buffer. - Value allocatedPtr; - Value alignedPtr; - std::tie(allocatedPtr, alignedPtr) = - this->allocateBuffer(rewriter, loc, sizeBytes, op); - - // Create the MemRef descriptor. - auto memRefDescriptor = this->createMemRefDescriptor( - loc, memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter); - - // Return the final value of the descriptor. - rewriter.replaceOp(op, {memRefDescriptor}); - } -}; - -struct AllocOpLowering : public AllocLikeOpLowering { +struct AllocOpLowering : public AllocLikeOpLLVMLowering { AllocOpLowering(LLVMTypeConverter &converter) - : AllocLikeOpLowering(memref::AllocOp::getOperationName(), converter) {} + : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), + converter) {} std::tuple allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes, @@ -1967,9 +1885,10 @@ } }; -struct AlignedAllocOpLowering : public AllocLikeOpLowering { +struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering { AlignedAllocOpLowering(LLVMTypeConverter &converter) - : AllocLikeOpLowering(memref::AllocOp::getOperationName(), converter) {} + : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), + converter) {} /// Returns the memref's element size in bytes. // TODO: there are other places where this is used. Expose publicly? @@ -2047,9 +1966,10 @@ // Out of line definition, required till C++17. constexpr uint64_t AlignedAllocOpLowering::kMinAlignedAllocAlignment; -struct AllocaOpLowering : public AllocLikeOpLowering { +struct AllocaOpLowering : public AllocLikeOpLLVMLowering { AllocaOpLowering(LLVMTypeConverter &converter) - : AllocLikeOpLowering(memref::AllocaOp::getOperationName(), converter) {} + : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(), + converter) {} /// Allocates the underlying buffer using the right call. `allocatedBytePtr` /// is set to null for stack allocations. `accessAlignment` is set if @@ -2310,10 +2230,10 @@ /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to /// the first element stashed into the descriptor. This reuses /// `AllocLikeOpLowering` to reuse the Memref descriptor construction. -struct GetGlobalMemrefOpLowering : public AllocLikeOpLowering { +struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering { GetGlobalMemrefOpLowering(LLVMTypeConverter &converter) - : AllocLikeOpLowering(memref::GetGlobalOp::getOperationName(), - converter) {} + : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(), + converter) {} /// Buffer "allocation" for memref.get_global op is getting the address of /// the global variable referenced. @@ -4195,6 +4115,45 @@ }; } // end namespace +Value AllocLikeOpLLVMLowering::createAligned( + ConversionPatternRewriter &rewriter, Location loc, Value input, + Value alignment) { + Value one = createIndexAttrConstant(rewriter, loc, alignment.getType(), 1); + Value bump = rewriter.create(loc, alignment, one); + Value bumped = rewriter.create(loc, input, bump); + Value mod = rewriter.create(loc, bumped, alignment); + return rewriter.create(loc, bumped, mod); +} + +void AllocLikeOpLLVMLowering::rewrite( + Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + MemRefType memRefType = getMemRefResultType(op); + auto loc = op->getLoc(); + + // Get actual sizes of the memref as values: static sizes are constant + // values and dynamic sizes are passed to 'alloc' as operands. In case of + // zero-dimensional memref, assume a scalar (size 1). + SmallVector sizes; + SmallVector strides; + Value sizeBytes; + this->getMemRefDescriptorSizes(loc, memRefType, operands, rewriter, sizes, + strides, sizeBytes); + + // Allocate the underlying buffer. + Value allocatedPtr; + Value alignedPtr; + std::tie(allocatedPtr, alignedPtr) = + this->allocateBuffer(rewriter, loc, sizeBytes, op); + + // Create the MemRef descriptor. + auto memRefDescriptor = this->createMemRefDescriptor( + loc, memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter); + + // Return the final value of the descriptor. + rewriter.replaceOp(op, {memRefDescriptor}); +} + mlir::LLVMConversionTarget::LLVMConversionTarget(MLIRContext &ctx) : ConversionTarget(ctx) { this->addLegalDialect();