diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -196,13 +196,24 @@ return inBoundsCondition; } +// TODO: Parallelism and threadlocal considerations. +static Value setAllocAtFunctionEntry(MemRefType memRefMinorVectorType, + Operation *op) { + auto &b = ScopedContext::getBuilderRef(); + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(&op->getParentOfType().front()); + Value res = + std_alloca(memRefMinorVectorType, ValueRange{}, b.getI64IntegerAttr(128)); + return res; +} + template <> LogicalResult NDTransferOpHelper::doReplace() { Value alloc, result; if (options.unroll) result = std_splat(vectorType, xferOp.padding()); else - alloc = std_alloc(memRefMinorVectorType); + alloc = setAllocAtFunctionEntry(memRefMinorVectorType, op); emitLoops([&](ValueRange majorIvs, ValueRange leadingOffsets, ValueRange majorOffsets, ValueRange minorOffsets, @@ -297,7 +308,7 @@ LogicalResult NDTransferOpHelper::doReplace() { Value alloc; if (!options.unroll) { - alloc = std_alloc(memRefMinorVectorType); + alloc = setAllocAtFunctionEntry(memRefMinorVectorType, op); std_store(xferOp.vector(), vector_type_cast(MemRefType::get({}, vectorType), alloc)); }