diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -91,10 +91,7 @@ let hasFolder = 1; } -def Shape_ConstShapeOp : Shape_Op<"const_shape", - [ConstantLike, - NoSideEffect, - DeclareOpInterfaceMethods]> { +def Shape_ConstShapeOp : Shape_Op<"const_shape", [ConstantLike, NoSideEffect]> { let summary = "Creates a constant of !shape.shape type."; let description = [{ Creates a !shape.shape with rank given by the length of `shape` and with diff --git a/mlir/include/mlir/TableGen/Attribute.h b/mlir/include/mlir/TableGen/Attribute.h --- a/mlir/include/mlir/TableGen/Attribute.h +++ b/mlir/include/mlir/TableGen/Attribute.h @@ -230,6 +230,9 @@ std::vector getAllFields() const; }; +// Name of infer type op interface. +extern const char *inferTypeOpInterface; + } // end namespace tblgen } // end namespace mlir diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h --- a/mlir/include/mlir/TableGen/Operator.h +++ b/mlir/include/mlir/TableGen/Operator.h @@ -23,6 +23,7 @@ #include "mlir/TableGen/Type.h" #include "llvm/ADT/PointerUnion.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/SMLoc.h" @@ -227,10 +228,45 @@ // debugging purposes. void print(llvm::raw_ostream &os) const; + // Return whether all the result types are known. + bool allResultTypesKnown() const { return allResultsHaveKnownTypes; }; + + // Pair representing either a index to an argument or a type constraint. Only + // one of these entries should have the non-default value. + struct ArgOrType { + explicit ArgOrType(int index) : index(index), constraint(None) {} + explicit ArgOrType(TypeConstraint constraint) + : index(None), constraint(constraint) {} + bool isArg() const { + assert(constraint.hasValue() ^ index.hasValue()); + return index.hasValue(); + } + bool isType() const { + assert(constraint.hasValue() ^ index.hasValue()); + return constraint.hasValue(); + } + + int getArg() const { return *index; } + TypeConstraint getType() const { return *constraint; } + + private: + Optional index; + Optional constraint; + }; + + // Return all arguments or type constraints with same type as result[index]. + // Requires: all result types are known. + ArrayRef getSameTypeAsResult(int index) const; + private: // Populates the vectors containing operands, attributes, results and traits. void populateOpStructure(); + // Populates type inference info (mostly equality) with input a mapping from + // names to indices for arguments and results. + void populateTypeInferenceInfo( + const llvm::StringMap &argumentsAndResultsIndex); + // The dialect of this op. Dialect dialect; @@ -261,12 +297,18 @@ // The regions of this op. SmallVector regions; + // The argument with the same type as the result. + SmallVector, 4> resultTypeMapping; + // The number of native attributes stored in the leading positions of // `attributes`. int numNativeAttributes; // The TableGen definition of this op. const llvm::Record &def; + + // Whether the type of all results are known. + bool allResultsHaveKnownTypes; }; } // end namespace tblgen diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -223,15 +223,6 @@ OpFoldResult ConstShapeOp::fold(ArrayRef) { return shape(); } -LogicalResult -ConstShapeOp::inferReturnTypes(MLIRContext *context, - Optional location, ValueRange operands, - DictionaryAttr attributes, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { - inferredReturnTypes.push_back(ShapeType::get(context)); - return success(); -} - //===----------------------------------------------------------------------===// // ConstSizeOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp --- a/mlir/lib/TableGen/Attribute.cpp +++ b/mlir/lib/TableGen/Attribute.cpp @@ -288,3 +288,5 @@ return attributes; } + +const char *mlir::tblgen::inferTypeOpInterface = "InferTypeOpInterface"; diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -14,6 +14,8 @@ #include "mlir/TableGen/OpTrait.h" #include "mlir/TableGen/Predicate.h" #include "mlir/TableGen/Type.h" +#include "llvm/ADT/EquivalenceClasses.h" +#include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" @@ -155,13 +157,13 @@ const tblgen::OpTrait *tblgen::Operator::getTrait(StringRef trait) const { for (const auto &t : traits) { - if (auto opTrait = dyn_cast(&t)) { + if (const auto *opTrait = dyn_cast(&t)) { if (opTrait->getTrait() == trait) return opTrait; - } else if (auto opTrait = dyn_cast(&t)) { + } else if (const auto *opTrait = dyn_cast(&t)) { if (opTrait->getTrait() == trait) return opTrait; - } else if (auto opTrait = dyn_cast(&t)) { + } else if (const auto *opTrait = dyn_cast(&t)) { if (opTrait->getTrait() == trait) return opTrait; } @@ -252,22 +254,127 @@ return arguments[index]; } +// Mapping from name of to argument or result index. Arguments are indexed +// to match getArg index, while the results are negatively indexed. +static int resultIndex(int i) { return -1 - i; } + +bool tblgen::Operator::isVariadic() const { + for (auto op : operands) + if (op.isVariadic()) + return true; + for (auto res : results) + if (res.isVariadic()) + return true; + return false; +} + +void tblgen::Operator::populateTypeInferenceInfo( + const llvm::StringMap &argumentsAndResultsIndex) { + // If the type inference op interface is not registered, then do not attempt + // to determine if the result types an be inferred. + auto &recordKeeper = def.getRecords(); + auto *inferTrait = recordKeeper.getDef(inferTypeOpInterface); + allResultsHaveKnownTypes = false; + if (!inferTrait) + return; + + // If there are no results, the skip this else the build method generated + // overlaps with another autogenerated builder. + if (getNumResults() == 0) + return; + + // Skip for ops with variadic operands/results. + // TODO: This can be relaxed. + if (isVariadic()) + return; + + // Skip cases currently being custom generated. + // TODO: Remove special cases. + if (getTrait("OpTrait::SameOperandsAndResultType")) + return; + + llvm::EquivalenceClasses ecs; + resultTypeMapping.resize(getNumResults()); + // Captures the argument whose type matches a given result type. Preference + // towards capturing operands first before attributes. + auto captureMapping = [&](int i) { + bool found = false; + ecs.insert(resultIndex(i)); + auto mi = ecs.findLeader(resultIndex(i)); + for (auto me = ecs.member_end(); mi != me; ++mi) { + if (*mi < 0) { + auto tc = getResultTypeConstraint(i); + if (tc.getBuilderCall().hasValue()) { + resultTypeMapping[i].emplace_back(tc); + found = true; + } + continue; + } + + if (auto *attr = getArg(*mi).dyn_cast()) { + // TODO: Handle attributes. + continue; + } else { + resultTypeMapping[i].emplace_back(*mi); + found = true; + } + } + return found; + }; + + for (const OpTrait &trait : traits) { + const llvm::Record &def = trait.getDef(); + // If the infer type op interface was manually added, then treat it as + // intention that the op needs special handling. + // TODO: Reconsider whether to always generate, this is more conservative + // and keeps existing behavior so starting that way for now. + if (def.isSubClassOf( + llvm::formatv("{0}::Trait", inferTypeOpInterface).str())) + return; + if (const auto *opTrait = dyn_cast(&trait)) + if (opTrait->getTrait().startswith(inferTypeOpInterface)) + return; + + if (!def.isSubClassOf("AllTypesMatch")) + continue; + + auto values = def.getValueAsListOfStrings("values"); + auto root = argumentsAndResultsIndex.lookup(values.front()); + for (StringRef str : values) + ecs.unionSets(argumentsAndResultsIndex.lookup(str), root); + } + + // Verifies that all output types have a corresponding known input type + // and chooses matching operand or attribute (in that order) that + // matches it. + allResultsHaveKnownTypes = + all_of(llvm::seq(0, getNumResults()), captureMapping); + + // If the types could be computed, then add type inference trait. + if (allResultsHaveKnownTypes) + traits.push_back(OpTrait::create(inferTrait->getDefInit())); +} + void tblgen::Operator::populateOpStructure() { auto &recordKeeper = def.getRecords(); - auto typeConstraintClass = recordKeeper.getClass("TypeConstraint"); - auto attrClass = recordKeeper.getClass("Attr"); - auto derivedAttrClass = recordKeeper.getClass("DerivedAttr"); - auto opVarClass = recordKeeper.getClass("OpVariable"); + auto *typeConstraintClass = recordKeeper.getClass("TypeConstraint"); + auto *attrClass = recordKeeper.getClass("Attr"); + auto *derivedAttrClass = recordKeeper.getClass("DerivedAttr"); + auto *opVarClass = recordKeeper.getClass("OpVariable"); numNativeAttributes = 0; DagInit *argumentValues = def.getValueAsDag("arguments"); unsigned numArgs = argumentValues->getNumArgs(); + // Mapping from name of to argument or result index. Arguments are indexed + // to match getArg index, while the results are negatively indexed. + llvm::StringMap argumentsAndResultsIndex; + // Handle operands and native attributes. for (unsigned i = 0; i != numArgs; ++i) { - auto arg = argumentValues->getArg(i); + auto *arg = argumentValues->getArg(i); auto givenName = argumentValues->getArgNameStr(i); - auto argDefInit = dyn_cast(arg); + auto *argDefInit = dyn_cast(arg); if (!argDefInit) PrintFatalError(def.getLoc(), Twine("undefined type for argument #") + Twine(i)); @@ -290,6 +397,8 @@ PrintFatalError(def.getLoc(), "unexpected def type; only defs deriving " "from TypeConstraint or Attr are allowed"); } + if (!givenName.empty()) + argumentsAndResultsIndex[givenName] = i; } // Handle derived attributes. @@ -348,6 +457,8 @@ if (resultDef->isSubClassOf(opVarClass)) resultDef = resultDef->getValueAsDef("constraint"); results.push_back({name, TypeConstraint(resultDef)}); + if (!name.empty()) + argumentsAndResultsIndex[name] = resultIndex(i); } // Handle successors @@ -375,17 +486,19 @@ // Create list of traits, skipping over duplicates: appending to lists in // tablegen is easy, making them unique less so, so dedupe here. - if (auto traitList = def.getValueAsListInit("traits")) { + if (auto *traitList = def.getValueAsListInit("traits")) { // This is uniquing based on pointers of the trait. SmallPtrSet traitSet; traits.reserve(traitSet.size()); - for (auto traitInit : *traitList) { + for (auto *traitInit : *traitList) { // Keep traits in the same order while skipping over duplicates. if (traitSet.insert(traitInit).second) traits.push_back(OpTrait::create(traitInit)); } } + populateTypeInferenceInfo(argumentsAndResultsIndex); + // Handle regions auto *regionsDag = def.getValueAsDag("regions"); auto *regionsOp = dyn_cast(regionsDag->getOperator()); @@ -415,6 +528,12 @@ LLVM_DEBUG(print(llvm::dbgs())); } +auto tblgen::Operator::getSameTypeAsResult(int index) const + -> ArrayRef { + assert(allResultTypesKnown()); + return resultTypeMapping[index]; +} + ArrayRef tblgen::Operator::getLoc() const { return def.getLoc(); } bool tblgen::Operator::hasDescription() const { diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -756,15 +756,6 @@ def OpSymbolBindingB : TEST_Op<"symbol_binding_b", []> { let arguments = (ins I32:$operand); let results = (outs I32); - - let builders = [ - OpBuilder< - "OpBuilder &builder, OperationState &state, Value operand", - [{ - state.types.assign({builder.getIntegerType(32)}); - state.addOperands({operand}); - }]> - ]; } def OpSymbolBindingC : TEST_Op<"symbol_binding_c", []> { let arguments = (ins I32:$operand); @@ -868,17 +859,6 @@ def TwoResultOp : TEST_Op<"two_result"> { let arguments = (ins MultiResultOpEnum:$kind); let results = (outs I32:$result1, F32:$result2); - - let builders = [ - OpBuilder< - "OpBuilder &builder, OperationState &state, IntegerAttr kind", - [{ - auto i32 = builder.getIntegerType(32); - auto f32 = builder.getF32Type(); - state.types.assign({i32, f32}); - state.addAttribute("kind", kind); - }]> - ]; } def AnotherTwoResultOp : TEST_Op<"another_two_result"> { diff --git a/mlir/test/mlir-tblgen/op-decl.td b/mlir/test/mlir-tblgen/op-decl.td --- a/mlir/test/mlir-tblgen/op-decl.td +++ b/mlir/test/mlir-tblgen/op-decl.td @@ -1,6 +1,7 @@ // RUN: mlir-tblgen -gen-op-decls -I %S/../../include %s | FileCheck --dump-input-on-failure %s include "mlir/IR/OpBase.td" +include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" def Test_Dialect : Dialect { @@ -44,8 +45,6 @@ }]; } -// CHECK: class AOp; - // CHECK-LABEL: NS::AOp declarations // CHECK: class AOpOperandAdaptor { @@ -125,6 +124,26 @@ // CHECK: Value b(); // CHECK: static void build(OpBuilder &odsBuilder, OperationState &odsState, /*optional*/Type b, /*optional*/Value a) +// Check that all types match constraint results in generating builder. +// --- + +def NS_FOp : NS_Op<"op_with_all_types_constraint", + [AllTypesMatch<["a", "b"]>]> { + let arguments = (ins AnyType:$a); + let results = (outs AnyType:$b); +} + +// CHECK-LABEL: class FOp : +// CHECK: static LogicalResult inferReturnTypes + +def NS_GOp : NS_Op<"op_with_fixed_return_type", []> { + let arguments = (ins AnyType:$a); + let results = (outs I32:$b); +} + +// CHECK-LABEL: class GOp : +// CHECK: static LogicalResult inferReturnTypes + // Check that default builders can be suppressed. // --- diff --git a/mlir/test/mlir-tblgen/types.mlir b/mlir/test/mlir-tblgen/types.mlir --- a/mlir/test/mlir-tblgen/types.mlir +++ b/mlir/test/mlir-tblgen/types.mlir @@ -438,7 +438,7 @@ // ----- func @same_types_element_mismatch(%arg0: tensor<* x i32>, %arg1: tensor<* x f32>) { - // expected-error@+1 {{all of {x, res} have same type}} + // expected-error@+1 {{type incompatible with return type of operation}} "test.operand0_and_result_have_same_type"(%arg0, %arg1) : (tensor<* x i32>, tensor<* x f32>) -> tensor<* x f32> return } @@ -446,7 +446,7 @@ // ----- func @same_types_shape_mismatch(%arg0: tensor<1x2xi32>, %arg1: tensor<2x1xi32>) { - // expected-error@+1 {{all of {x, res} have same type}} + // expected-error@+1 {{type incompatible with return type of operation}} "test.operand0_and_result_have_same_type"(%arg0, %arg1) : (tensor<1x2xi32>, tensor<2x1xi32>) -> tensor<2x1xi32> return } diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -289,9 +289,15 @@ // Generate the OpInterface methods. void genOpInterfaceMethods(); + // Generate op interface method. + void genOpInterfaceMethod(const tblgen::InterfaceOpTrait *trait); + // Generate the side effect interface methods. void genSideEffectInterfaceMethods(); + // Generate the type inference interface methods. + void genTypeInterfaceMethods(); + private: // The TableGen record for this op. // TODO(antiagainst,zinenko): OpEmitter should not have a Record directly, @@ -315,6 +321,7 @@ verifyCtx.withOp("(*this->getOperation())"); genTraits(); + // Generate C++ code for various op methods. The order here determines the // methods in the generated file. genOpAsmInterface(); @@ -335,6 +342,7 @@ genOpInterfaceMethods(); generateOpFormat(op, opClass); genSideEffectInterfaceMethods(); + genTypeInterfaceMethods(); } void OpEmitter::emitDecl(const Operator &op, raw_ostream &os) { @@ -744,6 +752,10 @@ return canGenerate; } +static bool canInferType(Operator &op) { + return op.getTrait("InferTypeOpInterface::Trait") && op.getNumRegions() == 0; +} + void OpEmitter::genSeparateArgParamBuilder() { SmallVector attrBuilderType; attrBuilderType.push_back(AttrParamKind::WrappedAttr); @@ -808,11 +820,9 @@ llvm_unreachable("unhandled TypeParamKind"); }; - bool canInferType = - op.getTrait("InferTypeOpInterface::Trait") && op.getNumRegions() == 0; for (auto attrType : attrBuilderType) { emit(attrType, TypeParamKind::Separate, /*inferType=*/false); - if (canInferType) + if (canInferType(op)) emit(attrType, TypeParamKind::None, /*inferType=*/true); // Emit separate arg build with collective type, unless there is only one // variadic result, in which case the above would have already generated @@ -1064,11 +1074,8 @@ body << " " << builderOpState << ".addTypes(resultTypes);\n"; // Generate builder that infers type too. - // TODO(jpienaar): Subsume this with general checking if type can be inferred - // automatically. // TODO(jpienaar): Expand to handle regions and successors. - if (op.getTrait("InferTypeOpInterface::Trait") && op.getNumRegions() == 0 && - op.getNumSuccessors() == 0) + if (canInferType(op) && op.getNumSuccessors() == 0) genInferredTypeCollectiveParamBuilder(); } @@ -1312,40 +1319,43 @@ } } +void OpEmitter::genOpInterfaceMethod(const tblgen::InterfaceOpTrait *opTrait) { + auto interface = opTrait->getOpInterface(); + + // Get the set of methods that should always be declared. + auto alwaysDeclaredMethodsVec = opTrait->getAlwaysDeclaredMethods(); + llvm::StringSet<> alwaysDeclaredMethods; + alwaysDeclaredMethods.insert(alwaysDeclaredMethodsVec.begin(), + alwaysDeclaredMethodsVec.end()); + + for (const OpInterfaceMethod &method : interface.getMethods()) { + // Don't declare if the method has a body. + if (method.getBody()) + continue; + // Don't declare if the method has a default implementation and the op + // didn't request that it always be declared. + if (method.getDefaultImplementation() && + !alwaysDeclaredMethods.count(method.getName())) + continue; + + std::string args; + llvm::raw_string_ostream os(args); + interleaveComma(method.getArguments(), os, + [&](const OpInterfaceMethod::Argument &arg) { + os << arg.type << " " << arg.name; + }); + opClass.newMethod(method.getReturnType(), method.getName(), os.str(), + method.isStatic() ? OpMethod::MP_Static + : OpMethod::MP_None, + /*declOnly=*/true); + } +} + void OpEmitter::genOpInterfaceMethods() { for (const auto &trait : op.getTraits()) { - auto opTrait = dyn_cast(&trait); - if (!opTrait || !opTrait->shouldDeclareMethods()) - continue; - auto interface = opTrait->getOpInterface(); - - // Get the set of methods that should always be declared. - auto alwaysDeclaredMethodsVec = opTrait->getAlwaysDeclaredMethods(); - llvm::StringSet<> alwaysDeclaredMethods; - alwaysDeclaredMethods.insert(alwaysDeclaredMethodsVec.begin(), - alwaysDeclaredMethodsVec.end()); - - for (const OpInterfaceMethod &method : interface.getMethods()) { - // Don't declare if the method has a body. - if (method.getBody()) - continue; - // Don't declare if the method has a default implementation and the op - // didn't request that it always be declared. - if (method.getDefaultImplementation() && - !alwaysDeclaredMethods.count(method.getName())) - continue; - - std::string args; - llvm::raw_string_ostream os(args); - interleaveComma(method.getArguments(), os, - [&](const OpInterfaceMethod::Argument &arg) { - os << arg.type << " " << arg.name; - }); - opClass.newMethod(method.getReturnType(), method.getName(), os.str(), - method.isStatic() ? OpMethod::MP_Static - : OpMethod::MP_None, - /*declOnly=*/true); - } + if (const auto *opTrait = dyn_cast(&trait)) + if (opTrait->shouldDeclareMethods()) + genOpInterfaceMethod(opTrait); } } @@ -1425,6 +1435,46 @@ } } +void OpEmitter::genTypeInterfaceMethods() { + if (!op.allResultTypesKnown()) + return; + + auto &method = opClass.newMethod( + "LogicalResult", "inferReturnTypes", + "MLIRContext* context, Optional location, " + "ValueRange operands, DictionaryAttr attributes, RegionRange regions, " + "SmallVectorImpl& inferredReturnTypes", + OpMethod::MP_Static, + /*declOnly=*/false); + auto &os = method.body(); + os << " inferredReturnTypes.resize(" << op.getNumResults() << ");\n"; + + FmtContext fctx; + fctx.withBuilder("odsBuilder"); + os << " Builder odsBuilder(context);\n"; + + auto emitType = + [&](const tblgen::Operator::ArgOrType &type) -> OpMethodBody & { + if (type.isArg()) { + auto argIndex = type.getArg(); + assert(!op.getArg(argIndex).is()); + return os << "operands[" << argIndex << "].getType()"; + } else { + return os << tgfmt(*type.getType().getBuilderCall(), &fctx); + } + }; + + for (int i = 0, e = op.getNumResults(); i != e; ++i) { + os << " inferredReturnTypes[" << i << "] = "; + auto types = op.getSameTypeAsResult(i); + emitType(types[0]) << ";\n"; + if (types.size() == 1) + continue; + // TODO: We could verify equality here, but skipping that for verification. + } + os << " return success();"; +} + void OpEmitter::genParser() { if (!hasStringAttribute(def, "parser") || hasStringAttribute(def, "assemblyFormat"))