Index: mlir/include/mlir/Transforms/CSE.h =================================================================== --- mlir/include/mlir/Transforms/CSE.h +++ mlir/include/mlir/Transforms/CSE.h @@ -13,11 +13,14 @@ #ifndef MLIR_TRANSFORMS_CSE_H_ #define MLIR_TRANSFORMS_CSE_H_ +#include "mlir/IR/DialectInterface.h" + namespace mlir { class DominanceInfo; class Operation; class RewriterBase; +struct LogicalResult; /// Eliminate common subexpressions within the given operation. This transform /// looks for and deduplicates equivalent operations. @@ -27,6 +30,38 @@ DominanceInfo &domInfo, Operation *op, bool *changed = nullptr); +//===----------------------------------------------------------------------===// +// CSEInterface +//===----------------------------------------------------------------------===// + +/// This is the interface that allows users to customize CSE. +class DialectCSEInterface : public DialectInterface::Base { +public: + DialectCSEInterface(Dialect *dialect) : Base(dialect) {} + + //===--------------------------------------------------------------------===// + // Analysis Hooks + //===--------------------------------------------------------------------===// + + /// These two hooks are called in DenseMapInfo used by CSE. + + /// Returns a hash for the operation. + virtual unsigned getHashValue(Operation *op) const = 0; + + /// Returns true if two operations are considered to be equivalent. + virtual bool isEqual(Operation *lhs, Operation *rhs) const = 0; + + //===--------------------------------------------------------------------===// + // Transformation Hooks + //===--------------------------------------------------------------------===// + + /// This hook is called when 'op' is considered to be a common subexpression + /// of 'existingOp' and is going to be eliminated. This hook allows users to + /// propagate information of 'op' to 'existingOp'. Note that the hash value of + /// 'existingOp' must not be changed due the mutation of 'existingOp'. + virtual void mergeOperations(Operation *existingOp, Operation *op) const {} +}; + } // namespace mlir #endif // MLIR_TRANSFORMS_CSE_H_ Index: mlir/lib/Transforms/CSE.cpp =================================================================== --- mlir/lib/Transforms/CSE.cpp +++ mlir/lib/Transforms/CSE.cpp @@ -35,8 +35,14 @@ namespace { struct SimpleOperationInfo : public llvm::DenseMapInfo { static unsigned getHashValue(const Operation *opC) { + auto *op = const_cast(opC); + // Use a custom hook if provided. + if (auto *interface = + dyn_cast_or_null(op->getDialect())) + return interface->getHashValue(op); + return OperationEquivalence::computeHash( - const_cast(opC), + op, /*hashOperands=*/OperationEquivalence::directHashValue, /*hashResults=*/OperationEquivalence::ignoreHashValue, OperationEquivalence::IgnoreLocations); @@ -49,6 +55,14 @@ if (lhs == getTombstoneKey() || lhs == getEmptyKey() || rhs == getTombstoneKey() || rhs == getEmptyKey()) return false; + + if (lhs->getDialect() != rhs->getDialect()) + return false; + + if (auto *interface = + dyn_cast_or_null(lhs->getDialect())) + return interface->isEqual(lhs, rhs); + return OperationEquivalence::isEquivalentTo( const_cast(lhsC), const_cast(rhsC), OperationEquivalence::IgnoreLocations); @@ -131,6 +145,11 @@ void CSEDriver::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op, Operation *existing, bool hasSSADominance) { + // Invoke a callback provided by CSE interface. + if (auto *cseInterface = + dyn_cast_or_null(op->getDialect())) + cseInterface->mergeOperations(existing, op); + // If we find one then replace all uses of the current operation with the // existing one and mark it for deletion. We can only replace an operand in // an operation if it has not been visited yet. Index: mlir/test/Transforms/cse.mlir =================================================================== --- mlir/test/Transforms/cse.mlir +++ mlir/test/Transforms/cse.mlir @@ -459,3 +459,14 @@ // CHECK: } // CHECK-NOT: scf.if // CHECK: return %[[if]], %[[if]] + +func.func @cse_test_dialect_interface() -> (i32, i32, i32) { + %0 = "test.always_speculatable_op"() {test.non_essential = "b"} : () -> i32 + %1 = "test.always_speculatable_op"() {test.non_essential = "a"} : () -> i32 + %2 = "test.always_speculatable_op"() {test.non_essential = "c"} : () -> i32 + return %0, %1, %2 : i32, i32, i32 +} + +// CHECK-LABEL: @cse_test_dialect_interface +// CHECK-NEXT: %[[result:.+]] = "test.always_speculatable_op"() {test.non_essential = "a"} : () -> i32 +// CHECK-NEXT: return %[[result]], %[[result]], %[[result]] Index: mlir/test/lib/Dialect/Test/TestDialect.cpp =================================================================== --- mlir/test/lib/Dialect/Test/TestDialect.cpp +++ mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -29,6 +29,7 @@ #include "mlir/Interfaces/InferIntRangeInterface.h" #include "mlir/Reducer/ReductionPatternInterface.h" #include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/CSE.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/SmallString.h" @@ -458,6 +459,82 @@ } }; +/// This class defines the interface for customizing CSE. +struct TestCSEInterface : public DialectCSEInterface { + using DialectCSEInterface::DialectCSEInterface; + + StringRef nonEssentialAttrName = "test.non_essential"; + + //===--------------------------------------------------------------------===// + // Analysis Hooks + //===--------------------------------------------------------------------===// + + /// Return a hash that excludes 'test.non_essential' attributes. + unsigned getHashValue(Operation *op) const override { + auto hashOp = [&](Operation *op) { + auto hash = llvm::hash_combine(op->getName(), op->getResultTypes(), + op->hashProperties()); + auto attr = op->getDiscardableAttrDictionary(); + NamedAttrList attributes(attr); + attributes.erase(nonEssentialAttrName); + return llvm::hash_combine(hash, + attributes.getDictionary(op->getContext())); + }; + + return OperationEquivalence::computeHash( + op, + /*hashOp=*/hashOp, + /*hashOperands=*/OperationEquivalence::directHashValue, + /*hashResults=*/OperationEquivalence::ignoreHashValue, + OperationEquivalence::IgnoreLocations); + } + + /// Return true if operations are same except for 'test.non_essential' + /// attributes. + bool isEqual(Operation *lhs, Operation *rhs) const override { + auto checkOp = [&](Operation *lhs, Operation *rhs) -> LogicalResult { + bool result = lhs->getName() == rhs->getName() && + lhs->getNumRegions() == rhs->getNumRegions() && + lhs->getNumSuccessors() == rhs->getNumSuccessors() && + lhs->getNumOperands() == rhs->getNumOperands() && + lhs->getNumResults() == rhs->getNumResults() && + lhs->hashProperties() == rhs->hashProperties(); + if (!result) + return failure(); + auto lhsAttr = lhs->getDiscardableAttrs(); + auto rhsAttr = rhs->getDiscardableAttrs(); + NamedAttrList lhsFiltered(lhsAttr); + lhsFiltered.erase(nonEssentialAttrName); + NamedAttrList rhsFiltered(rhsAttr); + rhsFiltered.erase(nonEssentialAttrName); + return LogicalResult::success(lhsFiltered == rhsFiltered); + }; + + return OperationEquivalence::isEquivalentTo( + lhs, rhs, checkOp, OperationEquivalence::IgnoreLocations); + }; + + //===--------------------------------------------------------------------===// + // Transformation Hooks + //===--------------------------------------------------------------------===// + + /// Propagate 'test.non_essential' to an existing op. Use a smaller value if + /// both have the attribute. + void mergeOperations(Operation *existingOp, + Operation *opsToBeDeleted) const override { + if (auto rhs = + opsToBeDeleted->getAttrOfType(nonEssentialAttrName)) + if (auto lhs = + existingOp->getAttrOfType(nonEssentialAttrName)) { + if (rhs.getValue() < lhs.getValue()) { + existingOp->setAttr(nonEssentialAttrName, rhs); + } + } else { + existingOp->setAttr(nonEssentialAttrName, rhs); + } + } +}; + struct TestReductionPatternInterface : public DialectReductionPatternInterface { public: TestReductionPatternInterface(Dialect *dialect) @@ -565,7 +642,8 @@ addInterface(blobInterface); addInterfaces(); + TestReductionPatternInterface, TestBytecodeDialectInterface, + TestCSEInterface>(); allowUnknownOperations(); // Instantiate our fallback op interface that we'll use on specific