diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h @@ -19,6 +19,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/InplaceInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/VectorInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -13,6 +13,7 @@ #ifndef VECTOR_OPS #define VECTOR_OPS +include "mlir/Interfaces/InplaceInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/VectorInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" @@ -1388,6 +1389,7 @@ def Vector_TransferWriteOp : Vector_Op<"transfer_write", [ + InplaceOpInterface, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, @@ -1486,6 +1488,15 @@ "AffineMap":$permutationMap, "ArrayAttr":$inBounds)>, ]; + let extraClassDeclaration = [{ + // Return the OpResult that is tied to an operand. + OpResult getTiedOpResult(OpOperand &opOperand) { + if (opOperand.get() != source() || !source().getType().isa()) + return OpResult(); + return getOperation()->getResult(0); + } + }]; + let hasFolder = 1; let hasCanonicalizer = 1; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -82,6 +82,7 @@ #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/Operation.h" #include "mlir/Interfaces/InplaceInterface.h" #include "mlir/Interfaces/LoopLikeInterface.h" @@ -669,6 +670,54 @@ return success(); } +static LogicalResult convertTransferOp(OpBuilder &b, + VectorTransferOpInterface op, + BlockAndValueMapping &bvm) { + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(op); + Location loc = op.getLoc(); + + if (op.getShapedType().isa()) + return failure(); + + LLVM_DEBUG(DBGS() << "convert: " << *op << "\n"); + + /// transfer_read from buffer + if (auto readOp = dyn_cast(op.getOperation())) { + readOp.sourceMutable().assign(lookup(bvm, op.source())); + return success(); + } + + auto inPlace = getInPlace(op->getResult(0)); + auto writeOp = cast(op.getOperation()); + + // If transfer_write is not inPlace, allocate a new buffer. + Value newInputBuffer; + if (inPlace != InPlaceSpec::True) { + newInputBuffer = + createNewAllocDeallocPairForShapedValue(b, loc, writeOp.result()); + b.setInsertionPointAfter(newInputBuffer.getDefiningOp()); + map(bvm, writeOp.result(), newInputBuffer); + } else { + // InPlace write will result in memref.tensor_load(x) which must + // canonicalize away with one of it uses. + newInputBuffer = lookup(bvm, writeOp.source()); + } + + // Create a new transfer_write on buffer that doesn't have a return value. + // Leave the previous transfer_write to dead code as it still has uses at + // this point. + b.create( + loc, writeOp.vector(), newInputBuffer, writeOp.indices(), + writeOp.permutation_map(), + writeOp.in_bounds() ? *writeOp.in_bounds() : ArrayAttr()); + + map(bvm, op->getResult(0), newInputBuffer); + + return success(); +} + static void inPlaceAnalysisFuncOpInternals(FuncOp funcOp, const DominanceInfo &domInfo) { assert(funcOp && funcOp->getNumRegions() > 0 && !funcOp.body().empty() && @@ -694,6 +743,9 @@ .Case([&](memref::DimOp op) { return convertDimOp(b, op, bvm); }) .Case([&](LinalgOp op) { return convertAnyLinalgOp(b, op, bvm); }) .Case([&](ReturnOp op) { return convertReturnOp(b, op, bvm); }) + .Case([&](VectorTransferOpInterface op) { + return convertTransferOp(b, op, bvm); + }) .Default([&](Operation *op) { auto isaTensor = [](Type t) { return t.isa(); }; if (llvm::any_of(op->getOperandTypes(), isaTensor) || diff --git a/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir @@ -81,3 +81,39 @@ -> tensor return %r: tensor } +// ----- + +// CHECK-LABEL: func @vec_inplace +func @vec_inplace(%A : tensor {linalg.inplaceable = true}, %vec : vector<4xf32>) + -> tensor +{ + %c0 = constant 0 : index + // CHECK-NOT: alloc + %r = vector.transfer_write %vec, %A[%c0] : vector<4xf32>, tensor + return %r: tensor +} + +// ----- + +// CHECK-LABEL: func @vec_not_inplace +// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: tensor {linalg.inplaceable = true} +func @vec_not_inplace(%A : tensor {linalg.inplaceable = true}, %vec : vector<4xf32>) + -> (tensor, tensor) +{ + %c0 = constant 0 : index + %c1 = constant 1 : index + + // CHECK: %[[BUFFER_CAST:.*]] = memref.buffer_cast %[[A]] : memref + + /// Cross-op multiple uses of %A, the first vector.transfer which has interfering reads must alloc. + // CHECK: %[[ALLOC:.*]] = memref.alloc + // CHECK-NEXT: vector.transfer_write {{.*}}, %[[ALLOC]] + %r0 = vector.transfer_write %vec, %A[%c0] : vector<4xf32>, tensor + + /// The second vector.transfer has no interfering reads and can reuse the buffer. + // CHECK-NOT: alloc + // CHECK-NEXT: vector.transfer_write {{.*}}, %[[BUFFER_CAST]] + %r1 = vector.transfer_write %vec, %A[%c1] : vector<4xf32>, tensor + return %r0, %r1: tensor, tensor +} +