diff --git a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h --- a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h +++ b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h @@ -378,6 +378,9 @@ std::optional getMemoryFootprintBytes(AffineForOp forOp, int memorySpace = -1); +/// Returns the memref's element type's size in bytes. +std::optional getMemRefEltSizeInBytes(MemRefType memRefType); + /// Simplify the integer set by simplifying the underlying affine expressions by /// flattening and some simple inference. Also, drop any duplicate constraints. /// Returns the simplified integer set. This method runs in time linear in the diff --git a/mlir/include/mlir/Dialect/Affine/LoopUtils.h b/mlir/include/mlir/Dialect/Affine/LoopUtils.h --- a/mlir/include/mlir/Dialect/Affine/LoopUtils.h +++ b/mlir/include/mlir/Dialect/Affine/LoopUtils.h @@ -184,7 +184,9 @@ /// available for processing this block range. When 'filterMemRef' is specified, /// copies are only generated for the provided MemRef. Returns success if the /// explicit copying succeeded for all memrefs on which affine load/stores were -/// encountered. +/// encountered. For memrefs for whose element types a size in bytes can't be +/// computed (`index` type), their capacity is not accounted for and the +/// `fastMemCapacityBytes` copy option would be non-functional in such cases. LogicalResult affineDataCopyGenerate(Block::iterator begin, Block::iterator end, const AffineCopyOptions ©Options, std::optional filterMemRef, diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp --- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp @@ -594,16 +594,20 @@ return success(); } -static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { +std::optional mlir::getMemRefEltSizeInBytes(MemRefType memRefType) { auto elementType = memRefType.getElementType(); unsigned sizeInBits; if (elementType.isIntOrFloat()) { sizeInBits = elementType.getIntOrFloatBitWidth(); + } else if (auto vectorType = elementType.dyn_cast()) { + if (vectorType.getElementType().isIntOrFloat()) + sizeInBits = + vectorType.getElementTypeBitWidth() * vectorType.getNumElements(); + else + return std::nullopt; } else { - auto vectorType = elementType.cast(); - sizeInBits = - vectorType.getElementTypeBitWidth() * vectorType.getNumElements(); + return std::nullopt; } return llvm::divideCeil(sizeInBits, 8); } @@ -629,7 +633,10 @@ LLVM_DEBUG(llvm::dbgs() << "Dynamic shapes not yet supported\n"); return std::nullopt; } - return getMemRefEltSizeInBytes(memRefType) * *numElements; + auto eltSize = getMemRefEltSizeInBytes(memRefType); + if (!eltSize) + return std::nullopt; + return *eltSize * *numElements; } /// Returns the size of memref data in bytes if it's statically shaped, @@ -643,9 +650,11 @@ if (!elementType.isIntOrFloat() && !elementType.isa()) return std::nullopt; - uint64_t sizeInBytes = getMemRefEltSizeInBytes(memRefType); + auto sizeInBytes = getMemRefEltSizeInBytes(memRefType); + if (!sizeInBytes) + return std::nullopt; for (unsigned i = 0, e = memRefType.getRank(); i < e; i++) { - sizeInBytes = sizeInBytes * memRefType.getDimSize(i); + sizeInBytes = *sizeInBytes * memRefType.getDimSize(i); } return sizeInBytes; } diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp --- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp @@ -901,21 +901,6 @@ node->op = newRootForOp; } -// TODO: improve/complete this when we have target data. -static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { - auto elementType = memRefType.getElementType(); - - unsigned sizeInBits; - if (elementType.isIntOrFloat()) { - sizeInBits = elementType.getIntOrFloatBitWidth(); - } else { - auto vectorType = elementType.cast(); - sizeInBits = - vectorType.getElementTypeBitWidth() * vectorType.getNumElements(); - } - return llvm::divideCeil(sizeInBits, 8); -} - // Creates and returns a private (single-user) memref for fused loop rooted // at 'forOp', with (potentially reduced) memref size based on the // MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'. @@ -976,7 +961,9 @@ // Create 'newMemRefType' using 'newShape' from MemRefRegion accessed // by 'srcStoreOpInst'. - uint64_t bufSize = getMemRefEltSizeInBytes(oldMemRefType) * *numElements; + auto eltSize = getMemRefEltSizeInBytes(oldMemRefType); + assert(eltSize && "memrefs with size elt types expected"); + uint64_t bufSize = *eltSize * *numElements; unsigned newMemSpace; if (bufSize <= localBufSizeThreshold && fastMemorySpace.has_value()) { newMemSpace = *fastMemorySpace; diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp --- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp @@ -2181,7 +2181,10 @@ // Record it. fastBufferMap[memref] = fastMemRef; // fastMemRefType is a constant shaped memref. - *sizeInBytes = *getMemRefSizeInBytes(fastMemRefType); + auto maySizeInBytes = getMemRefSizeInBytes(fastMemRefType); + if (!maySizeInBytes) + maySizeInBytes = 0; + LLVM_DEBUG(emitRemarkForBlock(*block) << "Creating fast buffer of type " << fastMemRefType << " and size " << llvm::divideCeil(*sizeInBytes, 1024) diff --git a/mlir/test/Dialect/Affine/affine-data-copy.mlir b/mlir/test/Dialect/Affine/affine-data-copy.mlir --- a/mlir/test/Dialect/Affine/affine-data-copy.mlir +++ b/mlir/test/Dialect/Affine/affine-data-copy.mlir @@ -310,3 +310,26 @@ // CHECK-NEXT: affine.parallel return } + +// CHECK-LABEL: func @index_elt_type +func.func @index_elt_type(%arg0: memref<1x2x4x8xindex>) { + affine.for %arg1 = 0 to 1 { + affine.for %arg2 = 0 to 2 { + affine.for %arg3 = 0 to 4 { + affine.for %arg4 = 0 to 8 { + affine.store %arg4, %arg0[%arg1, %arg2, %arg3, %arg4] : memref<1x2x4x8xindex> + } + } + } + } + + // CHECK: affine.for %{{.*}} = 0 to 1 + // CHECK-NEXT: affine.for %{{.*}} = 0 to 2 + // CHECK-NEXT: affine.for %{{.*}} = 0 to 4 + // CHECK-NEXT: affine.for %{{.*}} = 0 to 8 + + // CHECK: affine.for %{{.*}} = 0 to 2 + // CHECK-NEXT: affine.for %{{.*}} = 0 to 4 + // CHECK-NEXT: affine.for %{{.*}} = 0 to 8 + return +}