diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -411,6 +411,10 @@ /// Specifies whether OpOperands with a different type that are not the result /// of a CastOpInterface op should be followed. bool followSameTypeOrCastsOnly = false; + + /// Specifies whether already visited values should be visited again. + /// (Note: This can result in infinite looping.) + bool revisitAlreadyVisitedValues = false; }; /// AnalysisState provides a variety of helper functions for dealing with diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td @@ -415,6 +415,13 @@ the input IR and returns `failure` in that case. If this op is expected to survive bufferization, `success` should be returned (together with `allow-unknown-ops` enabled). + + Note: If this op supports unstructured control flow in its regions, + then this function should also bufferize all block signatures that + belong to this op. Branch ops (that branch to a block) are typically + bufferized together with the block signature (this is just a + suggestion to make sure IR is valid at every point in time and could + be done differently). }], /*retType=*/"::mlir::LogicalResult", /*methodName=*/"bufferize", diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h @@ -0,0 +1,179 @@ +//===- UnstructuredControlFlow.h - Op Interface Helpers ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_BUFFERIZATION_IR_UNSTRUCTUREDCONTROLFLOW_H_ +#define MLIR_DIALECT_BUFFERIZATION_IR_UNSTRUCTUREDCONTROLFLOW_H_ + +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" + +//===----------------------------------------------------------------------===// +// Helpers for Unstructured Control Flow +//===----------------------------------------------------------------------===// + +namespace mlir { +namespace bufferization { + +namespace detail { +/// Return a list of operands that are forwarded to the given block argument. +/// I.e., find all predecessors of the block argument's owner and gather the +/// operands that are equivalent to the block argument. +SmallVector getCallerOpOperands(BlockArgument bbArg); +} // namespace detail + +/// A template that provides a default implementation of `getAliasingOpOperands` +/// for ops that support unstructured control flow within their regions. +template +struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel + : public BufferizableOpInterface::ExternalModel { + + FailureOr + getBufferType(Operation *op, Value value, const BufferizationOptions &options, + SmallVector &invocationStack) const { + // Note: The user may want to override this function for OpResults in + // case the bufferized result type is different from the bufferized type of + // the aliasing OpOperand (if any). + if (isa(value)) + return bufferization::detail::defaultGetBufferType(value, options, + invocationStack); + + // Compute the buffer type of the block argument by computing the bufferized + // operand types of all forwarded values. If these are all the same type, + // take that type. Otherwise, take only the memory space and fall back to a + // buffer type with a fully dynamic layout map. + BaseMemRefType bufferType; + auto tensorType = cast(value.getType()); + for (OpOperand *opOperand : + detail::getCallerOpOperands(cast(value))) { + + // If the forwarded operand is already on the invocation stack, we ran + // into a loop and this operand cannot be used to compute the bufferized + // type. + if (llvm::find(invocationStack, opOperand->get()) != + invocationStack.end()) + continue; + + // Compute the bufferized type of the forwarded operand. + BaseMemRefType callerType; + if (auto memrefType = + dyn_cast(opOperand->get().getType())) { + // The operand was already bufferized. Take its type directly. + callerType = memrefType; + } else { + FailureOr maybeCallerType = + bufferization::getBufferType(opOperand->get(), options, + invocationStack); + if (failed(maybeCallerType)) + return failure(); + callerType = *maybeCallerType; + } + + if (!bufferType) { + // This is the first buffer type that we computed. + bufferType = callerType; + continue; + } + + if (bufferType == callerType) + continue; + + // If the computed buffer type does not match the computed buffer type + // of the earlier forwarded operands, fall back to a buffer type with a + // fully dynamic layout map. +#ifndef NDEBUG + if (auto rankedTensorType = dyn_cast(tensorType)) { + assert(bufferType.hasRank() && callerType.hasRank() && + "expected ranked memrefs"); + assert(llvm::all_equal({bufferType.getShape(), callerType.getShape(), + rankedTensorType.getShape()}) && + "expected same shape"); + } else { + assert(!bufferType.hasRank() && !callerType.hasRank() && + "expected unranked memrefs"); + } +#endif // NDEBUG + + if (bufferType.getMemorySpace() != callerType.getMemorySpace()) + return op->emitOpError("incoming operands of block argument have " + "inconsistent memory spaces"); + + bufferType = getMemRefTypeWithFullyDynamicLayout( + tensorType, bufferType.getMemorySpace()); + } + + if (!bufferType) + return op->emitOpError("could not infer buffer type of block argument"); + + return bufferType; + } + +protected: + /// Assuming that `bbArg` is a block argument of a block that belongs to the + /// given `op`, return all OpOperands of users of this block that are + /// aliasing with the given block argument. + AliasingOpOperandList + getAliasingBranchOpOperands(Operation *op, BlockArgument bbArg, + const AnalysisState &state) const { + assert(bbArg.getOwner()->getParentOp() == op && "invalid bbArg"); + + // Gather aliasing OpOperands of all operations (callers) that link to + // this block. + AliasingOpOperandList result; + for (OpOperand *opOperand : detail::getCallerOpOperands(bbArg)) + result.addAlias( + {opOperand, BufferRelation::Equivalent, /*isDefinite=*/false}); + + return result; + } +}; + +/// A template that provides a default implementation of `getAliasingValues` +/// for ops that implement the `BranchOpInterface`. +template +struct BranchOpBufferizableOpInterfaceExternalModel + : public BufferizableOpInterface::ExternalModel { + AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + AliasingValueList result; + auto branchOp = cast(op); + auto operandNumber = opOperand.getOperandNumber(); + + // Gather aliasing block arguments of blocks to which this op may branch to. + for (const auto &it : llvm::enumerate(op->getSuccessors())) { + Block *block = it.value(); + SuccessorOperands operands = branchOp.getSuccessorOperands(it.index()); + assert(operands.getProducedOperandCount() == 0 && + "produced operands not supported"); + if (operands.getForwardedOperands().empty()) + continue; + // The first and last operands that are forwarded to this successor. + int64_t firstOperandIndex = + operands.getForwardedOperands().getBeginOperandIndex(); + int64_t lastOperandIndex = + firstOperandIndex + operands.getForwardedOperands().size(); + bool matchingDestination = operandNumber >= firstOperandIndex && + operandNumber < lastOperandIndex; + // A branch op may have multiple successors. Find the ones that correspond + // to this OpOperand. (There is usually only one.) + if (!matchingDestination) + continue; + // Compute the matching block argument of the destination block. + BlockArgument bbArg = + block->getArgument(operandNumber - firstOperandIndex); + result.addAlias( + {bbArg, BufferRelation::Equivalent, /*isDefinite=*/false}); + } + + return result; + } +}; + +} // namespace bufferization +} // namespace mlir + +#endif // MLIR_DIALECT_BUFFERIZATION_IR_UNSTRUCTUREDCONTROLFLOW_H_ diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h @@ -78,8 +78,15 @@ const OpFilter *opFilter = nullptr, BufferizationStatistics *statistics = nullptr); -/// Bufferize the signature of `block`. All block argument types are changed to -/// memref types. +/// Bufferize the signature of `block` and its callers (i.e., ops that have the +/// given block as a successor). All block argument types are changed to memref +/// types. All corresponding operands of all callers are wrapped in +/// bufferization.to_memref ops. All uses of bufferized tensor block arguments +/// are wrapped in bufferization.to_tensor ops. +/// +/// It is expected that all callers implement the `BranchOpInterface`. +/// Otherwise, this function will fail. The `BranchOpInterface` is used to query +/// the range of operands that are forwarded to this block. /// /// It is expected that the parent op of this block implements the /// `BufferizableOpInterface`. The buffer types of tensor block arguments are diff --git a/mlir/include/mlir/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.h @@ -0,0 +1,20 @@ +//===- BufferizableOpInterfaceImpl.h - Impl. of BufferizableOpInterface ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_CONTROLFLOW_BUFFERIZABLEOPINTERFACEIMPL_H +#define MLIR_DIALECT_CONTROLFLOW_BUFFERIZABLEOPINTERFACEIMPL_H + +namespace mlir { +class DialectRegistry; + +namespace cf { +void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); +} // namespace cf +} // namespace mlir + +#endif // MLIR_DIALECT_CONTROLFLOW_BUFFERIZABLEOPINTERFACEIMPL_H diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -29,6 +29,7 @@ #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -135,6 +136,7 @@ bufferization::func_ext::registerBufferizableOpInterfaceExternalModels( registry); builtin::registerCastOpInterfaceExternalModels(registry); + cf::registerBufferizableOpInterfaceExternalModels(registry); linalg::registerBufferizableOpInterfaceExternalModels(registry); linalg::registerTilingInterfaceExternalModels(registry); linalg::registerValueBoundsOpInterfaceExternalModels(registry); diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -234,6 +234,7 @@ }); if (aliasingValues.getNumAliases() == 1 && + isa(aliasingValues.getAliases()[0].value) && !state.bufferizesToMemoryWrite(opOperand) && state.getAliasingOpOperands(aliasingValues.getAliases()[0].value) .getNumAliases() == 1 && @@ -498,11 +499,16 @@ bool AnalysisState::isValueRead(Value value) const { assert(llvm::isa(value.getType()) && "expected TensorType"); SmallVector workingSet; + DenseSet visited; for (OpOperand &use : value.getUses()) workingSet.push_back(&use); while (!workingSet.empty()) { OpOperand *uMaybeReading = workingSet.pop_back_val(); + if (visited.contains(uMaybeReading)) + continue; + visited.insert(uMaybeReading); + // Skip over all ops that neither read nor write (but create an alias). if (bufferizesToAliasOnly(*uMaybeReading)) for (AliasingValue alias : getAliasingValues(*uMaybeReading)) @@ -522,11 +528,21 @@ llvm::SetVector AnalysisState::findValueInReverseUseDefChain( Value value, llvm::function_ref condition, TraversalConfig config) const { + llvm::DenseSet visited; llvm::SetVector result, workingSet; workingSet.insert(value); while (!workingSet.empty()) { Value value = workingSet.pop_back_val(); + + if (!config.revisitAlreadyVisitedValues && visited.contains(value)) { + // Stop traversal if value was already visited. + if (config.alwaysIncludeLeaves) + result.insert(value); + continue; + } + visited.insert(value); + if (condition(value)) { result.insert(value); continue; @@ -659,11 +675,15 @@ // preceding value, so we can follow SSA use-def chains and do a simple // analysis. SmallVector worklist; + DenseSet visited; for (OpOperand &use : tensor.getUses()) worklist.push_back(&use); while (!worklist.empty()) { OpOperand *operand = worklist.pop_back_val(); + if (visited.contains(operand)) + continue; + visited.insert(operand); Operation *op = operand->getOwner(); // If the op is not bufferizable, we can safely assume that the value is not diff --git a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt @@ -3,6 +3,7 @@ BufferizableOpInterface.cpp BufferizationOps.cpp BufferizationDialect.cpp + UnstructuredControlFlow.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Bufferization diff --git a/mlir/lib/Dialect/Bufferization/IR/UnstructuredControlFlow.cpp b/mlir/lib/Dialect/Bufferization/IR/UnstructuredControlFlow.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Bufferization/IR/UnstructuredControlFlow.cpp @@ -0,0 +1,32 @@ +//===- UnstructuredControlFlow.cpp - Op Interface Helpers ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h" + +using namespace mlir; + +SmallVector +mlir::bufferization::detail::getCallerOpOperands(BlockArgument bbArg) { + SmallVector result; + Block *block = bbArg.getOwner(); + for (Operation *caller : block->getUsers()) { + auto branchOp = dyn_cast(caller); + assert(branchOp && "expected that all callers implement BranchOpInterface"); + auto it = llvm::find(caller->getSuccessors(), block); + assert(it != caller->getSuccessors().end() && "could not find successor"); + int64_t successorIdx = std::distance(caller->getSuccessors().begin(), it); + SuccessorOperands operands = branchOp.getSuccessorOperands(successorIdx); + assert(operands.getProducedOperandCount() == 0 && + "produced operands not supported"); + int64_t operandIndex = + operands.getForwardedOperands().getBeginOperandIndex() + + bbArg.getArgNumber(); + result.push_back(&caller->getOpOperand(operandIndex)); + } + return result; +} diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Operation.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -357,6 +358,16 @@ /// Return true if the given op has a tensor result or a tensor operand. static bool hasTensorSemantics(Operation *op) { + bool hasTensorBlockArgument = any_of(op->getRegions(), [](Region &r) { + return any_of(r.getBlocks(), [](Block &b) { + return any_of(b.getArguments(), [](BlockArgument bbArg) { + return isaTensor(bbArg.getType()); + }); + }); + }); + if (hasTensorBlockArgument) + return true; + if (auto funcOp = dyn_cast(op)) { bool hasTensorArg = any_of(funcOp.getArgumentTypes(), isaTensor); bool hasTensorResult = any_of(funcOp.getResultTypes(), isaTensor); @@ -618,6 +629,43 @@ } } + // Bufferize callers of the block. + for (Operation *op : block->getUsers()) { + auto branchOp = dyn_cast(op); + if (!branchOp) + return op->emitOpError("cannot bufferize ops with block references that " + "do not implement BranchOpInterface"); + + auto it = llvm::find(op->getSuccessors(), block); + assert(it != op->getSuccessors().end() && "could find successor"); + int64_t successorIdx = std::distance(op->getSuccessors().begin(), it); + + SuccessorOperands operands = branchOp.getSuccessorOperands(successorIdx); + SmallVector newOperands; + for (auto [operand, type] : + llvm::zip(operands.getForwardedOperands(), newTypes)) { + if (operand.getType() == type) { + // Not a tensor type. Nothing to do for this operand. + newOperands.push_back(operand); + continue; + } + FailureOr operandBufferType = + bufferization::getBufferType(operand, options); + if (failed(operandBufferType)) + return failure(); + rewriter.setInsertionPointAfterValue(operand); + Value bufferizedOperand = rewriter.create( + operand.getLoc(), *operandBufferType, operand); + // A cast is needed if the operand and the block argument have different + // bufferized types. + if (type != *operandBufferType) + bufferizedOperand = rewriter.create( + operand.getLoc(), type, bufferizedOperand); + newOperands.push_back(bufferizedOperand); + } + operands.getMutableForwardedOperands().assign(newOperands); + } + return success(); } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h" #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -319,16 +320,45 @@ }; struct FuncOpInterface - : public BufferizableOpInterface::ExternalModel { + : public OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel< + FuncOpInterface, FuncOp> { + + static bool supportsUnstructuredControlFlow() { return true; } + + AliasingOpOperandList + getAliasingOpOperands(Operation *op, Value value, + const AnalysisState &state) const { + return getAliasingBranchOpOperands(op, cast(value), state); + } + FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, SmallVector &invocationStack) const { auto funcOp = cast(op); auto bbArg = cast(value); - // Unstructured control flow is not supported. - assert(bbArg.getOwner() == &funcOp.getBody().front() && - "expected that block argument belongs to first block"); - return getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options); + + // Function arguments are special. + if (bbArg.getOwner() == &funcOp.getBody().front()) + return getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), + options); + + return OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel:: + getBufferType(op, value, options, invocationStack); + } + + LogicalResult verifyAnalysis(Operation *op, + const AnalysisState &state) const { + auto funcOp = cast(op); + // TODO: func.func with multiple returns are not supported. + if (!getAssumedUniqueReturnOp(funcOp) && !funcOp.isExternal()) + return op->emitOpError("op without unique func.return is not supported"); + const auto &options = + static_cast(state.getOptions()); + // allow-return-allocs is required for ops with multiple blocks. + if (options.allowReturnAllocs || funcOp.getRegion().getBlocks().size() <= 1) + return success(); + return op->emitOpError( + "op cannot be bufferized without allow-return-allocs"); } /// Rewrite function bbArgs and return values into buffer form. This function @@ -358,7 +388,7 @@ // Bodiless functions are assumed opaque and we cannot know the // bufferization contract they want to enforce. As a consequence, only // support functions that don't return any tensors atm. - if (funcOp.getBody().empty()) { + if (funcOp.isExternal()) { SmallVector retTypes; for (Type resultType : funcType.getResults()) { if (isa(resultType)) @@ -420,6 +450,11 @@ BlockArgument bbArg = dyn_cast(value); assert(bbArg && "expected BlockArgument"); + // Non-entry block arguments are always writable. (They may alias with + // values that are not writable, which will turn them into read-only.) + if (bbArg.getOwner() != &funcOp.getBody().front()) + return true; + // "bufferization.writable" overrides other writability decisions. This is // currently used for testing only. if (BoolAttr writable = funcOp.getArgAttrOfType( diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -309,8 +309,29 @@ return false; } +static bool isReachable(Block *from, Block *to, ArrayRef except) { + DenseSet visited; + SmallVector worklist; + for (Block *succ : from->getSuccessors()) + worklist.push_back(succ); + while (!worklist.empty()) { + Block *next = worklist.pop_back_val(); + if (llvm::find(except, next) != except.end()) + continue; + if (next == to) + return true; + if (visited.contains(next)) + continue; + visited.insert(next); + for (Block *succ : next->getSuccessors()) + worklist.push_back(succ); + } + return false; +} + /// Return `true` if op dominance can be used to rule out a read-after-write -/// conflicts based on the ordering of ops. +/// conflicts based on the ordering of ops. Returns `false` if op dominance +/// cannot be used to due region-based loops. /// /// Generalized op dominance can often be used to rule out potential conflicts /// due to "read happens before write". E.g., the following IR is not a RaW @@ -383,9 +404,9 @@ /// regions. I.e., we can rule out a RaW conflict if READ happensBefore WRITE /// or WRITE happensBefore DEF. (Checked in `hasReadAfterWriteInterference`.) /// -static bool canUseOpDominance(OpOperand *uRead, OpOperand *uWrite, - const SetVector &definitions, - AnalysisState &state) { +static bool canUseOpDominanceDueToRegions(OpOperand *uRead, OpOperand *uWrite, + const SetVector &definitions, + AnalysisState &state) { const BufferizationOptions &options = state.getOptions(); for (Value def : definitions) { Region *rRead = @@ -411,9 +432,53 @@ if (rRead->getParentOp()->isAncestor(uWrite->getOwner())) return false; } + + return true; +} + +/// Return `true` if op dominance can be used to rule out a read-after-write +/// conflicts based on the ordering of ops. Returns `false` if op dominance +/// cannot be used to due block-based loops within a region. +/// +/// Refer to the `canUseOpDominanceDueToRegions` documentation for details on +/// how op domiance is used during RaW conflict detection. +/// +/// On a high-level, there is a potential RaW in a program if there exists a +/// possible program execution such that there is a sequence of DEF, followed +/// by WRITE, followed by READ. Each additional DEF resets the sequence. +/// +/// Op dominance cannot be used if there is a path from block(READ) to +/// block(WRITE) and a path from block(WRITE) to block(READ). block(DEF) should +/// not appear on that path. +static bool canUseOpDominanceDueToBlocks(OpOperand *uRead, OpOperand *uWrite, + const SetVector &definitions, + AnalysisState &state) { + // Fast path: If READ and WRITE are in different regions, their block cannot + // be reachable just via unstructured control flow. (Loops due to regions are + // covered by `canUseOpDominanceDueToRegions`.) + if (uRead->getOwner()->getParentRegion() != + uWrite->getOwner()->getParentRegion()) + return true; + + Block *readBlock = uRead->getOwner()->getBlock(); + Block *writeBlock = uWrite->getOwner()->getBlock(); + for (Value def : definitions) { + Block *defBlock = def.getParentBlock(); + if (isReachable(readBlock, writeBlock, {defBlock}) && + isReachable(writeBlock, readBlock, {defBlock})) + return false; + } + return true; } +static bool canUseOpDominance(OpOperand *uRead, OpOperand *uWrite, + const SetVector &definitions, + AnalysisState &state) { + return canUseOpDominanceDueToRegions(uRead, uWrite, definitions, state) && + canUseOpDominanceDueToBlocks(uRead, uWrite, definitions, state); +} + /// Annotate IR with details about the detected RaW conflict. static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite, Value definition) { diff --git a/mlir/lib/Dialect/ControlFlow/CMakeLists.txt b/mlir/lib/Dialect/ControlFlow/CMakeLists.txt --- a/mlir/lib/Dialect/ControlFlow/CMakeLists.txt +++ b/mlir/lib/Dialect/ControlFlow/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/mlir/lib/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.cpp @@ -0,0 +1,75 @@ +//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.h" + +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h" +#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Operation.h" + +using namespace mlir; +using namespace mlir::bufferization; + +namespace mlir { +namespace cf { +namespace { + +template +struct BranchLikeOpInterface + : public BranchOpBufferizableOpInterfaceExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return false; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return false; + } + + LogicalResult verifyAnalysis(Operation *op, + const AnalysisState &state) const { + const auto &options = + static_cast(state.getOptions()); + if (options.allowReturnAllocs) + return success(); + return op->emitOpError( + "op cannot be bufferized without allow-return-allocs"); + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const BufferizationOptions &options) const { + // The operands of this op are bufferized together with the block signature. + return success(); + } +}; + +/// Bufferization of cf.br. +struct BranchOpInterface + : public BranchLikeOpInterface {}; + +/// Bufferization of cf.cond_br. +struct CondBranchOpInterface + : public BranchLikeOpInterface {}; + +} // namespace +} // namespace cf +} // namespace mlir + +void mlir::cf::registerBufferizableOpInterfaceExternalModels( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, cf::ControlFlowDialect *dialect) { + cf::BranchOp::attachInterface(*ctx); + cf::CondBranchOp::attachInterface(*ctx); + }); +} diff --git a/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt @@ -0,0 +1,13 @@ +add_mlir_dialect_library(MLIRControlFlowTransforms + BufferizableOpInterfaceImpl.cpp + + ADDITIONAL_HEADER_DIRS + {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ControlFlow/Transforms + + LINK_LIBS PUBLIC + MLIRBufferizationDialect + MLIRBufferizationTransforms + MLIRControlFlowDialect + MLIRMemRefDialect + MLIRIR + ) diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -10,6 +10,8 @@ #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h" +#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" @@ -99,37 +101,74 @@ } }; +/// Return the unique scf.yield op. If there are multiple or no scf.yield ops, +/// return an empty op. +static scf::YieldOp getUniqueYieldOp(scf::ExecuteRegionOp executeRegionOp) { + scf::YieldOp result; + for (Block &block : executeRegionOp.getRegion()) { + if (auto yieldOp = dyn_cast(block.getTerminator())) { + if (result) + return {}; + result = yieldOp; + } + } + return result; +} + /// Bufferization of scf.execute_region. Can be analyzed, but bufferization not /// fully implemented at the moment. struct ExecuteRegionOpInterface - : public BufferizableOpInterface::ExternalModel { + : public OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel< + ExecuteRegionOpInterface, scf::ExecuteRegionOp> { + + static bool supportsUnstructuredControlFlow() { return true; } + + bool isWritable(Operation *op, Value value, + const AnalysisState &state) const { + return true; + } + + LogicalResult verifyAnalysis(Operation *op, + const AnalysisState &state) const { + auto executeRegionOp = cast(op); + // TODO: scf.execute_region with multiple yields are not supported. + if (!getUniqueYieldOp(executeRegionOp)) + return op->emitOpError("op without unique scf.yield is not supported"); + const auto &options = + static_cast(state.getOptions()); + // allow-return-allocs is required for ops with multiple blocks. + if (options.allowReturnAllocs || + executeRegionOp.getRegion().getBlocks().size() == 1) + return success(); + return op->emitOpError( + "op cannot be bufferized without allow-return-allocs"); + } + AliasingOpOperandList getAliasingOpOperands(Operation *op, Value value, const AnalysisState &state) const { + if (auto bbArg = dyn_cast(value)) + return getAliasingBranchOpOperands(op, bbArg, state); + // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be // any SSA value that is in scope. To allow for use-def chain traversal // through ExecuteRegionOps in the analysis, the corresponding yield value // is considered to be aliasing with the result. auto executeRegionOp = cast(op); - size_t resultNum = std::distance(op->getOpResults().begin(), - llvm::find(op->getOpResults(), value)); - // TODO: Support multiple blocks. - assert(executeRegionOp.getRegion().getBlocks().size() == 1 && - "expected exactly 1 block"); - auto yieldOp = dyn_cast( - executeRegionOp.getRegion().front().getTerminator()); - assert(yieldOp && "expected scf.yield terminator in scf.execute_region"); + auto it = llvm::find(op->getOpResults(), value); + assert(it != op->getOpResults().end() && "invalid value"); + size_t resultNum = std::distance(op->getOpResults().begin(), it); + auto yieldOp = getUniqueYieldOp(executeRegionOp); + // Note: If there is no unique scf.yield op, `verifyAnalysis` will fail. + if (!yieldOp) + return {}; return {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}}; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto executeRegionOp = cast(op); - assert(executeRegionOp.getRegion().getBlocks().size() == 1 && - "only 1 block supported"); - auto yieldOp = - cast(executeRegionOp.getRegion().front().getTerminator()); + auto yieldOp = getUniqueYieldOp(executeRegionOp); TypeRange newResultTypes(yieldOp.getResults()); // Create new op and move over region. @@ -137,6 +176,12 @@ rewriter.create(op->getLoc(), newResultTypes); newOp.getRegion().takeBody(executeRegionOp.getRegion()); + // Bufferize every block. + for (Block &block : newOp.getRegion()) + if (failed(bufferization::bufferizeBlockSignature(&block, rewriter, + options))) + return failure(); + // Update all uses of the old op. rewriter.setInsertionPointAfter(newOp); SmallVector newResults; diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir @@ -321,3 +321,12 @@ } return } + +// ----- + +// expected-error @below{{cannot bufferize a FuncOp with tensors and without a unique ReturnOp}} +func.func @func_multiple_yields(%t: tensor<5xf32>) -> tensor<5xf32> { + func.return %t : tensor<5xf32> +^bb1(%arg1 : tensor<5xf32>): + func.return %arg1 : tensor<5xf32> +} diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir @@ -659,3 +659,15 @@ return %r1 : vector<5xf32> } + +// ----- + +// Note: The cf.br canonicalizes away, so there's nothing to check here. There +// is a detailed test in ControlFlow/bufferize.mlir. + +// CHECK-LABEL: func @br_in_func( +func.func @br_in_func(%t: tensor<5xf32>) -> tensor<5xf32> { + cf.br ^bb1(%t : tensor<5xf32>) +^bb1(%arg1 : tensor<5xf32>): + func.return %arg1 : tensor<5xf32> +} diff --git a/mlir/test/Dialect/ControlFlow/one-shot-bufferize-analysis.mlir b/mlir/test/Dialect/ControlFlow/one-shot-bufferize-analysis.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/ControlFlow/one-shot-bufferize-analysis.mlir @@ -0,0 +1,221 @@ +// RUN: mlir-opt -one-shot-bufferize="allow-return-allocs test-analysis-only dump-alias-sets bufferize-function-boundaries" -split-input-file %s | FileCheck %s + +// CHECK-LABEL: func @single_branch( +// CHECK-SAME: {__bbarg_alias_set_attr__ = [{{\[}}[{{\[}}"%[[arg1:.*]]", "%[[t:.*]]"]], [{{\[}}"%[[arg1]]", "%[[t]]"]]]]} +func.func @single_branch(%t: tensor<5xf32>) -> tensor<5xf32> { +// CHECK: cf.br +// CHECK-SAME: {__inplace_operands_attr__ = ["true"]} + cf.br ^bb1(%t : tensor<5xf32>) +// CHECK: ^{{.*}}(%[[arg1]]: tensor<5xf32>) +^bb1(%arg1 : tensor<5xf32>): + func.return %arg1 : tensor<5xf32> +} + +// ----- + +// CHECK-LABEL: func @diamond_branch( +// CHECK-SAME: %{{.*}}: i1, %[[t0:.*]]: tensor<5xf32> {{.*}}, %[[t1:.*]]: tensor<5xf32> {{.*}}) -> tensor<5xf32> +// CHECK-SAME: {__bbarg_alias_set_attr__ = [{{\[}}[{{\[}}"%[[arg1:.*]]", "%[[arg3:.*]]", "%[[arg2:.*]]", "%[[t0]]", "%[[t1]]"], [ +func.func @diamond_branch(%c: i1, %t0: tensor<5xf32>, %t1: tensor<5xf32>) -> tensor<5xf32> { +// CHECK: cf.cond_br +// CHECK-SAME: {__inplace_operands_attr__ = ["none", "true", "true"]} + cf.cond_br %c, ^bb1(%t0 : tensor<5xf32>), ^bb2(%t1 : tensor<5xf32>) +// CHECK: ^{{.*}}(%[[arg1]]: tensor<5xf32>): +^bb3(%arg1 : tensor<5xf32>): + func.return %arg1 : tensor<5xf32> +// CHECK: ^{{.*}}(%[[arg2]]: tensor<5xf32>): +^bb1(%arg2 : tensor<5xf32>): +// CHECK: cf.br +// CHECK-SAME: {__inplace_operands_attr__ = ["true"]} + cf.br ^bb3(%arg2 : tensor<5xf32>) +// CHECK: ^{{.*}}(%[[arg3]]: tensor<5xf32>): +^bb2(%arg3 : tensor<5xf32>): +// CHECK: cf.br +// CHECK-SAME: {__inplace_operands_attr__ = ["true"]} + cf.br ^bb3(%arg3 : tensor<5xf32>) +} + +// ----- + +// CHECK-LABEL: func @looping_branches( +// CHECK-SAME: {__bbarg_alias_set_attr__ = [{{\[}}[], [{{\[}}"%[[arg2:.*]]", "%[[arg1:.*]]", "%[[inserted:.*]]", "%[[empty:.*]]"]], [ +func.func @looping_branches() -> tensor<5xf32> { +// CHECK: %[[empty]] = tensor.empty() + %0 = tensor.empty() : tensor<5xf32> +// CHECK: cf.br +// CHECK-SAME: {__inplace_operands_attr__ = ["true"]} + cf.br ^bb1(%0: tensor<5xf32>) +// CHECK: ^{{.*}}(%[[arg1]]: tensor<5xf32>): +^bb1(%arg1: tensor<5xf32>): + %pos = "test.foo"() : () -> (index) + %val = "test.bar"() : () -> (f32) +// CHECK: %[[inserted]] = tensor.insert +// CHECK-SAME: __inplace_operands_attr__ = ["none", "true", "none"] + %inserted = tensor.insert %val into %arg1[%pos] : tensor<5xf32> + %cond = "test.qux"() : () -> (i1) +// CHECK: cf.cond_br +// CHECK-SAME: {__inplace_operands_attr__ = ["none", "true", "true"]} + cf.cond_br %cond, ^bb1(%inserted: tensor<5xf32>), ^bb2(%inserted: tensor<5xf32>) +^bb2(%arg2: tensor<5xf32>): + func.return %arg2 : tensor<5xf32> +} + +// ----- + +// CHECK-LABEL: func @looping_branches_with_conflict( +func.func @looping_branches_with_conflict(%f: f32) -> tensor<5xf32> { + %0 = tensor.empty() : tensor<5xf32> + %filled = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32> +// CHECK: cf.br +// CHECK-SAME: {__inplace_operands_attr__ = ["false"]} + cf.br ^bb1(%filled: tensor<5xf32>) +^bb2(%arg2: tensor<5xf32>): + %pos2 = "test.foo"() : () -> (index) + // One OpOperand cannot bufferize in-place because an "old" value is read. + %element = tensor.extract %filled[%pos2] : tensor<5xf32> + func.return %arg2 : tensor<5xf32> +^bb1(%arg1: tensor<5xf32>): + %pos = "test.foo"() : () -> (index) + %val = "test.bar"() : () -> (f32) +// CHECK: tensor.insert +// CHECK-SAME: __inplace_operands_attr__ = ["none", "true", "none"] + %inserted = tensor.insert %val into %arg1[%pos] : tensor<5xf32> + %cond = "test.qux"() : () -> (i1) +// CHECK: cf.cond_br +// CHECK-SAME: {__inplace_operands_attr__ = ["none", "true", "true"]} + cf.cond_br %cond, ^bb1(%inserted: tensor<5xf32>), ^bb2(%inserted: tensor<5xf32>) +} + +// ----- + +// CHECK-LABEL: func @looping_branches_outside_def( +func.func @looping_branches_outside_def(%f: f32) { +// CHECK: %[[alloc:.*]] = bufferization.alloc_tensor() + %0 = bufferization.alloc_tensor() : tensor<5xf32> +// CHECK: %[[fill:.*]] = linalg.fill +// CHECK-SAME: {__inplace_operands_attr__ = ["none", "true"], __opresult_alias_set_attr__ = [{{\[}}"%[[fill]]", "%[[alloc]]"]]} + %filled = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32> + cf.br ^bb1 +^bb1: + %pos = "test.foo"() : () -> (index) + %val = "test.bar"() : () -> (f32) +// CHECK: tensor.insert +// CHECK-SAME: __inplace_operands_attr__ = ["none", "false", "none"] + %inserted = tensor.insert %val into %filled[%pos] : tensor<5xf32> + %pos2 = "test.foo"() : () -> (index) + %read = tensor.extract %inserted[%pos2] : tensor<5xf32> + %cond = "test.qux"(%read) : (f32) -> (i1) + cf.cond_br %cond, ^bb1, ^bb2 +^bb2: + func.return +} + +// ----- + +// CHECK-LABEL: func @looping_branches_outside_def2( +func.func @looping_branches_outside_def2(%f: f32) { +// CHECK: %[[alloc:.*]] = bufferization.alloc_tensor() + %0 = bufferization.alloc_tensor() : tensor<5xf32> +// CHECK: %[[fill:.*]] = linalg.fill +// CHECK-SAME: {__inplace_operands_attr__ = ["none", "true"], __opresult_alias_set_attr__ = [{{\[}}"%[[arg0:.*]]", "%[[fill]]", "%[[alloc]]"]]} + %filled = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32> +// CHECK: cf.br {{.*}}(%[[fill]] : tensor<5xf32>) +// CHECK-SAME: __inplace_operands_attr__ = ["true"] + cf.br ^bb1(%filled: tensor<5xf32>) +// CHECK: ^{{.*}}(%[[arg0]]: tensor<5xf32>): +^bb1(%arg0: tensor<5xf32>): + %pos = "test.foo"() : () -> (index) + %val = "test.bar"() : () -> (f32) +// CHECK: tensor.insert +// CHECK-SAME: __inplace_operands_attr__ = ["none", "false", "none"] + %inserted = tensor.insert %val into %arg0[%pos] : tensor<5xf32> + %pos2 = "test.foo"() : () -> (index) + %read = tensor.extract %inserted[%pos2] : tensor<5xf32> + %cond = "test.qux"(%read) : (f32) -> (i1) +// CHECK: cf.cond_br +// CHECK-SAME: __inplace_operands_attr__ = ["none", "true"] + cf.cond_br %cond, ^bb1(%arg0: tensor<5xf32>), ^bb2 +^bb2: + func.return +} + +// ----- + +// CHECK-LABEL: func @looping_branches_outside_def3( +func.func @looping_branches_outside_def3(%f: f32) { +// CHECK: %[[alloc:.*]] = bufferization.alloc_tensor() + %0 = bufferization.alloc_tensor() : tensor<5xf32> +// CHECK: %[[fill:.*]] = linalg.fill +// CHECK-SAME: {__inplace_operands_attr__ = ["none", "true"], __opresult_alias_set_attr__ = [{{\[}}"%[[arg0:.*]]", "%[[fill]]", "%[[alloc]]"]]} + %filled = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32> +// CHECK: cf.br {{.*}}(%[[fill]] : tensor<5xf32>) +// CHECK-SAME: __inplace_operands_attr__ = ["true"] + cf.br ^bb1(%filled: tensor<5xf32>) +// CHECK: ^{{.*}}(%[[arg0]]: tensor<5xf32>): +^bb1(%arg0: tensor<5xf32>): + %pos = "test.foo"() : () -> (index) + %val = "test.bar"() : () -> (f32) +// CHECK: tensor.insert +// CHECK-SAME: __inplace_operands_attr__ = ["none", "false", "none"] + %inserted = tensor.insert %val into %arg0[%pos] : tensor<5xf32> + %pos2 = "test.foo"() : () -> (index) + %read = tensor.extract %inserted[%pos2] : tensor<5xf32> + %cond = "test.qux"(%read) : (f32) -> (i1) +// CHECK: cf.cond_br +// CHECK-SAME: __inplace_operands_attr__ = ["none", "true"] + cf.cond_br %cond, ^bb1(%filled: tensor<5xf32>), ^bb2 +^bb2: + func.return +} + +// ----- + +// CHECK-LABEL: func @looping_branches_sequence_outside_def( +func.func @looping_branches_sequence_outside_def(%f: f32) { +// CHECK: %[[alloc:.*]] = bufferization.alloc_tensor() + %0 = bufferization.alloc_tensor() : tensor<5xf32> +// CHECK: %[[fill:.*]] = linalg.fill +// CHECK-SAME: {__inplace_operands_attr__ = ["none", "true"], __opresult_alias_set_attr__ = [{{\[}}"%[[fill]]", "%[[alloc]]"]]} + %filled = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32> + cf.br ^bb1 +^bb1: + %pos = "test.foo"() : () -> (index) + %val = "test.bar"() : () -> (f32) +// CHECK: tensor.insert +// CHECK-SAME: __inplace_operands_attr__ = ["none", "false", "none"] + %inserted = tensor.insert %val into %filled[%pos] : tensor<5xf32> + cf.br ^bb2 +^bb2: + %pos2 = "test.foo"() : () -> (index) + %read = tensor.extract %inserted[%pos2] : tensor<5xf32> + %cond = "test.qux"(%read) : (f32) -> (i1) + cf.cond_br %cond, ^bb1, ^bb3 +^bb3: + func.return +} + +// ----- + +// CHECK-LABEL: func @looping_branches_sequence_inside_def( +func.func @looping_branches_sequence_inside_def(%f: f32) { + cf.br ^bb1 +^bb1: +// CHECK: %[[alloc:.*]] = bufferization.alloc_tensor() + %0 = bufferization.alloc_tensor() : tensor<5xf32> +// CHECK: %[[fill:.*]] = linalg.fill +// CHECK-SAME: {__inplace_operands_attr__ = ["none", "true"], __opresult_alias_set_attr__ = [{{\[}}"%[[inserted:.*]]", "%[[fill]]", "%[[alloc]]"]]} + %filled = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32> + %pos = "test.foo"() : () -> (index) + %val = "test.bar"() : () -> (f32) +// CHECK: %[[inserted]] = tensor.insert +// CHECK-SAME: __inplace_operands_attr__ = ["none", "true", "none"] + %inserted = tensor.insert %val into %filled[%pos] : tensor<5xf32> + cf.br ^bb2 +^bb2: + %pos2 = "test.foo"() : () -> (index) + %read = tensor.extract %inserted[%pos2] : tensor<5xf32> + %cond = "test.qux"(%read) : (f32) -> (i1) + cf.cond_br %cond, ^bb1, ^bb3 +^bb3: + func.return +} diff --git a/mlir/test/Dialect/ControlFlow/one-shot-bufferize-invalid.mlir b/mlir/test/Dialect/ControlFlow/one-shot-bufferize-invalid.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/ControlFlow/one-shot-bufferize-invalid.mlir @@ -0,0 +1,24 @@ +// RUN: mlir-opt -one-shot-bufferize="allow-return-allocs bufferize-function-boundaries" -split-input-file %s -verify-diagnostics + +// expected-error @below{{failed to bufferize op}} +// expected-error @below{{incoming operands of block argument have inconsistent memory spaces}} +func.func @inconsistent_memory_space() -> tensor<5xf32> { + %0 = bufferization.alloc_tensor() {memory_space = 0 : ui64} : tensor<5xf32> + cf.br ^bb1(%0: tensor<5xf32>) +^bb1(%arg1: tensor<5xf32>): + func.return %arg1 : tensor<5xf32> +^bb2(): + %1 = bufferization.alloc_tensor() {memory_space = 1 : ui64} : tensor<5xf32> + cf.br ^bb1(%1: tensor<5xf32>) +} + +// ----- + +// expected-error @below{{failed to bufferize op}} +// expected-error @below{{could not infer buffer type of block argument}} +func.func @cannot_infer_type() { + return + // The type of the block argument cannot be inferred. +^bb1(%t: tensor<5xf32>): + cf.br ^bb1(%t: tensor<5xf32>) +} diff --git a/mlir/test/Dialect/ControlFlow/one-shot-bufferize.mlir b/mlir/test/Dialect/ControlFlow/one-shot-bufferize.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/ControlFlow/one-shot-bufferize.mlir @@ -0,0 +1,72 @@ +// RUN: mlir-opt -one-shot-bufferize="allow-return-allocs bufferize-function-boundaries" -split-input-file %s | FileCheck %s +// RUN: mlir-opt -one-shot-bufferize="allow-return-allocs" -split-input-file %s | FileCheck %s --check-prefix=CHECK-NO-FUNC + +// CHECK-NO-FUNC-LABEL: func @br( +// CHECK-NO-FUNC-SAME: %[[t:.*]]: tensor<5xf32>) +// CHECK-NO-FUNC: %[[m:.*]] = bufferization.to_memref %[[t]] : memref<5xf32, strided<[?], offset: ?>> +// CHECK-NO-FUNC: %[[r:.*]] = scf.execute_region -> memref<5xf32, strided<[?], offset: ?>> { +// CHECK-NO-FUNC: cf.br ^[[block:.*]](%[[m]] +// CHECK-NO-FUNC: ^[[block]](%[[arg1:.*]]: memref<5xf32, strided<[?], offset: ?>>): +// CHECK-NO-FUNC: scf.yield %[[arg1]] +// CHECK-NO-FUNC: } +// CHECK-NO-FUNC: return +func.func @br(%t: tensor<5xf32>) { + %0 = scf.execute_region -> tensor<5xf32> { + cf.br ^bb1(%t : tensor<5xf32>) + ^bb1(%arg1 : tensor<5xf32>): + scf.yield %arg1 : tensor<5xf32> + } + return +} + +// ----- + +// CHECK-NO-FUNC-LABEL: func @cond_br( +// CHECK-NO-FUNC-SAME: %[[t1:.*]]: tensor<5xf32>, +// CHECK-NO-FUNC: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<5xf32, strided<[?], offset: ?>> +// CHECK-NO-FUNC: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<5xf32> +// CHECK-NO-FUNC: %[[r:.*]] = scf.execute_region -> memref<5xf32, strided<[?], offset: ?>> { +// CHECK-NO-FUNC: cf.cond_br %{{.*}}, ^[[block1:.*]](%[[m1]] : {{.*}}), ^[[block2:.*]](%[[alloc]] : {{.*}}) +// CHECK-NO-FUNC: ^[[block1]](%[[arg1:.*]]: memref<5xf32, strided<[?], offset: ?>>): +// CHECK-NO-FUNC: scf.yield %[[arg1]] +// CHECK-NO-FUNC: ^[[block2]](%[[arg2:.*]]: memref<5xf32>): +// CHECK-NO-FUNC: %[[cast:.*]] = memref.cast %[[arg2]] : memref<5xf32> to memref<5xf32, strided<[?], offset: ?> +// CHECK-NO-FUNC: cf.br ^[[block1]](%[[cast]] : {{.*}}) +// CHECK-NO-FUNC: } +// CHECK-NO-FUNC: return +func.func @cond_br(%t1: tensor<5xf32>, %c: i1) { + // Use an alloc for the second block instead of a function block argument. + // A cast must be inserted because the two will have different layout maps. + %t0 = bufferization.alloc_tensor() : tensor<5xf32> + %0 = scf.execute_region -> tensor<5xf32> { + cf.cond_br %c, ^bb1(%t1 : tensor<5xf32>), ^bb2(%t0 : tensor<5xf32>) + ^bb1(%arg1 : tensor<5xf32>): + scf.yield %arg1 : tensor<5xf32> + ^bb2(%arg2 : tensor<5xf32>): + cf.br ^bb1(%arg2 : tensor<5xf32>) + } + return +} + +// ----- + +// CHECK-LABEL: func @looping_branches( +func.func @looping_branches() -> tensor<5xf32> { +// CHECK: %[[alloc:.*]] = memref.alloc + %0 = bufferization.alloc_tensor() : tensor<5xf32> +// CHECK: cf.br {{.*}}(%[[alloc]] + cf.br ^bb1(%0: tensor<5xf32>) +// CHECK: ^{{.*}}(%[[arg1:.*]]: memref<5xf32>): +^bb1(%arg1: tensor<5xf32>): + %pos = "test.foo"() : () -> (index) + %val = "test.bar"() : () -> (f32) +// CHECK: memref.store %{{.*}}, %[[arg1]] + %inserted = tensor.insert %val into %arg1[%pos] : tensor<5xf32> + %cond = "test.qux"() : () -> (i1) +// CHECK: cf.cond_br {{.*}}(%[[arg1]] {{.*}}(%[[arg1]] + cf.cond_br %cond, ^bb1(%inserted: tensor<5xf32>), ^bb2(%inserted: tensor<5xf32>) +// CHECK: ^{{.*}}(%[[arg2:.*]]: memref<5xf32>): +^bb2(%arg2: tensor<5xf32>): +// CHECK: return %[[arg2]] + func.return %arg2 : tensor<5xf32> +} diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize-invalid.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize-invalid.mlir --- a/mlir/test/Dialect/SCF/one-shot-bufferize-invalid.mlir +++ b/mlir/test/Dialect/SCF/one-shot-bufferize-invalid.mlir @@ -17,10 +17,10 @@ // ----- -func.func @execute_region_multiple_blocks(%t: tensor<5xf32>) -> tensor<5xf32> { - // expected-error @below{{op or BufferizableOpInterface implementation does not support unstructured control flow, but at least one region has multiple blocks}} +func.func @execute_region_multiple_yields(%t: tensor<5xf32>) -> tensor<5xf32> { + // expected-error @below{{op op without unique scf.yield is not supported}} %0 = scf.execute_region -> tensor<5xf32> { - cf.br ^bb1(%t : tensor<5xf32>) + scf.yield %t : tensor<5xf32> ^bb1(%arg1 : tensor<5xf32>): scf.yield %arg1 : tensor<5xf32> } @@ -29,6 +29,20 @@ // ----- +func.func @execute_region_no_yield(%t: tensor<5xf32>) -> tensor<5xf32> { + // expected-error @below{{op op without unique scf.yield is not supported}} + %0 = scf.execute_region -> tensor<5xf32> { + cf.br ^bb0(%t : tensor<5xf32>) + ^bb0(%arg0 : tensor<5xf32>): + cf.br ^bb1(%arg0: tensor<5xf32>) + ^bb1(%arg1 : tensor<5xf32>): + cf.br ^bb0(%arg1: tensor<5xf32>) + } + func.return %0 : tensor<5xf32> +} + +// ----- + func.func @inconsistent_memory_space_scf_for(%lb: index, %ub: index, %step: index) -> tensor<10xf32> { %0 = bufferization.alloc_tensor() {memory_space = 0 : ui64} : tensor<10xf32> %1 = bufferization.alloc_tensor() {memory_space = 1 : ui64} : tensor<10xf32> diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -4022,6 +4022,24 @@ ], ) +cc_library( + name = "ControlFlowTransforms", + srcs = glob([ + "lib/Dialect/ControlFlow/Transforms/*.cpp", + ]), + hdrs = glob([ + "include/mlir/Dialect/ControlFlow/Transforms/*.h", + ]), + includes = ["include"], + deps = [ + ":BufferizationDialect", + ":BufferizationTransforms", + ":ControlFlowDialect", + ":IR", + ":MemRefDialect", + ], +) + cc_library( name = "FuncDialect", srcs = glob( @@ -8203,6 +8221,7 @@ ":ComplexToLibm", ":ComplexToSPIRV", ":ControlFlowDialect", + ":ControlFlowTransforms", ":ConversionPasses", ":ConvertToLLVM", ":DLTIDialect", @@ -11907,11 +11926,13 @@ "lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp", "lib/Dialect/Bufferization/IR/BufferizationDialect.cpp", "lib/Dialect/Bufferization/IR/BufferizationOps.cpp", + "lib/Dialect/Bufferization/IR/UnstructuredControlFlow.cpp", ], hdrs = [ "include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h", "include/mlir/Dialect/Bufferization/IR/Bufferization.h", "include/mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h", + "include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h", ], includes = ["include"], deps = [