diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h --- a/mlir/include/mlir/IR/MLIRContext.h +++ b/mlir/include/mlir/IR/MLIRContext.h @@ -156,6 +156,10 @@ /// instances. This should not be used directly. StorageUniquer &getAttributeUniquer(); + /// Returns the storage uniquer used for constructing uniqued side effects. + /// This should not be used directly. + StorageUniquer &getEffectUniquer(); + /// These APIs are tracking whether the context will be used in a /// multithreading environment: this has no effect other than enabling /// assertions on misuses of some APIs. diff --git a/mlir/include/mlir/Interfaces/SideEffectInterfaces.h b/mlir/include/mlir/Interfaces/SideEffectInterfaces.h --- a/mlir/include/mlir/Interfaces/SideEffectInterfaces.h +++ b/mlir/include/mlir/Interfaces/SideEffectInterfaces.h @@ -15,60 +15,116 @@ #define MLIR_INTERFACES_SIDEEFFECTS_H #include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StorageUniquerSupport.h" namespace mlir { namespace SideEffects { -//===----------------------------------------------------------------------===// -// Effects -//===----------------------------------------------------------------------===// -/// This class represents a base class for a specific effect type. -class Effect { +namespace detail { +class EffectUniquer; +} // namespace detail + +/// The default storage class for Effects. Only contains the type ID of the +/// effect. +class EffectStorage : public StorageUniquer::BaseStorage { + friend StorageUniquer; + friend MLIRContext; + public: - /// This base class is used for derived effects that are non-parametric. - template - class Base : public BaseEffect { - public: - using BaseT = Base; + /// Constructs a storage with the given type ID. + explicit EffectStorage(const TypeID &id) : typeID(id) {} - /// Return the unique identifier for the base effects class. - static TypeID getEffectID() { return TypeID::get(); } + /// Returns type ID of the storage. + TypeID getTypeID() const { return typeID; } - /// 'classof' used to support llvm style cast functionality. - static bool classof(const ::mlir::SideEffects::Effect *effect) { - return effect->getEffectID() == BaseT::getEffectID(); - } +private: + /// Friends can construct a storage and initialize it later. + EffectStorage() {} + void initialize(const TypeID &id) { typeID = id; } - /// Returns a unique instance for the derived effect class. - static DerivedEffect *get() { - return BaseEffect::template get(); - } - using BaseEffect::get; + /// Type ID of the effect. + TypeID typeID; +}; - protected: - Base() : BaseEffect(BaseT::getEffectID()) {} - }; +namespace detail { +/// A utility class to get or create unique instances of side effects within an +/// MLIRContext. +class EffectUniquer { +public: + /// Get a uniqued non-parametric side effect. + // TODO: support parametric side effects + template + static T get(MLIRContext *ctx) { + return ctx->getEffectUniquer().get(T::getTypeID()); + } +}; +} // namespace detail - /// Return the unique identifier for the base effects class. - TypeID getEffectID() const { return id; } +//===----------------------------------------------------------------------===// +// Effects +//===----------------------------------------------------------------------===// - /// Returns a unique instance for the given effect class. - template static DerivedEffect *get() { - static_assert(std::is_base_of::value, - "expected DerivedEffect to inherit from Effect"); +/// This class represents a base class for a specific effect type. +class Effect { +public: + template + using Base = + mlir::detail::StorageUserBase; + using ImplType = EffectStorage; - static DerivedEffect instance; - return &instance; - } + /// Return the unique identifier for the base effects class. + TypeID getTypeID() const { return impl->getTypeID(); } + + /// Support for LLVM-style casting. + template + bool isa() const; + template + bool isa() const; + template + U dyn_cast() const; + template + U dyn_cast_or_null() const; + template + U cast() const; + + /// Support casting to itself. + static bool classof(Effect) { return true; } + + Effect() : impl(nullptr) {} + /*implicit*/ Effect(ImplType *impl) : impl(impl) {} + Effect(const Effect &) = default; protected: - Effect(TypeID id) : id(id) {} - -private: - /// The id of the derived effect class. - TypeID id; + ImplType *impl; }; +template +bool Effect::isa() const { + assert(impl && "isa<> used on a null attribute."); + return U::classof(*this); +} + +template +bool Effect::isa() const { + return isa() || isa(); +} + +template +U Effect::dyn_cast() const { + return isa() ? U(impl) : U(nullptr); +} +template +U Effect::dyn_cast_or_null() const { + return (impl && isa()) ? U(impl) : U(nullptr); +} +template +U Effect::cast() const { + assert(isa()); + return U(impl); +} + //===----------------------------------------------------------------------===// // Resources //===----------------------------------------------------------------------===// @@ -135,14 +191,14 @@ /// argument) that the effect is applied to. template class EffectInstance { public: - EffectInstance(EffectT *effect, Resource *resource = DefaultResource::get()) + EffectInstance(EffectT effect, Resource *resource = DefaultResource::get()) : effect(effect), resource(resource) {} - EffectInstance(EffectT *effect, Value value, + EffectInstance(EffectT effect, Value value, Resource *resource = DefaultResource::get()) : effect(effect), resource(resource), value(value) {} /// Return the effect being applied. - EffectT *getEffect() const { return effect; } + EffectT getEffect() const { return effect; } /// Return the value the effect is applied on, or nullptr if there isn't a /// known value being affected. @@ -153,7 +209,7 @@ private: /// The specific effect being applied. - EffectT *effect; + EffectT effect; /// The resource that the given value resides in. Resource *resource; @@ -197,22 +253,30 @@ /// The following effect indicates that the operation allocates from some /// resource. An 'allocate' effect implies only allocation of the resource, and /// not any visible mutation or dereference. -struct Allocate : public Effect::Base {}; +struct Allocate : public Effect::Base { + using Base::Base; +}; /// The following effect indicates that the operation frees some resource that /// has been allocated. An 'allocate' effect implies only de-allocation of the /// resource, and not any visible allocation, mutation or dereference. -struct Free : public Effect::Base {}; +struct Free : public Effect::Base { + using Base::Base; +}; /// The following effect indicates that the operation reads from some resource. /// A 'read' effect implies only dereferencing of the resource, and not any /// visible mutation. -struct Read : public Effect::Base {}; +struct Read : public Effect::Base { + using Base::Base; +}; /// The following effect indicates that the operation writes to some resource. A /// 'write' effect implies only mutating a resource, and not any visible /// dereference or read. -struct Write : public Effect::Base {}; +struct Write : public Effect::Base { + using Base::Base; +}; } // namespace MemoryEffects //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Interfaces/SideEffectInterfaces.td b/mlir/include/mlir/Interfaces/SideEffectInterfaces.td --- a/mlir/include/mlir/Interfaces/SideEffectInterfaces.td +++ b/mlir/include/mlir/Interfaces/SideEffectInterfaces.td @@ -91,7 +91,7 @@ }] # baseEffect # [{>> &effects) { getEffects(effects); llvm::erase_if(effects, [&](auto &it) { - return !llvm::isa(it.getEffect()); + return !it.getEffect().template isa(); }); } @@ -100,7 +100,7 @@ SmallVector, 4> effects; getEffects(effects); return llvm::any_of(effects, [](const auto &it) { - return llvm::isa(it.getEffect()); + return it.getEffect().template isa(); }); } @@ -109,7 +109,7 @@ SmallVector, 4> effects; getEffects(effects); return !effects.empty() && llvm::all_of(effects, [](const auto &it) { - return isa(it.getEffect()); + return it.getEffect().template isa(); }); } diff --git a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp --- a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp +++ b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp @@ -51,7 +51,7 @@ // original buffer. if (llvm::any_of( effects, [](const MemoryEffects::EffectInstance &instance) { - return isa(instance.getEffect()); + return instance.getEffect().isa(); })) return v; } diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -24,6 +24,7 @@ #include "mlir/IR/Location.h" #include "mlir/IR/Module.h" #include "mlir/IR/Types.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/SetVector.h" @@ -345,6 +346,13 @@ UnknownLoc unknownLocAttr; DictionaryAttr emptyDictionaryAttr; + //===--------------------------------------------------------------------===// + // Side effect uniquing + //===--------------------------------------------------------------------===// + + /// Effect uniquing. + StorageUniquer effectUniquer; + public: MLIRContextImpl() : identifiers(identifierAllocator) {} ~MLIRContextImpl() { @@ -418,6 +426,19 @@ impl->affineUniquer .registerParametricStorageType(); impl->affineUniquer.registerParametricStorageType(); + + // Register the memory effects with the uniquer. + auto registerMemoryEffect = [this](TypeID typeID) { + impl->effectUniquer + .registerSingletonStorageType( + typeID, [typeID](SideEffects::EffectStorage *storage) { + storage->initialize(typeID); + }); + }; + registerMemoryEffect(MemoryEffects::Allocate::getTypeID()); + registerMemoryEffect(MemoryEffects::Free::getTypeID()); + registerMemoryEffect(MemoryEffects::Read::getTypeID()); + registerMemoryEffect(MemoryEffects::Write::getTypeID()); } MLIRContext::~MLIRContext() {} @@ -832,6 +853,10 @@ return context->getImpl().emptyDictionaryAttr; } +StorageUniquer &MLIRContext::getEffectUniquer() { + return getImpl().effectUniquer; +} + //===----------------------------------------------------------------------===// // AffineMap uniquing //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Interfaces/SideEffectInterfaces.cpp b/mlir/lib/Interfaces/SideEffectInterfaces.cpp --- a/mlir/lib/Interfaces/SideEffectInterfaces.cpp +++ b/mlir/lib/Interfaces/SideEffectInterfaces.cpp @@ -22,7 +22,7 @@ //===----------------------------------------------------------------------===// bool MemoryEffects::Effect::classof(const SideEffects::Effect *effect) { - return isa(effect); + return effect->isa(); } //===----------------------------------------------------------------------===// @@ -65,10 +65,10 @@ if (!llvm::all_of(effects, [op](const MemoryEffects::EffectInstance &it) { // We can drop allocations if the value is a result of the // operation. - if (isa(it.getEffect())) + if (it.getEffect().isa()) return it.getValue() && it.getValue().getDefiningOp() == op; // Otherwise, the effect must be a read. - return isa(it.getEffect()); + return it.getEffect().isa(); })) { return false; } diff --git a/mlir/lib/Transforms/BufferPlacement.cpp b/mlir/lib/Transforms/BufferPlacement.cpp --- a/mlir/lib/Transforms/BufferPlacement.cpp +++ b/mlir/lib/Transforms/BufferPlacement.cpp @@ -269,7 +269,7 @@ effects, std::back_inserter(allocateResultEffects), [=](MemoryEffects::EffectInstance &it) { Value value = it.getValue(); - return isa(it.getEffect()) && value && + return it.getEffect().isa() && value && value.isa() && it.getResource() != SideEffects::AutomaticAllocationScopeResource::get(); @@ -559,7 +559,7 @@ effectInterface.getEffectsOnValue(entry.allocValue, effects); return llvm::any_of( effects, [&](MemoryEffects::EffectInstance &it) { - return isa(it.getEffect()); + return it.getEffect().isa(); }); }); // Assign the associated dealloc operation (if any). diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -639,13 +639,13 @@ DictionaryAttr effectElement = element.cast(); // Get the specific memory effect. - MemoryEffects::Effect *effect = - llvm::StringSwitch( + MemoryEffects::Effect effect = + llvm::StringSwitch( effectElement.get("effect").cast().getValue()) - .Case("allocate", MemoryEffects::Allocate::get()) - .Case("free", MemoryEffects::Free::get()) - .Case("read", MemoryEffects::Read::get()) - .Case("write", MemoryEffects::Write::get()); + .Case("allocate", MemoryEffects::Allocate::get(getContext())) + .Case("free", MemoryEffects::Free::get(getContext())) + .Case("read", MemoryEffects::Read::get(getContext())) + .Case("write", MemoryEffects::Write::get(getContext())); // Check for a result to affect. Value value; diff --git a/mlir/test/lib/IR/TestSideEffects.cpp b/mlir/test/lib/IR/TestSideEffects.cpp --- a/mlir/test/lib/IR/TestSideEffects.cpp +++ b/mlir/test/lib/IR/TestSideEffects.cpp @@ -32,13 +32,13 @@ for (MemoryEffects::EffectInstance instance : effects) { auto diag = op.emitRemark() << "found an instance of "; - if (isa(instance.getEffect())) + if (instance.getEffect().isa()) diag << "'allocate'"; - else if (isa(instance.getEffect())) + else if (instance.getEffect().isa()) diag << "'free'"; - else if (isa(instance.getEffect())) + else if (instance.getEffect().isa()) diag << "'read'"; - else if (isa(instance.getEffect())) + else if (instance.getEffect().isa()) diag << "'write'"; if (instance.getValue()) diff --git a/mlir/test/mlir-tblgen/op-side-effects.td b/mlir/test/mlir-tblgen/op-side-effects.td --- a/mlir/test/mlir-tblgen/op-side-effects.td +++ b/mlir/test/mlir-tblgen/op-side-effects.td @@ -20,9 +20,9 @@ // CHECK: void SideEffectOpA::getEffects // CHECK: for (::mlir::Value value : getODSOperands(0)) -// CHECK: effects.emplace_back(MemoryEffects::Read::get(), value, ::mlir::SideEffects::DefaultResource::get()); +// CHECK: effects.emplace_back(MemoryEffects::Read::get(getContext()), value, ::mlir::SideEffects::DefaultResource::get()); // CHECK: for (::mlir::Value value : getODSResults(0)) -// CHECK: effects.emplace_back(MemoryEffects::Allocate::get(), value, CustomResource::get()); +// CHECK: effects.emplace_back(MemoryEffects::Allocate::get(getContext()), value, CustomResource::get()); // CHECK: void SideEffectOpB::getEffects -// CHECK: effects.emplace_back(MemoryEffects::Write::get(), CustomResource::get()); +// CHECK: effects.emplace_back(MemoryEffects::Write::get(getContext()), CustomResource::get()); diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -1641,7 +1641,7 @@ } body << " effects.emplace_back(" << location.effect.getName() - << "::get()"; + << "::get(getContext())"; // If the effect isn't static, it has a specific value attached to it. if (location.kind != EffectKind::Static)