diff --git a/mlir/include/mlir/Dialect/Transform/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Transform/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Transform/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h @@ -14,6 +14,7 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/SymbolTable.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" #define GET_OP_CLASSES #include "mlir/Dialect/Transform/IR/TransformOps.h.inc" diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS #define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS +include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/IR/OpAsmInterface.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Dialect/PDL/IR/PDLTypes.td" @@ -48,7 +49,10 @@ } def SequenceOp : TransformDialectOp<"sequence", - [DeclareOpInterfaceMethods, + [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, OpAsmOpInterface, PossibleTopLevelTransformOpTrait, SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">]> { diff --git a/mlir/include/mlir/Dialect/Transform/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/Transforms/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Transform/Transforms/CMakeLists.txt @@ -0,0 +1,5 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name Transform) +add_public_tablegen_target(MLIRTransformDialectTransformsIncGen) + +add_mlir_doc(Passes TransformPasses ./ -gen-pass-doc) diff --git a/mlir/include/mlir/Dialect/Transform/Transforms/Passes.h b/mlir/include/mlir/Dialect/Transform/Transforms/Passes.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Transform/Transforms/Passes.h @@ -0,0 +1,26 @@ +//===- CheckUses.h - Expensive transform value validity checks --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TRANSFORM_TRANSFORMS_PASSES_H +#define MLIR_DIALECT_TRANSFORM_TRANSFORMS_PASSES_H + +#include "mlir/Pass/Pass.h" +#include + +namespace mlir { +class Pass; + +namespace transform { +std::unique_ptr createCheckUsesPass(); + +#define GEN_PASS_REGISTRATION +#include "mlir/Dialect/Transform/Transforms/Passes.h.inc" +} // namespace transform +} // namespace mlir + +#endif // MLIR_DIALECT_TRANSFORM_TRANSFORMS_PASSES_H diff --git a/mlir/include/mlir/Dialect/Transform/Transforms/Passes.td b/mlir/include/mlir/Dialect/Transform/Transforms/Passes.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Transform/Transforms/Passes.td @@ -0,0 +1,36 @@ +//===-- Passes.td - Transform dialect pass definitions -----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TRANSFORM_TRANSFORMS_PASSES +#define MLIR_DIALECT_TRANSFORM_TRANSFORMS_PASSES + +include "mlir/Pass/PassBase.td" + +def CheckUses : Pass<"transform-dialect-check-uses"> { + let summary = "warn about potential use-after-free in the transform dialect"; + let description = [{ + This pass analyzes operations from the transform dialect and its extensions + and warns if a transform IR value may be used by an operation after it was + "freed" by some other operation, as described by side effects on the + `TransformMappingResource`. This statically detects situations that lead to + errors when interpreting the Transform IR. + + The pass is capable of handling branching control flow and reports all + _potential_ use-after-free situations, e.g., a may-use-after-free is + reported if at least one of the control flow paths between the definition of + a value and its use contains an operation with a "free" effect on the + `TransformMappingResource`. It does not currently perform an SCCP-style data + flow analysis to prove that some branches are not taken, however, SCCP and + other control flow simplifications can be performed on the transform IR + prior to this pass provided that transform ops implement the relevant + control flow interfaces. + }]; + let constructor = "::mlir::transform::createCheckUsesPass()"; +} + +#endif // MLIR_DIALECT_TRANSFORM_TRANSFORMS_PASSES diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h --- a/mlir/include/mlir/InitAllPasses.h +++ b/mlir/include/mlir/InitAllPasses.h @@ -32,6 +32,7 @@ #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" #include "mlir/Dialect/Tensor/Transforms/Passes.h" #include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Dialect/Transform/Transforms/Passes.h" #include "mlir/Dialect/Vector/Transforms/Passes.h" #include "mlir/Transforms/Passes.h" @@ -72,6 +73,7 @@ spirv::registerSPIRVPasses(); tensor::registerTensorPasses(); tosa::registerTosaOptPasses(); + transform::registerTransformPasses(); vector::registerVectorPasses(); // Dialect pipelines diff --git a/mlir/lib/Dialect/Transform/CMakeLists.txt b/mlir/lib/Dialect/Transform/CMakeLists.txt --- a/mlir/lib/Dialect/Transform/CMakeLists.txt +++ b/mlir/lib/Dialect/Transform/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -13,6 +13,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Rewrite/FrozenRewritePatternSet.h" #include "mlir/Rewrite/PatternApplicator.h" #include "llvm/ADT/ScopeExit.h" @@ -289,6 +290,35 @@ } } +OperandRange transform::SequenceOp::getSuccessorEntryOperands(unsigned index) { + assert(index == 0 && "unexpected region index"); + if (getOperation()->getNumOperands() == 1) + return getOperation()->getOperands(); + return OperandRange(getOperation()->operand_end(), + getOperation()->operand_end()); +} + +void transform::SequenceOp::getSuccessorRegions( + Optional index, ArrayRef operands, + SmallVectorImpl ®ions) { + if (!index.hasValue()) { + Region *bodyRegion = &getBody(); + regions.emplace_back(bodyRegion, !operands.empty() + ? bodyRegion->getArguments() + : Block::BlockArgListType()); + return; + } + + assert(*index == 0 && "unexpected region index"); + regions.emplace_back(getOperation()->getResults()); +} + +void transform::SequenceOp::getRegionInvocationBounds( + ArrayRef operands, SmallVectorImpl &bounds) { + (void)operands; + bounds.emplace_back(1, 1); +} + //===----------------------------------------------------------------------===// // WithPDLPatternsOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt @@ -0,0 +1,11 @@ +add_mlir_dialect_library(MLIRTransformDialectTransforms + CheckUses.cpp + + DEPENDS + MLIRTransformDialectTransformsIncGen + + LINK_LIBS PUBLIC + MLIRTransformDialect + MLIRIR + MLIRPass +) diff --git a/mlir/lib/Dialect/Transform/Transforms/CheckUses.cpp b/mlir/lib/Dialect/Transform/Transforms/CheckUses.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Transform/Transforms/CheckUses.cpp @@ -0,0 +1,402 @@ +//===- CheckUses.cpp - Expensive transform value validity checks ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines a pass that performs expensive opt-in checks for Transform +// dialect values being potentially used after they have been consumed. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/DataFlowAnalysis.h" +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/Dialect/Transform/Transforms/Passes.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/SetOperations.h" + +using namespace mlir; + +namespace { + +/// Returns a reference to a cached set of blocks that are reachable from the +/// given block via edges computed by the `getNextNodes` function. For example, +/// if `getNextNodes` returns successors of a block, this will return the set of +/// reachable blocks; if it returns predecessors of a block, this will return +/// the set of blocks from which the given block can be reached. The block is +/// considered reachable form itself only if there is a cycle. +template +const llvm::SmallPtrSet & +getReachableImpl(Block *block, FnTy getNextNodes, + DenseMap> &cache) { + auto it = cache.find(block); + if (it != cache.end()) + return it->getSecond(); + + llvm::SmallPtrSet &reachable = cache[block]; + SmallVector worklist; + worklist.push_back(block); + while (!worklist.empty()) { + Block *current = worklist.pop_back_val(); + for (Block *predecessor : getNextNodes(current)) { + // The block is reachable from its transitive predecessors. Only add + // them to the worklist if they weren't already visited. + if (reachable.insert(predecessor).second) + worklist.push_back(predecessor); + } + } + return reachable; +} + +/// An analysis that identifies whether a value allocated by a Transform op may +/// be used by another such op after it may have been freed by a third op on +/// some control flow path. This is conceptually similar to a data flow +/// analysis, but relies on side effects related to particular values that +/// currently cannot be modeled by the MLIR data flow analysis framework (also, +/// the lattice element would be rather expensive as it would need to include +/// live and/or freed values for each operation). +/// +/// This analysis is conservatively pessimisic: it will consider that a value +/// may be freed if it is freed on any possible control flow path between its +/// allocation and a relevant use, even if the control never actually flows +/// through the operation that frees the value. It also does not differentiate +/// between may- (freed on at least one control flow path) and must-free (freed +/// on all possible control flow paths) because it would require expensive graph +/// algorithms. +/// +/// It is intended as an additional non-blocking verification or debugging aid +/// for ops in the Transform dialect. It leverages the requirement for Transform +/// dialect ops to implement the MemoryEffectsOpInterface, and expects the +/// values in the Transform IR to have an allocation effect on the +/// TransformMappingResource when defined. +class TransformOpMemFreeAnalysis { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TransformOpMemFreeAnalysis) + + /// Computes the analysis for Transform ops nested in the given operation. + explicit TransformOpMemFreeAnalysis(Operation *root) { + root->walk([&](Operation *op) { + if (isa(op)) { + collectFreedValues(op); + return WalkResult::skip(); + } + return WalkResult::advance(); + }); + } + + /// A list of operations that may be deleting a value. Non-empty list + /// contextually converts to boolean "true" value. + class PotentialDeleters { + public: + /// Creates an empty list that corresponds to the value being live. + static PotentialDeleters live() { return PotentialDeleters({}); } + + /// Creates a list from the operations that may be deleting the value. + static PotentialDeleters maybeFreed(ArrayRef deleters) { + return PotentialDeleters(deleters); + } + + /// Converts to "true" if there are operations that may be deleting the + /// value. + explicit operator bool() const { return !deleters.empty(); } + + /// Concatenates the lists of operations that may be deleting the value. The + /// value is known to be live if the reuslting list is still empty. + PotentialDeleters &operator|=(const PotentialDeleters &other) { + llvm::append_range(deleters, other.deleters); + return *this; + } + + /// Returns the list of ops that may be deleting the value. + ArrayRef getOps() const { return deleters; } + + private: + /// Constructs the list from the given operations. + explicit PotentialDeleters(ArrayRef ops) { + llvm::append_range(deleters, ops); + } + + /// The list of operations that may be deleting the value. + SmallVector deleters; + }; + + /// Returns the list of operations that may be deleting the operand value on + /// any control flow path between the definition of the value and its use as + /// the given operand. For the purposes of this analysis, the value is + /// considered to be allocated at its definition point and never re-allocated. + PotentialDeleters isUseLive(OpOperand &operand) { + const llvm::SmallPtrSet &deleters = freedBy[operand.get()]; + if (deleters.empty()) + return live(); + +#ifndef NDEBUG + // Check that the definition point actually allcoates the value. + Operation *valueSource = + operand.get().isa() + ? operand.get().getDefiningOp() + : operand.get().getParentBlock()->getParentOp(); + auto iface = cast(valueSource); + SmallVector instances; + iface.getEffectsOnResource(transform::TransformMappingResource::get(), + instances); + assert(hasEffect(instances, operand.get()) && + "expected the op defining the value to have an allocation effect " + "on it"); +#endif + + // Collect ancestors of the use operation. + Block *defBlock = operand.get().getParentBlock(); + SmallVector ancestors; + Operation *ancestor = operand.getOwner(); + do { + ancestors.push_back(ancestor); + if (ancestor->getParentRegion() == defBlock->getParent()) + break; + ancestor = ancestor->getParentOp(); + } while (true); + std::reverse(ancestors.begin(), ancestors.end()); + + // Consider the control flow from the definition point of the value to its + // use point. If the use is located in some nested region, consider the path + // from the entry block of the region to the use. + for (Operation *ancestor : ancestors) { + // The block should be considered partially if it is the block that + // contains the definition (allocation) of the value being used, and the + // value is defined in the middle of the block, i.e., is not a block + // argument. + bool isOutermost = ancestor == ancestors.front(); + bool isFromBlockPartial = isOutermost && operand.get().isa(); + + // Check if the value may be freed by operations between its definition + // (allocation) point in its block and the terminator of the block or the + // ancestor of the use if it is located in the same block. This is only + // done for partial blocks here, full blocks will be considered below + // similarly to other blocks. + if (isFromBlockPartial) { + bool defUseSameBlock = ancestor->getBlock() == defBlock; + // Consider all ops from the def to its block terminator, except the + // when the use is in the same block, in which case only consider the + // ops until the user. + if (PotentialDeleters potentialDeleters = isFreedInBlockAfter( + operand.get().getDefiningOp(), operand.get(), + defUseSameBlock ? ancestor : nullptr)) + return potentialDeleters; + } + + // Check if the value may be freed by opeations preceding the ancestor in + // its block. Skip the check for partial blocks that contain both the + // definition and the use point, as this has been already checked above. + if (!isFromBlockPartial || ancestor->getBlock() != defBlock) { + if (PotentialDeleters potentialDeleters = + isFreedInBlockBefore(ancestor, operand.get())) + return potentialDeleters; + } + + // Check if the value may be freed by operations in any of the blocks + // between the definition point (in the outermost region) or the entry + // block of the region (in other regions) and the operand or its ancestor + // in the region. This includes the entire "form" block if (1) the block + // has not been considered as partial above and (2) the block can be + // reached again through some control-flow loop. This includes the entire + // "to" block if it can be reached form itself through some control-flow + // cycle, regardless of whether it has been visited before. + Block *ancestorBlock = ancestor->getBlock(); + Block *from = + isOutermost ? defBlock : &ancestorBlock->getParent()->front(); + if (PotentialDeleters potentialDeleters = + isMaybeFreedOnPaths(from, ancestorBlock, operand.get(), + /*alwaysIncludeFrom=*/!isFromBlockPartial)) + return potentialDeleters; + } + return live(); + } + +private: + /// Make PotentialDeleters constructors available with shorter names. + static PotentialDeleters maybeFreed(ArrayRef deleters) { + return PotentialDeleters::maybeFreed(deleters); + } + static PotentialDeleters live() { return PotentialDeleters::live(); } + + /// Returns the list of operations that may be deleting the given value betwen + /// the first and last operations, non-inclusive. `getNext` indicates the + /// direction of the traversal. + PotentialDeleters + isFreedBetween(Value value, Operation *first, Operation *last, + llvm::function_ref getNext) const { + auto it = freedBy.find(value); + if (it == freedBy.end()) + return live(); + const llvm::SmallPtrSet &deleters = it->getSecond(); + for (Operation *op = getNext(first); op != last; op = getNext(op)) { + if (deleters.contains(op)) + return maybeFreed(op); + } + return live(); + } + + /// Returns the list of operations that may be deleting the given value + /// between `root` and `before` values. `root` is expected to be in the same + /// block as `before` and precede it. If `before` is null, consider all + /// operations until the end of the block including the terminator. + PotentialDeleters isFreedInBlockAfter(Operation *root, Value value, + Operation *before = nullptr) const { + return isFreedBetween(value, root, before, + [](Operation *op) { return op->getNextNode(); }); + } + + /// Returns the list of operations that may be deleting the given value + /// between the entry of the block and the `root` operation. + PotentialDeleters isFreedInBlockBefore(Operation *root, Value value) const { + return isFreedBetween(value, root, nullptr, + [](Operation *op) { return op->getPrevNode(); }); + } + + /// Returns the list of operations that may be deleting the given value on + /// any of the control flow paths between the "form" and the "to" block. The + /// operations from any block visited on any control flow path are + /// consdiered. The "from" block is considered if there is a control flow + /// cycle going through it, i.e., if there is a possibility that all + /// operations in this block are visited or if the `alwaysIncludeFrom` flag is + /// set. The "to" block is considered only if there is a control flow cycle + /// going through it. + PotentialDeleters isMaybeFreedOnPaths(Block *from, Block *to, Value value, + bool alwaysIncludeFrom) { + // Find all blocks that lie on any path between "from" and "to", i.e., the + // intersection of blocks reachable from "from" and blocks from which "to" + // is rechable. + const llvm::SmallPtrSet &sources = getReachableFrom(to); + if (!sources.contains(from)) + return live(); + + llvm::SmallPtrSet reachable(getReachable(from)); + llvm::set_intersect(reachable, sources); + + // If requested, include the "from" block that may not be present in the set + // of visited blocks when there is no cycle going through it. + if (alwaysIncludeFrom) + reachable.insert(from); + + // Join potential deleters from all blocks as we don't know here which of + // the paths through the control flow is taken. + PotentialDeleters potentialDeleters = live(); + for (Block *block : reachable) { + for (Operation &op : *block) { + if (freedBy[value].count(&op)) + potentialDeleters |= maybeFreed(&op); + } + } + return potentialDeleters; + } + + /// Popualtes `reachable` with the set of blocks that are rechable from the + /// given block. A block is considered reachable from itself if there is a + /// cycle in the control-flow graph that invovles the block. + const llvm::SmallPtrSet &getReachable(Block *block) { + return getReachableImpl( + block, [](Block *b) { return b->getSuccessors(); }, reachableCache); + } + + /// Populates `sources` with the set of blocks from which the given block is + /// reachable. + const llvm::SmallPtrSet &getReachableFrom(Block *block) { + return getReachableImpl( + block, [](Block *b) { return b->getPredecessors(); }, + reachableFromCache); + } + + /// Returns true of `instances` contains an effect of `EffectTy` on `value`. + template + static bool hasEffect(ArrayRef instances, + Value value) { + return llvm::any_of(instances, + [&](const MemoryEffects::EffectInstance &instance) { + return instance.getValue() == value && + isa(instance.getEffect()); + }); + } + + /// Records the values that are being freed by an operation or any of its + /// children in `freedBy`. + void collectFreedValues(Operation *root) { + SmallVector instances; + root->walk([&](Operation *child) { + // TODO: extend this to conservatively handle operations with undeclared + // side effects as maybe freeing the operands. + auto iface = cast(child); + instances.clear(); + iface.getEffectsOnResource(transform::TransformMappingResource::get(), + instances); + for (Value operand : child->getOperands()) { + if (hasEffect(instances, operand)) { + // All parents of the operation that frees a value should be + // considered as potentially freeing the value as well. + // + // TODO: differentiate between must-free/may-free as well as between + // this op having the effect and children having the effect. This may + // require some analysis of all control flow paths through the nested + // regions as well as a mechanism to separate proper side effects from + // those obtained by nesting. + Operation *parent = child; + do { + freedBy[operand].insert(parent); + if (parent == root) + break; + parent = parent->getParentOp(); + } while (true); + } + } + }); + } + + /// The mapping from a value to operations that have a Free memory effect on + /// the TransformMappingResource and associated with this value, or to + /// Transform operations transitively containing such operations. + DenseMap> freedBy; + + /// Caches for sets of reachable blocks. + DenseMap> reachableCache; + DenseMap> reachableFromCache; +}; + +#define GEN_PASS_CLASSES +#include "mlir/Dialect/Transform/Transforms/Passes.h.inc" + +//// A simple pass that warns about any use of a value by a transform operation +// that may be using the value after it has been freed. +class CheckUsesPass : public CheckUsesBase { +public: + void runOnOperation() override { + auto &analysis = getAnalysis(); + + getOperation()->walk([&](Operation *child) { + for (OpOperand &operand : child->getOpOperands()) { + TransformOpMemFreeAnalysis::PotentialDeleters deleters = + analysis.isUseLive(operand); + if (!deleters) + continue; + + InFlightDiagnostic diag = child->emitWarning() + << "operand #" << operand.getOperandNumber() + << " may be used after free"; + diag.attachNote(operand.get().getLoc()) << "allocated here"; + for (Operation *d : deleters.getOps()) { + diag.attachNote(d->getLoc()) << "freed here"; + } + } + }); + } +}; + +} // namespace + +namespace mlir { +namespace transform { +std::unique_ptr createCheckUsesPass() { + return std::make_unique(); +} +} // namespace transform +} // namespace mlir diff --git a/mlir/test/Dialect/Transform/check-use-after-free.mlir b/mlir/test/Dialect/Transform/check-use-after-free.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Transform/check-use-after-free.mlir @@ -0,0 +1,169 @@ +// RUN: mlir-opt %s --transform-dialect-check-uses --split-input-file --verify-diagnostics + +func.func @use_after_free_branching_control_flow() { + // expected-note @below {{allocated here}} + %0 = transform.test_produce_param_or_forward_operand 42 + transform.test_transform_op_with_regions { + "transform.test_branching_transform_op_terminator"() : () -> () + }, + { + ^bb0: + "transform.test_branching_transform_op_terminator"()[^bb1, ^bb2] : () -> () + ^bb1: + // expected-note @below {{freed here}} + transform.test_consume_operand_if_matches_param_or_fail %0[42] + "transform.test_branching_transform_op_terminator"()[^bb3] : () -> () + ^bb2: + "transform.test_branching_transform_op_terminator"()[^bb3] : () -> () + ^bb3: + // expected-warning @below {{operand #0 may be used after free}} + transform.sequence %0 { + ^bb0(%arg0: !pdl.operation): + } + "transform.test_branching_transform_op_terminator"() : () -> () + } + return +} + +// ----- + +func.func @use_after_free_in_nested_op() { + // expected-note @below {{allocated here}} + %0 = transform.test_produce_param_or_forward_operand 42 + // expected-note @below {{freed here}} + transform.test_transform_op_with_regions { + "transform.test_branching_transform_op_terminator"() : () -> () + }, + { + ^bb0: + "transform.test_branching_transform_op_terminator"()[^bb1, ^bb2] : () -> () + ^bb1: + transform.test_consume_operand_if_matches_param_or_fail %0[42] + "transform.test_branching_transform_op_terminator"()[^bb3] : () -> () + ^bb2: + "transform.test_branching_transform_op_terminator"()[^bb3] : () -> () + ^bb3: + "transform.test_branching_transform_op_terminator"() : () -> () + } + // expected-warning @below {{operand #0 may be used after free}} + transform.sequence %0 { + ^bb0(%arg0: !pdl.operation): + } + return +} + +// ----- + +func.func @use_after_free_recursive_side_effects() { + transform.sequence { + ^bb0(%arg0: !pdl.operation): + // expected-note @below {{allocated here}} + %0 = transform.sequence %arg0 attributes { ord = 1 } { + ^bb1(%arg1: !pdl.operation): + yield %arg1 : !pdl.operation + } : !pdl.operation + transform.sequence %0 attributes { ord = 2 } { + ^bb2(%arg2: !pdl.operation): + } + transform.sequence %0 attributes { ord = 3 } { + ^bb3(%arg3: !pdl.operation): + } + + // `transform.sequence` has recursive side effects so it has the same "free" + // as the child op it contains. + // expected-note @below {{freed here}} + transform.sequence %0 attributes { ord = 4 } { + ^bb4(%arg4: !pdl.operation): + test_consume_operand_if_matches_param_or_fail %0[42] + } + // expected-warning @below {{operand #0 may be used after free}} + transform.sequence %0 attributes { ord = 5 } { + ^bb3(%arg3: !pdl.operation): + } + } + return +} + +// ----- + +func.func @use_after_free() { + transform.sequence { + ^bb0(%arg0: !pdl.operation): + // expected-note @below {{allocated here}} + %0 = transform.sequence %arg0 attributes { ord = 1 } { + ^bb1(%arg1: !pdl.operation): + yield %arg1 : !pdl.operation + } : !pdl.operation + transform.sequence %0 attributes { ord = 2 } { + ^bb2(%arg2: !pdl.operation): + } + transform.sequence %0 attributes { ord = 3 } { + ^bb3(%arg3: !pdl.operation): + } + + // expected-note @below {{freed here}} + test_consume_operand_if_matches_param_or_fail %0[42] + // expected-warning @below {{operand #0 may be used after free}} + transform.sequence %0 attributes { ord = 5 } { + ^bb3(%arg3: !pdl.operation): + } + } + return +} + +// ----- + +// In the case of a control flow cycle, the operation that uses the value may +// precede the one that frees it in the same block. Both operations should +// be reported as use-after-free. +func.func @use_after_free_self_cycle() { + // expected-note @below {{allocated here}} + %0 = transform.test_produce_param_or_forward_operand 42 + transform.test_transform_op_with_regions { + "transform.test_branching_transform_op_terminator"() : () -> () + }, + { + ^bb0: + "transform.test_branching_transform_op_terminator"()[^bb1] : () -> () + ^bb1: + // expected-warning @below {{operand #0 may be used after free}} + transform.sequence %0 { + ^bb0(%arg0: !pdl.operation): + } + // expected-warning @below {{operand #0 may be used after free}} + // expected-note @below {{freed here}} + transform.test_consume_operand_if_matches_param_or_fail %0[42] + "transform.test_branching_transform_op_terminator"()[^bb1, ^bb2] : () -> () + ^bb2: + "transform.test_branching_transform_op_terminator"() : () -> () + } + return +} + + +// ----- + +// Check that the "free" that happens in a cycle is also reported as potential +// use-after-free. +func.func @use_after_free_cycle() { + // expected-note @below {{allocated here}} + %0 = transform.test_produce_param_or_forward_operand 42 + transform.test_transform_op_with_regions { + "transform.test_branching_transform_op_terminator"() : () -> () + }, + { + ^bb0: + "transform.test_branching_transform_op_terminator"()[^bb1, ^bb2] : () -> () + ^bb1: + // expected-warning @below {{operand #0 may be used after free}} + // expected-note @below {{freed here}} + transform.test_consume_operand_if_matches_param_or_fail %0[42] + "transform.test_branching_transform_op_terminator"()[^bb2, ^bb3] : () -> () + ^bb2: + "transform.test_branching_transform_op_terminator"()[^bb1] : () -> () + ^bb3: + "transform.test_branching_transform_op_terminator"() : () -> () + } + return +} + diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -184,6 +184,21 @@ state.removeExtension(); return success(); } +LogicalResult mlir::test::TestTransformOpWithRegions::apply( + transform::TransformResults &results, transform::TransformState &state) { + return success(); +} + +void mlir::test::TestTransformOpWithRegions::getEffects( + SmallVectorImpl &effects) {} + +LogicalResult mlir::test::TestBranchingTransformOpTerminator::apply( + transform::TransformResults &results, transform::TransformState &state) { + return success(); +} + +void mlir::test::TestBranchingTransformOpTerminator::getEffects( + SmallVectorImpl &effects) {} namespace { /// Test extension of the Transform dialect. Registers additional ops and diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td @@ -92,5 +92,21 @@ let cppNamespace = "::mlir::test"; } +def TestTransformOpWithRegions + : Op, + DeclareOpInterfaceMethods]> { + let regions = (region AnyRegion:$first, AnyRegion:$second); + let assemblyFormat = "attr-dict-with-keyword regions"; + let cppNamespace = "::mlir::test"; +} + +def TestBranchingTransformOpTerminator + : Op, + DeclareOpInterfaceMethods]> { + let successors = (successor VariadicSuccessor:$succ); + let cppNamespace = "::mlir::test"; +} #endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -6266,6 +6266,7 @@ ":TosaDialect", ":TosaToLinalg", ":TransformDialect", + ":TransformDialectTransforms", ":Transforms", ":TransformsPassIncGen", ":VectorOps", @@ -7876,6 +7877,7 @@ name = "TransformDialectTdFiles", srcs = glob(["include/mlir/Dialect/Transform/IR/*.td"]), deps = [ + ":ControlFlowInterfacesTdFiles", ":OpBaseTdFiles", ":PDLDialectTdFiles", ":SideEffectInterfacesTdFiles", @@ -7949,6 +7951,7 @@ srcs = glob(["lib/Dialect/Transform/IR/*.cpp"]), hdrs = glob(["include/mlir/Dialect/Transform/IR/*.h"]), deps = [ + ":ControlFlowInterfaces", ":IR", ":PDLDialect", ":PDLInterpDialect", @@ -7962,6 +7965,47 @@ ], ) +td_library( + name = "TransformDialectTransformsTdFiles", + srcs = glob(["include/mlir/Dialect/Transform/Transforms/*.td"]), + deps = [ + ":PassBaseTdFiles", + ":TransformDialectTdFiles", + ], +) + +gentbl_cc_library( + name = "TransformDialectTransformsIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + [ + "-gen-pass-decls", + "-name=Transform", + ], + "include/mlir/Dialect/Transform/Transforms/Passes.h.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/Transform/Transforms/Passes.td", + deps = [":TransformDialectTransformsTdFiles"], +) + +cc_library( + name = "TransformDialectTransforms", + srcs = glob(["lib/Dialect/Transform/Transforms/*.cpp"]), + hdrs = glob(["include/mlir/Dialect/Transform/Transforms/*.h"]), + deps = [ + ":Analysis", + ":IR", + ":Pass", + ":SideEffectInterfaces", + ":TransformDialect", + ":TransformDialectTransformsIncGen", + "//llvm:Support", + ], +) + td_library( name = "ComplexOpsTdFiles", srcs = [