Index: mlir/include/mlir/IR/OperationSupport.h =================================================================== --- mlir/include/mlir/IR/OperationSupport.h +++ mlir/include/mlir/IR/OperationSupport.h @@ -18,6 +18,7 @@ #include "mlir/IR/BlockSupport.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Diagnostics.h" +#include "mlir/IR/DialectInterface.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/IR/Location.h" #include "mlir/IR/TypeRange.h" @@ -1261,6 +1262,34 @@ /// Enable Bitmask enums for OperationEquivalence::Flags. LLVM_ENABLE_BITMASK_ENUMS_IN_NAMESPACE(); +//===----------------------------------------------------------------------===// +// OperationEquivalenceInterface +//===----------------------------------------------------------------------===// + +/// This is the interface that allows dialects to control the operation +/// equivalence. +class DialectOperationEquivalenceInterface + : public DialectInterface::Base { +public: + DialectOperationEquivalenceInterface(Dialect *dialect) : Base(dialect) {} + + //===--------------------------------------------------------------------===// + // Analysis Hooks + //===--------------------------------------------------------------------===// + + /// Returns true if the given 'attr' contains attributes that can be + /// ignored in the equivalence comparision. + virtual bool containsNonEssentialAttribute(DictionaryAttr attr) const { + return false; + } + + /// Returns true if the given named attribute 'namedAttr' can be ignored + /// in the equivalence comparision. + virtual bool isNonEssentialAttribute(NamedAttribute namedAttr) const { + return false; + } +}; + //===----------------------------------------------------------------------===// // OperationFingerPrint //===----------------------------------------------------------------------===// Index: mlir/include/mlir/Transforms/CSEUtils.h =================================================================== --- /dev/null +++ mlir/include/mlir/Transforms/CSEUtils.h @@ -0,0 +1,47 @@ +//===- CSEUtils.h - CSE utilities -------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines interfaces for CSE. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TRANSFORMS_CSEUTILS_H +#define MLIR_TRANSFORMS_CSEUTILS_H + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/DialectInterface.h" + +namespace mlir { +class DictionaryAttribute; +class NamedAttribute; + +//===----------------------------------------------------------------------===// +// CSEInterface +//===----------------------------------------------------------------------===// + +class DialectCSEInterface : public DialectInterface::Base { +public: + DialectCSEInterface(Dialect *dialect) : Base(dialect) {} + + //===--------------------------------------------------------------------===// + // Transformation Hooks + //===--------------------------------------------------------------------===// + + /// Returns a dictionary attribute that are combined derived from + /// `opAttr` (attributes attached to an op that is going to be eliminated) and + /// `existingOpAttr`. + virtual DictionaryAttr + combineDiscardableAttributes(DictionaryAttr opAttr, + DictionaryAttr existingOpAttr) const { + return existingOpAttr; + } +}; + +} // namespace mlir + +#endif // MLIR_TRANSFORMS_CSEUTILS_H Index: mlir/lib/IR/OperationSupport.cpp =================================================================== --- mlir/lib/IR/OperationSupport.cpp +++ mlir/lib/IR/OperationSupport.cpp @@ -15,7 +15,9 @@ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" +#include "mlir/Transforms/CSEUtils.h" #include "llvm/ADT/BitVector.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/Support/SHA1.h" #include #include @@ -650,11 +652,24 @@ function_ref hashResults, Flags flags) { // Hash operations based upon their: // - Operation Name - // - Attributes // - Result Types - llvm::hash_code hash = - llvm::hash_combine(op->getName(), op->getDiscardableAttrDictionary(), - op->getResultTypes(), op->hashProperties()); + llvm::hash_code hash = llvm::hash_combine(op->getName(), op->getResultTypes(), + op->hashProperties()); + // - Attributes + auto attr = op->getDiscardableAttrDictionary(); + // If the dialect implements `OperationEquivalenceInterface`, skip + // non-essential attributes. + if (auto *interface = dyn_cast_or_null( + op->getDialect())) + if (interface->containsNonEssentialAttribute(attr)) { + SmallVector attributes( + llvm::make_filter_range(attr, [interface](auto namedAttr) { + return !interface->isNonEssentialAttribute(namedAttr); + })); + attr = DictionaryAttr::get(op->getContext(), attributes); + } + + hash = llvm::hash_combine(hash, attr); // - Location if required if (!(flags & Flags::IgnoreLocations)) @@ -770,16 +785,41 @@ if (lhs == rhs) return true; - // 1. Compare the operation properties. + // 1. Compare the operation properties except for discardable attributes. if (lhs->getName() != rhs->getName() || - lhs->getDiscardableAttrDictionary() != - rhs->getDiscardableAttrDictionary() || lhs->getNumRegions() != rhs->getNumRegions() || lhs->getNumSuccessors() != rhs->getNumSuccessors() || lhs->getNumOperands() != rhs->getNumOperands() || lhs->getNumResults() != rhs->getNumResults() || lhs->hashProperties() != rhs->hashProperties()) return false; + + // Compare discardable attributes. + if (lhs->getDiscardableAttrDictionary() != + rhs->getDiscardableAttrDictionary()) { + // Look up a dialect interface if attributes are different. + auto *interface = + dyn_cast(lhs->getDialect()); + // If there is no dialect interface, give up. + if (!interface) + return false; + auto lDict = lhs->getDiscardableAttrDictionary(); + auto rDict = rhs->getDiscardableAttrDictionary(); + if (!interface->containsNonEssentialAttribute(lDict) && + !interface->containsNonEssentialAttribute(rDict)) + return false; + + // A helper to skip non-essential attributes. + auto skipNonEssentialAttr = [interface](NamedAttribute namedAttr) { + return !interface->isNonEssentialAttribute(namedAttr); + }; + + // Check if they are equivalent when non essential attributes are stripped. + if (!llvm::equal(llvm::make_filter_range(lDict, skipNonEssentialAttr), + llvm::make_filter_range(rDict, skipNonEssentialAttr))) + return false; + } + if (!(flags & IgnoreLocations) && lhs->getLoc() != rhs->getLoc()) return false; Index: mlir/lib/Transforms/CSE.cpp =================================================================== --- mlir/lib/Transforms/CSE.cpp +++ mlir/lib/Transforms/CSE.cpp @@ -16,6 +16,7 @@ #include "mlir/IR/Dominance.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Pass/Pass.h" +#include "mlir/Transforms/CSEUtils.h" #include "llvm/ADT/DenseMapInfo.h" #include "llvm/ADT/Hashing.h" #include "llvm/ADT/ScopedHashTable.h" @@ -113,6 +114,15 @@ void CSE::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op, Operation *existing, bool hasSSADominance) { + // Merge attributes if a dialect interface is given. + if (auto *interface = + dyn_cast_or_null(op->getDialect())) { + auto attr = interface->combineDiscardableAttributes( + op->getDiscardableAttrDictionary(), + existing->getDiscardableAttrDictionary()); + if (existing->getDiscardableAttrDictionary() != attr) + existing->setDiscardableAttrs(attr); + } // If we find one then replace all uses of the current operation with the // existing one and mark it for deletion. We can only replace an operand in // an operation if it has not been visited yet. Index: mlir/test/Transforms/cse.mlir =================================================================== --- mlir/test/Transforms/cse.mlir +++ mlir/test/Transforms/cse.mlir @@ -493,3 +493,14 @@ // CHECK: } // CHECK-NOT: scf.if // CHECK: return %[[if]], %[[if]] + +func.func @cse_test_dialect_interface() -> (i32, i32, i32) { + %0 = "test.always_speculatable_op"() {test.non_essential = "b"} : () -> i32 + %1 = "test.always_speculatable_op"() {test.non_essential = "a"} : () -> i32 + %2 = "test.always_speculatable_op"() {test.non_essential = "c"} : () -> i32 + return %0, %1, %2 : i32, i32, i32 +} + +// CHECK-LABEL: @cse_test_dialect_interface +// CHECK-NEXT: %[[result:.+]] = "test.always_speculatable_op"() {test.non_essential = "a"} : () -> i32 +// CHECK-NEXT: return %[[result]], %[[result]], %[[result]] Index: mlir/test/lib/Dialect/Test/TestDialect.cpp =================================================================== --- mlir/test/lib/Dialect/Test/TestDialect.cpp +++ mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -29,6 +29,7 @@ #include "mlir/Interfaces/InferIntRangeInterface.h" #include "mlir/Reducer/ReductionPatternInterface.h" #include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/CSEUtils.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/SmallString.h" @@ -457,6 +458,64 @@ } }; +/// This class defines the interface for customizing OperationEquivalence. +struct TestOperationEquivalenceInterface + : public DialectOperationEquivalenceInterface { + using DialectOperationEquivalenceInterface:: + DialectOperationEquivalenceInterface; + + //===--------------------------------------------------------------------===// + // Analysis Hooks + //===--------------------------------------------------------------------===// + + /// Returns true if the given 'attr' contains attributes that can be + /// ignored in the equivalence comparison. + bool containsNonEssentialAttribute(DictionaryAttr attr) const override { + return attr.contains("test.non_essential"); + } + + /// Returns true if the given named attribute 'namedAttr' can be ignored + /// in the equivalence comparison. + bool isNonEssentialAttribute(NamedAttribute namedAttr) const override { + return namedAttr.getName() == "test.non_essential"; + } +}; + +/// This class defines the interface for customizing CSE. +struct TestCSEInterface : public DialectCSEInterface { + using DialectCSEInterface::DialectCSEInterface; + + //===--------------------------------------------------------------------===// + // Transformation Hooks + //===--------------------------------------------------------------------===// + + /// Returns a dictionary attribute that are combined derived from + /// `opAttr` (attributes attached to an op that is going to be eliminated) and + /// `existingOpAttr`. + DictionaryAttr + combineDiscardableAttributes(DictionaryAttr opAttr, + DictionaryAttr existingOpAttr) const override { + if (auto src = opAttr.getAs("test.non_essential")) { + StringAttr newValue; + // If both has "test.non_essential" attr, set the smaller value. + if (auto dest = existingOpAttr.getAs("test.non_essential")) { + if (src.getValue() < dest.getValue()) + newValue = src; + } else { + newValue = src; + } + + if (newValue) { + NamedAttrList attributes(existingOpAttr); + attributes.set("test.non_essential", newValue); + existingOpAttr = attributes.getDictionary(getContext()); + } + } + + return existingOpAttr; + } +}; + struct TestReductionPatternInterface : public DialectReductionPatternInterface { public: TestReductionPatternInterface(Dialect *dialect) @@ -564,7 +623,8 @@ addInterface(blobInterface); addInterfaces(); + TestReductionPatternInterface, TestBytecodeDialectInterface, + TestOperationEquivalenceInterface, TestCSEInterface>(); allowUnknownOperations(); // Instantiate our fallback op interface that we'll use on specific