diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td @@ -389,6 +389,29 @@ return mlir::bufferization::detail::defaultIsRepetitiveRegion( cast($_op.getOperation()), index); }] + >, + InterfaceMethod< + /*desc=*/[{ + Return `true` if allocations are allowed inside the given region of + this op. By default, allocations are allowed. + + This method is queries during TensorCopyInsertion. If an allocation + is attempted to be inserted in a region that does not allow + allocations, it is instead inserted in the parent region. + + Note: This method should be overridden only if setting the insertion + point to the parent region is generally safe. In particular, changing + the insertion point is not safe if the dynamic extents of an + allocation depend on an SSA value defined in the region that disallows + allocations. + }], + /*retType=*/"bool", + /*methodName=*/"areAllocationsAllowedInRegion", + /*args=*/(ins "unsigned":$index), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return true; + }] > ]; diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -96,12 +96,80 @@ return !attr[opResult.getResultNumber()].cast().getValue(); } +/// Compute and set an allocation point for a tensor copy of the given shaped +/// value. This function queries the BufferizableOpInterface to detect regions +/// in which allocations are forbidden. In such a case, an allocation is placed +/// in a parent region. Example: +/// +/// vector.mask ... { +/// vector.transfer_write %v[%c0], %t : vector<5xf32>, tensor +/// } : ... -> tensor +/// +/// In case %t bufferizes out-of-place, the allocation must be placed outside +/// of vector.mask as per the op's BufferizableOpInterface implementation. +/// +/// Note: Allocations may not jump repetitive regions. If an allocation would be +/// placed in a different repetitive region, return failure. This indicates an +/// incorrect implementation of the BufferizableOpInterface. +/// +/// Note: If the new insertion point violates op dominance, return failure. This +/// also indicates an incorrect implementation of the BufferizableOpInterface. +/// +static LogicalResult +setAllocationInsertionPoint(OpBuilder &b, Value shapedValue, + const BufferizationOptions &options) { + Region *r = b.getInsertionBlock()->getParent(); + Region *repetitiveRegion = + getEnclosingRepetitiveRegion(b.getInsertionBlock(), options); + Operation *ip = nullptr; + do { + Operation *op = r->getParentOp(); + if (!op) + return success(); + auto bufferizableOp = options.dynCastBufferizableOp(op); + if (bufferizableOp && + !bufferizableOp.areAllocationsAllowedInRegion(r->getRegionNumber())) + ip = op; + } while ((r = r->getParentRegion())); + + if (ip) { + // A custom insertion point is necessary. + if (getEnclosingRepetitiveRegion(ip, options) != repetitiveRegion) + // It is incorrect to set the buffer allocation point into a different + // repetitive region. This would effectively de-privatize a buffer. + return getOwnerOfValue(shapedValue) + ->emitError( + "unable to move tensor copy ip to different repetitive region"); + + // Check for op dominance errors. + if (auto bbArg = shapedValue.dyn_cast()) { + if (!bbArg.getParentBlock()->findAncestorOpInBlock(*ip)) + // The computed insertion point violates op dominance. + return getOwnerOfValue(shapedValue) + ->emitError( + "unable to find suitable insertion point for tensor copy"); + } else { + Operation *shapedOp = shapedValue.dyn_cast().getDefiningOp(); + Operation *ipInBlock = shapedOp->getBlock()->findAncestorOpInBlock(*ip); + if (!ipInBlock || shapedOp == ipInBlock || + ipInBlock->isBeforeInBlock(shapedOp)) + // The computed insertion point violates op dominance. + return getOwnerOfValue(shapedValue) + ->emitError( + "unable to find suitable insertion point for tensor copy"); + } + b.setInsertionPoint(ip); + } + return success(); +} + /// Create an AllocTensorOp for the given shaped value. If `copy` is set, the /// shaped value is copied. Otherwise, a tensor with undefined contents is /// allocated. FailureOr bufferization::allocateTensorForShapedValue( OpBuilder &b, Location loc, Value shapedValue, bool escape, const BufferizationOptions &options, bool copy) { + OpBuilder::InsertionGuard g(b); Value tensor; if (shapedValue.getType().isa()) { tensor = shapedValue; @@ -137,6 +205,10 @@ populateDynamicDimSizes(b, loc, tensor, dynamicSizes); } + // Compute insertion point for allocation. + if (failed(setAllocationInsertionPoint(b, shapedValue, options))) + return failure(); + // Create AllocTensorOp. auto allocTensorOp = b.create(loc, tensorType, dynamicSizes, copy ? tensor : Value()); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -495,6 +495,15 @@ cast(op)); } + // Remove all dead to_tensor ops. + op->walk([&](ToTensorOp toTensorOp) { + if (toTensorOp->getUses().empty()) { + rewriter.eraseOp(toTensorOp); + return WalkResult::skip(); + } + return WalkResult::advance(); + }); + /// Check the result of bufferization. Return an error if an op was not /// bufferized, unless partial bufferization is allowed. if (options.allowUnknownOps) diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp @@ -131,6 +131,142 @@ } }; +/// Bufferization of vector.mask. Replaced with a new vector.mask that +/// operates on a memref. +struct MaskOpInterface + : public BufferizableOpInterface::ExternalModel { + SmallVector + getAliasingOpOperand(Operation *op, OpResult opResult, + const AnalysisState &state) const { + // MaskOps do not have tensor OpOperands. The yielded values are the result + // of the wrapped op. + auto maskOp = cast(op); + size_t resultNum = std::distance(op->getOpResults().begin(), + llvm::find(op->getOpResults(), opResult)); + auto yieldOp = dyn_cast( + maskOp.getMaskRegion().front().getTerminator()); + assert(yieldOp && "expected vector.yield terminator in vector.mask"); + return {&yieldOp->getOpOperand(resultNum)}; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const BufferizationOptions &options) const { + auto maskOp = cast(op); + + // Do not bufferize if the masked op is not bufferizable. + Operation *maskedOp = &maskOp.getMaskRegion().front().front(); + if (!options.dynCastBufferizableOp(maskedOp)) + return success(); + + // Update the terminator: Drop all operands that are not results of the + // masked op. + auto yieldOp = + cast(maskOp.getMaskRegion().front().getTerminator()); + SmallVector newReturnValues(maskOp->getNumResults(), Value()); + SmallVector newYieldedValues; + for (const auto &it : llvm::enumerate(yieldOp.getOperands())) { + if (llvm::find(maskedOp->getOpResults(), it.value()) != + maskedOp->getOpResults().end()) { + newYieldedValues.push_back(it.value()); + } else { + // This used to be a tensor result of the masked op, but is now a memref + // that is defined outside of the vector.mask op. + newReturnValues[it.index()] = it.value(); + } + } + rewriter.updateRootInPlace(yieldOp, [&]() { + yieldOp.getOperandsMutable().assign(newYieldedValues); + }); + + // Create a new vector.mask op. + TypeRange newResultTypes(newYieldedValues); + auto newOp = rewriter.create( + op->getLoc(), newResultTypes, maskOp.getMask(), maskOp.getPassthru()); + rewriter.eraseOp(newOp.getMaskRegion().front().getTerminator()); + newOp.getRegion().takeBody(maskOp.getMaskRegion()); + + // Replace all uses of the old vector.mask op. + int idx = 0; + for (int i = 0; i < maskOp->getNumResults(); ++i) { + if (!newReturnValues[i]) + newReturnValues[i] = newOp->getResult(idx++); + } + replaceOpWithBufferizedValues(rewriter, maskOp, newReturnValues); + return success(); + } + + BufferRelation bufferRelation(Operation *op, OpResult opResult, + const AnalysisState &state) const { + return BufferRelation::Equivalent; + } + + bool areAllocationsAllowedInRegion(Operation *op, unsigned index) const { + return false; + } +}; + +/// Bufferization of vector.yield. Replaced with a new vector.yield that +/// operates on a memref. +struct YieldOpInterface + : public BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return false; + } + + SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return {op->getParentOp()->getResult(opOperand.getOperandNumber())}; + } + + bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + // Yield operands always bufferize inplace. Otherwise, an alloc + copy + // may be generated inside the block. We should not return/yield allocations + // when possible. + return true; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const BufferizationOptions &options) const { + auto yieldOp = cast(op); + + // Only supported as a vector.mask terminator. + auto maskOp = dyn_cast(yieldOp->getParentOp()); + if (!maskOp) + return yieldOp->emitError("unsupported vector::YieldOp parent"); + + // Do not bufferize if the masked op is not bufferizable. + Operation *maskedOp = &maskOp.getMaskRegion().front().front(); + if (!options.dynCastBufferizableOp(maskedOp)) + return success(); + + // Create a new terminator with the same number of operands. Some of these + // may get dropped during the bufferization of vector.mask. + SmallVector newResults; + for (Value value : yieldOp.getOperands()) { + if (value.getType().isa()) { + FailureOr maybeBuffer = getBuffer(rewriter, value, options); + if (failed(maybeBuffer)) + return failure(); + newResults.push_back(*maybeBuffer); + } else { + newResults.push_back(value); + } + } + + replaceOpWithNewBufferizedOp(rewriter, op, newResults); + return success(); + } +}; + } // namespace } // namespace vector } // namespace mlir @@ -141,5 +277,7 @@ TransferReadOp::attachInterface(*ctx); TransferWriteOp::attachInterface(*ctx); GatherOp::attachInterface(*ctx); + MaskOp::attachInterface(*ctx); + YieldOp::attachInterface(*ctx); }); } diff --git a/mlir/test/Dialect/Vector/bufferize.mlir b/mlir/test/Dialect/Vector/bufferize.mlir --- a/mlir/test/Dialect/Vector/bufferize.mlir +++ b/mlir/test/Dialect/Vector/bufferize.mlir @@ -43,3 +43,21 @@ %0 = vector.gather %base[%c0, %c0][%v], %mask, %pass_thru : tensor, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> return %0 : vector<16xf32> } + +// ----- + +// CHECK-LABEL: func @mask( +// CHECK-SAME: %[[t0:.*]]: tensor, %[[val:.*]]: vector<16xf32> +// CHECK-SAME: %[[idx:.*]]: index, %[[mask:.*]]: vector<16xi1>) +// CHECK-DAG: %[[m:.*]] = bufferization.to_memref %[[t0]] +// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index +// CHECK: %[[dim:.*]] = memref.dim %[[m]], %[[c0]] +// CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]]) +// CHECK: memref.copy %[[m]], %[[alloc]] +// CHECK: vector.mask %[[mask]] { vector.transfer_write %[[val]], %[[alloc]][%[[idx]]] : vector<16xf32>, memref } : vector<16xi1> +// CHECK: %[[r:.*]] = bufferization.to_tensor %[[alloc]] : memref +// CHECK: return %[[r]] +func.func @mask(%t0: tensor, %val: vector<16xf32>, %idx: index, %mask: vector<16xi1>) -> tensor { + %0 = vector.mask %mask { vector.transfer_write %val, %t0[%idx] : vector<16xf32>, tensor } : vector<16xi1> -> tensor + return %0 : tensor +} diff --git a/mlir/test/Dialect/Vector/one-shot-bufferize.mlir b/mlir/test/Dialect/Vector/one-shot-bufferize.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/one-shot-bufferize.mlir @@ -0,0 +1,12 @@ +// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries" -split-input-file | FileCheck %s + +// CHECK-LABEL: func @mask( +// CHECK-SAME: %[[t0:.*]]: memref> +func.func @mask(%t0: tensor, %val: vector<16xf32>, %idx: index, %m0: vector<16xi1>) -> tensor { + // CHECK-NOT: alloc + // CHECK-NOT: copy + // CHECK: vector.mask %{{.*}} { vector.transfer_write %{{.*}}, %[[t0]][%{{.*}}] : vector<16xf32>, memref> } : vector<16xi1> + %0 = vector.mask %m0 { vector.transfer_write %val, %t0[%idx] : vector<16xf32>, tensor } : vector<16xi1> -> tensor + // CHECK: return %[[t0]] + return %0 : tensor +}