diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md --- a/mlir/docs/OpDefinitions.md +++ b/mlir/docs/OpDefinitions.md @@ -442,7 +442,7 @@ // Provide only a default definition of the method. // Note: `ConcreteOp` corresponds to the derived operation typename. InterfaceMethod<"/*insert doc here*/", - "unsigned", "getNumInputsAndOutputs", (ins), /*methodBody=*/[{}], [{ + "unsigned", "getNumWithDefault", (ins), /*methodBody=*/[{}], [{ ConcreteOp op = cast(getOperation()); return op.getNumInputs() + op.getNumOutputs(); }]>, @@ -455,6 +455,13 @@ // declaration but instead handled by the op interface trait directly. def OpWithInferTypeInterfaceOp : Op<... [DeclareOpInterfaceMethods]> { ... } + +// Methods that have a default implementation do not have declarations +// generated. If an operation wishes to override the default behavior, it can +// explicitly specify the method that it wishes to override. This will force +// the generation of a declaration for those methods. +def OpWithOverrideInferTypeInterfaceOp : Op<... + [DeclareOpInterfaceMethods]> { ... } ``` A verification method can also be specified on the `OpInterface` by setting diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -567,7 +567,8 @@ //===----------------------------------------------------------------------===// def BranchOp : Std_Op<"br", - [DeclareOpInterfaceMethods, NoSideEffect, Terminator]> { + [DeclareOpInterfaceMethods, + NoSideEffect, Terminator]> { let summary = "branch operation"; let description = [{ The `br` operation represents a branch operation in a function. @@ -603,10 +604,6 @@ /// Erase the operand at 'index' from the operand list. void eraseOperand(unsigned index); - - /// Returns the successor that would be chosen with the given constant - /// operands. Returns nullptr if a single successor could not be chosen. - Block *getSuccessorForOperands(ArrayRef); }]; let hasCanonicalizer = 1; @@ -996,7 +993,8 @@ //===----------------------------------------------------------------------===// def CondBranchOp : Std_Op<"cond_br", - [AttrSizedOperandSegments, DeclareOpInterfaceMethods, + [AttrSizedOperandSegments, + DeclareOpInterfaceMethods, NoSideEffect, Terminator]> { let summary = "conditional branch operation"; let description = [{ @@ -1103,10 +1101,6 @@ eraseSuccessorOperand(falseIndex, index); } - /// Returns the successor that would be chosen with the given constant - /// operands. Returns nullptr if a single successor could not be chosen. - Block *getSuccessorForOperands(ArrayRef operands); - private: /// Get the index of the first true destination operand. unsigned getTrueDestOperandIndex() { return 1; } diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1773,12 +1773,20 @@ // Whether to declare the op interface methods in the op's header. This class // simply wraps an OpInterface but is used to indicate that the method -// declarations should be generated. -class DeclareOpInterfaceMethods : - OpInterface { +// declarations should be generated. This class takes an optional set of methods +// that should have declarations generated even if the method has a default +// implementation. +class DeclareOpInterfaceMethods overridenMethods = []> + : OpInterface { let description = interface.description; let cppClassName = interface.cppClassName; let methods = interface.methods; + + // This field contains a set of method names that should always have their + // declarations generated. This allows for generating declarations for + // methods with default implementations that need to be overridden. + list alwaysOverriddenMethods = overridenMethods; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/TableGen/OpTrait.h b/mlir/include/mlir/TableGen/OpTrait.h --- a/mlir/include/mlir/TableGen/OpTrait.h +++ b/mlir/include/mlir/TableGen/OpTrait.h @@ -105,6 +105,10 @@ // Whether the declaration of methods for this trait should be emitted. bool shouldDeclareMethods() const; + + // Returns the methods that should always be declared if this interface is + // emitting declarations. + std::vector getAlwaysDeclaredMethods() const; }; } // end namespace tblgen diff --git a/mlir/lib/TableGen/OpTrait.cpp b/mlir/lib/TableGen/OpTrait.cpp --- a/mlir/lib/TableGen/OpTrait.cpp +++ b/mlir/lib/TableGen/OpTrait.cpp @@ -63,3 +63,7 @@ bool InterfaceOpTrait::shouldDeclareMethods() const { return def->isSubClassOf("DeclareOpInterfaceMethods"); } + +std::vector InterfaceOpTrait::getAlwaysDeclaredMethods() const { + return def->getValueAsListOfStrings("alwaysOverriddenMethods"); +} diff --git a/mlir/test/mlir-tblgen/op-interface.td b/mlir/test/mlir-tblgen/op-interface.td --- a/mlir/test/mlir-tblgen/op-interface.td +++ b/mlir/test/mlir-tblgen/op-interface.td @@ -1,4 +1,5 @@ // RUN: mlir-tblgen -gen-op-interface-decls -I %S/../../include %s | FileCheck %s --check-prefix=DECL --dump-input-on-failure +// RUN: mlir-tblgen -gen-op-decls -I %S/../../include %s | FileCheck %s --check-prefix=OP_DECL --dump-input-on-failure include "mlir/IR/OpBase.td" @@ -12,6 +13,14 @@ /*methodName=*/"foo", /*args=*/(ins "int":$input) >, + InterfaceMethod< + /*desc=*/[{some function comment}], + /*retTy=*/"int", + /*methodName=*/"default_foo", + /*args=*/(ins "int":$input), + /*body=*/[{}], + /*defaultBody=*/[{ return 0; }] + >, ]; } @@ -27,8 +36,19 @@ def DeclareMethodsOp : Op]>; +def DeclareMethodsWithDefaultOp : Op]>; + // DECL-LABEL: TestOpInterfaceInterfaceTraits // DECL: class TestOpInterface : public OpInterface // DECL: int foo(int input); // DECL-NOT: TestOpInterface + +// OP_DECL-LABEL: class DeclareMethodsOp : public +// OP_DECL: int foo(int input); +// OP_DECL-NOT: int default_foo(int input); + +// OP_DECL-LABEL: class DeclareMethodsWithDefaultOp : public +// OP_DECL: int foo(int input); +// OP_DECL: int default_foo(int input); diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -1280,10 +1280,23 @@ if (!opTrait || !opTrait->shouldDeclareMethods()) continue; auto interface = opTrait->getOpInterface(); - for (auto method : interface.getMethods()) { - // Don't declare if the method has a body or a default implementation. - if (method.getBody() || method.getDefaultImplementation()) + + // Get the set of methods that should always be declared. + auto alwaysDeclaredMethodsVec = opTrait->getAlwaysDeclaredMethods(); + llvm::StringSet<> alwaysDeclaredMethods; + alwaysDeclaredMethods.insert(alwaysDeclaredMethodsVec.begin(), + alwaysDeclaredMethodsVec.end()); + + for (const OpInterfaceMethod &method : interface.getMethods()) { + // Don't declare if the method has a body. + if (method.getBody()) continue; + // Don't declare if the method has a default implementation and the op + // didn't request that it always be declared. + if (method.getDefaultImplementation() && + !alwaysDeclaredMethods.count(method.getName())) + continue; + std::string args; llvm::raw_string_ostream os(args); interleaveComma(method.getArguments(), os,