diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -82,8 +82,8 @@ #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/Operation.h" -#include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/BufferUtils.h" @@ -128,16 +128,25 @@ return linalgOp->getResult(outputOperandIndex - numOutputBuffers); } +OpResult getMatchingOpResult(VectorTransferOpInterface op, + OpOperand &opOperand) { + if (opOperand.get() != op.source() || + !op.source().getType().isa()) + return OpResult(); + return op->getResult(0); +} + /// Determine which results may be reused inplace by the bufferization /// patterns of `bufferizeFuncOpInternals`. /// The inplace analysis uses this information along with interfering read /// analysis to determine which op results reuse the same buffer as some /// operand. OpResult getMatchingOpResult(OpOperand &opOperand) { - OpResult res = - llvm::TypeSwitch(opOperand.getOwner()) - .Case([&](LinalgOp op) { return getMatchingOpResult(op, opOperand); }) - .Default([&](Operation *op) { return OpResult(); }); + OpResult res = llvm::TypeSwitch(opOperand.getOwner()) + .Case([&](auto op) { + return getMatchingOpResult(op, opOperand); + }) + .Default([&](Operation *op) { return OpResult(); }); return res; } @@ -708,6 +717,54 @@ return success(); } +static LogicalResult convertTransferOp(OpBuilder &b, + VectorTransferOpInterface op, + BlockAndValueMapping &bvm) { + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(op); + Location loc = op.getLoc(); + + if (op.getShapedType().isa()) + return failure(); + + LLVM_DEBUG(DBGS() << "convert: " << *op << "\n"); + + /// transfer_read from buffer + if (auto readOp = dyn_cast(op.getOperation())) { + readOp.sourceMutable().assign(lookup(bvm, op.source())); + return success(); + } + + auto inPlace = getInPlace(op->getResult(0)); + auto writeOp = cast(op.getOperation()); + + // If transfer_write is not inPlace, allocate a new buffer. + Value newInputBuffer; + if (inPlace != InPlaceSpec::True) { + newInputBuffer = + createNewAllocDeallocPairForShapedValue(b, loc, writeOp.result()); + b.setInsertionPointAfter(newInputBuffer.getDefiningOp()); + map(bvm, writeOp.result(), newInputBuffer); + } else { + // InPlace write will result in memref.tensor_load(x) which must + // canonicalize away with one of it uses. + newInputBuffer = lookup(bvm, writeOp.source()); + } + + // Create a new transfer_write on buffer that doesn't have a return value. + // Leave the previous transfer_write to dead code as it still has uses at + // this point. + b.create( + loc, writeOp.vector(), newInputBuffer, writeOp.indices(), + writeOp.permutation_map(), + writeOp.in_bounds() ? *writeOp.in_bounds() : ArrayAttr()); + + map(bvm, op->getResult(0), newInputBuffer); + + return success(); +} + static void inPlaceAnalysisFuncOpInternals(FuncOp funcOp, const DominanceInfo &domInfo) { assert(funcOp && funcOp->getNumRegions() > 0 && !funcOp.body().empty() && @@ -733,6 +790,9 @@ .Case([&](memref::DimOp op) { return convertDimOp(b, op, bvm); }) .Case([&](LinalgOp op) { return convertAnyLinalgOp(b, op, bvm); }) .Case([&](ReturnOp op) { return convertReturnOp(b, op, bvm); }) + .Case([&](VectorTransferOpInterface op) { + return convertTransferOp(b, op, bvm); + }) .Default([&](Operation *op) { auto isaTensor = [](Type t) { return t.isa(); }; if (llvm::any_of(op->getOperandTypes(), isaTensor) || diff --git a/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir @@ -81,3 +81,39 @@ -> tensor return %r: tensor } +// ----- + +// CHECK-LABEL: func @vec_inplace +func @vec_inplace(%A : tensor {linalg.inplaceable = true}, %vec : vector<4xf32>) + -> tensor +{ + %c0 = constant 0 : index + // CHECK-NOT: alloc + %r = vector.transfer_write %vec, %A[%c0] : vector<4xf32>, tensor + return %r: tensor +} + +// ----- + +// CHECK-LABEL: func @vec_not_inplace +// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: tensor {linalg.inplaceable = true} +func @vec_not_inplace(%A : tensor {linalg.inplaceable = true}, %vec : vector<4xf32>) + -> (tensor, tensor) +{ + %c0 = constant 0 : index + %c1 = constant 1 : index + + // CHECK: %[[BUFFER_CAST:.*]] = memref.buffer_cast %[[A]] : memref + + /// Cross-op multiple uses of %A, the first vector.transfer which has interfering reads must alloc. + // CHECK: %[[ALLOC:.*]] = memref.alloc + // CHECK-NEXT: vector.transfer_write {{.*}}, %[[ALLOC]] + %r0 = vector.transfer_write %vec, %A[%c0] : vector<4xf32>, tensor + + /// The second vector.transfer has no interfering reads and can reuse the buffer. + // CHECK-NOT: alloc + // CHECK-NEXT: vector.transfer_write {{.*}}, %[[BUFFER_CAST]] + %r1 = vector.transfer_write %vec, %A[%c1] : vector<4xf32>, tensor + return %r0, %r1: tensor, tensor +} +