diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -859,6 +859,8 @@ [](Value v) { return hash_value(v); }, function_ref hashResults = [](Value v) { return hash_value(v); }, + function_ref hashRegions = + [](Region & /*r*/) { return llvm::hash_code{}; }, Flags flags = Flags::None); /// Helper that can be used with `computeHash` above to ignore operation @@ -867,6 +869,9 @@ /// Helper that can be used with `computeHash` above to ignore operation /// operands/result mapping. static llvm::hash_code directHashValue(Value v) { return hash_value(v); } + static llvm::hash_code ignoreRegionHashValue(Region &r) { + return llvm::hash_code{}; + } /// Compare two operations and return if they are equivalent. /// `mapOperands` and `mapResults` are optional callbacks that allows the diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -621,7 +621,8 @@ llvm::hash_code OperationEquivalence::computeHash( Operation *op, function_ref hashOperands, - function_ref hashResults, Flags flags) { + function_ref hashResults, + function_ref hashRegions, Flags flags) { // Hash operations based upon their: // - Operation Name // - Attributes @@ -645,6 +646,10 @@ // - Operands for (Value result : op->getResults()) hash = llvm::hash_combine(hash, hashResults(result)); + + // - Regions + for (Region &r : op->getRegions()) + hash = llvm::hash_combine(hash, hashRegions(r)); return hash; } @@ -717,21 +722,51 @@ if (!(flags & IgnoreLocations) && lhs->getLoc() != rhs->getLoc()) return false; + auto getOperandsListFn = + [](ValueRange values) -> SmallVector> { + return llvm::to_vector(llvm::map_range( + llvm::enumerate(values), [](auto value) -> std::pair { + return {value.value(), value.index()}; + })); + }; ValueRange lhsOperands = lhs->getOperands(), rhsOperands = rhs->getOperands(); - SmallVector lhsOperandStorage, rhsOperandStorage; - if (lhs->hasTrait()) { - lhsOperandStorage.append(lhsOperands.begin(), lhsOperands.end()); - llvm::sort(lhsOperandStorage, [](Value a, Value b) -> bool { - return a.getAsOpaquePointer() < b.getAsOpaquePointer(); - }); - lhsOperands = lhsOperandStorage; - rhsOperandStorage.append(rhsOperands.begin(), rhsOperands.end()); - llvm::sort(rhsOperandStorage, [](Value a, Value b) -> bool { - return a.getAsOpaquePointer() < b.getAsOpaquePointer(); - }); - rhsOperands = rhsOperandStorage; + // For commutative operations use a sorted list, but also track the + // original position of the operands to pass correct values to `mapOperands` + // function. + auto lhsOperandsSortedList = getOperandsListFn(lhsOperands); + auto rhsOperandsSortedList = getOperandsListFn(rhsOperands); + // Commutativity causes issues with the callback logic. For now disable. + if (lhs->hasTrait()) { + auto compareFn = [](std::pair a, + std::pair b) -> bool { + return a.first.getAsOpaquePointer() < b.first.getAsOpaquePointer(); + }; + llvm::sort(lhsOperandsSortedList, compareFn); + llvm::sort(rhsOperandsSortedList, compareFn); } + + auto checkOperandRangeMapping = + [&](ArrayRef> lhs, + ArrayRef> rhs, + function_ref mapValues) { + for (auto operandPair : llvm::zip(lhs, rhs)) { + std::pair curArg = std::get<0>(operandPair); + std::pair otherArg = std::get<1>(operandPair); + if (curArg.first.getType() != otherArg.first.getType()) + return false; + if (failed(mapValues(lhsOperands[curArg.second], + rhsOperands[otherArg.second]))) + return false; + } + return true; + }; + // Check mapping of operands. + if (!checkOperandRangeMapping(lhsOperandsSortedList, rhsOperandsSortedList, + mapOperands)) + return false; + + // Check mapping of results. auto checkValueRangeMapping = [](ValueRange lhs, ValueRange rhs, function_ref mapValues) { @@ -745,9 +780,6 @@ } return true; }; - // Check mapping of operands and results. - if (!checkValueRangeMapping(lhsOperands, rhsOperands, mapOperands)) - return false; if (!checkValueRangeMapping(lhs->getResults(), rhs->getResults(), mapResults)) return false; for (auto regionPair : llvm::zip(lhs->getRegions(), rhs->getRegions())) diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -14,6 +14,7 @@ #include "mlir/Transforms/Passes.h" #include "mlir/IR/Dominance.h" +#include "mlir/IR/Value.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Pass/Pass.h" #include "llvm/ADT/DenseMapInfo.h" @@ -35,8 +36,9 @@ static unsigned getHashValue(const Operation *opC) { return OperationEquivalence::computeHash( const_cast(opC), - /*hashOperands=*/OperationEquivalence::directHashValue, + /*hashOperands=*/hashOperands, /*hashResults=*/OperationEquivalence::ignoreHashValue, + /*hashRegions=*/OperationEquivalence::ignoreRegionHashValue, OperationEquivalence::IgnoreLocations); } static bool isEqual(const Operation *lhsC, const Operation *rhsC) { @@ -47,11 +49,99 @@ if (lhs == getTombstoneKey() || lhs == getEmptyKey() || rhs == getTombstoneKey() || rhs == getEmptyKey()) return false; + + llvm::DenseMap areEquivalentValues; + + if (lhs->getNumRegions() == 1 && rhs->getNumRegions() == 1 && + llvm::hasSingleElement(lhs->getRegion(0)) && + llvm::hasSingleElement(rhs->getRegion(0)) && + lhs->getRegion(0).getNumArguments() == + rhs->getRegion(0).getNumArguments()) { + for (auto bbArgs : llvm::zip(lhs->getRegion(0).getArguments(), + rhs->getRegion(0).getArguments())) { + areEquivalentValues[std::get<0>(bbArgs)] = std::get<1>(bbArgs); + } + } + + auto getParent = [](Value v) -> Operation * { + if (auto blockArg = v.dyn_cast()) + return blockArg.getParentBlock()->getParentOp(); + return v.getDefiningOp()->getParentOp(); + }; + + auto mapOperands = [&](Value lhsValue, Value rhsValue) -> LogicalResult { + if (lhsValue == rhsValue) + return success(); + if (getParent(lhsValue) == lhs && getParent(rhsValue) == rhs && + areEquivalentValues.lookup(lhsValue) == rhsValue) + return success(); + return failure(); + }; + + auto mapResults = [&](Value lhsResult, Value rhsResult) -> LogicalResult { + if (getParent(lhsResult) == lhs && getParent(rhsResult) == rhs) { + auto insertion = areEquivalentValues.insert({lhsResult, rhsResult}); + return success(insertion.first->second == rhsResult); + } + return success(); + }; + return OperationEquivalence::isEquivalentTo( const_cast(lhsC), const_cast(rhsC), - /*mapOperands=*/OperationEquivalence::exactValueMatch, - /*mapResults=*/OperationEquivalence::ignoreValueEquivalence, - OperationEquivalence::IgnoreLocations); + mapOperands, mapResults, OperationEquivalence::IgnoreLocations); + } + static llvm::hash_code hashBlockArguments(BlockArgument arg) { + Block *block = arg.getOwner(); + llvm::hash_code hash = llvm::hash_value(block); + hash = llvm::hash_combine(hash, llvm::hash_value(arg.getArgNumber())); + return hash; + } + static llvm::hash_code hashOperands(Value v) { + if (BlockArgument arg = v.dyn_cast()) { + hashBlockArguments(arg); + } + return hash_value(v); + } + static llvm::hash_code hashRegion(Region &r) { + if (!llvm::hasSingleElement(r)) { + return llvm::hash_code{}; + } + + // [DO NO SUBMIT YET] : Hash computation accounting for region. This is not + // really used but is added here just in case. + Block *body = &r.front(); + Optional hash; + auto combineHash = [&hash](llvm::hash_code update) { + if (hash) + hash = llvm::hash_combine(hash, update); + else + hash = update; + }; + + llvm::hash_code bodyHash = llvm::hash_value(body); + llvm::DenseMap localValueMap; + unsigned localValNum = 0; + for (BlockArgument arg : body->getArguments()) { + llvm::hash_code argHash = + llvm::hash_combine(bodyHash, llvm::hash_value(arg.getArgNumber())); + localValueMap[arg] = argHash; + combineHash(argHash); + localValueMap[arg] = localValNum++; + } + + auto hashOperandsOfOpsInBlock = [&](Value v) { + auto iterator = localValueMap.find(v); + if (iterator == localValueMap.end()) + return hash_value(v); + return iterator->second; + }; + for (Operation &op : *body) { + llvm::hash_code opHash = OperationEquivalence::computeHash( + &op, hashOperandsOfOpsInBlock, OperationEquivalence::ignoreHashValue, + hashRegion, OperationEquivalence::IgnoreLocations); + combineHash(opHash); + } + return llvm::hash_value(hash); } }; } // namespace @@ -204,7 +294,8 @@ // Don't simplify operations with nested blocks. We don't currently model // equality comparisons correctly among other things. It is also unclear // whether we would want to CSE such operations. - if (op->getNumRegions() != 0) + if (!(op->getNumRegions() == 0 || + (op->getNumRegions() == 1 && llvm::hasSingleElement(op->getRegion(0))))) return failure(); // Some simple use case of operation with memory side-effect are dealt with diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp --- a/mlir/lib/Transforms/Utils/RegionUtils.cpp +++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp @@ -429,6 +429,7 @@ auto opHash = OperationEquivalence::computeHash( &op, OperationEquivalence::ignoreHashValue, OperationEquivalence::ignoreHashValue, + OperationEquivalence::ignoreRegionHashValue, OperationEquivalence::IgnoreLocations); hash = llvm::hash_combine(hash, opHash); }