diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -267,15 +267,22 @@ Value maskBuffer; }; +// TODO: Parallelism and threadlocal considerations with a ParallelScope trait. +static Operation *getAutomaticAllocationScope(Operation *op) { + Operation *scope = + op->getParentWithTrait(); + assert(scope && "Expected op to be inside automatic allocation scope"); + return scope; +} + /// Allocate temporary buffers for data (vector) and mask (if present). -/// TODO: Parallelism and threadlocal considerations. template static BufferAllocs allocBuffers(OpBuilder &b, OpTy xferOp) { Location loc = xferOp.getLoc(); OpBuilder::InsertionGuard guard(b); - Operation *scope = - xferOp->template getParentWithTrait(); - assert(scope && "Expected op to be inside automatic allocation scope"); + Operation *scope = getAutomaticAllocationScope(xferOp); + assert(scope->getNumRegions() == 1 && + "AutomaticAllocationScope with >1 regions"); b.setInsertionPointToStart(&scope->getRegion(0).front()); BufferAllocs result; diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp @@ -438,6 +438,14 @@ }); } +// TODO: Parallelism and threadlocal considerations with a ParallelScope trait. +static Operation *getAutomaticAllocationScope(Operation *op) { + Operation *scope = + op->getParentWithTrait(); + assert(scope && "Expected op to be inside automatic allocation scope"); + return scope; +} + /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds /// masking) fastpath and a slowpath. /// @@ -538,12 +546,14 @@ // Top of the function `alloc` for transient storage. Value alloc; { - FuncOp funcOp = xferOp->getParentOfType(); RewriterBase::InsertionGuard guard(b); - b.setInsertionPointToStart(&funcOp.getRegion().front()); + Operation *scope = getAutomaticAllocationScope(xferOp); + assert(scope->getNumRegions() == 1 && + "AutomaticAllocationScope with >1 regions"); + b.setInsertionPointToStart(&scope->getRegion(0).front()); auto shape = xferOp.getVectorType().getShape(); Type elementType = xferOp.getVectorType().getElementType(); - alloc = b.create(funcOp.getLoc(), + alloc = b.create(scope->getLoc(), MemRefType::get(shape, elementType), ValueRange{}, b.getI64IntegerAttr(32)); } diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir --- a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir +++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir @@ -481,3 +481,22 @@ // CHECK-LABEL: transfer_write_strided( // CHECK: scf.for // CHECK: store + +// ----- + +func private @fake_side_effecting_fun(%0: vector<2x2xf32>) -> () + +// CHECK-LABEL: transfer_read_within_async_execute +func @transfer_read_within_async_execute(%A : memref<2x2xf32>) -> !async.token { + %c0 = arith.constant 0 : index + %f0 = arith.constant 0.0 : f32 + // CHECK-NOT: alloca + // CHECK: async.execute + // CHECK: alloca + %token = async.execute { + %0 = vector.transfer_read %A[%c0, %c0], %f0 : memref<2x2xf32>, vector<2x2xf32> + call @fake_side_effecting_fun(%0) : (vector<2x2xf32>) -> () + async.yield + } + return %token : !async.token +} diff --git a/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir @@ -393,3 +393,22 @@ // LINALG: } // LINALG: return // LINALG: } + +// ----- + +func private @fake_side_effecting_fun(%0: vector<2x2xf32>) -> () + +// CHECK-LABEL: transfer_read_within_async_execute +func @transfer_read_within_async_execute(%A : memref) -> !async.token { + %c0 = arith.constant 0 : index + %f0 = arith.constant 0.0 : f32 + // CHECK-NOT: alloca + // CHECK: async.execute + // CHECK: alloca + %token = async.execute { + %0 = vector.transfer_read %A[%c0, %c0], %f0 : memref, vector<2x2xf32> + call @fake_side_effecting_fun(%0) : (vector<2x2xf32>) -> () + async.yield + } + return %token : !async.token +}