diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -11,6 +11,7 @@ #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/DenseMapInfoVariant.h" #include "llvm/ADT/SetVector.h" @@ -411,6 +412,10 @@ /// Specifies whether OpOperands with a different type that are not the result /// of a CastOpInterface op should be followed. bool followSameTypeOrCastsOnly = false; + + /// Specifies whether already visited values should be visited again. + /// (Note: This can result in infinite looping.) + bool revisitAlreadyVisitedValues = false; }; /// AnalysisState provides a variety of helper functions for dealing with @@ -726,4 +731,177 @@ #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h.inc" +//===----------------------------------------------------------------------===// +// Helpers for Unstructured Control Flow +//===----------------------------------------------------------------------===// + +namespace mlir { +namespace bufferization { + +/// A template that provides a default implementation of `getAliasingOpOperands` +/// for ops that support unstructured control flow within their regions. +template +struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel + : public BufferizableOpInterface::ExternalModel { + + /// Return a list of operands that are forwarded to the given block argument. + /// I.e., find all predecessors of the block argument's owner and gather the + /// operands that are equivalent to the block argument. + static SmallVector getCallerOpOperands(BlockArgument bbArg) { + SmallVector result; + Block *block = bbArg.getOwner(); + for (Operation *caller : block->getUsers()) { + auto branchOp = dyn_cast(caller); + if (!branchOp) + continue; + auto it = llvm::find(caller->getSuccessors(), block); + assert(it != caller->getSuccessors().end() && "could find successor"); + int64_t successorIdx = std::distance(caller->getSuccessors().begin(), it); + SuccessorOperands operands = branchOp.getSuccessorOperands(successorIdx); + assert(operands.getProducedOperandCount() == 0 && + "produced operands not supported"); + int64_t operandIndex = + operands.getForwardedOperands().getBeginOperandIndex() + + bbArg.getArgNumber(); + result.push_back(&caller->getOpOperand(operandIndex)); + } + return result; + } + + AliasingOpOperandList + getAliasingOpOperands(Operation *op, Value value, + const AnalysisState &state) const { + // This is a default implementation for block arguments only. If the op also + // has tensor results, the external model should define its own + // implementation and call this base function in case of a block argument. + AliasingOpOperandList result; + auto bbArg = dyn_cast(value); + if (!bbArg) + return result; + + // Gather aliasing OpOperands of all operations (callers) that link to + // this block. + for (OpOperand *opOperand : getCallerOpOperands(bbArg)) + result.addAlias( + {opOperand, BufferRelation::Equivalent, /*isDefinite=*/false}); + + return result; + } + + FailureOr + getBufferType(Operation *op, Value value, const BufferizationOptions &options, + SmallVector &invocationStack) const { + // This is a default implementation for block arguments only. If the op also + // has tensor results, the external model should define its own + // implementation and call this base function in case of a block argument. + if (isa(value)) + return bufferization::detail::defaultGetBufferType(value, options, + invocationStack); + + // Compute the buffer type of the block argument by computing the bufferized + // operand types of all forwarded values. If these are all the same type, + // take that type. Otherwise, take only the memory space and fall back to a + // buffer type with a fully dynamic layout map. + BaseMemRefType bufferType; + auto tensorType = cast(value.getType()); + for (OpOperand *opOperand : + getCallerOpOperands(cast(value))) { + + // If the forwarded operand is already on the invocation stack, we ran + // into a loop and this operand cannot be used to compute the bufferized + // type. + if (llvm::find(invocationStack, opOperand->get()) != + invocationStack.end()) + continue; + + // Compute the bufferized type of the forwarded operand. + BaseMemRefType callerType; + if (auto memrefType = + dyn_cast(opOperand->get().getType())) { + // The operand was already bufferized. Take its type directly. + callerType = memrefType; + } else { + FailureOr maybeCallerType = + bufferization::getBufferType(opOperand->get(), options, + invocationStack); + if (failed(maybeCallerType)) + return failure(); + callerType = *maybeCallerType; + } + + if (!bufferType) { + // This is the first buffer type that we computed. + bufferType = callerType; + continue; + } + + if (bufferType == callerType) + continue; + + // If the computed buffer type does not match the computed buffer type of + // the earlier forwarded operands, fall back to a buffer type with a fully + // dynamic layout map. +#ifndef NDEBUG + assert(llvm::all_equal({bufferType.getShape(), callerType.getShape(), + tensorType.getShape()}) && + "expected same shape"); +#endif // NDEBUG + + if (bufferType.getMemorySpace() != callerType.getMemorySpace()) + return op->emitOpError("incoming operands of block argument have " + "inconsistent memory spaces"); + + bufferType = getMemRefTypeWithFullyDynamicLayout( + tensorType, bufferType.getMemorySpace()); + } + + if (!bufferType) + return op->emitOpError("could not infer buffer type of block argument"); + + return bufferType; + } +}; + +/// A template that provides a default implementation of `getAliasingValues` +/// for ops that implement the `BranchOpInterface`. +template +struct BranchOpBufferizableOpInterfaceExternalModel + : public BufferizableOpInterface::ExternalModel { + AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + AliasingValueList result; + auto branchOp = cast(op); + auto operandNumber = opOperand.getOperandNumber(); + + // Gather aliasing block arguments of blocks to which this op may branch to. + for (const auto &it : llvm::enumerate(op->getSuccessors())) { + Block *block = it.value(); + SuccessorOperands operands = branchOp.getSuccessorOperands(it.index()); + assert(operands.getProducedOperandCount() == 0 && + "produced operands not supported"); + // The first and last operands that are forwarded to this successor. + int64_t firstOperandIndex = + operands.getForwardedOperands().getBeginOperandIndex(); + int64_t lastOperandIndex = + firstOperandIndex + operands.getForwardedOperands().size(); + bool matchingDestination = operandNumber >= firstOperandIndex && + operandNumber < lastOperandIndex; + // A branch op may have multiple successors. Find the ones that correspond + // to this OpOperand. (There is usually only one.) + if (!matchingDestination) + continue; + // Compute the matching block argument of the destination block. + BlockArgument bbArg = + block->getArgument(operandNumber - firstOperandIndex); + result.addAlias( + {bbArg, BufferRelation::Equivalent, /*isDefinite=*/false}); + } + + return result; + } +}; + +} // namespace bufferization +} // namespace mlir + #endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZABLEOPINTERFACE_H_ diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td @@ -410,6 +410,11 @@ the input IR and returns `failure` in that case. If this op is expected to survive bufferization, `success` should be returned (together with `allow-unknown-ops` enabled). + + Note: If this op supports unstructured control flow in its regions, + then this function should also bufferize all block signatures that + belong to this op. Branch ops (that branch to a block) are typically + bufferized together with the block signature. }], /*retType=*/"::mlir::LogicalResult", /*methodName=*/"bufferize", diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h @@ -78,6 +78,21 @@ const OpFilter *opFilter = nullptr, BufferizationStatistics *statistics = nullptr); +/// Bufferize the signature of `block` and its callers. All block argument types +/// are changed to memref types. All corresponding operands of all callers +/// (i.e., ops that have the given block as a successor) are wrapped in +/// bufferization.to_memref ops. +/// +/// It is expected that all callers implement the `BranchOpInterface`. +/// Otherwise, this function will fail. The `BranchOpInterface` is used to query +/// the range of operands that are forwarded to this block. +/// +/// It is expected that the parent op of this block implements the +/// `BufferizableOpInterface`. The buffer types of tensor block arugments are +/// computed with `BufferizableOpIntercace::getBufferType`. +LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter, + const BufferizationOptions &options); + BufferizationOptions getPartialBufferizationOptions(); } // namespace bufferization diff --git a/mlir/include/mlir/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.h @@ -0,0 +1,20 @@ +//===- BufferizableOpInterfaceImpl.h - Impl. of BufferizableOpInterface ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_CONTROLFLOW_BUFFERIZABLEOPINTERFACEIMPL_H +#define MLIR_DIALECT_CONTROLFLOW_BUFFERIZABLEOPINTERFACEIMPL_H + +namespace mlir { +class DialectRegistry; + +namespace cf { +void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); +} // namespace cf +} // namespace mlir + +#endif // MLIR_DIALECT_CONTROLFLOW_BUFFERIZABLEOPINTERFACEIMPL_H diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -31,6 +31,7 @@ #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -157,6 +158,7 @@ bufferization::func_ext::registerBufferizableOpInterfaceExternalModels( registry); builtin::registerCastOpInterfaceExternalModels(registry); + cf::registerBufferizableOpInterfaceExternalModels(registry); linalg::registerBufferizableOpInterfaceExternalModels(registry); linalg::registerTilingInterfaceExternalModels(registry); linalg::registerValueBoundsOpInterfaceExternalModels(registry); diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -234,6 +234,7 @@ }); if (aliasingValues.getNumAliases() == 1 && + isa(aliasingValues.getAliases()[0].value) && !state.bufferizesToMemoryWrite(opOperand) && state.getAliasingOpOperands(aliasingValues.getAliases()[0].value) .getNumAliases() == 1 && @@ -498,11 +499,16 @@ bool AnalysisState::isValueRead(Value value) const { assert(llvm::isa(value.getType()) && "expected TensorType"); SmallVector workingSet; + DenseSet visited; for (OpOperand &use : value.getUses()) workingSet.push_back(&use); while (!workingSet.empty()) { OpOperand *uMaybeReading = workingSet.pop_back_val(); + if (visited.contains(uMaybeReading)) + continue; + visited.insert(uMaybeReading); + // Skip over all ops that neither read nor write (but create an alias). if (bufferizesToAliasOnly(*uMaybeReading)) for (AliasingValue alias : getAliasingValues(*uMaybeReading)) @@ -522,11 +528,21 @@ llvm::SetVector AnalysisState::findValueInReverseUseDefChain( Value value, llvm::function_ref condition, TraversalConfig config) const { + llvm::DenseSet visited; llvm::SetVector result, workingSet; workingSet.insert(value); while (!workingSet.empty()) { Value value = workingSet.pop_back_val(); + + if (!config.revisitAlreadyVisitedValues && visited.contains(value)) { + // Stop traversal if value was already visited. + if (config.alwaysIncludeLeaves) + result.insert(value); + continue; + } + visited.insert(value); + if (condition(value)) { result.insert(value); continue; @@ -659,11 +675,15 @@ // preceding value, so we can follow SSA use-def chains and do a simple // analysis. SmallVector worklist; + DenseSet visited; for (OpOperand &use : tensor.getUses()) worklist.push_back(&use); while (!worklist.empty()) { OpOperand *operand = worklist.pop_back_val(); + if (visited.contains(operand)) + continue; + visited.insert(operand); Operation *op = operand->getOwner(); // If the op is not bufferizable, we can safely assume that the value is not diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Operation.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -321,6 +322,16 @@ /// Return true if the given op has a tensor result or a tensor operand. static bool hasTensorSemantics(Operation *op) { + bool hasTensorBlockArgument = any_of(op->getRegions(), [](Region &r) { + return any_of(r.getBlocks(), [](Block &b) { + return any_of(b.getArguments(), [](BlockArgument bbArg) { + return isaTensor(bbArg.getType()); + }); + }); + }); + if (hasTensorBlockArgument) + return true; + if (auto funcOp = dyn_cast(op)) { bool hasTensorArg = any_of(funcOp.getArgumentTypes(), isaTensor); bool hasTensorResult = any_of(funcOp.getResultTypes(), isaTensor); @@ -535,6 +546,93 @@ return success(); } +LogicalResult +bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter, + const BufferizationOptions &options) { + OpBuilder::InsertionGuard g(rewriter); + auto bufferizableOp = options.dynCastBufferizableOp(block->getParentOp()); + if (!bufferizableOp) + return failure(); + + // Compute the new signature. + SmallVector newTypes; + for (BlockArgument &bbArg : block->getArguments()) { + auto tensorType = dyn_cast(bbArg.getType()); + if (!tensorType) { + newTypes.push_back(bbArg.getType()); + continue; + } + + FailureOr memrefType = + bufferization::getBufferType(bbArg, options); + if (failed(memrefType)) + return failure(); + newTypes.push_back(*memrefType); + } + + // Change the type of all block arguments. + for (auto [bbArg, type] : llvm::zip(block->getArguments(), newTypes)) { + if (bbArg.getType() == type) + continue; + + // Collect all uses of the bbArg. + SmallVector bbArgUses; + for (OpOperand &use : bbArg.getUses()) + bbArgUses.push_back(&use); + + // Change the bbArg type to memref. + bbArg.setType(type); + + // Replace all uses of the original tensor bbArg. + rewriter.setInsertionPointToStart(block); + if (!bbArgUses.empty()) { + Value toTensorOp = + rewriter.create(bbArg.getLoc(), bbArg); + for (OpOperand *use : bbArgUses) + use->set(toTensorOp); + } + } + + // Bufferize callers of the block. + for (Operation *op : block->getUsers()) { + auto branchOp = dyn_cast(op); + if (!branchOp) + return op->emitOpError("cannot bufferize ops with block references that " + "do not implement BranchOpInterface"); + + auto it = llvm::find(op->getSuccessors(), block); + assert(it != op->getSuccessors().end() && "could find successor"); + int64_t successorIdx = std::distance(op->getSuccessors().begin(), it); + + SuccessorOperands operands = branchOp.getSuccessorOperands(successorIdx); + SmallVector newOperands; + for (auto [operand, type] : + llvm::zip(operands.getForwardedOperands(), newTypes)) { + if (operand.getType() == type) { + // Not a tensor type. Nothing to do for this operand. + newOperands.push_back(operand); + continue; + } + FailureOr operandBufferType = + bufferization::getBufferType(operand, options); + if (failed(operandBufferType)) + return failure(); + rewriter.setInsertionPointAfterValue(operand); + Value bufferizedOperand = rewriter.create( + operand.getLoc(), *operandBufferType, operand); + // A cast is needed if the operand and the block argument have different + // bufferized types. + if (type != *operandBufferType) + bufferizedOperand = rewriter.create( + operand.getLoc(), type, bufferizedOperand); + newOperands.push_back(bufferizedOperand); + } + operands.getMutableForwardedOperands().assign(newOperands); + } + + return success(); +} + BufferizationOptions bufferization::getPartialBufferizationOptions() { BufferizationOptions options; options.allowUnknownOps = true; diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" @@ -318,16 +319,39 @@ }; struct FuncOpInterface - : public BufferizableOpInterface::ExternalModel { + : public OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel< + FuncOpInterface, FuncOp> { + + static bool supportsUnstructuredControlFlow() { return true; } + FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, SmallVector &invocationStack) const { auto funcOp = cast(op); auto bbArg = cast(value); - // Unstructured control flow is not supported. - assert(bbArg.getOwner() == &funcOp.getBody().front() && - "expected that block argument belongs to first block"); - return getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options); + + // Function arguments are special. + if (bbArg.getOwner() == &funcOp.getBody().front()) + return getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), + options); + + return OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel:: + getBufferType(op, value, options, invocationStack); + } + + LogicalResult verifyAnalysis(Operation *op, + const AnalysisState &state) const { + auto funcOp = cast(op); + // TODO: func.func with multiple returns are not supported. + if (!getAssumedUniqueReturnOp(funcOp) && !funcOp.getBody().empty()) + return op->emitOpError("op without unique func.return is not supported"); + const auto &options = + static_cast(state.getOptions()); + // allow-return-allocs is required for ops with multiple blocks. + if (options.allowReturnAllocs || funcOp.getRegion().getBlocks().size() <= 1) + return success(); + return op->emitOpError( + "op cannot be bufferized without allow-return-allocs"); } /// Rewrite function bbArgs and return values into buffer form. This function @@ -374,37 +398,11 @@ assert(returnOp && "expected func with single return op"); Location loc = returnOp.getLoc(); - // 1. Rewrite the bbArgs. Turn every tensor bbArg into a memref bbArg. - Block &frontBlock = funcOp.getBody().front(); - for (BlockArgument &bbArg : frontBlock.getArguments()) { - auto tensorType = dyn_cast(bbArg.getType()); - // Non-tensor types stay the same. - if (!tensorType) - continue; - - // Collect all uses of the bbArg. - SmallVector bbArgUses; - for (OpOperand &use : bbArg.getUses()) - bbArgUses.push_back(&use); - - // Change the bbArg type to memref. - FailureOr memrefType = - bufferization::getBufferType(bbArg, options); - if (failed(memrefType)) + // 1. Bufferize every block. + for (Block &block : funcOp.getBody()) + if (failed(bufferization::bufferizeBlockSignature(&block, rewriter, + options))) return failure(); - bbArg.setType(*memrefType); - - // Replace all uses of the original tensor bbArg. - rewriter.setInsertionPointToStart(&frontBlock); - if (!bbArgUses.empty()) { - // Insert to_tensor because the remaining function body has not been - // bufferized yet. - Value toTensorOp = - rewriter.create(funcOp.getLoc(), bbArg); - for (OpOperand *use : bbArgUses) - use->set(toTensorOp); - } - } // 2. For each result, keep track of which inplace argument it reuses. SmallVector returnValues; @@ -445,6 +443,11 @@ BlockArgument bbArg = dyn_cast(value); assert(bbArg && "expected BlockArgument"); + // Non-entry block arguments are always writable. (They may alias with + // values that are not writable, which will turn them into read-only.) + if (bbArg.getOwner() != &funcOp.getBody().front()) + return true; + // "bufferization.writable" overrides other writability decisions. This is // currently used for testing only. if (BoolAttr writable = funcOp.getArgAttrOfType( diff --git a/mlir/lib/Dialect/ControlFlow/CMakeLists.txt b/mlir/lib/Dialect/ControlFlow/CMakeLists.txt --- a/mlir/lib/Dialect/ControlFlow/CMakeLists.txt +++ b/mlir/lib/Dialect/ControlFlow/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/mlir/lib/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.cpp @@ -0,0 +1,75 @@ +//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.h" + +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" + +using namespace mlir; +using namespace mlir::bufferization; + +namespace mlir { +namespace cf { +namespace { + +template +struct BranchLikeOpInterface + : public BranchOpBufferizableOpInterfaceExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return false; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return false; + } + + LogicalResult verifyAnalysis(Operation *op, + const AnalysisState &state) const { + const auto &options = + static_cast(state.getOptions()); + if (options.allowReturnAllocs) + return success(); + return op->emitOpError("op cannot be bufferized with allow-return-allocs"); + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const BufferizationOptions &options) const { + // The operands of this op are bufferized together with the block signature. + return success(); + } +}; + +/// Bufferization of cf.br. +struct BranchOpInterface + : public BranchLikeOpInterface {}; + +/// Bufferization of cf.cond_br. +struct CondBranchOpInterface + : public BranchLikeOpInterface {}; + +} // namespace +} // namespace cf +} // namespace mlir + +void mlir::cf::registerBufferizableOpInterfaceExternalModels( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, cf::ControlFlowDialect *dialect) { + cf::BranchOp::attachInterface(*ctx); + cf::CondBranchOp::attachInterface(*ctx); + }); +} diff --git a/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt @@ -0,0 +1,13 @@ +add_mlir_dialect_library(MLIRControlFlowTransforms + BufferizableOpInterfaceImpl.cpp + + ADDITIONAL_HEADER_DIRS + {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ControlFlow/Transforms + + LINK_LIBS PUBLIC + MLIRBufferizationDialect + MLIRBufferizationTransforms + MLIRControlFlowDialect + MLIRMemRefDialect + MLIRIR + ) diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" @@ -99,37 +100,75 @@ } }; +/// Return the unique scf.yield op. If there are multiple or none scf.yield ops, +/// return an empty op. +static scf::YieldOp getUniqueYieldOp(scf::ExecuteRegionOp executeRegionOp) { + scf::YieldOp result; + for (Block &block : executeRegionOp.getRegion()) { + if (auto yieldOp = dyn_cast(block.getTerminator())) { + if (result) + return {}; + result = yieldOp; + } + } + return result; +} + /// Bufferization of scf.execute_region. Can be analyzed, but bufferization not /// fully implemented at the moment. struct ExecuteRegionOpInterface - : public BufferizableOpInterface::ExternalModel { + : public OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel< + ExecuteRegionOpInterface, scf::ExecuteRegionOp> { + + static bool supportsUnstructuredControlFlow() { return true; } + + bool isWritable(Operation *op, Value value, + const AnalysisState &state) const { + return true; + } + + LogicalResult verifyAnalysis(Operation *op, + const AnalysisState &state) const { + auto executeRegionOp = cast(op); + // TODO: scf.execute_region with multiple yields are not supported. + if (!getUniqueYieldOp(executeRegionOp)) + return op->emitOpError("op without unique scf.yield is not supported"); + const auto &options = + static_cast(state.getOptions()); + // allow-return-allocs is required for ops with multiple blocks. + if (options.allowReturnAllocs || + executeRegionOp.getRegion().getBlocks().size() == 1) + return success(); + return op->emitOpError( + "op cannot be bufferized without allow-return-allocs"); + } + AliasingOpOperandList getAliasingOpOperands(Operation *op, Value value, const AnalysisState &state) const { + if (isa(value)) + return OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel:: + getAliasingOpOperands(op, value, state); + // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be // any SSA value that is in scope. To allow for use-def chain traversal // through ExecuteRegionOps in the analysis, the corresponding yield value // is considered to be aliasing with the result. auto executeRegionOp = cast(op); - size_t resultNum = std::distance(op->getOpResults().begin(), - llvm::find(op->getOpResults(), value)); - // TODO: Support multiple blocks. - assert(executeRegionOp.getRegion().getBlocks().size() == 1 && - "expected exactly 1 block"); - auto yieldOp = dyn_cast( - executeRegionOp.getRegion().front().getTerminator()); - assert(yieldOp && "expected scf.yield terminator in scf.execute_region"); + auto it = llvm::find(op->getOpResults(), value); + assert(it != op->getOpResults().end() && "invalid value"); + size_t resultNum = std::distance(op->getOpResults().begin(), it); + auto yieldOp = getUniqueYieldOp(executeRegionOp); + // Note: If there is no unique scf.yield op, `verifyAnalysis` will fail. + if (!yieldOp) + return {}; return {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}}; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto executeRegionOp = cast(op); - assert(executeRegionOp.getRegion().getBlocks().size() == 1 && - "only 1 block supported"); - auto yieldOp = - cast(executeRegionOp.getRegion().front().getTerminator()); + auto yieldOp = getUniqueYieldOp(executeRegionOp); TypeRange newResultTypes(yieldOp.getResults()); // Create new op and move over region. @@ -137,6 +176,12 @@ rewriter.create(op->getLoc(), newResultTypes); newOp.getRegion().takeBody(executeRegionOp.getRegion()); + // Bufferize every block. + for (Block &block : newOp.getRegion()) + if (failed(bufferization::bufferizeBlockSignature(&block, rewriter, + options))) + return failure(); + // Update all uses of the old op. rewriter.setInsertionPointAfter(newOp); SmallVector newResults; diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir @@ -659,3 +659,15 @@ return %r1 : vector<5xf32> } + +// ----- + +// Note: The cf.br canonicalizes away, so there's nothing to test here. There is +// a detailed test in ControlFlow/bufferize.mlir. + +// CHECK-LABEL: func @br_in_func( +func.func @br_in_func(%t: tensor<5xf32>) -> tensor<5xf32> { + cf.br ^bb1(%t : tensor<5xf32>) +^bb1(%arg1 : tensor<5xf32>): + func.return %arg1 : tensor<5xf32> +} diff --git a/mlir/test/Dialect/ControlFlow/one-shot-bufferize-analysis.mlir b/mlir/test/Dialect/ControlFlow/one-shot-bufferize-analysis.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/ControlFlow/one-shot-bufferize-analysis.mlir @@ -0,0 +1,87 @@ +// RUN: mlir-opt -one-shot-bufferize="allow-return-allocs test-analysis-only dump-alias-sets bufferize-function-boundaries" -split-input-file %s | FileCheck %s + +// CHECK-LABEL: func @single_branch( +// CHECK-SAME: {__bbarg_alias_set_attr__ = [{{\[}}[{{\[}}"%[[arg1:.*]]", "%[[t:.*]]"]], [{{\[}}"%[[arg1]]", "%[[t]]"]]]]} +func.func @single_branch(%t: tensor<5xf32>) -> tensor<5xf32> { +// CHECK: cf.br +// CHECK-SAME: {__inplace_operands_attr__ = ["true"]} + cf.br ^bb1(%t : tensor<5xf32>) +// CHECK: ^{{.*}}(%[[arg1]]: tensor<5xf32>) +^bb1(%arg1 : tensor<5xf32>): + func.return %arg1 : tensor<5xf32> +} + +// ----- + +// CHECK-LABEL: func @diamond_branch( +// CHECK-SAME: %{{.*}}: i1, %[[t0:.*]]: tensor<5xf32> {{.*}}, %[[t1:.*]]: tensor<5xf32> {{.*}}) -> tensor<5xf32> +// CHECK-SAME: {__bbarg_alias_set_attr__ = [{{\[}}[{{\[}}"%[[arg1:.*]]", "%[[arg3:.*]]", "%[[arg2:.*]]", "%[[t0]]", "%[[t1]]"], [ +func.func @diamond_branch(%c: i1, %t0: tensor<5xf32>, %t1: tensor<5xf32>) -> tensor<5xf32> { +// CHECK: cf.cond_br +// CHECK-SAME: {__inplace_operands_attr__ = ["none", "true", "true"]} + cf.cond_br %c, ^bb1(%t0 : tensor<5xf32>), ^bb2(%t1 : tensor<5xf32>) +// CHECK: ^{{.*}}(%[[arg1]]: tensor<5xf32>): +^bb3(%arg1 : tensor<5xf32>): + func.return %arg1 : tensor<5xf32> +// CHECK: ^{{.*}}(%[[arg2]]: tensor<5xf32>): +^bb1(%arg2 : tensor<5xf32>): +// CHECK: cf.br +// CHECK-SAME: {__inplace_operands_attr__ = ["true"]} + cf.br ^bb3(%arg2 : tensor<5xf32>) +// CHECK: ^{{.*}}(%[[arg3]]: tensor<5xf32>): +^bb2(%arg3 : tensor<5xf32>): +// CHECK: cf.br +// CHECK-SAME: {__inplace_operands_attr__ = ["true"]} + cf.br ^bb3(%arg3 : tensor<5xf32>) +} + +// ----- + +// CHECK-LABEL: func @looping_branches( +// CHECK-SAME: {__bbarg_alias_set_attr__ = [{{\[}}[], [{{\[}}"%[[arg2:.*]]", "%[[arg1:.*]]", "%[[inserted:.*]]", "%[[empty:.*]]"]], [ +func.func @looping_branches() -> tensor<5xf32> { +// CHECK: %[[empty]] = tensor.empty() + %0 = tensor.empty() : tensor<5xf32> +// CHECK: cf.br +// CHECK-SAME: {__inplace_operands_attr__ = ["true"]} + cf.br ^bb1(%0: tensor<5xf32>) +// CHECK: ^{{.*}}(%[[arg1]]: tensor<5xf32>): +^bb1(%arg1: tensor<5xf32>): + %pos = "test.foo"() : () -> (index) + %val = "test.bar"() : () -> (f32) +// CHECK: %[[inserted]] = tensor.insert +// CHECK-SAME: __inplace_operands_attr__ = ["none", "true", "none"] + %inserted = tensor.insert %val into %arg1[%pos] : tensor<5xf32> + %cond = "test.qux"() : () -> (i1) +// CHECK: cf.cond_br +// CHECK-SAME: {__inplace_operands_attr__ = ["none", "true", "true"]} + cf.cond_br %cond, ^bb1(%inserted: tensor<5xf32>), ^bb2(%inserted: tensor<5xf32>) +^bb2(%arg2: tensor<5xf32>): + func.return %arg2 : tensor<5xf32> +} + +// ----- + +// CHECK-LABEL: func @looping_branches_with_conflict( +func.func @looping_branches_with_conflict(%f: f32) -> tensor<5xf32> { + %0 = tensor.empty() : tensor<5xf32> + %filled = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32> +// CHECK: cf.br +// CHECK-SAME: {__inplace_operands_attr__ = ["false"]} + cf.br ^bb1(%filled: tensor<5xf32>) +^bb2(%arg2: tensor<5xf32>): + %pos2 = "test.foo"() : () -> (index) + // One OpOperand cannot bufferize in-place because an "old" value is read. + %element = tensor.extract %filled[%pos2] : tensor<5xf32> + func.return %arg2 : tensor<5xf32> +^bb1(%arg1: tensor<5xf32>): + %pos = "test.foo"() : () -> (index) + %val = "test.bar"() : () -> (f32) +// CHECK: tensor.insert +// CHECK-SAME: __inplace_operands_attr__ = ["none", "true", "none"] + %inserted = tensor.insert %val into %arg1[%pos] : tensor<5xf32> + %cond = "test.qux"() : () -> (i1) +// CHECK: cf.cond_br +// CHECK-SAME: {__inplace_operands_attr__ = ["none", "true", "true"]} + cf.cond_br %cond, ^bb1(%inserted: tensor<5xf32>), ^bb2(%inserted: tensor<5xf32>) +} diff --git a/mlir/test/Dialect/ControlFlow/one-shot-bufferize-invalid.mlir b/mlir/test/Dialect/ControlFlow/one-shot-bufferize-invalid.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/ControlFlow/one-shot-bufferize-invalid.mlir @@ -0,0 +1,24 @@ +// RUN: mlir-opt -one-shot-bufferize="allow-return-allocs bufferize-function-boundaries" -split-input-file %s -verify-diagnostics + +// expected-error @below{{failed to bufferize op}} +// expected-error @below{{incoming operands of block argument have inconsistent memory spaces}} +func.func @inconsistent_memory_space() -> tensor<5xf32> { + %0 = bufferization.alloc_tensor() {memory_space = 0 : ui64} : tensor<5xf32> + cf.br ^bb1(%0: tensor<5xf32>) +^bb1(%arg1: tensor<5xf32>): + func.return %arg1 : tensor<5xf32> +^bb2(): + %1 = bufferization.alloc_tensor() {memory_space = 1 : ui64} : tensor<5xf32> + cf.br ^bb1(%1: tensor<5xf32>) +} + +// ----- + +// expected-error @below{{failed to bufferize op}} +// expected-error @below{{could not infer buffer type of block argument}} +func.func @cannot_infer_type() { + return + // The type of the block argument cannot be inferred. +^bb1(%t: tensor<5xf32>): + cf.br ^bb1(%t: tensor<5xf32>) +} diff --git a/mlir/test/Dialect/ControlFlow/one-shot-bufferize.mlir b/mlir/test/Dialect/ControlFlow/one-shot-bufferize.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/ControlFlow/one-shot-bufferize.mlir @@ -0,0 +1,72 @@ +// RUN: mlir-opt -one-shot-bufferize="allow-return-allocs bufferize-function-boundaries" -split-input-file %s | FileCheck %s +// RUN: mlir-opt -one-shot-bufferize="allow-return-allocs" -split-input-file %s | FileCheck %s --check-prefix=CHECK-NO-FUNC + +// CHECK-NO-FUNC-LABEL: func @br( +// CHECK-NO-FUNC-SAME: %[[t:.*]]: tensor<5xf32>) +// CHECK-NO-FUNC: %[[m:.*]] = bufferization.to_memref %[[t]] : memref<5xf32, strided<[?], offset: ?>> +// CHECK-NO-FUNC: %[[r:.*]] = scf.execute_region -> memref<5xf32, strided<[?], offset: ?>> { +// CHECK-NO-FUNC: cf.br ^[[block:.*]](%[[m]] +// CHECK-NO-FUNC: ^[[block]](%[[arg1:.*]]: memref<5xf32, strided<[?], offset: ?>>): +// CHECK-NO-FUNC: scf.yield %[[arg1]] +// CHECK-NO-FUNC: } +// CHECK-NO-FUNC: return +func.func @br(%t: tensor<5xf32>) { + %0 = scf.execute_region -> tensor<5xf32> { + cf.br ^bb1(%t : tensor<5xf32>) + ^bb1(%arg1 : tensor<5xf32>): + scf.yield %arg1 : tensor<5xf32> + } + return +} + +// ----- + +// CHECK-NO-FUNC-LABEL: func @cond_br( +// CHECK-NO-FUNC-SAME: %[[t1:.*]]: tensor<5xf32>, +// CHECK-NO-FUNC: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<5xf32, strided<[?], offset: ?>> +// CHECK-NO-FUNC: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<5xf32> +// CHECK-NO-FUNC: %[[r:.*]] = scf.execute_region -> memref<5xf32, strided<[?], offset: ?>> { +// CHECK-NO-FUNC: cf.cond_br %{{.*}}, ^[[block1:.*]](%[[m1]] : {{.*}}), ^[[block2:.*]](%[[alloc]] : {{.*}}) +// CHECK-NO-FUNC: ^[[block1]](%[[arg1:.*]]: memref<5xf32, strided<[?], offset: ?>>): +// CHECK-NO-FUNC: scf.yield %[[arg1]] +// CHECK-NO-FUNC: ^[[block2]](%[[arg2:.*]]: memref<5xf32>): +// CHECK-NO-FUNC: %[[cast:.*]] = memref.cast %[[arg2]] : memref<5xf32> to memref<5xf32, strided<[?], offset: ?> +// CHECK-NO-FUNC: cf.br ^[[block1]](%[[cast]] : {{.*}}) +// CHECK-NO-FUNC: } +// CHECK-NO-FUNC: return +func.func @cond_br(%t1: tensor<5xf32>, %c: i1) { + // Use an alloc for the second block instead of a function block argument. + // A cast must be inserted because the two will have different layout maps. + %t0 = bufferization.alloc_tensor() : tensor<5xf32> + %0 = scf.execute_region -> tensor<5xf32> { + cf.cond_br %c, ^bb1(%t1 : tensor<5xf32>), ^bb2(%t0 : tensor<5xf32>) + ^bb1(%arg1 : tensor<5xf32>): + scf.yield %arg1 : tensor<5xf32> + ^bb2(%arg2 : tensor<5xf32>): + cf.br ^bb1(%arg2 : tensor<5xf32>) + } + return +} + +// ----- + +// CHECK-LABEL: func @looping_branches( +func.func @looping_branches() -> tensor<5xf32> { +// CHECK: %[[alloc:.*]] = memref.alloc + %0 = bufferization.alloc_tensor() : tensor<5xf32> +// CHECK: cf.br {{.*}}(%[[alloc]] + cf.br ^bb1(%0: tensor<5xf32>) +// CHECK: ^{{.*}}(%[[arg1:.*]]: memref<5xf32>): +^bb1(%arg1: tensor<5xf32>): + %pos = "test.foo"() : () -> (index) + %val = "test.bar"() : () -> (f32) +// CHECK: memref.store %{{.*}}, %[[arg1]] + %inserted = tensor.insert %val into %arg1[%pos] : tensor<5xf32> + %cond = "test.qux"() : () -> (i1) +// CHECK: cf.cond_br {{.*}}(%[[arg1]] {{.*}}(%[[arg1]] + cf.cond_br %cond, ^bb1(%inserted: tensor<5xf32>), ^bb2(%inserted: tensor<5xf32>) +// CHECK: ^{{.*}}(%[[arg2:.*]]: memref<5xf32>): +^bb2(%arg2: tensor<5xf32>): +// CHECK: return %[[arg2]] + func.return %arg2 : tensor<5xf32> +} diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize-invalid.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize-invalid.mlir --- a/mlir/test/Dialect/SCF/one-shot-bufferize-invalid.mlir +++ b/mlir/test/Dialect/SCF/one-shot-bufferize-invalid.mlir @@ -17,10 +17,10 @@ // ----- -func.func @execute_region_multiple_blocks(%t: tensor<5xf32>) -> tensor<5xf32> { - // expected-error @below{{op or BufferizableOpInterface implementation does not support unstructured control flow, but at least one region has multiple blocks}} +func.func @execute_region_multiple_yields(%t: tensor<5xf32>) -> tensor<5xf32> { + // expected-error @below{{op op without unique scf.yield is not supported}} %0 = scf.execute_region -> tensor<5xf32> { - cf.br ^bb1(%t : tensor<5xf32>) + scf.yield %t : tensor<5xf32> ^bb1(%arg1 : tensor<5xf32>): scf.yield %arg1 : tensor<5xf32> } diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -4005,6 +4005,24 @@ ], ) +cc_library( + name = "ControlFlowTransforms", + srcs = glob([ + "lib/Dialect/ControlFlow/Transforms/*.cpp", + ]), + hdrs = glob([ + "include/mlir/Dialect/ControlFlow/Transforms/*.h", + ]), + includes = ["include"], + deps = [ + ":BufferizationDialect", + ":BufferizationTransforms", + ":ControlFlowDialect", + ":IR", + ":MemRefDialect", + ], +) + cc_library( name = "FuncDialect", srcs = glob( @@ -8164,6 +8182,7 @@ ":ComplexToLibm", ":ComplexToSPIRV", ":ControlFlowDialect", + ":ControlFlowTransforms", ":ConversionPasses", ":ConvertToLLVM", ":DLTIDialect",