diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -29,6 +29,7 @@ : public llvm::ilist_node_with_parent, private llvm::TrailingObjects { + public: /// Create a new Operation with the specific fields. static Operation *create(Location location, OperationName name, @@ -204,6 +205,14 @@ /// take O(N) where N is the number of operations within the parent block. bool isBeforeInBlock(Operation *other); + /// Returns true if this operation is identical to the `other` provided + /// operation. Equality here means that this operation has the same name, same + /// attribute, and that any nested region compares equally as well. + /// If the `valueMap` parameter is provided, equality of operands, successors, + /// and successor operands, or their type is also checked with respect to the + /// enclosing regions. + bool equals(Operation *other, DenseMap &valueMap); + void print(raw_ostream &os, const OpPrintingFlags &flags = llvm::None); void print(raw_ostream &os, AsmState &state, const OpPrintingFlags &flags = llvm::None); diff --git a/mlir/include/mlir/IR/Region.h b/mlir/include/mlir/IR/Region.h --- a/mlir/include/mlir/IR/Region.h +++ b/mlir/include/mlir/IR/Region.h @@ -72,6 +72,8 @@ return &Region::blocks; } + bool equals(Region &other, DenseMap &valuesMap); + //===--------------------------------------------------------------------===// // Argument Handling //===--------------------------------------------------------------------===// diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -340,6 +340,42 @@ return orderIndex < other->orderIndex; } +bool Operation::equals(Operation *other, DenseMap &valuesMap) { + if (other->getName() != getName() || + other->getAttrDictionary() != getAttrDictionary() || + other->getNumRegions() != getNumRegions() || + other->getNumSuccessors() != getNumSuccessors() || + other->getNumOperands() != getNumOperands() || + other->getNumResults() != getNumResults()) + return false; + + for (auto argPair : llvm::zip(getOperands(), other->getOperands())) { + Value curArg = std::get<0>(argPair); + Value otherArg = std::get<1>(argPair); + if (curArg.getType() != otherArg.getType()) + return false; + // Check if this value was already mapped to another value. + auto insertion = valuesMap.insert({curArg, otherArg}); + if (insertion.first->getSecond() != otherArg) + return false; + } + for (auto resultPair : llvm::zip(getResults(), other->getResults())) { + Value curRes = std::get<0>(resultPair); + Value otherRes = std::get<1>(resultPair); + if (curRes.getType() != otherRes.getType()) + return false; + // Check if this value was already mapped to another value. + auto insertion = valuesMap.insert({curRes, otherRes}); + if (insertion.first->getSecond() != otherRes) + return false; + } + for (auto regionPair : llvm::zip(getRegions(), other->getRegions())) { + if (!std::get<0>(regionPair).equals(std::get<1>(regionPair), valuesMap)) + return false; + } + return true; +} + /// Update the order index of this operation of this operation if necessary, /// potentially recomputing the order of the parent block. void Operation::updateOrderIfNecessary() { diff --git a/mlir/lib/IR/Region.cpp b/mlir/lib/IR/Region.cpp --- a/mlir/lib/IR/Region.cpp +++ b/mlir/lib/IR/Region.cpp @@ -9,6 +9,7 @@ #include "mlir/IR/Region.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Operation.h" +#include "llvm/ADT/ScopedHashTable.h" using namespace mlir; Region::Region(Operation *container) : container(container) {} @@ -158,6 +159,69 @@ return reinterpret_cast(reinterpret_cast(Anchor) - Offset); } +bool Region::equals(Region &other, DenseMap &valuesMap) { + // Compare the two linked-list of blocks in this region. We don't know if they + // have the same size and we'd like to avoid traversing it twice. + iterator blockIt = begin(); + iterator otherBlockIt = other.begin(); + + DenseMap blocksMap; + while (blockIt != end() && otherBlockIt != other.end()) { + // Check block arguments. + if (blockIt->getNumArguments() != otherBlockIt->getNumArguments()) + return false; + + // Map the two blocks. + auto insertion = blocksMap.insert({&*blockIt, &*otherBlockIt}); + if (insertion.first->getSecond() != &*otherBlockIt) + return false; + + for (auto argPair : + llvm::zip(blockIt->getArguments(), otherBlockIt->getArguments())) { + Value curArg = std::get<0>(argPair); + Value otherArg = std::get<1>(argPair); + if (curArg.getType() != otherArg.getType()) + return false; + // Check if this value was already mapped to another value. + auto insertion = valuesMap.insert({curArg, otherArg}); + if (insertion.first->getSecond() != otherArg) + return false; + } + + // Compare the two linked-list of operations in this block. We don't know if + // they have the same size and we'd like to avoid traversing it twice. + Block::iterator opIt = blockIt->begin(); + Block::iterator otherOpIt = otherBlockIt->begin(); + + while (opIt != blockIt->end() && otherOpIt != otherBlockIt->end()) { + // Check for op equality (recursively). + if (!opIt->equals(&*otherOpIt, valuesMap)) + return false; + // Check successor mapping. + for (auto successorsPair : + llvm::zip(opIt->getSuccessors(), otherOpIt->getSuccessors())) { + Block *curSuccessor = std::get<0>(successorsPair); + Block *otherSuccessor = std::get<1>(successorsPair); + auto insertion = blocksMap.insert({curSuccessor, otherSuccessor}); + if (insertion.first->getSecond() != otherSuccessor) + return false; + } + + ++opIt; + ++otherOpIt; + } + // Check that the blocks have the same size. + if (opIt != blockIt->end() || otherOpIt != otherBlockIt->end()) + return false; + ++blockIt; + ++otherBlockIt; + } + // Check that the regions have the same size. + if (blockIt != end() || otherBlockIt != other.end()) + return false; + return true; +} + /// This is a trait method invoked when a basic block is added to a region. /// We keep the region pointer up to date. void llvm::ilist_traits<::mlir::Block>::addNodeToList(Block *block) { diff --git a/mlir/test/IR/operation-equality.mlir b/mlir/test/IR/operation-equality.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/operation-equality.mlir @@ -0,0 +1,145 @@ +// RUN: mlir-opt %s -split-input-file --test-operations-equality | FileCheck %s + + +// CHECK-LABEL: test.top_level_op +// CHECK-SAME: compares equals + +"test.top_level_op"() : () -> () +"test.top_level_op"() : () -> () + +// ----- + +// CHECK-LABEL: test.top_level_name_mismatch +// CHECK-SAME: compares NOT equals + +"test.top_level_name_mismatch"() : () -> () +"test.top_level_name_mismatch2"() : () -> () + +// ----- + +// CHECK-LABEL: test.top_level_op_attr_mismatch +// CHECK-SAME: compares NOT equals + +"test.top_level_op_attr_mismatch"() { foo = "bar" } : () -> () +"test.top_level_op_attr_mismatch"() { foo = "bar2"} : () -> () + +// ----- + +// CHECK-LABEL: test.top_level_op_cfg +// CHECK-SAME: compares equals + +"test.top_level_op_cfg"() ({ + ^bb0(%arg0 : i32, %arg1 : f32): + "test.some_branching_op"(%arg1, %arg0) [^bb1, ^bb2] : (f32, i32) -> () + ^bb1(%arg2 : f32): + "test.some_branching_op"() : () -> () + ^bb2(%arg3 : i32): + "test.some_branching_op"() : () -> () + }, { + ^bb0(%arg0 : i32, %arg1 : f32): + "test.some_branching_op"(%arg1, %arg0) [^bb1, ^bb2] : (f32, i32) -> () + ^bb1(%arg2 : f32): + "test.some_branching_op"() : () -> () + ^bb2(%arg3 : i32): + "test.some_branching_op"() : () -> () + }) + { attr = "foo" } : () -> () +"test.top_level_op_cfg"() ({ + ^bb0(%arg0 : i32, %arg1 : f32): + "test.some_branching_op"(%arg1, %arg0) [^bb1, ^bb2] : (f32, i32) -> () + ^bb1(%arg2 : f32): + "test.some_branching_op"() : () -> () + ^bb2(%arg3 : i32): + "test.some_branching_op"() : () -> () + }, { + ^bb0(%arg0 : i32, %arg1 : f32): + "test.some_branching_op"(%arg1, %arg0) [^bb1, ^bb2] : (f32, i32) -> () + ^bb1(%arg2 : f32): + "test.some_branching_op"() : () -> () + ^bb2(%arg3 : i32): + "test.some_branching_op"() : () -> () + }) + { attr = "foo" } : () -> () + +// ----- + +// CHECK-LABEL: test.operand_num_mismatch +// CHECK-SAME: compares NOT equals + +"test.operand_num_mismatch"() ({ + ^bb0(%arg0 : i32, %arg1 : f32): + "test.some_branching_op"(%arg1, %arg0) : (f32, i32) -> () + }) : () -> () +"test.operand_num_mismatch"() ({ + ^bb0(%arg0 : i32, %arg1 : f32): + "test.some_branching_op"(%arg1) : (f32) -> () + }) : () -> () + +// ----- + +// CHECK-LABEL: test.operand_type_mismatch +// CHECK-SAME: compares NOT equals + +"test.operand_type_mismatch"() ({ + ^bb0(%arg0 : i32, %arg1 : f32): + "test.some_branching_op"(%arg1, %arg0) : (f32, i32) -> () + }) : () -> () +"test.operand_type_mismatch"() ({ + ^bb0(%arg0 : i32, %arg1 : f32): + "test.some_branching_op"(%arg1, %arg1) : (f32, f32) -> () + }) : () -> () + +// ----- + +// CHECK-LABEL: test.block_type_mismatch +// CHECK-SAME: compares NOT equals + +"test.block_type_mismatch"() ({ + ^bb0(%arg0 : f32, %arg1 : f32): + "test.some_branching_op"() : () -> () + }) : () -> () +"test.block_type_mismatch"() ({ + ^bb0(%arg0 : i32, %arg1 : f32): + "test.some_branching_op"() : () -> () + }) : () -> () + +// ----- + +// CHECK-LABEL: test.block_arg_num_mismatch +// CHECK-SAME: compares NOT equals + +"test.block_arg_num_mismatch"() ({ + ^bb0(%arg0 : f32, %arg1 : f32): + "test.some_branching_op"() : () -> () + }) : () -> () +"test.block_arg_num_mismatch"() ({ + ^bb0(%arg0 : f32): + "test.some_branching_op"() : () -> () + }) : () -> () + +// ----- + +// CHECK-LABEL: test.dataflow_match +// CHECK-SAME: compares equals + +"test.dataflow_match"() ({ + %0:2 = "test.producer"() : () -> (i32, i32) + "test.consumer"(%0#0, %0#1) : (i32, i32) -> () + }) : () -> () +"test.dataflow_match"() ({ + %0:2 = "test.producer"() : () -> (i32, i32) + "test.consumer"(%0#0, %0#1) : (i32, i32) -> () + }) : () -> () +// ----- + +// CHECK-LABEL: test.dataflow_mismatch +// CHECK-SAME: compares NOT equals + +"test.dataflow_mismatch"() ({ + %0:2 = "test.producer"() : () -> (i32, i32) + "test.consumer"(%0#0, %0#1) : (i32, i32) -> () + }) : () -> () +"test.dataflow_mismatch"() ({ + %0:2 = "test.producer"() : () -> (i32, i32) + "test.consumer"(%0#1, %0#0) : (i32, i32) -> () + }) : () -> () diff --git a/mlir/test/lib/IR/CMakeLists.txt b/mlir/test/lib/IR/CMakeLists.txt --- a/mlir/test/lib/IR/CMakeLists.txt +++ b/mlir/test/lib/IR/CMakeLists.txt @@ -6,6 +6,7 @@ TestInterfaces.cpp TestMatchers.cpp TestOpaqueLoc.cpp + TestOperationEquals.cpp TestPrintDefUse.cpp TestPrintNesting.cpp TestSideEffects.cpp diff --git a/mlir/test/lib/IR/TestOperationEquals.cpp b/mlir/test/lib/IR/TestOperationEquals.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/IR/TestOperationEquals.cpp @@ -0,0 +1,47 @@ +//===- TestPrintDefUse.cpp - Passes to illustrate the IR def-use chains ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { +/// This pass illustrates the IR def-use chains through printing. +struct TestOperationEqualPass + : public PassWrapper> { + StringRef getArgument() const final { return "test-operations-equality"; } + StringRef getDescription() const final { return "Test operations equality."; } + void runOnOperation() override { + ModuleOp module = getOperation(); + // Expects two operations at the top-level: + int opCount = module.getBody()->getOperations().size(); + if (opCount != 2) { + module.emitError() << "expected 2 top-level ops in the module, got " + << opCount; + signalPassFailure(); + return; + } + DenseMap valuesMap; + Operation *first = &module.getBody()->front(); + llvm::outs() << first->getName().getStringRef() << " with attr " + << first->getAttrDictionary(); + if (first->equals(&module.getBody()->back(), valuesMap)) + llvm::outs() << " compares equals.\n"; + else + llvm::outs() << " compares NOT equals!\n"; + } +}; +} // end anonymous namespace + +namespace mlir { +void registerTestOperationEqualPass() { + PassRegistration(); +} +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -43,6 +43,7 @@ void registerTestGpuMemoryPromotionPass(); void registerTestLoopPermutationPass(); void registerTestMatchers(); +void registerTestOperationEqualPass(); void registerTestPrintDefUsePass(); void registerTestPrintNestingPass(); void registerTestReducer(); @@ -122,6 +123,7 @@ registerTestGpuMemoryPromotionPass(); registerTestLoopPermutationPass(); registerTestMatchers(); + registerTestOperationEqualPass(); registerTestPrintDefUsePass(); registerTestPrintNestingPass(); registerTestReducer();