diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -1970,11 +1970,15 @@ return rewriter.create(loc, ptrType, allocatedPtr); } + /// Returns if buffer allocation needs buffer size to be computed. This size + /// feeds into the `bufferSize` argument of `allocateBuffer`. + virtual bool needsBufferSize() const { return true; } + /// Allocates the underlying buffer. Returns the allocated pointer and the /// aligned pointer. virtual std::tuple allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, - Value cumulativeSize, Operation *op) const = 0; + Value bufferSize, Operation *op) const = 0; private: static MemRefType getMemRefResultType(Operation *op) { @@ -2027,14 +2031,16 @@ SmallVector sizes; this->getMemRefDescriptorSizes(loc, memRefType, operands, rewriter, sizes); - Value cumulativeSize = this->getCumulativeSizeInBytes( - loc, memRefType.getElementType(), sizes, rewriter); + Value bufferSize; + if (needsBufferSize()) + bufferSize = this->getCumulativeSizeInBytes( + loc, memRefType.getElementType(), sizes, rewriter); // Allocate the underlying buffer. Value allocatedPtr; Value alignedPtr; std::tie(allocatedPtr, alignedPtr) = - this->allocateBuffer(rewriter, loc, cumulativeSize, op); + this->allocateBuffer(rewriter, loc, bufferSize, op); int64_t offset; SmallVector strides; @@ -2065,7 +2071,7 @@ : AllocLikeOpLowering(AllocOp::getOperationName(), converter) {} std::tuple allocateBuffer(ConversionPatternRewriter &rewriter, - Location loc, Value cumulativeSize, + Location loc, Value bufferSize, Operation *op) const override { // Heap allocations. AllocOp allocOp = cast(op); @@ -2084,15 +2090,14 @@ if (alignment) { // Adjust the allocation size to consider alignment. - cumulativeSize = - rewriter.create(loc, cumulativeSize, alignment); + bufferSize = rewriter.create(loc, bufferSize, alignment); } // Allocate the underlying buffer and store a pointer to it in the MemRef // descriptor. Type elementPtrType = this->getElementPtrType(memRefType); Value allocatedPtr = - createAllocCall(loc, "malloc", elementPtrType, {cumulativeSize}, + createAllocCall(loc, "malloc", elementPtrType, {bufferSize}, allocOp.getParentOfType(), rewriter); Value alignedPtr = allocatedPtr; @@ -2159,7 +2164,7 @@ } std::tuple allocateBuffer(ConversionPatternRewriter &rewriter, - Location loc, Value cumulativeSize, + Location loc, Value bufferSize, Operation *op) const override { // Heap allocations. AllocOp allocOp = cast(op); @@ -2170,12 +2175,11 @@ // aligned_alloc requires size to be a multiple of alignment; we will pad // the size to the next multiple if necessary. if (!isMemRefSizeMultipleOf(memRefType, alignment)) - cumulativeSize = - createAligned(rewriter, loc, cumulativeSize, allocAlignment); + bufferSize = createAligned(rewriter, loc, bufferSize, allocAlignment); Type elementPtrType = this->getElementPtrType(memRefType); Value allocatedPtr = createAllocCall( - loc, "aligned_alloc", elementPtrType, {allocAlignment, cumulativeSize}, + loc, "aligned_alloc", elementPtrType, {allocAlignment, bufferSize}, allocOp.getParentOfType(), rewriter); return std::make_tuple(allocatedPtr, allocatedPtr); @@ -2196,7 +2200,7 @@ /// is set to null for stack allocations. `accessAlignment` is set if /// alignment is needed post allocation (for eg. in conjunction with malloc). std::tuple allocateBuffer(ConversionPatternRewriter &rewriter, - Location loc, Value cumulativeSize, + Location loc, Value bufferSize, Operation *op) const override { // With alloca, one gets a pointer to the element type right away. @@ -2205,7 +2209,7 @@ auto elementPtrType = this->getElementPtrType(allocaOp.getType()); auto allocatedElementPtr = rewriter.create( - loc, elementPtrType, cumulativeSize, + loc, elementPtrType, bufferSize, allocaOp.alignment() ? *allocaOp.alignment() : 0); return std::make_tuple(allocatedElementPtr, allocatedElementPtr); @@ -2420,6 +2424,100 @@ } }; +/// Returns the LLVM type of the global variable given the memref type `type`. +static LLVM::LLVMType +convertGlobalMemrefTypeToLLVM(MemRefType type, + LLVMTypeConverter &typeConverter) { + // LLVM type for a global memref will be a multi-dimension array. For + // declarations or uninitialized global memrefs, we can potentially flatten + // this to a 1D array. However, for global_memref's with an initial value, + // we do not intend to flatten the ElementsAttribute when going from std -> + // LLVM dialect, so the LLVM type needs to me a multi-dimension array. + LLVM::LLVMType elementType = + unwrap(typeConverter.convertType(type.getElementType())); + LLVM::LLVMType arrayTy = elementType; + // Shape has the outermost dim at index 0, so need to walk it backwards + for (int64_t dim : llvm::reverse(type.getShape())) + arrayTy = LLVM::LLVMType::getArrayTy(arrayTy, dim); + return arrayTy; +} + +/// GlobalMemrefOp is lowered to a LLVM Global Variable. +struct GlobalMemrefOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto global = cast(op); + MemRefType type = global.type().cast(); + if (!isSupportedMemRefType(type)) + return failure(); + + LLVM::LLVMType arrayTy = convertGlobalMemrefTypeToLLVM(type, typeConverter); + + LLVM::Linkage linkage = + global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private; + + Attribute initialValue = nullptr; + if (!global.isExternal() && !global.isUnitialized()) { + auto elementsAttr = global.initial_value()->cast(); + initialValue = elementsAttr; + + // For scalar memrefs, the global variable created is of the element type, + // so unpack the elements attribute to extract the value. + if (type.getRank() == 0) + initialValue = elementsAttr.getValue({}); + } + + rewriter.replaceOpWithNewOp( + op, arrayTy, global.constant(), linkage, global.sym_name(), + initialValue, type.getMemorySpace()); + return success(); + } +}; + +/// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to +/// the first element stashed into the descriptor. This reuses +/// `AllocLikeOpLowering` to reuse the Memref descriptor construction. +struct GetGlobalMemrefOpLowering : public AllocLikeOpLowering { + GetGlobalMemrefOpLowering(LLVMTypeConverter &converter) + : AllocLikeOpLowering(GetGlobalMemrefOp::getOperationName(), converter) {} + + /// Allocation for GetGlobalMemrefOp just returns the GV pointer, so no need + /// to compute buffer size. + bool needsBufferSize() const override { return false; } + + /// Buffer "allocation" for get_global_memref op is getting the address of + /// the global variable referenced. + std::tuple allocateBuffer(ConversionPatternRewriter &rewriter, + Location loc, Value bufferSize, + Operation *op) const override { + auto getGlobalOp = cast(op); + MemRefType type = getGlobalOp.result().getType().cast(); + unsigned memSpace = type.getMemorySpace(); + + LLVM::LLVMType arrayTy = convertGlobalMemrefTypeToLLVM(type, typeConverter); + auto addressOf = rewriter.create( + loc, arrayTy.getPointerTo(memSpace), getGlobalOp.name()); + + // Get the address of the first element in the array by creating a GEP with + // the address of the GV as the base, and (rank + 1) number of 0 indices. + LLVM::LLVMType elementType = + unwrap(typeConverter.convertType(type.getElementType())); + LLVM::LLVMType elementPtrType = elementType.getPointerTo(memSpace); + + SmallVector operands = {addressOf}; + operands.insert(operands.end(), type.getRank() + 1, + createIndexConstant(rewriter, loc, 0)); + auto gep = rewriter.create(loc, elementPtrType, operands); + + // Both allocated and aligned pointers are same. We could potentially stash + // a nullptr for the allocated pointer since we do not expect any dealloc. + return {gep, gep}; + } +}; + // A `rsqrt` is converted into `1 / sqrt`. struct RsqrtOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -3941,6 +4039,8 @@ AssumeAlignmentOpLowering, DeallocOpLowering, DimOpLowering, + GlobalMemrefOpLowering, + GetGlobalMemrefOpLowering, LoadOpLowering, MemRefCastOpLowering, MemRefReinterpretCastOpLowering, diff --git a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir --- a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir +++ b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir @@ -131,3 +131,76 @@ %0 = transpose %arg0 (i, j, k) -> (k, i, j) : memref to memref (d2 * s1 + s0 + d0 * s2 + d1)>> return } + +// ----- + +// CHECK: llvm.mlir.global external @gv0() : !llvm.array<2 x float> +global_memref @gv0 : memref<2xf32> = uninitialized + +// CHECK: llvm.mlir.global private @gv1() : !llvm.array<2 x float> +global_memref "private" @gv1 : memref<2xf32> + +// CHECK: llvm.mlir.global external @gv2(dense<{{\[\[}}0.000000e+00, 1.000000e+00, 2.000000e+00], [3.000000e+00, 4.000000e+00, 5.000000e+00]]> : tensor<2x3xf32>) : !llvm.array<2 x array<3 x float>> +global_memref @gv2 : memref<2x3xf32> = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> + +// Test 1D memref. +// CHECK-LABEL: func @get_gv0_memref +func @get_gv0_memref() { + %0 = get_global_memref @gv0 : memref<2xf32> + // CHECK: %[[DIM:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64 + // CHECK: %[[ADDR:.*]] = llvm.mlir.addressof @gv0 : !llvm.ptr> + // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ADDR]][%[[ZERO]], %[[ZERO]]] : (!llvm.ptr>, !llvm.i64, !llvm.i64) -> !llvm.ptr + // CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: llvm.insertvalue %[[GEP]], {{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: llvm.insertvalue %[[GEP]], {{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: %[[OFFSET:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 + // CHECK: llvm.insertvalue %[[OFFSET]], {{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: %[[STRIDE:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 + // CHECK: llvm.insertvalue %[[DIM]], {{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: llvm.insertvalue %[[STRIDE]], {{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + return +} + +// Test 2D memref. +// CHECK-LABEL: func @get_gv2_memref +func @get_gv2_memref() { + // CHECK: %[[DIM0:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64 + // CHECK: %[[DIM1:.*]] = llvm.mlir.constant(3 : index) : !llvm.i64 + // CHECK: %[[ADDR:.*]] = llvm.mlir.addressof @gv2 : !llvm.ptr>> + // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ADDR]][%[[ZERO]], %[[ZERO]], %[[ZERO]]] : (!llvm.ptr>>, !llvm.i64, !llvm.i64, !llvm.i64) -> !llvm.ptr + // CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: llvm.insertvalue %[[GEP]], {{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: llvm.insertvalue %[[GEP]], {{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: %[[OFFSET:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 + // CHECK: llvm.insertvalue %[[OFFSET]], {{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: %[[STRIDE1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 + // CHECK: %[[STRIDE0:.*]] = llvm.mlir.constant(3 : index) : !llvm.i64 + // CHECK: llvm.insertvalue %[[DIM0]], {{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: llvm.insertvalue %[[STRIDE0]], {{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: llvm.insertvalue %[[DIM1]], {{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: llvm.insertvalue %[[STRIDE1]], {{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + + %0 = get_global_memref @gv2 : memref<2x3xf32> + return +} + +// Test scalar memref. +// CHECK: llvm.mlir.global external @gv3(1.000000e+00 : f32) : !llvm.float +global_memref @gv3 : memref = dense<1.0> + +// CHECK-LABEL: func @get_gv3_memref +func @get_gv3_memref() { + // CHECK: %[[ADDR:.*]] = llvm.mlir.addressof @gv3 : !llvm.ptr + // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ADDR]][%[[ZERO]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr + // CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64)> + // CHECK: llvm.insertvalue %[[GEP]], {{.*}}[0] : !llvm.struct<(ptr, ptr, i64)> + // CHECK: llvm.insertvalue %[[GEP]], {{.*}}[1] : !llvm.struct<(ptr, ptr, i64)> + // CHECK: %[[OFFSET:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 + // CHECK: llvm.insertvalue %[[OFFSET]], {{.*}}[2] : !llvm.struct<(ptr, ptr, i64)> + %0 = get_global_memref @gv3 : memref + return +} + diff --git a/mlir/test/mlir-cpu-runner/global_memref.mlir b/mlir/test/mlir-cpu-runner/global_memref.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-cpu-runner/global_memref.mlir @@ -0,0 +1,107 @@ +// RUN: mlir-opt %s -convert-std-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext,%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext | FileCheck %s + +func @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface } +func @print_memref_i32(memref<*xi32>) attributes { llvm.emit_c_interface } +func @printNewline() -> () + +global_memref "private" @gv0 : memref<4xf32> = dense<[0.0, 1.0, 2.0, 3.0]> +func @test1DMemref() { + %0 = get_global_memref @gv0 : memref<4xf32> + %U = memref_cast %0 : memref<4xf32> to memref<*xf32> + // CHECK: rank = 1 + // CHECK: offset = 0 + // CHECK: sizes = [4] + // CHECK: strides = [1] + // CHECK: [0, 1, 2, 3] + call @print_memref_f32(%U) : (memref<*xf32>) -> () + call @printNewline() : () -> () + + // Overwrite some of the elements. + %c0 = constant 0 : index + %c2 = constant 2 : index + %fp0 = constant 4.0 : f32 + %fp1 = constant 5.0 : f32 + store %fp0, %0[%c0] : memref<4xf32> + store %fp1, %0[%c2] : memref<4xf32> + // CHECK: rank = 1 + // CHECK: offset = 0 + // CHECK: sizes = [4] + // CHECK: strides = [1] + // CHECK: [4, 1, 5, 3] + call @print_memref_f32(%U) : (memref<*xf32>) -> () + call @printNewline() : () -> () + return +} + +global_memref constant @gv1 : memref<3x2xi32> = dense<[[0, 1],[2, 3],[4, 5]]> +func @testConstantMemref() { + %0 = get_global_memref @gv1 : memref<3x2xi32> + %U = memref_cast %0 : memref<3x2xi32> to memref<*xi32> + // CHECK: rank = 2 + // CHECK: offset = 0 + // CHECK: sizes = [3, 2] + // CHECK: strides = [2, 1] + // CHECK: [0, 1] + // CHECK: [2, 3] + // CHECK: [4, 5] + call @print_memref_i32(%U) : (memref<*xi32>) -> () + call @printNewline() : () -> () + return +} + +global_memref "private" @gv2 : memref<4x2xf32> = dense<[[0.0, 1.0], [2.0, 3.0], [4.0, 5.0], [6.0, 7.0]]> +func @test2DMemref() { + %0 = get_global_memref @gv2 : memref<4x2xf32> + %U = memref_cast %0 : memref<4x2xf32> to memref<*xf32> + // CHECK: rank = 2 + // CHECK: offset = 0 + // CHECK: sizes = [4, 2] + // CHECK: strides = [2, 1] + // CHECK: [0, 1] + // CHECK: [2, 3] + // CHECK: [4, 5] + // CHECK: [6, 7] + call @print_memref_f32(%U) : (memref<*xf32>) -> () + call @printNewline() : () -> () + + // Overwrite the 1.0 (at index [0, 1]) with 10.0 + %c0 = constant 0 : index + %c1 = constant 1 : index + %fp10 = constant 10.0 : f32 + store %fp10, %0[%c0, %c1] : memref<4x2xf32> + // CHECK: rank = 2 + // CHECK: offset = 0 + // CHECK: sizes = [4, 2] + // CHECK: strides = [2, 1] + // CHECK: [0, 10] + // CHECK: [2, 3] + // CHECK: [4, 5] + // CHECK: [6, 7] + call @print_memref_f32(%U) : (memref<*xf32>) -> () + call @printNewline() : () -> () + return +} + +global_memref @gv3 : memref = dense<11> +func @testScalarMemref() { + %0 = get_global_memref @gv3 : memref + %U = memref_cast %0 : memref to memref<*xi32> + // CHECK: rank = 0 + // CHECK: offset = 0 + // CHECK: sizes = [] + // CHECK: strides = [] + // CHECK: [11] + call @print_memref_i32(%U) : (memref<*xi32>) -> () + call @printNewline() : () -> () + return +} + +func @main() -> () { + call @test1DMemref() : () -> () + call @testConstantMemref() : () -> () + call @test2DMemref() : () -> () + call @testScalarMemref() : () -> () + return +} + +