diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h --- a/llvm/include/llvm/ADT/STLExtras.h +++ b/llvm/include/llvm/ADT/STLExtras.h @@ -1697,6 +1697,26 @@ std::equal(adl_begin(Range) + 1, adl_end(Range), adl_begin(Range))); } +/// Compare two ranges using the provided predicate, returns true if all +/// elements satisfy the predicate and false otherwise. None is returned if the +/// ranges' size mismatch. +template +Optional all_of_zip(R &&Lhs, U &&Rhs, Predicate P) { + auto Literator = adl_begin(Lhs); + auto Riterator = adl_begin(Rhs); + auto Lend = adl_end(Lhs); + auto Rend = adl_end(Rhs); + while (Literator != Lend && Riterator != Rend) { + if (!P(*Literator, *Riterator)) + return false; + ++Literator; + ++Riterator; + } + if (Literator != Lend || Riterator != Rend) + return None; + return true; +} + /// Provide a container algorithm similar to C++ Library Fundamentals v2's /// `erase_if` which is equivalent to: /// diff --git a/llvm/unittests/ADT/STLExtrasTest.cpp b/llvm/unittests/ADT/STLExtrasTest.cpp --- a/llvm/unittests/ADT/STLExtrasTest.cpp +++ b/llvm/unittests/ADT/STLExtrasTest.cpp @@ -876,4 +876,13 @@ EXPECT_EQ(2, Destructors); } +TEST(STLExtrasTest, AllOfZip) { + std::vector v1 = {0, 4, 2, 1}; + std::vector v2 = {1, 4, 3, 6}; + EXPECT_TRUE(all_of_zip(v1, v2, [](int L, int R) { return L <= R; })); + EXPECT_FALSE(all_of_zip(v1, v2, [](int L, int R) { return L < R; })); + std::vector v3 = {1, 4}; + EXPECT_EQ(None, all_of_zip(v1, v3, [](int L, int R) { return true; })); +} + } // namespace diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -901,21 +901,43 @@ enum Flags { None = 0, - /// This flag signals that operands should not be considered when checking - /// for equivalence. This allows for users to implement there own - /// equivalence schemes for operand values. The number of operands are still - /// checked, just not the operands themselves. - IgnoreOperands = 1, + // When provided, the location attached to the operation are ignored. + IgnoreLocations = 1, - LLVM_MARK_AS_BITMASK_ENUM(/* LargestValue = */ IgnoreOperands) + LLVM_MARK_AS_BITMASK_ENUM(/* LargestValue = */ IgnoreLocations) }; /// Compute a hash for the given operation. - static llvm::hash_code computeHash(Operation *op, Flags flags = Flags::None); + /// The `hashOperands` and `hashResults` callbacks are expected to return a + /// unique hash_code for a given Value. + static llvm::hash_code computeHash( + Operation *op, + function_ref hashOperands = + [](Value v) { return hash_value(v); }, + function_ref hashResults = + [](Value v) { return hash_value(v); }, + Flags flags = Flags::None); + + /// Helper that can be used with `computeHash` above to ignore operation + /// operands/result mapping. + static llvm::hash_code ignoreHashValue(Value) { return llvm::hash_code{}; } /// Compare two operations and return if they are equivalent. - static bool isEquivalentTo(Operation *lhs, Operation *rhs, - Flags flags = Flags::None); + /// `mapOperands` and `mapResults` are optional callbacks that allows the + /// caller to check the mapping of SSA value between the lhs and rhs + /// operations. It is expected to return success if the mapping is valid and + /// failure if it conflicts with a previous mapping. + static bool + isEquivalentTo(Operation *lhs, Operation *rhs, + function_ref mapOperands, + function_ref mapResults, + Flags flags = Flags::None); + + /// Helper that can be used with `isEquivalentTo` above to ignore operation + /// operands/result mapping. + static LogicalResult ignoreValueEquivalence(Value lhs, Value rhs) { + return success(); + } }; /// Enable Bitmask enums for OperationEquivalence::Flags. diff --git a/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp --- a/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp +++ b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp @@ -108,7 +108,9 @@ hash = llvm::hash_combine( hash, OperationEquivalence::computeHash( - &op, OperationEquivalence::Flags::IgnoreOperands)); + &op, OperationEquivalence::ignoreHashValue, + OperationEquivalence::ignoreHashValue, + OperationEquivalence::Flags::IgnoreLocations)); } return hash; diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -522,7 +522,9 @@ // Operation Equivalency //===----------------------------------------------------------------------===// -llvm::hash_code OperationEquivalence::computeHash(Operation *op, Flags flags) { +llvm::hash_code OperationEquivalence::computeHash( + Operation *op, function_ref hashOperands, + function_ref hashResults, Flags flags) { // Hash operations based upon their: // - Operation Name // - Attributes @@ -531,37 +533,108 @@ op->getName(), op->getAttrDictionary(), op->getResultTypes()); // - Operands - bool ignoreOperands = flags & Flags::IgnoreOperands; - if (!ignoreOperands) { - // TODO: Allow commutative operations to have different ordering. - hash = llvm::hash_combine( - hash, llvm::hash_combine_range(op->operand_begin(), op->operand_end())); - } + for (Value operand : op->getOperands()) + hash = llvm::hash_combine(hash, hashOperands(operand)); + // - Operands + for (Value result : op->getResults()) + hash = llvm::hash_combine(hash, hashResults(result)); return hash; } -bool OperationEquivalence::isEquivalentTo(Operation *lhs, Operation *rhs, - Flags flags) { +static bool +isRegionEquivalentTo(Region *lhs, Region *rhs, + function_ref mapOperands, + function_ref mapResults, + OperationEquivalence::Flags flags) { + DenseMap blocksMap; + Optional allBlockMatch = + llvm::all_of_zip(*lhs, *rhs, [&](Block &lBlock, Block &rBlock) { + // Check block arguments. + if (lBlock.getNumArguments() != rBlock.getNumArguments()) + return false; + + // Map the two blocks. + auto insertion = blocksMap.insert({&lBlock, &rBlock}); + if (insertion.first->getSecond() != &rBlock) + return false; + + for (auto argPair : + llvm::zip(lBlock.getArguments(), rBlock.getArguments())) { + Value curArg = std::get<0>(argPair); + Value otherArg = std::get<1>(argPair); + if (curArg.getType() != otherArg.getType()) + return false; + if (!(flags & OperationEquivalence::IgnoreLocations) && + curArg.getLoc() != otherArg.getLoc()) + return false; + // Check if this value was already mapped to another value. + if (failed(mapOperands(curArg, otherArg))) + return false; + } + + Optional allOpsMatch = llvm::all_of_zip( + lBlock, rBlock, [&](Operation &lOp, Operation &rOp) { + // Check for op equality (recursively). + if (!OperationEquivalence::isEquivalentTo(&lOp, &rOp, mapOperands, + mapResults, flags)) + return false; + // Check successor mapping. + for (auto successorsPair : + llvm::zip(lOp.getSuccessors(), rOp.getSuccessors())) { + Block *curSuccessor = std::get<0>(successorsPair); + Block *otherSuccessor = std::get<1>(successorsPair); + auto insertion = + blocksMap.insert({curSuccessor, otherSuccessor}); + if (insertion.first->getSecond() != otherSuccessor) + return false; + } + return true; + }); + return allOpsMatch.hasValue() && allOpsMatch.getValue(); + }); + return allBlockMatch.hasValue() && allBlockMatch.getValue(); +} + +bool OperationEquivalence::isEquivalentTo( + Operation *lhs, Operation *rhs, + function_ref mapOperands, + function_ref mapResults, Flags flags) { if (lhs == rhs) return true; - // Compare the operation name. - if (lhs->getName() != rhs->getName()) - return false; - // Check operand counts. - if (lhs->getNumOperands() != rhs->getNumOperands()) - return false; - // Compare attributes. - if (lhs->getAttrDictionary() != rhs->getAttrDictionary()) + // Compare the operation properties. + if (lhs->getName() != rhs->getName() || + lhs->getAttrDictionary() != rhs->getAttrDictionary() || + lhs->getNumRegions() != rhs->getNumRegions() || + lhs->getNumSuccessors() != rhs->getNumSuccessors() || + lhs->getNumOperands() != rhs->getNumOperands() || + lhs->getNumResults() != rhs->getNumResults()) return false; - // Compare result types. - if (lhs->getResultTypes() != rhs->getResultTypes()) + if (!(flags & IgnoreLocations) && lhs->getLoc() != rhs->getLoc()) return false; - // Compare operands. - bool ignoreOperands = flags & Flags::IgnoreOperands; - if (ignoreOperands) - return true; - // TODO: Allow commutative operations to have different ordering. - return std::equal(lhs->operand_begin(), lhs->operand_end(), - rhs->operand_begin()); + + // Check mapping of operands. + for (auto operandPair : llvm::zip(lhs->getOperands(), rhs->getOperands())) { + Value curArg = std::get<0>(operandPair); + Value otherArg = std::get<1>(operandPair); + if (curArg.getType() != otherArg.getType()) + return false; + if (failed(mapOperands(curArg, otherArg))) + return false; + } + // Check mapping of results. + for (auto resultPair : llvm::zip(lhs->getResults(), rhs->getResults())) { + Value curRes = std::get<0>(resultPair); + Value otherRes = std::get<1>(resultPair); + if (curRes.getType() != otherRes.getType()) + return false; + if (failed(mapResults(curRes, otherRes))) + return false; + } + for (auto regionPair : llvm::zip(lhs->getRegions(), rhs->getRegions())) + if (!isRegionEquivalentTo(&std::get<0>(regionPair), + &std::get<1>(regionPair), mapOperands, mapResults, + flags)) + return false; + return true; } diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -28,7 +28,10 @@ namespace { struct SimpleOperationInfo : public llvm::DenseMapInfo { static unsigned getHashValue(const Operation *opC) { - return OperationEquivalence::computeHash(const_cast(opC)); + return OperationEquivalence::computeHash( + const_cast(opC), [](Value v) { return hash_value(v); }, + [](Value v) { return llvm::hash_code{}; }, + OperationEquivalence::IgnoreLocations); } static bool isEqual(const Operation *lhsC, const Operation *rhsC) { auto *lhs = const_cast(lhsC); @@ -38,8 +41,15 @@ if (lhs == getTombstoneKey() || lhs == getEmptyKey() || rhs == getTombstoneKey() || rhs == getEmptyKey()) return false; - return OperationEquivalence::isEquivalentTo(const_cast(lhsC), - const_cast(rhsC)); + auto eq = OperationEquivalence::isEquivalentTo( + const_cast(lhsC), const_cast(rhsC), + [](Value lhs, Value rhs) { return success(lhs == rhs); }, + [](Value lhs, Value rhs) { return success(); }, + OperationEquivalence::IgnoreLocations); + llvm::errs() << "Equal: " << *const_cast(lhsC) << " vs " + << *const_cast(rhsC) << " compares " << eq + << "\n"; + return eq; } }; } // end anonymous namespace diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp --- a/mlir/lib/Transforms/Utils/RegionUtils.cpp +++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp @@ -428,7 +428,9 @@ orderIt += numResults; } auto opHash = OperationEquivalence::computeHash( - &op, OperationEquivalence::Flags::IgnoreOperands); + &op, OperationEquivalence::ignoreHashValue, + OperationEquivalence::ignoreHashValue, + OperationEquivalence::IgnoreLocations); hash = llvm::hash_combine(hash, opHash); } } @@ -491,7 +493,9 @@ for (int opI = 0; lhsIt != lhsE && rhsIt != rhsE; ++lhsIt, ++rhsIt, ++opI) { // Check that the operations are equivalent. if (!OperationEquivalence::isEquivalentTo( - &*lhsIt, &*rhsIt, OperationEquivalence::Flags::IgnoreOperands)) + &*lhsIt, &*rhsIt, OperationEquivalence::ignoreValueEquivalence, + OperationEquivalence::ignoreValueEquivalence, + OperationEquivalence::Flags::IgnoreLocations)) return failure(); // Compare the operands of the two operations. If the operand is within diff --git a/mlir/test/IR/operation-equality.mlir b/mlir/test/IR/operation-equality.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/operation-equality.mlir @@ -0,0 +1,186 @@ +// RUN: mlir-opt %s -split-input-file --test-operations-equality | FileCheck %s + + +// CHECK-LABEL: test.top_level_op +// CHECK-SAME: compares equals + +"test.top_level_op"() : () -> () +"test.top_level_op"() : () -> () + +// ----- + +// CHECK-LABEL: test.top_level_op_strict_loc +// CHECK-SAME: compares NOT equals + +"test.top_level_op_strict_loc"() { strict_loc_check } : () -> () +"test.top_level_op_strict_loc"() { strict_loc_check } : () -> () + +// ----- + +// CHECK-LABEL: test.top_level_op_loc_match +// CHECK-SAME: compares equals + +"test.top_level_op_loc_match"() { strict_loc_check } : () -> () loc("foo") +"test.top_level_op_loc_match"() { strict_loc_check } : () -> () loc("foo") + +// ----- + +// CHECK-LABEL: test.top_level_op_block_loc_mismatch +// CHECK-SAME: compares NOT equals + +"test.top_level_op_block_loc_mismatch"() ({ + ^bb0(%a : i32): +}) { strict_loc_check } : () -> () loc("foo") +"test.top_level_op_block_loc_mismatch"() ({ + ^bb0(%a : i32): +}) { strict_loc_check } : () -> () loc("foo") + +// ----- + +// CHECK-LABEL: test.top_level_op_block_loc_match +// CHECK-SAME: compares equals + +"test.top_level_op_block_loc_match"() ({ + ^bb0(%a : i32 loc("bar")): +}) { strict_loc_check } : () -> () loc("foo") +"test.top_level_op_block_loc_match"() ({ + ^bb0(%a : i32 loc("bar")): +}) { strict_loc_check } : () -> () loc("foo") + +// ----- + +// CHECK-LABEL: test.top_level_name_mismatch +// CHECK-SAME: compares NOT equals + +"test.top_level_name_mismatch"() : () -> () +"test.top_level_name_mismatch2"() : () -> () + +// ----- + +// CHECK-LABEL: test.top_level_op_attr_mismatch +// CHECK-SAME: compares NOT equals + +"test.top_level_op_attr_mismatch"() { foo = "bar" } : () -> () +"test.top_level_op_attr_mismatch"() { foo = "bar2"} : () -> () + +// ----- + +// CHECK-LABEL: test.top_level_op_cfg +// CHECK-SAME: compares equals + +"test.top_level_op_cfg"() ({ + ^bb0(%arg0 : i32, %arg1 : f32): + "test.some_branching_op"(%arg1, %arg0) [^bb1, ^bb2] : (f32, i32) -> () + ^bb1(%arg2 : f32): + "test.some_branching_op"() : () -> () + ^bb2(%arg3 : i32): + "test.some_branching_op"() : () -> () + }, { + ^bb0(%arg0 : i32, %arg1 : f32): + "test.some_branching_op"(%arg1, %arg0) [^bb1, ^bb2] : (f32, i32) -> () + ^bb1(%arg2 : f32): + "test.some_branching_op"() : () -> () + ^bb2(%arg3 : i32): + "test.some_branching_op"() : () -> () + }) + { attr = "foo" } : () -> () +"test.top_level_op_cfg"() ({ + ^bb0(%arg0 : i32, %arg1 : f32): + "test.some_branching_op"(%arg1, %arg0) [^bb1, ^bb2] : (f32, i32) -> () + ^bb1(%arg2 : f32): + "test.some_branching_op"() : () -> () + ^bb2(%arg3 : i32): + "test.some_branching_op"() : () -> () + }, { + ^bb0(%arg0 : i32, %arg1 : f32): + "test.some_branching_op"(%arg1, %arg0) [^bb1, ^bb2] : (f32, i32) -> () + ^bb1(%arg2 : f32): + "test.some_branching_op"() : () -> () + ^bb2(%arg3 : i32): + "test.some_branching_op"() : () -> () + }) + { attr = "foo" } : () -> () + +// ----- + +// CHECK-LABEL: test.operand_num_mismatch +// CHECK-SAME: compares NOT equals + +"test.operand_num_mismatch"() ({ + ^bb0(%arg0 : i32, %arg1 : f32): + "test.some_branching_op"(%arg1, %arg0) : (f32, i32) -> () + }) : () -> () +"test.operand_num_mismatch"() ({ + ^bb0(%arg0 : i32, %arg1 : f32): + "test.some_branching_op"(%arg1) : (f32) -> () + }) : () -> () + +// ----- + +// CHECK-LABEL: test.operand_type_mismatch +// CHECK-SAME: compares NOT equals + +"test.operand_type_mismatch"() ({ + ^bb0(%arg0 : i32, %arg1 : f32): + "test.some_branching_op"(%arg1, %arg0) : (f32, i32) -> () + }) : () -> () +"test.operand_type_mismatch"() ({ + ^bb0(%arg0 : i32, %arg1 : f32): + "test.some_branching_op"(%arg1, %arg1) : (f32, f32) -> () + }) : () -> () + +// ----- + +// CHECK-LABEL: test.block_type_mismatch +// CHECK-SAME: compares NOT equals + +"test.block_type_mismatch"() ({ + ^bb0(%arg0 : f32, %arg1 : f32): + "test.some_branching_op"() : () -> () + }) : () -> () +"test.block_type_mismatch"() ({ + ^bb0(%arg0 : i32, %arg1 : f32): + "test.some_branching_op"() : () -> () + }) : () -> () + +// ----- + +// CHECK-LABEL: test.block_arg_num_mismatch +// CHECK-SAME: compares NOT equals + +"test.block_arg_num_mismatch"() ({ + ^bb0(%arg0 : f32, %arg1 : f32): + "test.some_branching_op"() : () -> () + }) : () -> () +"test.block_arg_num_mismatch"() ({ + ^bb0(%arg0 : f32): + "test.some_branching_op"() : () -> () + }) : () -> () + +// ----- + +// CHECK-LABEL: test.dataflow_match +// CHECK-SAME: compares equals + +"test.dataflow_match"() ({ + %0:2 = "test.producer"() : () -> (i32, i32) + "test.consumer"(%0#0, %0#1) : (i32, i32) -> () + }) : () -> () +"test.dataflow_match"() ({ + %0:2 = "test.producer"() : () -> (i32, i32) + "test.consumer"(%0#0, %0#1) : (i32, i32) -> () + }) : () -> () + +// ----- + +// CHECK-LABEL: test.dataflow_mismatch +// CHECK-SAME: compares NOT equals + +"test.dataflow_mismatch"() ({ + %0:2 = "test.producer"() : () -> (i32, i32) + "test.consumer"(%0#0, %0#1) : (i32, i32) -> () + }) : () -> () +"test.dataflow_mismatch"() ({ + %0:2 = "test.producer"() : () -> (i32, i32) + "test.consumer"(%0#1, %0#0) : (i32, i32) -> () + }) : () -> () diff --git a/mlir/test/Transforms/cse.mlir b/mlir/test/Transforms/cse.mlir --- a/mlir/test/Transforms/cse.mlir +++ b/mlir/test/Transforms/cse.mlir @@ -264,4 +264,4 @@ "foo.yield"(%0) : (i32) -> () } return -} +} diff --git a/mlir/test/lib/IR/CMakeLists.txt b/mlir/test/lib/IR/CMakeLists.txt --- a/mlir/test/lib/IR/CMakeLists.txt +++ b/mlir/test/lib/IR/CMakeLists.txt @@ -6,6 +6,7 @@ TestInterfaces.cpp TestMatchers.cpp TestOpaqueLoc.cpp + TestOperationEquals.cpp TestPrintDefUse.cpp TestPrintNesting.cpp TestSideEffects.cpp diff --git a/mlir/test/lib/IR/TestOperationEquals.cpp b/mlir/test/lib/IR/TestOperationEquals.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/IR/TestOperationEquals.cpp @@ -0,0 +1,56 @@ +//===- TestOperationEquals.cpp - Passes to test OperationEquivalence ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { +/// This pass illustrates the IR def-use chains through printing. +struct TestOperationEqualPass + : public PassWrapper> { + StringRef getArgument() const final { return "test-operations-equality"; } + StringRef getDescription() const final { return "Test operations equality."; } + void runOnOperation() override { + ModuleOp module = getOperation(); + // Expects two operations at the top-level: + int opCount = module.getBody()->getOperations().size(); + if (opCount != 2) { + module.emitError() << "expected 2 top-level ops in the module, got " + << opCount; + signalPassFailure(); + return; + } + DenseMap valuesMap; + auto mapValue = [&](Value lhs, Value rhs) { + auto insertion = valuesMap.insert({lhs, rhs}); + return success(insertion.first->second == rhs); + }; + + Operation *first = &module.getBody()->front(); + llvm::outs() << first->getName().getStringRef() << " with attr " + << first->getAttrDictionary(); + OperationEquivalence::Flags flags{}; + if (!first->hasAttr("strict_loc_check")) + flags |= OperationEquivalence::IgnoreLocations; + if (OperationEquivalence::isEquivalentTo(first, &module.getBody()->back(), + mapValue, mapValue, flags)) + llvm::outs() << " compares equals.\n"; + else + llvm::outs() << " compares NOT equals!\n"; + } +}; +} // end anonymous namespace + +namespace mlir { +void registerTestOperationEqualPass() { + PassRegistration(); +} +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -43,6 +43,7 @@ void registerTestGpuMemoryPromotionPass(); void registerTestLoopPermutationPass(); void registerTestMatchers(); +void registerTestOperationEqualPass(); void registerTestPrintDefUsePass(); void registerTestPrintNestingPass(); void registerTestReducer(); @@ -122,6 +123,7 @@ registerTestGpuMemoryPromotionPass(); registerTestLoopPermutationPass(); registerTestMatchers(); + registerTestOperationEqualPass(); registerTestPrintDefUsePass(); registerTestPrintNestingPass(); registerTestReducer();