diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h @@ -87,6 +87,10 @@ ValueRange inputs; }; +/// Return `true` if `a` and `b` are in mutually exclusive regions as per +/// RegionBranchOpInterface. +bool insideMutuallyExclusiveRegions(Operation *a, Operation *b); + //===----------------------------------------------------------------------===// // RegionBranchTerminatorOpInterface //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -430,9 +430,8 @@ aliasInfo)) continue; - // Special rules for branches. - // TODO: Use an interface. - if (scf::insideMutuallyExclusiveBranches(readingOp, conflictingWritingOp)) + // Ops are not conflicting if they are in mutually exclusive regions. + if (insideMutuallyExclusiveRegions(readingOp, conflictingWritingOp)) continue; LDBG("WRITE = #" << printValueInfo(lastWrite) << "\n"); diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp --- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp +++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp @@ -219,6 +219,78 @@ return success(); } +/// Return `true` if `a` and `b` are in mutually exclusive regions. +/// +/// 1. Find the first common of `a` and `b` (ancestor) that implements +/// RegionBranchOpInterface. +/// 2. Determine the regions `regionA` and `regionB` in which `a` and `b` are +/// contained. +/// 3. Check if `regionA` and `regionB` are mutually exclusive. They are +/// mutually exclusive if they are not reachable from each other as per +/// RegionBranchOpInterface::getSuccessorRegions. +bool mlir::insideMutuallyExclusiveRegions(Operation *a, Operation *b) { + assert(a && "expected non-empty operation"); + assert(b && "expected non-empty operation"); + + auto branchOp = a->getParentOfType(); + while (branchOp) { + // Check if b is inside branchOp. (We already know that a is.) + if (!branchOp->isProperAncestor(b)) { + // Check next enclosing RegionBranchOpInterface. + branchOp = branchOp->getParentOfType(); + continue; + } + + // b is contained in branchOp. Retrieve the regions in which `a` and `b` + // are contained. + Region *regionA = nullptr, *regionB = nullptr; + for (Region &r : branchOp->getRegions()) { + if (r.findAncestorOpInRegion(*a)) { + assert(!regionA && "already found a region for a"); + regionA = &r; + } + if (r.findAncestorOpInRegion(*b)) { + assert(!regionB && "already found a region for b"); + regionB = &r; + } + } + assert(regionA && regionB && "could not find region of op"); + + // Helper function that checks if region `r` is reachable from region + // `begin`. + std::function isRegionReachable = + [&](Region *begin, Region *r) { + if (begin == r) + return true; + if (begin == nullptr) + return false; + // Compute index of region. + int64_t beginIndex = -1; + for (auto it : llvm::enumerate(branchOp->getRegions())) + if (&it.value() == begin) + beginIndex = it.index(); + assert(beginIndex != -1 && "could not find region in op"); + // Retrieve all successors of the region. + SmallVector successors; + branchOp.getSuccessorRegions(beginIndex, successors); + // Call function recursively on all successors. + for (RegionSuccessor successor : successors) + if (isRegionReachable(successor.getSuccessor(), r)) + return true; + return false; + }; + + // `a` and `b` are in mutually exclusive regions if neither region is + // reachable from the other region. + return !isRegionReachable(regionA, regionB) && + !isRegionReachable(regionB, regionA); + } + + // Could not find a common RegionBranchOpInterface among a's and b's + // ancestors. + return false; +} + //===----------------------------------------------------------------------===// // RegionBranchTerminatorOpInterface //===----------------------------------------------------------------------===// diff --git a/mlir/unittests/Interfaces/CMakeLists.txt b/mlir/unittests/Interfaces/CMakeLists.txt --- a/mlir/unittests/Interfaces/CMakeLists.txt +++ b/mlir/unittests/Interfaces/CMakeLists.txt @@ -1,10 +1,12 @@ add_mlir_unittest(MLIRInterfacesTests + ControlFlowInterfacesTest.cpp DataLayoutInterfacesTest.cpp InferTypeOpInterfaceTest.cpp ) target_link_libraries(MLIRInterfacesTests PRIVATE + MLIRControlFlowInterfaces MLIRDataLayoutInterfaces MLIRDLTI MLIRInferTypeOpInterface diff --git a/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp b/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp new file mode 100644 --- /dev/null +++ b/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp @@ -0,0 +1,145 @@ +//===- ControlFlowInterfacesTest.cpp - Unit Tests for Control Flow Interf. ===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/Parser.h" + +#include + +using namespace mlir; + +/// A dummy op that is also a terminator. +struct DummyOp : public Op { + using Op::Op; + static ArrayRef getAttributeNames() { return {}; } + + static StringRef getOperationName() { return "cftest.dummy_op"; } +}; + +/// All regions of this op are mutually exclusive. +struct MutuallyExclusiveRegionsOp + : public Op { + using Op::Op; + static ArrayRef getAttributeNames() { return {}; } + + static StringRef getOperationName() { + return "cftest.mutually_exclusive_regions_op"; + } + + // Regions have no successors. + void getSuccessorRegions(Optional index, + ArrayRef operands, + SmallVectorImpl ®ions) {} +}; + +/// Regions are executed sequentially. +struct SequentialRegionsOp + : public Op { + using Op::Op; + static ArrayRef getAttributeNames() { return {}; } + + static StringRef getOperationName() { return "cftest.sequential_regions_op"; } + + // Region 0 has Region 1 as a successor. + void getSuccessorRegions(Optional index, + ArrayRef operands, + SmallVectorImpl ®ions) { + assert(index.hasValue() && "expected index"); + if (*index == 0) { + Operation *thisOp = this->getOperation(); + regions.push_back(RegionSuccessor(&thisOp->getRegion(1))); + } + } +}; + +/// A dialect putting all the above together. +struct CFTestDialect : Dialect { + explicit CFTestDialect(MLIRContext *ctx) + : Dialect(getDialectNamespace(), ctx, TypeID::get()) { + addOperations(); + } + static StringRef getDialectNamespace() { return "cftest"; } +}; + +TEST(RegionBranchOpInterface, MutuallyExclusiveOps) { + const char *ir = R"MLIR( +"cftest.mutually_exclusive_regions_op"() ( + {"cftest.dummy_op"() : () -> ()}, // op1 + {"cftest.dummy_op"() : () -> ()} // op2 + ) : () -> () + )MLIR"; + + DialectRegistry registry; + registry.insert(); + MLIRContext ctx(registry); + + OwningModuleRef module = parseSourceString(ir, &ctx); + Operation *testOp = &module->getBody()->getOperations().front(); + Operation *op1 = &testOp->getRegion(0).front().front(); + Operation *op2 = &testOp->getRegion(1).front().front(); + + EXPECT_TRUE(insideMutuallyExclusiveRegions(op1, op2)); + EXPECT_TRUE(insideMutuallyExclusiveRegions(op2, op1)); +} + +TEST(RegionBranchOpInterface, NotMutuallyExclusiveOps) { + const char *ir = R"MLIR( +"cftest.sequential_regions_op"() ( + {"cftest.dummy_op"() : () -> ()}, // op1 + {"cftest.dummy_op"() : () -> ()} // op2 + ) : () -> () + )MLIR"; + + DialectRegistry registry; + registry.insert(); + MLIRContext ctx(registry); + + OwningModuleRef module = parseSourceString(ir, &ctx); + Operation *testOp = &module->getBody()->getOperations().front(); + Operation *op1 = &testOp->getRegion(0).front().front(); + Operation *op2 = &testOp->getRegion(1).front().front(); + + EXPECT_FALSE(insideMutuallyExclusiveRegions(op1, op2)); + EXPECT_FALSE(insideMutuallyExclusiveRegions(op2, op1)); +} + +TEST(RegionBranchOpInterface, NestedMutuallyExclusiveOps) { + const char *ir = R"MLIR( +"cftest.mutually_exclusive_regions_op"() ( + { + "cftest.sequential_regions_op"() ( + {"cftest.dummy_op"() : () -> ()}, // op1 + {"cftest.dummy_op"() : () -> ()} // op3 + ) : () -> () + "cftest.dummy_op"() : () -> () + }, + {"cftest.dummy_op"() : () -> ()} // op2 + ) : () -> () + )MLIR"; + + DialectRegistry registry; + registry.insert(); + MLIRContext ctx(registry); + + OwningModuleRef module = parseSourceString(ir, &ctx); + Operation *testOp = &module->getBody()->getOperations().front(); + Operation *op1 = + &testOp->getRegion(0).front().front().getRegion(0).front().front(); + Operation *op2 = &testOp->getRegion(1).front().front(); + Operation *op3 = + &testOp->getRegion(0).front().front().getRegion(1).front().front(); + + EXPECT_TRUE(insideMutuallyExclusiveRegions(op1, op2)); + EXPECT_TRUE(insideMutuallyExclusiveRegions(op3, op2)); + EXPECT_FALSE(insideMutuallyExclusiveRegions(op1, op3)); +}