diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -440,6 +440,20 @@ ValueRange indices, ConversionPatternRewriter &rewriter, llvm::Module &module) const; + /// Returns the type of a pointer to an element of the memref. + Type getElementPtrType(MemRefType type) const; + + /// Determines sizes to be used in the memref descriptor. + void getMemRefDescriptorSizes(Location loc, MemRefType memRefType, + ArrayRef dynSizes, + ConversionPatternRewriter &rewriter, + SmallVectorImpl &sizes) const; + + /// Computes total size in bytes of to store the given shape. + Value getCumulativeSizeInBytes(Location loc, Type elementType, + ArrayRef shape, + ConversionPatternRewriter &rewriter) const; + protected: /// Reference to the type converter, with potential extensions. LLVMTypeConverter &typeConverter; diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -925,6 +925,51 @@ offset, rewriter); } +Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const { + auto elementType = type.getElementType(); + auto structElementType = typeConverter.convertType(elementType); + return structElementType.cast().getPointerTo( + type.getMemorySpace()); +} + +void ConvertToLLVMPattern::getMemRefDescriptorSizes( + Location loc, MemRefType memRefType, ArrayRef dynSizes, + ConversionPatternRewriter &rewriter, SmallVectorImpl &sizes) const { + sizes.reserve(memRefType.getRank()); + unsigned i = 0; + for (int64_t s : memRefType.getShape()) + sizes.push_back(s == -1 ? dynSizes[i++] + : createIndexConstant(rewriter, loc, s)); +} + +Value ConvertToLLVMPattern::getCumulativeSizeInBytes( + Location loc, Type elementType, ArrayRef sizes, + ConversionPatternRewriter &rewriter) const { + // Compute the total number of memref elements. + Value cumulativeSizeInBytes = + sizes.empty() ? createIndexConstant(rewriter, loc, 1) : sizes.front(); + for (unsigned i = 1, e = sizes.size(); i < e; ++i) + cumulativeSizeInBytes = rewriter.create( + loc, getIndexType(), ArrayRef{cumulativeSizeInBytes, sizes[i]}); + + // Compute the size of an individual element. This emits the MLIR equivalent + // of the following sizeof(...) implementation in LLVM IR: + // %0 = getelementptr %elementType* null, %indexType 1 + // %1 = ptrtoint %elementType* %0 to %indexType + // which is a common pattern of getting the size of a type in bytes. + auto convertedPtrType = typeConverter.convertType(elementType) + .template cast() + .getPointerTo(); + auto nullPtr = rewriter.create(loc, convertedPtrType); + auto gep = rewriter.create( + loc, convertedPtrType, + ArrayRef{nullPtr, createIndexConstant(rewriter, loc, 1)}); + auto elementSize = + rewriter.create(loc, getIndexType(), gep); + return rewriter.create( + loc, getIndexType(), ArrayRef{cumulativeSizeInBytes, elementSize}); +} + /// Only retain those attributes that are not constructed by /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument /// attributes. @@ -1698,7 +1743,7 @@ Location loc, ConversionPatternRewriter &rewriter, MemRefType memRefType, Value allocatedTypePtr, Value allocatedBytePtr, Value accessAlignment, uint64_t offset, ArrayRef strides, ArrayRef sizes) const { - auto elementPtrType = getElementPtrType(memRefType); + auto elementPtrType = this->getElementPtrType(memRefType); auto structType = typeConverter.convertType(memRefType); auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType); @@ -1756,52 +1801,6 @@ return memRefDescriptor; } - /// Determines sizes to be used in the memref descriptor. - void getSizes(Location loc, MemRefType memRefType, ArrayRef operands, - ConversionPatternRewriter &rewriter, - SmallVectorImpl &sizes, Value &cumulativeSize, - Value &one) const { - sizes.reserve(memRefType.getRank()); - unsigned i = 0; - for (int64_t s : memRefType.getShape()) - sizes.push_back(s == -1 ? operands[i++] - : createIndexConstant(rewriter, loc, s)); - if (sizes.empty()) - sizes.push_back(createIndexConstant(rewriter, loc, 1)); - - // Compute the total number of memref elements. - cumulativeSize = sizes.front(); - for (unsigned i = 1, e = sizes.size(); i < e; ++i) - cumulativeSize = rewriter.create( - loc, getIndexType(), ArrayRef{cumulativeSize, sizes[i]}); - - // Compute the size of an individual element. This emits the MLIR equivalent - // of the following sizeof(...) implementation in LLVM IR: - // %0 = getelementptr %elementType* null, %indexType 1 - // %1 = ptrtoint %elementType* %0 to %indexType - // which is a common pattern of getting the size of a type in bytes. - auto elementType = memRefType.getElementType(); - auto convertedPtrType = typeConverter.convertType(elementType) - .template cast() - .getPointerTo(); - auto nullPtr = rewriter.create(loc, convertedPtrType); - one = createIndexConstant(rewriter, loc, 1); - auto gep = rewriter.create(loc, convertedPtrType, - ArrayRef{nullPtr, one}); - auto elementSize = - rewriter.create(loc, getIndexType(), gep); - cumulativeSize = rewriter.create( - loc, getIndexType(), ArrayRef{cumulativeSize, elementSize}); - } - - /// Returns the type of a pointer to an element of the memref. - Type getElementPtrType(MemRefType memRefType) const { - auto elementType = memRefType.getElementType(); - auto structElementType = typeConverter.convertType(elementType); - return structElementType.template cast().getPointerTo( - memRefType.getMemorySpace()); - } - /// Returns the memref's element size in bytes. // TODO: there are other places where this is used. Expose publicly? static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { @@ -1856,7 +1855,7 @@ MemRefType memRefType, Value one, Value &accessAlignment, Value &allocatedBytePtr, ConversionPatternRewriter &rewriter) const { - auto elementPtrType = getElementPtrType(memRefType); + auto elementPtrType = this->getElementPtrType(memRefType); // With alloca, one gets a pointer to the element type right away. // For stack allocations. @@ -1959,9 +1958,10 @@ // values and dynamic sizes are passed to 'alloc' as operands. In case of // zero-dimensional memref, assume a scalar (size 1). SmallVector sizes; - Value cumulativeSize, one; - getSizes(loc, memRefType, operands, rewriter, sizes, cumulativeSize, one); + this->getMemRefDescriptorSizes(loc, memRefType, operands, rewriter, sizes); + Value cumulativeSize = this->getCumulativeSizeInBytes( + loc, memRefType.getElementType(), sizes, rewriter); // Allocate the underlying buffer. // Value holding the alignment that has to be performed post allocation // (in conjunction with allocators that do not support alignment, eg. @@ -1970,8 +1970,9 @@ // Byte pointer to the allocated buffer. Value allocatedBytePtr; Value allocatedTypePtr = - allocateBuffer(loc, cumulativeSize, op, memRefType, one, - accessAlignment, allocatedBytePtr, rewriter); + allocateBuffer(loc, cumulativeSize, op, memRefType, + createIndexConstant(rewriter, loc, 1), accessAlignment, + allocatedBytePtr, rewriter); int64_t offset; SmallVector strides; diff --git a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir @@ -36,6 +36,7 @@ // CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> // CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64 // CHECK-NEXT: %[[sz_bytes:.*]] = llvm.mul %[[sz]], %[[sizeof]] : !llvm.i64 +// CHECK-NEXT: %[[one_1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK-NEXT: llvm.call @malloc(%[[sz_bytes]]) : (!llvm.i64) -> !llvm<"i8*"> // CHECK-NEXT: llvm.bitcast %{{.*}} : !llvm<"i8*"> to !llvm<"float*"> // CHECK-NEXT: llvm.mlir.undef : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> @@ -76,6 +77,7 @@ // CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> // CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64 // CHECK-NEXT: %[[sz_bytes:.*]] = llvm.mul %[[sz]], %[[sizeof]] : !llvm.i64 +// CHECK-NEXT: %[[one_1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK-NEXT: llvm.call @malloc(%[[sz_bytes]]) : (!llvm.i64) -> !llvm<"i8*"> // CHECK-NEXT: llvm.bitcast %{{.*}} : !llvm<"i8*"> to !llvm<"float*"> // CHECK-NEXT: llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> @@ -105,6 +107,7 @@ // CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> // CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64 // CHECK-NEXT: %[[sz_bytes:.*]] = llvm.mul %[[num_elems]], %[[sizeof]] : !llvm.i64 +// CHECK-NEXT: %[[one_1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK-NEXT: %[[allocated:.*]] = llvm.alloca %[[sz_bytes]] x !llvm.float : (!llvm.i64) -> !llvm<"float*"> // CHECK-NEXT: llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> // CHECK-NEXT: llvm.insertvalue %[[allocated]], %{{.*}}[0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> @@ -150,6 +153,7 @@ // ALIGNED-ALLOC-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> // ALIGNED-ALLOC-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64 // ALIGNED-ALLOC-NEXT: %[[bytes:.*]] = llvm.mul %[[num_elems]], %[[sizeof]] : !llvm.i64 +// ALIGNED-ALLOC-NEXT: %[[one_1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // ALIGNED-ALLOC-NEXT: %[[alignment:.*]] = llvm.mlir.constant(32 : i64) : !llvm.i64 // ALIGNED-ALLOC-NEXT: %[[allocated:.*]] = llvm.call @aligned_alloc(%[[alignment]], %[[bytes]]) : (!llvm.i64, !llvm.i64) -> !llvm<"i8*"> // ALIGNED-ALLOC-NEXT: llvm.bitcast %[[allocated]] : !llvm<"i8*"> to !llvm<"float*"> diff --git a/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir @@ -74,6 +74,7 @@ // CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> // CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64 // CHECK-NEXT: llvm.mul %{{.*}}, %[[sizeof]] : !llvm.i64 +// CHECK-NEXT: %[[one_1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK-NEXT: llvm.call @malloc(%{{.*}}) : (!llvm.i64) -> !llvm<"i8*"> // CHECK-NEXT: %[[ptr:.*]] = llvm.bitcast %{{.*}} : !llvm<"i8*"> to !llvm<"float*"> // CHECK-NEXT: llvm.mlir.undef : !llvm<"{ float*, float*, i64 }"> @@ -88,6 +89,7 @@ // BAREPTR-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> // BAREPTR-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64 // BAREPTR-NEXT: llvm.mul %{{.*}}, %[[sizeof]] : !llvm.i64 +// BAREPTR-NEXT: %[[one_1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // BAREPTR-NEXT: llvm.call @malloc(%{{.*}}) : (!llvm.i64) -> !llvm<"i8*"> // BAREPTR-NEXT: %[[ptr:.*]] = llvm.bitcast %{{.*}} : !llvm<"i8*"> to !llvm<"float*"> // BAREPTR-NEXT: llvm.mlir.undef : !llvm<"{ float*, float*, i64 }"> @@ -126,9 +128,10 @@ // CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> // CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64 // CHECK-NEXT: llvm.mul %{{.*}}, %[[sizeof]] : !llvm.i64 +// CHECK-NEXT: %[[one_1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK-NEXT: %[[alignment:.*]] = llvm.mlir.constant(8 : index) : !llvm.i64 // CHECK-NEXT: %[[alignmentMinus1:.*]] = llvm.add {{.*}}, %[[alignment]] : !llvm.i64 -// CHECK-NEXT: %[[allocsize:.*]] = llvm.sub %[[alignmentMinus1]], %[[one]] : !llvm.i64 +// CHECK-NEXT: %[[allocsize:.*]] = llvm.sub %[[alignmentMinus1]], %[[one_1]] : !llvm.i64 // CHECK-NEXT: %[[allocated:.*]] = llvm.call @malloc(%[[allocsize]]) : (!llvm.i64) -> !llvm<"i8*"> // CHECK-NEXT: %[[ptr:.*]] = llvm.bitcast %{{.*}} : !llvm<"i8*"> to !llvm<"float*"> // CHECK-NEXT: llvm.mlir.undef : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> @@ -137,7 +140,7 @@ // CHECK-NEXT: %[[alignAdj1:.*]] = llvm.urem %[[allocatedAsInt]], %[[alignment]] : !llvm.i64 // CHECK-NEXT: %[[alignAdj2:.*]] = llvm.sub %[[alignment]], %[[alignAdj1]] : !llvm.i64 // CHECK-NEXT: %[[alignAdj3:.*]] = llvm.urem %[[alignAdj2]], %[[alignment]] : !llvm.i64 -// CHECK-NEXT: %[[aligned:.*]] = llvm.getelementptr %9[%[[alignAdj3]]] : (!llvm<"i8*">, !llvm.i64) -> !llvm<"i8*"> +// CHECK-NEXT: %[[aligned:.*]] = llvm.getelementptr %[[allocated]][%[[alignAdj3]]] : (!llvm<"i8*">, !llvm.i64) -> !llvm<"i8*"> // CHECK-NEXT: %[[alignedBitCast:.*]] = llvm.bitcast %[[aligned]] : !llvm<"i8*"> to !llvm<"float*"> // CHECK-NEXT: llvm.insertvalue %[[alignedBitCast]], %{{.*}}[1] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> // CHECK-NEXT: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 @@ -149,9 +152,10 @@ // BAREPTR-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> // BAREPTR-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64 // BAREPTR-NEXT: llvm.mul %{{.*}}, %[[sizeof]] : !llvm.i64 +// BAREPTR-NEXT: %[[one_1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // BAREPTR-NEXT: %[[alignment:.*]] = llvm.mlir.constant(8 : index) : !llvm.i64 // BAREPTR-NEXT: %[[alignmentMinus1:.*]] = llvm.add {{.*}}, %[[alignment]] : !llvm.i64 -// BAREPTR-NEXT: %[[allocsize:.*]] = llvm.sub %[[alignmentMinus1]], %[[one]] : !llvm.i64 +// BAREPTR-NEXT: %[[allocsize:.*]] = llvm.sub %[[alignmentMinus1]], %[[one_1]] : !llvm.i64 // BAREPTR-NEXT: %[[allocated:.*]] = llvm.call @malloc(%[[allocsize]]) : (!llvm.i64) -> !llvm<"i8*"> // BAREPTR-NEXT: %[[ptr:.*]] = llvm.bitcast %{{.*}} : !llvm<"i8*"> to !llvm<"float*"> // BAREPTR-NEXT: llvm.mlir.undef : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> @@ -160,7 +164,7 @@ // BAREPTR-NEXT: %[[alignAdj1:.*]] = llvm.urem %[[allocatedAsInt]], %[[alignment]] : !llvm.i64 // BAREPTR-NEXT: %[[alignAdj2:.*]] = llvm.sub %[[alignment]], %[[alignAdj1]] : !llvm.i64 // BAREPTR-NEXT: %[[alignAdj3:.*]] = llvm.urem %[[alignAdj2]], %[[alignment]] : !llvm.i64 -// BAREPTR-NEXT: %[[aligned:.*]] = llvm.getelementptr %9[%[[alignAdj3]]] : (!llvm<"i8*">, !llvm.i64) -> !llvm<"i8*"> +// BAREPTR-NEXT: %[[aligned:.*]] = llvm.getelementptr %[[allocated]][%[[alignAdj3]]] : (!llvm<"i8*">, !llvm.i64) -> !llvm<"i8*"> // BAREPTR-NEXT: %[[alignedBitCast:.*]] = llvm.bitcast %[[aligned]] : !llvm<"i8*"> to !llvm<"float*"> // BAREPTR-NEXT: llvm.insertvalue %[[alignedBitCast]], %{{.*}}[1] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> // BAREPTR-NEXT: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 @@ -182,6 +186,7 @@ // CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> // CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64 // CHECK-NEXT: %[[bytes:.*]] = llvm.mul %[[num_elems]], %[[sizeof]] : !llvm.i64 +// CHECK-NEXT: %[[one_1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK-NEXT: %[[allocated:.*]] = llvm.call @malloc(%[[bytes]]) : (!llvm.i64) -> !llvm<"i8*"> // CHECK-NEXT: llvm.bitcast %[[allocated]] : !llvm<"i8*"> to !llvm<"float*"> @@ -193,6 +198,7 @@ // BAREPTR-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> // BAREPTR-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64 // BAREPTR-NEXT: %[[bytes:.*]] = llvm.mul %[[num_elems]], %[[sizeof]] : !llvm.i64 +// BAREPTR-NEXT: %[[one_1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // BAREPTR-NEXT: %[[allocated:.*]] = llvm.call @malloc(%[[bytes]]) : (!llvm.i64) -> !llvm<"i8*"> // BAREPTR-NEXT: llvm.bitcast %[[allocated]] : !llvm<"i8*"> to !llvm<"float*"> %0 = alloc() : memref<32x18xf32> @@ -211,6 +217,7 @@ // CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> // CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64 // CHECK-NEXT: %[[bytes:.*]] = llvm.mul %[[num_elems]], %[[sizeof]] : !llvm.i64 +// CHECK-NEXT: %[[one_1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK-NEXT: %[[allocated:.*]] = llvm.alloca %[[bytes]] x !llvm.float : (!llvm.i64) -> !llvm<"float*"> %0 = alloca() : memref<32x18xf32>