diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -496,6 +496,15 @@ cast(op)); } + // Remove all dead to_tensor ops. + op->walk([&](ToTensorOp toTensorOp) { + if (toTensorOp->getUses().empty()) { + rewriter.eraseOp(toTensorOp); + return WalkResult::skip(); + } + return WalkResult::advance(); + }); + /// Check the result of bufferization. Return an error if an op was not /// bufferized, unless partial bufferization is allowed. if (options.allowUnknownOps) diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/Dialect.h" @@ -131,6 +132,158 @@ } }; +/// Bufferization of vector.mask. Replaced with a new vector.mask that +/// operates on a memref. +struct MaskOpInterface + : public BufferizableOpInterface::ExternalModel { + SmallVector + getAliasingOpOperand(Operation *op, OpResult opResult, + const AnalysisState &state) const { + // MaskOps do not have tensor OpOperands. The yielded values are the result + // of the wrapped op. + auto maskOp = cast(op); + size_t resultNum = std::distance(op->getOpResults().begin(), + llvm::find(op->getOpResults(), opResult)); + auto yieldOp = + cast(maskOp.getMaskRegion().front().getTerminator()); + return {&yieldOp->getOpOperand(resultNum)}; + } + + LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter, + const AnalysisState &state) const { + auto bufferizableOp = cast(op); + if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state))) + return failure(); + + // TODO: Remove this function when vector.mask bodies can bufferize + // out-of-place. This is currently not supported because yielding allocs + // from a block leads to a memory leak and because vector.mask supports only + // a single op in its body. + auto maskOp = cast(op); + if (!maskOp.getMaskRegion() + .front() + .getOps() + .empty()) + return op->emitOpError("body must bufferize in-place"); + + return success(); + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const BufferizationOptions &options) const { + auto maskOp = cast(op); + + // Do not bufferize if the masked op is not bufferizable. + Operation *maskedOp = maskOp.getMaskableOp(); + if (!options.dynCastBufferizableOp(maskedOp)) + return success(); + + // Update the terminator: Drop all operands that are not results of the + // masked op. + auto yieldOp = + cast(maskOp.getMaskRegion().front().getTerminator()); + SmallVector newReturnValues(maskOp->getNumResults(), Value()); + SmallVector newYieldedValues; + for (const auto &it : llvm::enumerate(yieldOp.getOperands())) { + if (llvm::find(maskedOp->getOpResults(), it.value()) != + maskedOp->getOpResults().end()) { + newYieldedValues.push_back(it.value()); + } else { + // This used to be a tensor result of the masked op, but is now a memref + // that is defined outside of the vector.mask op. + newReturnValues[it.index()] = it.value(); + } + } + rewriter.updateRootInPlace(yieldOp, [&]() { + yieldOp.getOperandsMutable().assign(newYieldedValues); + }); + + // Create a new vector.mask op. + TypeRange newResultTypes(newYieldedValues); + auto newOp = rewriter.create( + op->getLoc(), newResultTypes, maskOp.getMask(), maskOp.getPassthru(), + /*maskableOp=*/nullptr, + /*maskRegionBuilder=*/[](OpBuilder &b, Operation *) {}); + newOp.getRegion().takeBody(maskOp.getMaskRegion()); + + // Replace all uses of the old vector.mask op. + int idx = 0; + for (int i = 0; i < maskOp->getNumResults(); ++i) { + if (!newReturnValues[i]) + newReturnValues[i] = newOp->getResult(idx++); + } + replaceOpWithBufferizedValues(rewriter, maskOp, newReturnValues); + return success(); + } + + BufferRelation bufferRelation(Operation *op, OpResult opResult, + const AnalysisState &state) const { + return BufferRelation::Equivalent; + } +}; + +/// Bufferization of vector.yield. Replaced with a new vector.yield that +/// operates on a memref. +struct YieldOpInterface + : public BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return false; + } + + SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return {op->getParentOp()->getResult(opOperand.getOperandNumber())}; + } + + bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + // Yield operands always bufferize inplace. Otherwise, an alloc + copy + // may be generated inside the block. We should not return/yield allocations + // when possible. + return true; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const BufferizationOptions &options) const { + auto yieldOp = cast(op); + + // Only supported as a vector.mask terminator. + auto maskOp = dyn_cast(yieldOp->getParentOp()); + if (!maskOp) + return yieldOp->emitError("unsupported vector::YieldOp parent"); + + // Do not bufferize if the masked op is not bufferizable. + Operation *maskedOp = &maskOp.getMaskRegion().front().front(); + if (!options.dynCastBufferizableOp(maskedOp)) + return success(); + + // Create a new terminator with the same number of operands. Some of these + // may get dropped during the bufferization of vector.mask. + SmallVector newResults; + for (Value value : yieldOp.getOperands()) { + if (value.getType().isa()) { + FailureOr maybeBuffer = getBuffer(rewriter, value, options); + if (failed(maybeBuffer)) + return failure(); + newResults.push_back(*maybeBuffer); + } else { + newResults.push_back(value); + } + } + + replaceOpWithNewBufferizedOp(rewriter, op, newResults); + return success(); + } +}; + } // namespace } // namespace vector } // namespace mlir @@ -141,5 +294,7 @@ TransferReadOp::attachInterface(*ctx); TransferWriteOp::attachInterface(*ctx); GatherOp::attachInterface(*ctx); + MaskOp::attachInterface(*ctx); + YieldOp::attachInterface(*ctx); }); } diff --git a/mlir/test/Dialect/Vector/bufferize-invalid.mlir b/mlir/test/Dialect/Vector/bufferize-invalid.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/bufferize-invalid.mlir @@ -0,0 +1,9 @@ +// RUN: mlir-opt %s -vector-bufferize -split-input-file -verify-diagnostics +// | FileCheck %s + +// CHECK-LABEL: func @mask( +func.func @mask(%t0: tensor, %val: vector<16xf32>, %idx: index, %m0: vector<16xi1>) -> tensor { + // expected-error @+1 {{'vector.mask' op body must bufferize in-place}} + %0 = vector.mask %m0 { vector.transfer_write %val, %t0[%idx] : vector<16xf32>, tensor } : vector<16xi1> -> tensor + return %0 : tensor +} diff --git a/mlir/test/Dialect/Vector/bufferize.mlir b/mlir/test/Dialect/Vector/bufferize.mlir --- a/mlir/test/Dialect/Vector/bufferize.mlir +++ b/mlir/test/Dialect/Vector/bufferize.mlir @@ -43,3 +43,6 @@ %0 = vector.gather %base[%c0, %c0][%v], %mask, %pass_thru : tensor, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> return %0 : vector<16xf32> } + +// TODO: Add test case for vector.mask. The masked op can currently not +// bufferize out-of-place, so the only test case is in one-shot-bufferize.mlir. diff --git a/mlir/test/Dialect/Vector/one-shot-bufferize.mlir b/mlir/test/Dialect/Vector/one-shot-bufferize.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/one-shot-bufferize.mlir @@ -0,0 +1,12 @@ +// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries" -split-input-file | FileCheck %s + +// CHECK-LABEL: func @mask( +// CHECK-SAME: %[[t0:.*]]: memref> +func.func @mask(%t0: tensor, %val: vector<16xf32>, %idx: index, %m0: vector<16xi1>) -> tensor { + // CHECK-NOT: alloc + // CHECK-NOT: copy + // CHECK: vector.mask %{{.*}} { vector.transfer_write %{{.*}}, %[[t0]][%{{.*}}] : vector<16xf32>, memref> } : vector<16xi1> + %0 = vector.mask %m0 { vector.transfer_write %val, %t0[%idx] : vector<16xf32>, tensor } : vector<16xi1> -> tensor + // CHECK: return %[[t0]] + return %0 : tensor +}