diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td --- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td @@ -1305,13 +1305,6 @@ SignlessIntegerLikeOfAnyRank:$lhs, SignlessIntegerLikeOfAnyRank:$rhs); - let builders = [ - OpBuilder<(ins "CmpIPredicate":$predicate, "Value":$lhs, "Value":$rhs), [{ - build($_builder, $_state, ::getI1SameShape(lhs.getType()), - predicate, lhs, rhs); - }]> - ]; - let extraClassDeclaration = [{ static arith::CmpIPredicate getPredicateByName(StringRef name); }]; @@ -1356,13 +1349,6 @@ FloatLike:$lhs, FloatLike:$rhs); - let builders = [ - OpBuilder<(ins "CmpFPredicate":$predicate, "Value":$lhs, "Value":$rhs), [{ - build($_builder, $_state, ::getI1SameShape(lhs.getType()), - predicate, lhs, rhs); - }]> - ]; - let extraClassDeclaration = [{ static arith::CmpFPredicate getPredicateByName(StringRef name); }]; diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td @@ -294,12 +294,6 @@ "the reference to load from", [MemRead]>:$memref); let results = (outs AnyTensor:$result); - let builders = [ - OpBuilder<(ins "Value":$memref), [{ - $_state.addOperands(memref); - $_state.addTypes(memref::getTensorTypeFromMemRefType(memref.getType())); - }]>]; - let extraClassDeclaration = [{ /// The result of a to_tensor is always a tensor. TensorType getType() { diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -119,9 +119,6 @@ LLVM_ScalarOrVectorOf]>:$lhs, AnyTypeOf<[LLVM_ScalarOrVectorOf, LLVM_ScalarOrVectorOf]>:$rhs); - let builders = [ - OpBuilder<(ins "ICmpPredicate":$predicate, "Value":$lhs, "Value":$rhs)> - ]; let hasCustomAssemblyFormat = 1; string llvmInstName = "ICmp"; string llvmBuilder = [{ @@ -145,9 +142,6 @@ LLVM_ScalarOrVectorOf:$rhs, DefaultValuedAttr:$fastmathFlags); - let builders = [ - OpBuilder<(ins "FCmpPredicate":$predicate, "Value":$lhs, "Value":$rhs)> - ]; let hasCustomAssemblyFormat = 1; string llvmInstName = "FCmp"; string llvmBuilder = [{ @@ -583,11 +577,6 @@ let arguments = (ins LLVM_AnyVector:$vector, AnyInteger:$position); let results = (outs LLVM_Type:$res); - let builders = [ - OpBuilder<(ins "Value":$vector, "Value":$position, - CArg<"ArrayRef", "{}">:$attrs)> - ]; - let assemblyFormat = [{ $vector `[` $position `:` type($position) `]` attr-dict `:` type($vector) }]; diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -1158,14 +1158,6 @@ Variadic:$indices); let results = (outs AnyType:$result); - let builders = [ - OpBuilder<(ins "Value":$memref, CArg<"ValueRange", "{}">:$indices), [{ - auto memrefType = memref.getType().cast(); - $_state.addOperands(memref); - $_state.addOperands(indices); - $_state.types.push_back(memrefType.getElementType()); - }]>]; - let extraClassDeclaration = [{ Value getMemRef() { return getOperand(0); } void setMemRef(Value value) { setOperand(0, value); } diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td @@ -30,11 +30,6 @@ "getUnaryOpResultType($_self)" >])> { let assemblyFormat = "$operand1 `,` $operand2 `:` type($operand1) attr-dict"; - - let builders = [ - OpBuilder<(ins "Value":$lhs, "Value":$rhs), - [{::buildLogicalBinaryOp($_builder, $_state, lhs, rhs);}]> - ]; } class SPIRV_LogicalUnaryOp])> { let assemblyFormat = "$operand `:` type($operand) attr-dict"; - - let builders = [ - OpBuilder<(ins "Value":$value), - [{::buildLogicalUnaryOp($_builder, $_state, value);}]> - ]; } // ----- diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -237,12 +237,6 @@ let results = (outs AnyType:$result); let assemblyFormat = "$tensor `[` $indices `]` attr-dict `:` type($tensor)"; - let builders = [ - OpBuilder<(ins "Value":$tensor, CArg<"ValueRange", "{}">:$indices), [{ - auto resType = tensor.getType().cast().getElementType(); - build($_builder, $_state, resType, tensor, indices); - }]>]; - let hasCanonicalizer = 1; let hasFolder = 1; let hasVerifier = 1; @@ -292,7 +286,7 @@ between different flavors of ops on that operate on tensors. #### Verification vs Inference in the rank-reduced case - + Note that there may be multiple ways to infer a resulting rank-reduced type. e.g. 1x6x1 could potentially rank-reduce to either 1x6 or 6x1 2-D shapes. @@ -724,13 +718,6 @@ $scalar `into` $dest `[` $indices `]` attr-dict `:` type($dest) }]; - let builders = [ - OpBuilder<(ins "Value":$scalar, "Value":$dest, - CArg<"ValueRange", "{}">:$indices), [{ - auto resType = dest.getType(); - build($_builder, $_state, resType, scalar, dest, indices); - }]>]; - let extraClassDeclaration = [{ std::pair getDpsInitsPositionRange() { return {1, 2}; // `dest` operand @@ -795,7 +782,7 @@ behavior of tensor.extract_slice. #### Verification in the rank-reduced case - + The same verification discussion and mechanisms apply as for ExtractSliceOp. Unlike ExtractSliceOp however, there is no need for a specific inference. @@ -1399,7 +1386,7 @@ rank-reducing behavior of tensor.insert_slice and tensor.extract_slice. #### Verification in the rank-reduced case - + The same verification discussion and mechanisms apply as for ExtractSliceOp. Unlike ExtractSliceOp however, there is no need for a specific inference. }]; diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -576,8 +576,6 @@ let builders = [ // 0-D builder. OpBuilder<(ins "Value":$source)>, - // 1-D + position builder. - OpBuilder<(ins "Value":$source, "Value":$position)>, ]; let extraClassDeclaration = [{ VectorType getVectorType() { diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h --- a/mlir/include/mlir/TableGen/Operator.h +++ b/mlir/include/mlir/TableGen/Operator.h @@ -37,6 +37,39 @@ namespace mlir { namespace tblgen { +/// This class represents an inferred result type. The result type can be +/// inferred from an argument or result type. If it is inferred from another +/// result type, that type must be buildable or inferred from yet another type. +class InferredResultType { +public: + InferredResultType(int index, std::string transformer) + : index(index), transformer(std::move(transformer)) {} + + /// Returns true if result type is inferred from an argument type. + bool isArg() const { return isArgIndex(index); } + /// Return the mapped argument or result index. + int getIndex() const { return index; } + /// If the type is inferred from a result, return the result index. + int getResultIndex() const { return unmapResultIndex(index); } + + // Mapping from result index to combined argument and result index. + // Arguments are indexed to match getArg index, while the result indexes are + // mapped to avoid overlap. + static int mapResultIndex(int i) { return -1 - i; } + static int unmapResultIndex(int i) { return -i - 1; } + static bool isResultIndex(int i) { return i < 0; } + static bool isArgIndex(int i) { return i >= 0; } + + StringRef getTransformer() const { return transformer; } + +private: + /// The index of the source argument or result. + int index; + + /// The transfer to apply to the type to obtain the inferred type. + std::string transformer; +}; + /// Wrapper class that contains a MLIR op's information (e.g., operands, /// attributes) defined in TableGen and provides helper methods for /// accessing them. @@ -259,32 +292,9 @@ /// Return whether all the result types are known. bool allResultTypesKnown() const { return allResultsHaveKnownTypes; }; - /// Pair representing either a index to an argument or a type constraint. Only - /// one of these entries should have the non-default value. - struct ArgOrType { - explicit ArgOrType(int index) : index(index), constraint(std::nullopt) {} - explicit ArgOrType(TypeConstraint constraint) - : index(std::nullopt), constraint(constraint) {} - bool isArg() const { - assert(constraint.has_value() ^ index.has_value()); - return index.has_value(); - } - bool isType() const { - assert(constraint.has_value() ^ index.has_value()); - return constraint.has_value(); - } - - int getArg() const { return *index; } - TypeConstraint getType() const { return *constraint; } - - private: - std::optional index; - std::optional constraint; - }; - - /// Return all arguments or type constraints with same type as result[index]. + /// Return all arguments or type constraints with same type as result[index]. /// Requires: all result types are known. - ArrayRef getSameTypeAsResult(int index) const; + const InferredResultType &getInferredResultType(int index) const; /// Pair consisting kind of argument and index into operands or attributes. struct OperandOrAttribute { @@ -359,7 +369,7 @@ SmallVector regions; /// The argument with the same type as the result. - SmallVector, 4> resultTypeMapping; + SmallVector resultTypeMapping; /// Map from argument to attribute or operand number. SmallVector attrOrOperandMapping; diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -298,7 +298,7 @@ auto maskBuffer = b.create(loc, maskType); b.setInsertionPoint(xferOp); b.create(loc, xferOp.getMask(), maskBuffer); - result.maskBuffer = b.create(loc, maskBuffer); + result.maskBuffer = b.create(loc, maskBuffer, ValueRange()); } return result; diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -101,16 +101,6 @@ // Printing, parsing and builder for LLVM::CmpOp. //===----------------------------------------------------------------------===// -void ICmpOp::build(OpBuilder &builder, OperationState &result, - ICmpPredicate predicate, Value lhs, Value rhs) { - build(builder, result, getI1SameShape(lhs.getType()), predicate, lhs, rhs); -} - -void FCmpOp::build(OpBuilder &builder, OperationState &result, - FCmpPredicate predicate, Value lhs, Value rhs) { - build(builder, result, getI1SameShape(lhs.getType()), predicate, lhs, rhs); -} - void ICmpOp::print(OpAsmPrinter &p) { p << " \"" << stringifyICmpPredicate(getPredicate()) << "\" " << getOperand(0) << ", " << getOperand(1); @@ -1372,20 +1362,6 @@ return success(); } -//===----------------------------------------------------------------------===// -// ExtractElementOp -//===----------------------------------------------------------------------===// - -/// Expects vector to be an LLVM vector type and position to be an integer type. -void LLVM::ExtractElementOp::build(OpBuilder &b, OperationState &result, - Value vector, Value position, - ArrayRef attrs) { - auto vectorType = vector.getType(); - auto llvmType = LLVM::getVectorElementType(vectorType); - build(b, result, llvmType, vector, position); - result.addAttributes(attrs); -} - //===----------------------------------------------------------------------===// // ExtractValueOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -1005,28 +1005,6 @@ return success(); } -static void buildLogicalBinaryOp(OpBuilder &builder, OperationState &state, - Value lhs, Value rhs) { - assert(lhs.getType() == rhs.getType()); - - Type boolType = builder.getI1Type(); - if (auto vecType = lhs.getType().dyn_cast()) - boolType = VectorType::get(vecType.getShape(), boolType); - state.addTypes(boolType); - - state.addOperands({lhs, rhs}); -} - -static void buildLogicalUnaryOp(OpBuilder &builder, OperationState &state, - Value value) { - Type boolType = builder.getI1Type(); - if (auto vecType = value.getType().dyn_cast()) - boolType = VectorType::get(vecType.getShape(), boolType); - state.addTypes(boolType); - - state.addOperands(value); -} - //===----------------------------------------------------------------------===// // spirv.AccessChainOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1015,12 +1015,6 @@ result.addTypes(source.getType().cast().getElementType()); } -void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result, - Value source, Value position) { - result.addOperands({source, position}); - result.addTypes(source.getType().cast().getElementType()); -} - LogicalResult vector::ExtractElementOp::verify() { VectorType vectorType = getVectorType(); if (vectorType.getRank() == 0) { diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp @@ -432,7 +432,8 @@ Value load = b.create( loc, b.create( - loc, MemRefType::get({}, xferOp.getVector().getType()), alloc)); + loc, MemRefType::get({}, xferOp.getVector().getType()), alloc), + ValueRange()); mapping.map(xferOp.getVector(), load); b.clone(*xferOp.getOperation(), mapping); b.create(loc, ValueRange{}); diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -25,6 +25,7 @@ #include "llvm/Support/FormatVariadic.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Record.h" +#include #define DEBUG_TYPE "mlir-tblgen-operator" @@ -344,11 +345,6 @@ auto Operator::getArg(int index) const -> Argument { return arguments[index]; } -// Mapping from result index to combined argument and result index. Arguments -// are indexed to match getArg index, while the result indexes are mapped to -// avoid overlap. -static int resultIndex(int i) { return -1 - i; } - bool Operator::isVariadic() const { return any_of(llvm::concat(operands, results), [](const NamedTypeConstraint &op) { return op.isVariadic(); }); @@ -384,46 +380,47 @@ if (operandI == arguments.end()) return; - // Map each of the result types to the anchor operation. + // All result types are inferred from the operand type. int operandIdx = operandI - arguments.begin(); - resultTypeMapping.resize(getNumResults()); for (int i = 0; i < getNumResults(); ++i) - resultTypeMapping[i].emplace_back(operandIdx); + resultTypeMapping.emplace_back(operandIdx, "$_self"); allResultsHaveKnownTypes = true; traits.push_back(Trait::create(inferTrait->getDefInit())); return; } - // We create equivalence classes of argument/result types where arguments - // and results are mapped into the same index space and indices corresponding - // to the same type are in the same equivalence class. - llvm::EquivalenceClasses ecs; - resultTypeMapping.resize(getNumResults()); - // Captures the argument whose type matches a given result type. Preference - // towards capturing operands first before attributes. - auto captureMapping = [&](int i) { - bool found = false; - ecs.insert(resultIndex(i)); - auto mi = ecs.findLeader(resultIndex(i)); - for (auto me = ecs.member_end(); mi != me; ++mi) { - if (*mi < 0) { - auto tc = getResultTypeConstraint(i); - if (tc.getBuilderCall()) { - resultTypeMapping[i].emplace_back(tc); - found = true; - } - continue; - } + /// This struct represents a node in this operation's result type inferenece + /// graph. Each node has a list of incoming type inference edges `sources`. + /// Each edge represents a "source" from which the result type can be + /// inferred, either an operand (leaf) or another result (node). When a node + /// is known to have a fully-inferred type, `inferred` is set to true. + struct ResultTypeInference { + /// The list of incoming type inference edges. + SmallVector sources; + /// This flag is set to true when the result type is known to be inferrable. + bool inferred = false; + }; - resultTypeMapping[i].emplace_back(*mi); - found = true; + // This vector represents the type inference graph, with one node for each + // operation result. The nth element is the node for the nth result. + SmallVector inference(getNumResults(), {}); + + // For all results whose types are buildable, initialize their type inference + // nodes with an edge to themselves. Mark those nodes are fully-inferred. + for (auto &[idx, infer] : llvm::enumerate(inference)) { + if (getResult(idx).constraint.getBuilderCall()) { + infer.sources.emplace_back(InferredResultType::mapResultIndex(idx), + "$_self"); + infer.inferred = true; } - return found; - }; + } + // Use `AllTypesMatch` and `TypesMatchWith` operation traits to build the + // result type inference graph. for (const Trait &trait : traits) { const llvm::Record &def = trait.getDef(); + // If the infer type op interface was manually added, then treat it as // intention that the op needs special handling. // TODO: Reconsider whether to always generate, this is more conservative @@ -435,24 +432,106 @@ if (&traitDef->getDef() == inferTrait) return; + // The `TypesMatchWith` trait represents a 1 -> 1 type inference edge with a + // type transformer. + if (def.isSubClassOf("TypesMatchWith")) { + int target = argumentsAndResultsIndex.lookup(def.getValueAsString("rhs")); + // Ignore operand type inference. + if (InferredResultType::isArgIndex(target)) + continue; + int resultIndex = InferredResultType::unmapResultIndex(target); + ResultTypeInference &infer = inference[resultIndex]; + // If the type of the result has already been inferred, do nothing. + if (infer.inferred) + continue; + int sourceIndex = + argumentsAndResultsIndex.lookup(def.getValueAsString("lhs")); + infer.sources.emplace_back(sourceIndex, + def.getValueAsString("transformer").str()); + // Locally propagate inferredness. + infer.inferred = + InferredResultType::isArgIndex(sourceIndex) || + inference[InferredResultType::unmapResultIndex(sourceIndex)].inferred; + continue; + } + if (!def.isSubClassOf("AllTypesMatch")) continue; auto values = def.getValueAsListOfStrings("values"); - auto root = argumentsAndResultsIndex.lookup(values.front()); - for (StringRef str : values) - ecs.unionSets(argumentsAndResultsIndex.lookup(str), root); + // The `AllTypesMatch` trait represents an N <-> N fanin and fanout. That + // is, every result type has an edge from every other type. However, if any + // one of the values refers to an operand or a result with a fully-inferred + // type, we can infer all other types from that value. Try to find a + // fully-inferred type in the list. + std::optional fullyInferredIndex; + SmallVector resultIndices; + for (StringRef name : values) { + int index = argumentsAndResultsIndex.lookup(name); + if (InferredResultType::isResultIndex(index)) + resultIndices.push_back(InferredResultType::unmapResultIndex(index)); + if (InferredResultType::isArgIndex(index) || + inference[InferredResultType::unmapResultIndex(index)].inferred) + fullyInferredIndex = index; + } + if (fullyInferredIndex) { + // Make the fully-inferred type the only source for all results that + // aren't already inferred -- a 1 -> N fanout. + for (int resultIndex : resultIndices) { + ResultTypeInference &infer = inference[resultIndex]; + if (!infer.inferred) { + infer.sources.assign(1, {*fullyInferredIndex, "$_self"}); + infer.inferred = true; + } + } + } else { + // Add an edge between every result and every other type; N <-> N. + for (int resultIndex : resultIndices) { + for (int otherResultIndex : resultIndices) { + if (resultIndex == otherResultIndex) + continue; + inference[resultIndex].sources.emplace_back(otherResultIndex, + "$_self"); + } + } + } } - // Verifies that all output types have a corresponding known input type - // and chooses matching operand or attribute (in that order) that - // matches it. - allResultsHaveKnownTypes = - all_of(llvm::seq(0, getNumResults()), captureMapping); + // Propagate inferredness until a fixed point. + std::list worklist; + for (ResultTypeInference &infer : inference) + if (!infer.inferred) + worklist.push_back(&infer); + bool changed; + do { + changed = false; + // This is `llvm::make_early_inc_range` but keeps the iterator for erasing. + for (auto earlyIncIt = worklist.begin(), cur = earlyIncIt; + cur = earlyIncIt++, cur != worklist.end();) { + ResultTypeInference &infer = **cur; + for (auto &[idx, source] : llvm::enumerate(infer.sources)) { + assert(InferredResultType::isResultIndex(source.getIndex())); + if (inference[InferredResultType::unmapResultIndex(source.getIndex())] + .inferred) { + changed = true; + infer.inferred = true; + // Make this the only source for the result. This breaks any cycles. + infer.sources.assign(1, source); + worklist.erase(cur); + break; + } + } + } + } while (changed); + + allResultsHaveKnownTypes = worklist.empty(); // If the types could be computed, then add type inference trait. - if (allResultsHaveKnownTypes) + if (allResultsHaveKnownTypes) { traits.push_back(Trait::create(inferTrait->getDefInit())); + for (const ResultTypeInference &infer : inference) + resultTypeMapping.push_back(infer.sources.front()); + } } void Operator::populateOpStructure() { @@ -562,7 +641,7 @@ resultDef = resultDef->getValueAsDef("constraint"); results.push_back({name, TypeConstraint(resultDef)}); if (!name.empty()) - argumentsAndResultsIndex[name] = resultIndex(i); + argumentsAndResultsIndex[name] = InferredResultType::mapResultIndex(i); // We currently only support VariadicOfVariadic operands. if (results.back().constraint.isVariadicOfVariadic()) { @@ -683,7 +762,7 @@ LLVM_DEBUG(print(llvm::dbgs())); } -auto Operator::getSameTypeAsResult(int index) const -> ArrayRef { +const InferredResultType &Operator::getInferredResultType(int index) const { assert(allResultTypesKnown()); return resultTypeMapping[index]; } diff --git a/mlir/test/mlir-tblgen/op-result.td b/mlir/test/mlir-tblgen/op-result.td --- a/mlir/test/mlir-tblgen/op-result.td +++ b/mlir/test/mlir-tblgen/op-result.td @@ -155,6 +155,22 @@ // CHECK: ::mlir::Type odsInferredType0 = attributes.get("a").cast<::mlir::TypedAttr>().getType(); // CHECK: inferredReturnTypes[0] = odsInferredType0; +def OpL4 : NS_Op<"two_inference_edges", [ + TypesMatchWith<"", "a", "b", "infer0($_self)">, + TypesMatchWith<"", "b", "c", "infer1($_self)">, + TypesMatchWith<"", "input", "a", "fromInput($_self)">]> { + let arguments = (ins I32:$input); + let results = (outs AnyType:$a, AnyType:$b, AnyType:$c); +} + +// CHECK-LABEL: LogicalResult OpL4::inferReturnTypes +// CHECK: odsInferredType0 = fromInput(operands[0].getType()) +// CHECK: odsInferredType1 = infer0(odsInferredType0) +// CHECK: odsInferredType2 = infer1(odsInferredType1) +// CHECK: inferredReturnTypes[0] = odsInferredType0 +// CHECK: inferredReturnTypes[1] = odsInferredType1 +// CHECK: inferredReturnTypes[2] = odsInferredType2 + def OpM : NS_Op<"mix_diff_size_variadic_and_normal_results_op", [AttrSizedResultSegments]> { let results = (outs Variadic:$output1, AnyTensor:$output2, Optional:$output3); } diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -22,6 +22,7 @@ #include "mlir/TableGen/Operator.h" #include "mlir/TableGen/SideEffects.h" #include "mlir/TableGen/Trait.h" +#include "llvm/ADT/BitVector.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/StringExtras.h" @@ -2518,67 +2519,57 @@ FmtContext fctx; fctx.withBuilder("odsBuilder"); + fctx.addSubst("_ctxt", "context"); body << " ::mlir::Builder odsBuilder(context);\n"; - // Preprocess the result types and build all of the types used during - // inferrence. This limits the amount of duplicated work when a type is used - // to infer multiple others. - llvm::DenseMap constraintsTypes; - llvm::DenseMap argumentsTypes; + // Process the type inference graph in topological order, starting from types + // that are always fully-inferred: operands and results with constructible + // types. The type inference graph here will always be a DAG, so this gives + // us the correct order for generating the types. -1 is a placeholder to + // indicate the type for a result has not been generated. + SmallVector constructedIndices(op.getNumResults(), -1); int inferredTypeIdx = 0; - for (int i = 0, e = op.getNumResults(); i != e; ++i) { - auto type = op.getSameTypeAsResult(i).front(); - - // If the type isn't an argument, it refers to a buildable type. - if (!type.isArg()) { - auto it = constraintsTypes.try_emplace(type.getType(), inferredTypeIdx); - if (!it.second) + for (int numResults = op.getNumResults(); inferredTypeIdx != numResults;) { + for (int i = 0, e = op.getNumResults(); i != e; ++i) { + if (constructedIndices[i] >= 0) continue; - - // If we haven't seen this constraint, generate a variable for it. - body << " ::mlir::Type odsInferredType" << inferredTypeIdx++ << " = " - << tgfmt(*type.getType().getBuilderCall(), &fctx) << ";\n"; - continue; - } - - // Otherwise, this is an argument. - int argIndex = type.getArg(); - auto it = argumentsTypes.try_emplace(argIndex, inferredTypeIdx); - if (!it.second) - continue; - body << " ::mlir::Type odsInferredType" << inferredTypeIdx++ << " = "; - - // If this is an operand, just index into operand list to access the type. - auto arg = op.getArgToOperandOrAttribute(argIndex); - if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) { - body << "operands[" << arg.operandOrAttributeIndex() << "].getType()"; - - // If this is an attribute, index into the attribute dictionary. - } else { - auto *attr = - op.getArg(arg.operandOrAttributeIndex()).get(); - body << "attributes.get(\"" << attr->name - << "\").cast<::mlir::TypedAttr>().getType()"; + const InferredResultType &infer = op.getInferredResultType(i); + std::string typeStr; + body << " ::mlir::Type odsInferredType" << inferredTypeIdx++ << " = "; + if (infer.isArg()) { + // If this is an operand, just index into operand list to access the + // type. + auto arg = op.getArgToOperandOrAttribute(infer.getIndex()); + if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) { + typeStr = ("operands[" + Twine(arg.operandOrAttributeIndex()) + + "].getType()") + .str(); + + // If this is an attribute, index into the attribute dictionary. + } else { + auto *attr = + op.getArg(arg.operandOrAttributeIndex()).get(); + typeStr = ("attributes.get(\"" + attr->name + + "\").cast<::mlir::TypedAttr>().getType()") + .str(); + } + } else if (std::optional builder = + op.getResult(infer.getResultIndex()) + .constraint.getBuilderCall()) { + typeStr = tgfmt(*builder, &fctx).str(); + } else if (int index = constructedIndices[infer.getResultIndex()]; + index >= 0) { + typeStr = ("odsInferredType" + Twine(index)).str(); + } else { + continue; + } + body << tgfmt(infer.getTransformer(), &fctx.withSelf(typeStr)) << ";\n"; + constructedIndices[i] = inferredTypeIdx - 1; } - body << ";\n"; } - - // Perform a second pass that handles assigning the inferred types to the - // results. - for (int i = 0, e = op.getNumResults(); i != e; ++i) { - auto types = op.getSameTypeAsResult(i); - - // Append the inferred type. - auto type = types.front(); - body << " inferredReturnTypes[" << i << "] = odsInferredType" - << (type.isArg() ? argumentsTypes[type.getArg()] - : constraintsTypes[type.getType()]) + for (auto [i, index] : llvm::enumerate(constructedIndices)) + body << " inferredReturnTypes[" << i << "] = odsInferredType" << index << ";\n"; - - if (types.size() == 1) - continue; - // TODO: We could verify equality here, but skipping that for verification. - } body << " return ::mlir::success();"; }