diff --git a/mlir/include/mlir/Support/InterfaceSupport.h b/mlir/include/mlir/Support/InterfaceSupport.h --- a/mlir/include/mlir/Support/InterfaceSupport.h +++ b/mlir/include/mlir/Support/InterfaceSupport.h @@ -16,7 +16,6 @@ #include "mlir/Support/TypeID.h" #include "llvm/ADT/DenseMap.h" #include "llvm/Support/TypeName.h" -#include "llvm/Support/raw_ostream.h" namespace mlir { namespace detail { @@ -75,10 +74,28 @@ using InterfaceBase = Interface; + /// This is a special trait that registers a given interface with an object. + template + struct Trait : public BaseTrait { + using ModelT = Model; + + /// Define an accessor for the ID of this interface. + static TypeID getInterfaceID() { return TypeID::get(); } + }; + + /// Construct an interface from an instance of the value type. Interface(ValueT t = ValueT()) : BaseType(t), impl(t ? ConcreteType::getInterfaceFor(t) : nullptr) { - assert((!t || impl) && - "instantiating an interface with an unregistered operation"); + assert((!t || impl) && "expected value to provide interface instance"); + } + + /// Construct an interface instance from a type that implements this + /// interface's trait. + template , T>::value> * = nullptr> + Interface(T t) + : BaseType(t), impl(t ? ConcreteType::getInterfaceFor(t) : nullptr) { + assert((!t || impl) && "expected value to provide interface instance"); } /// Support 'classof' by checking if the given object defines the concrete @@ -88,15 +105,6 @@ /// Define an accessor for the ID of this interface. static TypeID getInterfaceID() { return TypeID::get(); } - /// This is a special trait that registers a given interface with an object. - template - struct Trait : public BaseTrait { - using ModelT = Model; - - /// Define an accessor for the ID of this interface. - static TypeID getInterfaceID() { return TypeID::get(); } - }; - protected: /// Get the raw concept in the correct derived concept type. const Concept *getImpl() const { return impl; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -339,10 +339,9 @@ static LinalgOp createLinalgOpOfSameType(LinalgOp op, PatternRewriter &rewriter, Args... args) { if (isa(op.getOperation())) - return cast(rewriter.create(args...).getOperation()); + return rewriter.create(args...); if (isa(op.getOperation())) - return cast( - rewriter.create(args...).getOperation()); + return rewriter.create(args...); llvm_unreachable( "expected only linalg.generic or linalg.indexed_generic ops"); return nullptr;