diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h @@ -334,7 +334,12 @@ /// that is specific to ops from a certain dialect can be stored in derived /// variants of this struct. struct DialectBufferizationState { + DialectBufferizationState() = default; + virtual ~DialectBufferizationState() = default; + + // Copying state is forbidden. Always pass as reference. + DialectBufferizationState(const DialectBufferizationState &) = delete; }; /// BufferizationState keeps track of memory buffers and provides a variety of @@ -373,10 +378,15 @@ /// Creates a memcpy between two given buffers. void createMemCpy(OpBuilder &b, Location loc, Value from, Value to); + /// Replace an op with replacement values. The op is deleted. + void replaceOp(Operation *op, ValueRange values); + /// Map tensor values to memref buffers. + // TODO: Deprecated. Remove all uses of this op. Use `replaceOp` instead. void mapBuffer(ValueRange tensors, ValueRange buffers); /// Map a tensor value to a memref buffer. + // TODO: Deprecated. Remove all uses of this op. Use `replaceOp` instead. void mapBuffer(Value tensor, Value buffer); /// Lookup the memref buffer that is associated to the given tensor value. @@ -387,6 +397,7 @@ bool isInPlace(OpResult opResult) const; /// Return `true` if the given value is mapped. + // TODO: Deprecated. Remove all uses of this op. bool isMapped(Value value) const; /// Return the result buffer (memref) for a given OpResult (tensor). Allocate @@ -395,9 +406,11 @@ Value getResultBuffer(OpResult result); /// Mark `op` as obsolete, so that it is deleted after bufferization. + // TODO: Deprecated. Remove all uses of this op. void markOpObsolete(Operation *op); /// Erase all ops that were marked obsolete. + // TODO: Deprecated. Remove all uses of this op. void eraseObsoleteOps(); /// Return dialect-specific bufferization state. diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp @@ -417,6 +417,37 @@ return operandBuffer; } +void mlir::linalg::comprehensive_bufferize::BufferizationState::replaceOp( + Operation *op, ValueRange values) { + OpBuilder &b = getBuilder(); + OpBuilder::InsertionGuard g(b); + + // Replace all OpResults with the given values. + for (OpResult opResult : op->getOpResults()) { + // Skip OpResult if it has no uses. + if (opResult.getUses().empty()) + continue; + + Value replacement = values[opResult.getResultNumber()]; + if (opResult.getType().isa()) { + // The OpResult is a tensor. Such values are replaced with memrefs during + // bufferization. + assert((replacement.getType().isa() || + replacement.getType().isa()) && + "tensor op result should be replaced with a memref value"); + // The existing uses of the OpResult still expect a tensor. Insert a + // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually + // loose all of its users and eventually DCE away. + setInsertionPointAfter(b, replacement); + replacement = b.create(replacement.getLoc(), + replacement); + } + opResult.replaceAllUsesWith(replacement); + } + + op->erase(); +} + LogicalResult mlir::linalg::comprehensive_bufferize::bufferize(Region *region, BufferizationState &state) { @@ -429,8 +460,14 @@ LogicalResult mlir::linalg::comprehensive_bufferize::bufferize(Block *block, BufferizationState &state) { + // Ops may get deleted during the traversal, so do not iterate over `block` + // directly. + SmallVector ops; + ops.reserve(block->getOperations().size()); for (Operation &op : *block) - if (failed(bufferize(&op, state))) + ops.push_back(&op); + for (Operation *op : ops) + if (failed(bufferize(op, state))) return failure(); return success(); } @@ -651,10 +688,13 @@ /// Wrapper for better debugging. Value mlir::linalg::comprehensive_bufferize::BufferizationState::lookupBuffer( Value tensor) { - // TODO: if key comes from bbArg, forward. assert(tensor.getType().isa() && "unexpected non-tensor type"); - Value buffer = mapping.lookupOrNull(tensor); + // Replace "%t = to_tensor %m" with %m. + if (auto toTensorOp = tensor.getDefiningOp()) + return toTensorOp.memref(); + + Value buffer = mapping.lookupOrNull(tensor); if (!buffer) { if (options.allowUnknownOps) { // `tensor` was not bufferized yet. This should never happen with diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp @@ -25,14 +25,17 @@ // TODO: These ops should implement BufferizableOpInterface directly when moved // to the Bufferization dialect. -// TODO: These implementations are conservative and will likely have to be -// loosened for partial bufferization. - /// ToMemrefOp casts a tensor into a memref. The resulting memref is the memory /// location of the incoming tensor once it will be bufferized. In the anlysis, /// the incoming tensor is assumed to bufferize to a memory read and to an /// inplace memory write, since it is unknown what will happen to the resulting /// memref. +/// +/// Note: ToMemrefOp / ToTensorOp are temporary ops that are inserted at the +/// bufferization boundary. When bufferization is complete, there should be no +/// such ops left over. If `allowUnknownOps`, such ops may be part of the +/// resulting IR, but such IR may no longer be bufferizable by Comprehensive +/// Bufferize. struct ToMemrefOpInterface : public BufferizableOpInterface::ExternalModel { @@ -47,6 +50,35 @@ LogicalResult bufferize(Operation *op, OpBuilder &b, BufferizationState &state) const { + auto toMemrefOp = cast(op); + + // Fold to_memref(to_tensor(x)) to x. + if (auto toTensorOp = + toMemrefOp.tensor().getDefiningOp()) { + toMemrefOp.replaceAllUsesWith(toTensorOp.memref()); + toMemrefOp.erase(); + return success(); + } + + // If a ToMemrefOp's tensor operand has not been bufferized yet, the op + // remains unchanged. All IR up to this ToMemrefOp has already been + // bufferized, unless there were unknown ops that could be bufferized. + if (!state.isMapped(toMemrefOp.tensor())) { + assert(state.getOptions().allowUnknownOps && + "expected that tensor is mapped"); + return success(); + } + + // If a ToMemrefOp's tensor operand has been bufferized, the op can be + // removed. + Value memref = state.lookupBuffer(toMemrefOp.tensor()); + // Do not replace a ToMemrefOp with itself. E.g., when bufferizing a + // function body, ToMemrefOps were inserted before starting bufferization of + // the function body. Such ToMemrefOps are replaced in a separate step after + // the function body has been bufferized. + if (toMemrefOp.getResult() != memref) + toMemrefOp.replaceAllUsesWith(memref); + return success(); } }; diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp @@ -7,10 +7,12 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" namespace mlir { namespace linalg { @@ -214,51 +216,79 @@ return true; } - LogicalResult bufferize(Operation *op, OpBuilder &b, + LogicalResult bufferize(Operation *op, OpBuilder & /*b*/, BufferizationState &state) const { auto forOp = cast(op); - - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - - for (OpResult opResult : forOp->getResults()) { - if (!opResult.getType().isa()) - continue; - // TODO: Atm we bail on unranked TensorType because we don't know how to - // alloc an UnrankedMemRefType + its underlying ranked MemRefType. - assert(opResult.getType().isa() && - "unsupported unranked tensor"); - - // TODO: More general: Matching bbArg does not bufferize to a read. - Value resultBuffer = state.getResultBuffer(opResult); - if (!resultBuffer) - return failure(); - - OpOperand &opOperand = forOp.getOpOperandForResult(opResult); - BlockArgument bbArg = forOp.getRegionIterArgForOpOperand(opOperand); - state.mapBuffer(bbArg, resultBuffer); - state.mapBuffer(opResult, resultBuffer); - } + Block *oldLoopBody = &forOp.getLoopBody().front(); + + // Use IRRewriter instead of OpBuilder because it has additional helper + // functions. + IRRewriter rewriter(op->getContext()); + rewriter.setInsertionPoint(forOp); + + // Indices of all iter_args that have tensor type. These are the ones that + // are bufferized. + DenseSet indices; + for (const auto &it : llvm::enumerate(forOp.initArgs())) + if (it.value().getType().isa()) + indices.insert(it.index()); + + // Given a range of values, apply `func` to those marked in `indices`. + // Otherwise, store the unmodified value in the result vector. + auto convert = [&](ValueRange values, + std::function func) { + SmallVector result; + for (const auto &it : llvm::enumerate(values)) { + size_t idx = it.index(); + Value val = it.value(); + result.push_back(indices.contains(idx) ? func(val, idx) : val); + } + return result; + }; + + // Construct a new scf.for op with memref instead of tensor values. + SmallVector initArgs = + convert(forOp.initArgs(), [&](Value val, int64_t index) { + return state.getResultBuffer(forOp->getOpResult(index)); + }); + auto newForOp = + rewriter.create(forOp.getLoc(), forOp.lowerBound(), + forOp.upperBound(), forOp.step(), initArgs); + Block *loopBody = &newForOp.getLoopBody().front(); + + // Set up new iter_args. The loop body uses tensors, so wrap the (memref) + // iter_args of the new loop in ToTensorOps. + rewriter.setInsertionPointToStart(loopBody); + SmallVector iterArgs = + convert(newForOp.getRegionIterArgs(), [&](Value val, int64_t index) { + return rewriter.create(val.getLoc(), val); + }); + iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar()); + + // Erase terminator if present. + if (iterArgs.size() == 1) + rewriter.eraseOp(loopBody->getTerminator()); + + // Move loop body to new loop. + rewriter.mergeBlocks(oldLoopBody, loopBody, iterArgs); + + // Update scf.yield of new loop. + auto yieldOp = cast(loopBody->getTerminator()); + rewriter.setInsertionPoint(yieldOp); + SmallVector yieldValues = + convert(yieldOp.results(), [&](Value val, int64_t index) { + return rewriter.create( + val.getLoc(), initArgs[index].getType(), val); + }); + yieldOp.resultsMutable().assign(yieldValues); + + // Replace loop results. + state.replaceOp(op, newForOp->getResults()); // Bufferize loop body. - if (failed(comprehensive_bufferize::bufferize(&forOp.region(), state))) + if (failed(comprehensive_bufferize::bufferize(loopBody, state))) return failure(); - // Finish bufferizing scf::ForOp. - auto yieldOp = cast(&forOp.region().front().back()); - for (OpOperand &operand : yieldOp->getOpOperands()) { - auto tensorType = operand.get().getType().dyn_cast(); - if (!tensorType) - continue; - - OpOperand &forOperand = forOp.getOpOperandForResult( - forOp->getResult(operand.getOperandNumber())); - auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand); - - // Buffers are equivalent so the work is already done and we just yield - // the bbArg so that it later canonicalizes away. - operand.set(bbArg); - } return success(); } }; diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir @@ -1,14 +1,12 @@ // RUN: mlir-opt %s -allow-unregistered-dialect -linalg-comprehensive-module-bufferize="allow-return-memref allow-unknown-ops" -split-input-file | FileCheck %s -// TODO: Bufferize result IR of bufferization. -// TODO: mlir-opt %s -allow-unregistered-dialect -linalg-comprehensive-module-bufferize="allow-return-memref allow-unknown-ops" -linalg-comprehensive-module-bufferize="allow-return-memref allow-unknown-ops" -split-input-file | FileCheck %s - // Run fuzzer with different seeds. // RUN: mlir-opt %s -allow-unregistered-dialect -linalg-comprehensive-module-bufferize="test-analysis-only analysis-fuzzer-seed=23" -split-input-file -o /dev/null // RUN: mlir-opt %s -allow-unregistered-dialect -linalg-comprehensive-module-bufferize="test-analysis-only analysis-fuzzer-seed=59" -split-input-file -o /dev/null // RUN: mlir-opt %s -allow-unregistered-dialect -linalg-comprehensive-module-bufferize="test-analysis-only analysis-fuzzer-seed=91" -split-input-file -o /dev/null // RUN: mlir-opt %s -allow-unregistered-dialect -test-comprehensive-function-bufferize="dialect-filter=tensor allow-unknown-ops allow-return-memref" -canonicalize -split-input-file | FileCheck %s --check-prefix=CHECK-TENSOR +// RUN: mlir-opt %s -allow-unregistered-dialect -test-comprehensive-function-bufferize="dialect-filter=scf allow-unknown-ops allow-return-memref" -canonicalize -split-input-file | FileCheck %s --check-prefix=CHECK-SCF // CHECK-LABEL: func @use_of_unknown_op_1( // CHECK-SAME: %[[m1:.*]]: memref } + +// ----- + +// CHECK-SCF-LABEL: func @simple_scf_for( +// CHECK-SCF-SAME: %[[t1:.*]]: tensor +func @simple_scf_for( + %t1: tensor, %sz: index, %step: index, %f: f32) -> tensor { + %c0 = arith.constant 0 : index + + // CHECK-SCF: %[[t1_memref:.*]] = bufferization.to_memref %[[t1]] + // CHECK-SCF: %[[alloc:.*]] = memref.alloc + // CHECK-SCF: %[[casted:.*]] = memref.cast %[[alloc]] + // CHECK-SCF: memref.copy %[[t1_memref]], %[[casted]] + // CHECK-SCF: %[[scf_for:.*]] = scf.for %[[iv:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[arg0:.*]] = %[[casted]]) -> ({{.*}}) { + %0 = scf.for %iv = %c0 to %sz step %step iter_args(%arg0 = %t1) -> tensor { + // CHECK-SCF: %[[arg0_tensor:.*]] = bufferization.to_tensor %[[arg0]] + // CHECK-SCF: %[[insert:.*]] = tensor.insert %{{.*}} into %[[arg0_tensor]] + %1 = tensor.insert %f into %arg0[%iv] : tensor + + // CHECK-SCF: %[[insert_memref:.*]] = bufferization.to_memref %[[insert]] + // CHECK-SCF: scf.yield %[[insert_memref]] + scf.yield %1 : tensor + } + // CHECK-SCF: } + + // CHECK-SCF: %[[scf_for_tensor:.*]] = bufferization.to_tensor %[[scf_for]] + // CHECK-SCF: return %[[scf_for_tensor]] + return %0 : tensor +} diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -291,6 +291,7 @@ -> (tensor, tensor) { // CHECK: %[[ALLOC_FOR_A:.*]] = memref.alloc + // CHECK: %[[CASTED:.*]] = memref.cast %[[ALLOC_FOR_A]] // CHECK: linalg.copy(%[[A]], %[[ALLOC_FOR_A]]) // The first scf.for remains but just turns into dead code. @@ -304,7 +305,7 @@ } // CHECK: memref.dealloc %[[ALLOC_FOR_A]] : memref - // CHECK: return %[[ALLOC_FOR_A]] : memref + // CHECK: return %[[CASTED]] : memref return %r0, %r1: tensor, tensor } @@ -346,6 +347,7 @@ -> (tensor, tensor) { // CHECK: %[[ALLOC_FOR_A:.*]] = memref.alloc + // CHECK: %[[CASTED:.*]] = memref.cast %[[ALLOC_FOR_A]] // CHECK: linalg.copy(%[[A]], %[[ALLOC_FOR_A]]) // CHECK: %[[svA:.*]] = memref.subview %[[ALLOC_FOR_A]][0] [4] [1] @@ -369,7 +371,7 @@ } // CHECK: memref.dealloc %[[ALLOC_FOR_A]] : memref - // CHECK: return %[[ALLOC_FOR_A]] : memref + // CHECK: return %[[CASTED]] : memref return %r0#0, %r0#1: tensor, tensor } diff --git a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp --- a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp @@ -101,7 +101,6 @@ // TODO: Find a way to enable this step automatically when bufferizing // tensor dialect ops. options.addPostAnalysisStep(); - options.addPostAnalysisStep(); options.allowReturnMemref = allowReturnMemref; options.allowUnknownOps = allowUnknownOps; diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -6413,6 +6413,7 @@ includes = ["include"], deps = [ ":BufferizableOpInterface", + ":BufferizationDialect", ":IR", ":SCFDialect", ":Support",