diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md --- a/mlir/docs/OpDefinitions.md +++ b/mlir/docs/OpDefinitions.md @@ -221,11 +221,28 @@ Normally operations have no variadic operands or just one variadic operand. For the latter case, it is easy to deduce which dynamic operands are for the static -variadic operand definition. But if an operation has more than one variadic -operands, it would be impossible to attribute dynamic operands to the -corresponding static variadic operand definitions without further information -from the operation. Therefore, the `SameVariadicOperandSize` trait is needed to -indicate that all variadic operands have the same number of dynamic values. +variadic operand definition. Though, if an operation has more than one variable +length operands (either optional or variadic), it would be impossible to +attribute dynamic operands to the corresponding static variadic operand +definitions without further information from the operation. Therefore, either +the `SameVariadicOperandSize` or `AttrSizedOperandSegments` trait is needed to +indicate that all variable length operands have the same number of dynamic +values. + +#### Optional operands + +To declare an optional operand, wrap the `TypeConstraint` for the operand with +`Optional<...>`. + +Normally operations have no optional operands or just one optional operand. For +the latter case, it is easy to deduce which dynamic operands are for the static +operand definition. Though, if an operation has more than one variable length +operands (either optional or variadic), it would be impossible to attribute +dynamic operands to the corresponding static variadic operand definitions +without further information from the operation. Therefore, either the +`SameVariadicOperandSize` or `AttrSizedOperandSegments` trait is needed to +indicate that all variable length operands have the same number of dynamic +values. #### Optional attributes @@ -693,7 +710,7 @@ the group. - Any attribute variable may be used, but only optional attributes can be marked as the anchor. - - Only variadic, i.e. optional, operand arguments can be used. + - Only variadic or optional operand arguments can be used. - The operands to a type directive must be defined within the optional group. diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -297,13 +297,17 @@ } // A variadic type constraint. It expands to zero or more of the base type. This -// class is used for supporting variadic operands/results. An op can declare no -// more than one variadic operand/result, and that operand/result must be the -// last one in the operand/result list. +// class is used for supporting variadic operands/results. class Variadic : TypeConstraint { Type baseType = type; } +// An optional type constraint. It expands to either zero or one of the base +// type. This class is used for supporting optional operands/results. +class Optional : TypeConstraint { + Type baseType = type; +} + // A type that can be constructed using MLIR::Builder. // Note that this does not "inherit" from Type because it would require // duplicating Type subclasses for buildable and non-buildable cases to avoid diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -621,6 +621,9 @@ /// Parse a type. virtual ParseResult parseType(Type &result) = 0; + /// Parse an optional type. + virtual OptionalParseResult parseOptionalType(Type &result) = 0; + /// Parse a type of a specific type. template ParseResult parseType(TypeT &result) { diff --git a/mlir/include/mlir/TableGen/Argument.h b/mlir/include/mlir/TableGen/Argument.h --- a/mlir/include/mlir/TableGen/Argument.h +++ b/mlir/include/mlir/TableGen/Argument.h @@ -43,8 +43,14 @@ struct NamedTypeConstraint { // Returns true if this operand/result has constraint to be satisfied. bool hasPredicate() const; + // Returns true if this is an optional type constraint. This is a special case + // of variadic for 0 or 1 type. + bool isOptional() const; // Returns true if this operand/result is variadic. bool isVariadic() const; + // Returns true if this is a variable length type constraint. This is either + // variadic or optional. + bool isVariableLength() const { return isOptional() || isVariadic(); } llvm::StringRef name; TypeConstraint constraint; diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h --- a/mlir/include/mlir/TableGen/Operator.h +++ b/mlir/include/mlir/TableGen/Operator.h @@ -88,7 +88,7 @@ using value_iterator = NamedTypeConstraint *; using value_range = llvm::iterator_range; - // Returns true if this op has variadic operands or results. + // Returns true if this op has variable length operands or results. bool isVariadic() const; // Returns true if default builders should not be generated. @@ -115,8 +115,8 @@ // Returns the `index`-th result's decorators. var_decorator_range getResultDecorators(int index) const; - // Returns the number of variadic results in this operation. - unsigned getNumVariadicResults() const; + // Returns the number of variable length results in this operation. + unsigned getNumVariableLengthResults() const; // Op attribute iterators. using attribute_iterator = const NamedAttribute *; @@ -142,7 +142,7 @@ } // Returns the number of variadic operands in this operation. - unsigned getNumVariadicOperands() const; + unsigned getNumVariableLengthOperands() const; // Returns the total number of arguments. int getNumArgs() const { return arguments.size(); } diff --git a/mlir/include/mlir/TableGen/Type.h b/mlir/include/mlir/TableGen/Type.h --- a/mlir/include/mlir/TableGen/Type.h +++ b/mlir/include/mlir/TableGen/Type.h @@ -34,9 +34,16 @@ static bool classof(const Constraint *c) { return c->getKind() == CK_Type; } + // Returns true if this is an optional type constraint. + bool isOptional() const; + // Returns true if this is a variadic type constraint. bool isVariadic() const; + // Returns true if this is a variable length type constraint. This is either + // variadic or optional. + bool isVariableLength() const { return isOptional() || isVariadic(); } + // Returns the builder call for this constraint if this is a buildable type, // returns None otherwise. Optional getBuilderCall() const; diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -227,6 +227,9 @@ ParseResult parseTypeListNoParens(SmallVectorImpl &elements); ParseResult parseTypeListParens(SmallVectorImpl &elements); + /// Optionally parse a type. + OptionalParseResult parseOptionalType(Type &type); + /// Parse an arbitrary type. Type parseType(); @@ -899,6 +902,31 @@ // Type Parsing //===----------------------------------------------------------------------===// +/// Optionally parse a type. +OptionalParseResult Parser::parseOptionalType(Type &type) { + // There are many different starting tokens for a type, check them here. + switch (getToken().getKind()) { + case Token::l_paren: + case Token::kw_memref: + case Token::kw_tensor: + case Token::kw_complex: + case Token::kw_tuple: + case Token::kw_vector: + case Token::inttype: + case Token::kw_bf16: + case Token::kw_f16: + case Token::kw_f32: + case Token::kw_f64: + case Token::kw_index: + case Token::kw_none: + case Token::exclamation_identifier: + return failure(!(type = parseType())); + + default: + return llvm::None; + } +} + /// Parse an arbitrary type. /// /// type ::= function-type @@ -4509,6 +4537,11 @@ return failure(!(result = parser.parseType())); } + /// Parse an optional type. + OptionalParseResult parseOptionalType(Type &result) override { + return parser.parseOptionalType(result); + } + /// Parse an arrow followed by a type list. ParseResult parseArrowTypeList(SmallVectorImpl &result) override { if (parseArrow() || parser.parseFunctionResultTypes(result)) diff --git a/mlir/lib/TableGen/Argument.cpp b/mlir/lib/TableGen/Argument.cpp --- a/mlir/lib/TableGen/Argument.cpp +++ b/mlir/lib/TableGen/Argument.cpp @@ -15,6 +15,10 @@ return !constraint.getPredicate().isNull(); } +bool tblgen::NamedTypeConstraint::isOptional() const { + return constraint.isOptional(); +} + bool tblgen::NamedTypeConstraint::isVariadic() const { return constraint.isVariadic(); } diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -81,10 +81,6 @@ const llvm::Record &tblgen::Operator::getDef() const { return def; } -bool tblgen::Operator::isVariadic() const { - return getNumVariadicOperands() != 0 || getNumVariadicResults() != 0; -} - bool tblgen::Operator::skipDefaultBuilders() const { return def.getValueAsBit("skipDefaultBuilders"); } @@ -119,16 +115,16 @@ return *result->getValueAsListInit("decorators"); } -unsigned tblgen::Operator::getNumVariadicResults() const { - return std::count_if( - results.begin(), results.end(), - [](const NamedTypeConstraint &c) { return c.constraint.isVariadic(); }); +unsigned tblgen::Operator::getNumVariableLengthResults() const { + return llvm::count_if(results, [](const NamedTypeConstraint &c) { + return c.constraint.isVariableLength(); + }); } -unsigned tblgen::Operator::getNumVariadicOperands() const { - return std::count_if( - operands.begin(), operands.end(), - [](const NamedTypeConstraint &c) { return c.constraint.isVariadic(); }); +unsigned tblgen::Operator::getNumVariableLengthOperands() const { + return llvm::count_if(operands, [](const NamedTypeConstraint &c) { + return c.constraint.isVariableLength(); + }); } tblgen::Operator::arg_iterator tblgen::Operator::arg_begin() const { diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp --- a/mlir/lib/TableGen/Pattern.cpp +++ b/mlir/lib/TableGen/Pattern.cpp @@ -255,7 +255,7 @@ auto *operand = op->getArg(*argIndex).get(); // If this operand is variadic, then return a range. Otherwise, return the // value itself. - if (operand->isVariadic()) { + if (operand->isVariableLength()) { auto repl = formatv(fmt, name); LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicOperand)\n"); return std::string(repl); diff --git a/mlir/lib/TableGen/Type.cpp b/mlir/lib/TableGen/Type.cpp --- a/mlir/lib/TableGen/Type.cpp +++ b/mlir/lib/TableGen/Type.cpp @@ -26,6 +26,10 @@ TypeConstraint::TypeConstraint(const llvm::DefInit *init) : TypeConstraint(init->getDef()) {} +bool TypeConstraint::isOptional() const { + return def->isSubClassOf("Optional"); +} + bool TypeConstraint::isVariadic() const { return def->isSubClassOf("Variadic"); } @@ -34,7 +38,7 @@ // returns None otherwise. Optional TypeConstraint::getBuilderCall() const { const llvm::Record *baseType = def; - if (isVariadic()) + if (isVariableLength()) baseType = baseType->getValueAsDef("baseType"); // Check to see if this type constraint has a builder call. diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1179,39 +1179,41 @@ } // Test various mixings of result type formatting. -class FormatResultBase : TEST_Op { +class FormatResultBase + : TEST_Op<"format_result_" # suffix # "_op"> { let results = (outs I64:$buildable_res, AnyMemRef:$result); let assemblyFormat = fmt; } -def FormatResultAOp : FormatResultBase<"format_result_a_op", [{ +def FormatResultAOp : FormatResultBase<"a", [{ type($result) attr-dict }]>; -def FormatResultBOp : FormatResultBase<"format_result_b_op", [{ +def FormatResultBOp : FormatResultBase<"b", [{ type(results) attr-dict }]>; -def FormatResultCOp : FormatResultBase<"format_result_c_op", [{ +def FormatResultCOp : FormatResultBase<"c", [{ functional-type($buildable_res, $result) attr-dict }]>; // Test various mixings of operand type formatting. -class FormatOperandBase : TEST_Op { +class FormatOperandBase + : TEST_Op<"format_operand_" # suffix # "_op"> { let arguments = (ins I64:$buildable, AnyMemRef:$operand); let assemblyFormat = fmt; } -def FormatOperandAOp : FormatOperandBase<"format_operand_a_op", [{ +def FormatOperandAOp : FormatOperandBase<"a", [{ operands `:` type(operands) attr-dict }]>; -def FormatOperandBOp : FormatOperandBase<"format_operand_b_op", [{ +def FormatOperandBOp : FormatOperandBase<"b", [{ operands `:` type($operand) attr-dict }]>; -def FormatOperandCOp : FormatOperandBase<"format_operand_c_op", [{ +def FormatOperandCOp : FormatOperandBase<"c", [{ $buildable `,` $operand `:` type(operands) attr-dict }]>; -def FormatOperandDOp : FormatOperandBase<"format_operand_d_op", [{ +def FormatOperandDOp : FormatOperandBase<"d", [{ $buildable `,` $operand `:` type($operand) attr-dict }]>; -def FormatOperandEOp : FormatOperandBase<"format_operand_e_op", [{ +def FormatOperandEOp : FormatOperandBase<"e", [{ $buildable `,` $operand `:` type($buildable) `,` type($operand) attr-dict }]>; @@ -1220,6 +1222,25 @@ let assemblyFormat = "$targets attr-dict"; } +// Test various mixings of optional operand and result type formatting. +class FormatOptionalOperandResultOpBase + : TEST_Op<"format_optional_operand_result_" # suffix # "_op", + [AttrSizedOperandSegments]> { + let arguments = (ins Optional:$optional, Variadic:$variadic); + let results = (outs Optional:$optional_res); + let assemblyFormat = fmt; +} + +def FormatOptionalOperandResultAOp : FormatOptionalOperandResultOpBase<"a", [{ + `(` $optional `:` type($optional) `)` `:` type($optional_res) + (`[` $variadic^ `]`)? attr-dict +}]>; + +def FormatOptionalOperandResultBOp : FormatOptionalOperandResultOpBase<"b", [{ + (`(` $optional^ `:` type($optional) `)`)? `:` type($optional_res) + (`[` $variadic^ `]`)? attr-dict +}]>; + //===----------------------------------------------------------------------===// // Test SideEffects //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-tblgen/op-decl.td b/mlir/test/mlir-tblgen/op-decl.td --- a/mlir/test/mlir-tblgen/op-decl.td +++ b/mlir/test/mlir-tblgen/op-decl.td @@ -112,6 +112,16 @@ // CHECK-LABEL: NS::DOp declarations // CHECK: OpTrait::NOperands<2>::Impl +def NS_EOp : NS_Op<"op_with_optionals", []> { + let arguments = (ins Optional:$a); + let results = (outs Optional:$b); +} + +// CHECK-LABEL: NS::EOp declarations +// CHECK: Value a(); +// CHECK: Value b(); +// CHECK: static void build(Builder *odsBuilder, OperationState &odsState, /*optional*/Type b, /*optional*/Value a) + // Check that default builders can be suppressed. // --- diff --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td --- a/mlir/test/mlir-tblgen/op-format-spec.td +++ b/mlir/test/mlir-tblgen/op-format-spec.td @@ -222,7 +222,7 @@ def OptionalInvalidG : TestFormat_Op<"optional_invalid_g", [{ ($attr^) attr-dict }]>, Arguments<(ins I64Attr:$attr)>; -// CHECK: error: only variadic operands can be used within an optional group +// CHECK: error: only variable length operands can be used within an optional group def OptionalInvalidH : TestFormat_Op<"optional_invalid_h", [{ ($arg^) attr-dict }]>, Arguments<(ins I64:$arg)>; @@ -327,6 +327,17 @@ }]> { let successors = (successor AnySuccessor:$successor); } +// CHECK: error: type of operand #0, named 'operand', is not buildable and a buildable type cannot be inferred +// CHECK: note: suggest adding a type constraint to the operation or adding a 'type($operand)' directive to the custom assembly format +def ZCoverageInvalidG : TestFormat_Op<"variable_invalid_g", [{ + operands attr-dict +}]>, Arguments<(ins Optional:$operand)>; +// CHECK: error: type of result #0, named 'result', is not buildable and a buildable type cannot be inferred +// CHECK: note: suggest adding a type constraint to the operation or adding a 'type($result)' directive to the custom assembly format +def ZCoverageInvalidH : TestFormat_Op<"variable_invalid_h", [{ + attr-dict +}]>, Results<(outs Optional:$result)>; + // CHECK-NOT: error def ZCoverageValidA : TestFormat_Op<"variable_valid_a", [{ $operand type($operand) type($result) attr-dict diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir --- a/mlir/test/mlir-tblgen/op-format.mlir +++ b/mlir/test/mlir-tblgen/op-format.mlir @@ -18,6 +18,10 @@ // CHECK: test.format_buildable_type_op %[[I64]] %ignored = test.format_buildable_type_op %i64 +//===----------------------------------------------------------------------===// +// Format results +//===----------------------------------------------------------------------===// + // CHECK: test.format_result_a_op memref<1xf64> %ignored_a:2 = test.format_result_a_op memref<1xf64> @@ -27,6 +31,10 @@ // CHECK: test.format_result_c_op (i64) -> memref<1xf64> %ignored_c:2 = test.format_result_c_op (i64) -> memref<1xf64> +//===----------------------------------------------------------------------===// +// Format operands +//===----------------------------------------------------------------------===// + // CHECK: test.format_operand_a_op %[[I64]], %[[MEMREF]] : i64, memref<1xf64> test.format_operand_a_op %i64, %memref : i64, memref<1xf64> @@ -42,6 +50,10 @@ // CHECK: test.format_operand_e_op %[[I64]], %[[MEMREF]] : i64, memref<1xf64> test.format_operand_e_op %i64, %memref : i64, memref<1xf64> +//===----------------------------------------------------------------------===// +// Format successors +//===----------------------------------------------------------------------===// + "foo.successor_test_region"() ( { ^bb0: // CHECK: test.format_successor_a_op ^bb1 {attr} @@ -57,3 +69,28 @@ }) { arg_names = ["i", "j", "k"] } : () -> () +//===----------------------------------------------------------------------===// +// Format optional operands and results +//===----------------------------------------------------------------------===// + +// CHECK: test.format_optional_operand_result_a_op(%[[I64]] : i64) : i64 +test.format_optional_operand_result_a_op(%i64 : i64) : i64 + +// CHECK: test.format_optional_operand_result_a_op( : ) : i64 +test.format_optional_operand_result_a_op( : ) : i64 + +// CHECK: test.format_optional_operand_result_a_op(%[[I64]] : i64) : +// CHECK-NOT: i64 +test.format_optional_operand_result_a_op(%i64 : i64) : + +// CHECK: test.format_optional_operand_result_a_op(%[[I64]] : i64) : [%[[I64]], %[[I64]]] +test.format_optional_operand_result_a_op(%i64 : i64) : [%i64, %i64] + +// CHECK: test.format_optional_operand_result_b_op(%[[I64]] : i64) : i64 +test.format_optional_operand_result_b_op(%i64 : i64) : i64 + +// CHECK: test.format_optional_operand_result_b_op : i64 +test.format_optional_operand_result_b_op( : ) : i64 + +// CHECK: test.format_optional_operand_result_b_op : i64 +test.format_optional_operand_result_b_op : i64 diff --git a/mlir/test/mlir-tblgen/predicate.td b/mlir/test/mlir-tblgen/predicate.td --- a/mlir/test/mlir-tblgen/predicate.td +++ b/mlir/test/mlir-tblgen/predicate.td @@ -16,7 +16,8 @@ } // CHECK-LABEL: OpA::verify -// CHECK: for (Value v : getODSOperands(0)) { +// CHECK: auto valueGroup0 = getODSOperands(0); +// CHECK: for (Value v : valueGroup0) { // CHECK: if (!((v.getType().isInteger(32) || v.getType().isF32()))) def OpB : NS_Op<"op_for_And_PredOpTrait", [ @@ -90,5 +91,6 @@ } // CHECK-LABEL: OpK::verify -// CHECK: for (Value v : getODSOperands(0)) { +// CHECK: auto valueGroup0 = getODSOperands(0); +// CHECK: for (Value v : valueGroup0) { // CHECK: if (!(((v.getType().isa())) && (((v.getType().cast().getElementType().isF32())) || ((v.getType().cast().getElementType().isSignlessInteger(32)))))) diff --git a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp --- a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp +++ b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp @@ -75,7 +75,7 @@ if (numOperands == 0) return false; const auto &operand = op.getOperand(numOperands - 1); - return operand.isVariadic() && operand.name == name; + return operand.isVariableLength() && operand.name == name; } // Check if `result` is a known name of a result of `op`. diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -452,7 +452,7 @@ StringRef rangeSizeCall, StringRef getOperandCallPattern) { const int numOperands = op.getNumOperands(); - const int numVariadicOperands = op.getNumVariadicOperands(); + const int numVariadicOperands = op.getNumVariableLengthOperands(); const int numNormalOperands = numOperands - numVariadicOperands; const auto *sameVariadicSize = @@ -493,9 +493,9 @@ // calculation at run-time. llvm::SmallVector isVariadic; isVariadic.reserve(numOperands); - for (int i = 0; i < numOperands; ++i) { - isVariadic.push_back(llvm::toStringRef(op.getOperand(i).isVariadic())); - } + for (int i = 0; i < numOperands; ++i) + isVariadic.push_back(op.getOperand(i).isVariableLength() ? "true" + : "false"); std::string isVariadicList = llvm::join(isVariadic, ", "); m.body() << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList, @@ -511,11 +511,15 @@ if (operand.name.empty()) continue; - if (operand.isVariadic()) { + if (operand.isOptional()) { + auto &m = opClass.newMethod("Value", operand.name); + m.body() << " auto operands = getODSOperands(" << i << ");\n" + << " return operands.empty() ? Value() : *operands.begin();"; + } else if (operand.isVariadic()) { auto &m = opClass.newMethod(rangeType, operand.name); m.body() << " return getODSOperands(" << i << ");"; } else { - auto &m = opClass.newMethod("Value ", operand.name); + auto &m = opClass.newMethod("Value", operand.name); m.body() << " return *getODSOperands(" << i << ").begin();"; } } @@ -534,7 +538,7 @@ void OpEmitter::genNamedResultGetters() { const int numResults = op.getNumResults(); - const int numVariadicResults = op.getNumVariadicResults(); + const int numVariadicResults = op.getNumVariableLengthResults(); const int numNormalResults = numResults - numVariadicResults; // If we have more than one variadic results, we need more complicated logic @@ -573,9 +577,9 @@ } else { llvm::SmallVector isVariadic; isVariadic.reserve(numResults); - for (int i = 0; i < numResults; ++i) { - isVariadic.push_back(llvm::toStringRef(op.getResult(i).isVariadic())); - } + for (int i = 0; i < numResults; ++i) + isVariadic.push_back(op.getResult(i).isVariableLength() ? "true" + : "false"); std::string isVariadicList = llvm::join(isVariadic, ", "); m.body() << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList, @@ -589,11 +593,15 @@ if (result.name.empty()) continue; - if (result.isVariadic()) { + if (result.isOptional()) { + auto &m = opClass.newMethod("Value", result.name); + m.body() << " auto results = getODSResults(" << i << ");\n" + << " return results.empty() ? Value() : *results.begin();"; + } else if (result.isVariadic()) { auto &m = opClass.newMethod("Operation::result_range", result.name); m.body() << " return getODSResults(" << i << ");"; } else { - auto &m = opClass.newMethod("Value ", result.name); + auto &m = opClass.newMethod("Value", result.name); m.body() << " return *getODSResults(" << i << ").begin();"; } } @@ -706,6 +714,8 @@ return; case TypeParamKind::Separate: for (int i = 0, e = op.getNumResults(); i < e; ++i) { + if (op.getResult(i).isOptional()) + body << " if (" << resultNames[i] << ")\n "; body << " " << builderOpState << ".addTypes(" << resultNames[i] << ");\n"; } @@ -713,12 +723,12 @@ case TypeParamKind::Collective: body << " " << "assert(resultTypes.size() " - << (op.getNumVariadicResults() == 0 ? "==" : ">=") << " " - << (op.getNumResults() - op.getNumVariadicResults()) + << (op.getNumVariableLengthResults() == 0 ? "==" : ">=") << " " + << (op.getNumResults() - op.getNumVariableLengthResults()) << "u && \"mismatched number of results\");\n"; body << " " << builderOpState << ".addTypes(resultTypes);\n"; return; - }; + } llvm_unreachable("unhandled TypeParamKind"); }; @@ -731,7 +741,7 @@ // Emit separate arg build with collective type, unless there is only one // variadic result, in which case the above would have already generated // the same build method. - if (!(op.getNumResults() == 1 && op.getResult(0).isVariadic())) + if (!(op.getNumResults() == 1 && op.getResult(0).isVariableLength())) emit(attrType, TypeParamKind::Collective, /*inferType=*/false); } } @@ -739,7 +749,7 @@ void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() { // If this op has a variadic result, we cannot generate this builder because // we don't know how many results to create. - if (op.getNumVariadicResults() != 0) + if (op.getNumVariableLengthResults() != 0) return; int numResults = op.getNumResults(); @@ -887,7 +897,7 @@ // 3. one having a stand-alone parameter for each operand and attribute, // use the first operand or attribute's type as all result types // to facilitate different call patterns. - if (op.getNumVariadicResults() == 0) { + if (op.getNumVariableLengthResults() == 0) { if (op.getTrait("OpTrait::SameOperandsAndResultType")) { genUseOperandAsResultTypeSeparateParamBuilder(); genUseOperandAsResultTypeCollectiveParamBuilder(); @@ -899,11 +909,11 @@ void OpEmitter::genCollectiveParamBuilder() { int numResults = op.getNumResults(); - int numVariadicResults = op.getNumVariadicResults(); + int numVariadicResults = op.getNumVariableLengthResults(); int numNonVariadicResults = numResults - numVariadicResults; int numOperands = op.getNumOperands(); - int numVariadicOperands = op.getNumVariadicOperands(); + int numVariadicOperands = op.getNumVariableLengthOperands(); int numNonVariadicOperands = numOperands - numVariadicOperands; // Signature std::string params = std::string("Builder *, OperationState &") + @@ -972,7 +982,12 @@ if (resultName.empty()) resultName = std::string(formatv("resultType{0}", i)); - paramList.append(result.isVariadic() ? ", ArrayRef " : ", Type "); + if (result.isOptional()) + paramList.append(", /*optional*/Type "); + else if (result.isVariadic()) + paramList.append(", ArrayRef "); + else + paramList.append(", Type "); paramList.append(resultName); resultTypeNames.emplace_back(std::move(resultName)); @@ -1018,7 +1033,12 @@ auto argument = op.getArg(i); if (argument.is()) { const auto &operand = op.getOperand(numOperands); - paramList.append(operand.isVariadic() ? ", ValueRange " : ", Value "); + if (operand.isOptional()) + paramList.append(", /*optional*/Value "); + else if (operand.isVariadic()) + paramList.append(", ValueRange "); + else + paramList.append(", Value "); paramList.append(getArgumentName(op, numOperands)); ++numOperands; } else { @@ -1076,8 +1096,10 @@ bool isRawValueAttr) { // Push all operands to the result. for (int i = 0, e = op.getNumOperands(); i < e; ++i) { - body << " " << builderOpState << ".addOperands(" << getArgumentName(op, i) - << ");\n"; + std::string argName = getArgumentName(op, i); + if (op.getOperand(i).isOptional()) + body << " if (" << argName << ")\n "; + body << " " << builderOpState << ".addOperands(" << argName << ");\n"; } // If the operation has the operand segment size attribute, add it here. @@ -1086,7 +1108,9 @@ << ".addAttribute(\"operand_segment_sizes\", " "odsBuilder->getI32VectorAttr({"; interleaveComma(llvm::seq(0, op.getNumOperands()), body, [&](int i) { - if (op.getOperand(i).isVariadic()) + if (op.getOperand(i).isOptional()) + body << "(" << getArgumentName(op, i) << " ? 1 : 0)"; + else if (op.getOperand(i).isVariadic()) body << "static_cast(" << getArgumentName(op, i) << ".size())"; else body << "1"; @@ -1160,7 +1184,7 @@ void OpEmitter::genFolderDecls() { bool hasSingleResult = - op.getNumResults() == 1 && op.getNumVariadicResults() == 0; + op.getNumResults() == 1 && op.getNumVariableLengthResults() == 0; if (def.getValueAsBit("hasFolder")) { if (hasSingleResult) { @@ -1434,17 +1458,33 @@ body << " unsigned index = 0; (void)index;\n"; for (auto staticValue : llvm::enumerate(values)) { - if (!staticValue.value().hasPredicate()) + bool hasPredicate = staticValue.value().hasPredicate(); + bool isOptional = staticValue.value().isOptional(); + if (!hasPredicate && !isOptional) continue; - - // Emit a loop to check all the dynamic values in the pack. - body << formatv(" for (Value v : getODS{0}{1}s({2})) {{\n", + body << formatv(" auto valueGroup{2} = getODS{0}{1}s({2});\n", // Capitalize the first letter to match the function name valueKind.substr(0, 1).upper(), valueKind.substr(1), staticValue.index()); - auto constraint = staticValue.value().constraint; + // If the constraint is optional check that the value group has at most 1 + // value. + if (isOptional) { + body << formatv(" if (valueGroup{0}.size() > 1)\n" + " return emitOpError(\"{1} group starting at #\") " + "<< index << \" requires 0 or 1 element, but found \" << " + "valueGroup{0}.size();\n", + staticValue.index(), valueKind); + } + + // Otherwise, if there is no predicate there is nothing left to do. + if (!hasPredicate) + continue; + // Emit a loop to check all the dynamic values in the pack. + body << " for (Value v : valueGroup" << staticValue.index() << ") {\n"; + + auto constraint = staticValue.value().constraint; body << " (void)v;\n" << " if (!(" << tgfmt(constraint.getConditionTemplate(), @@ -1569,7 +1609,7 @@ // Add result size trait. int numResults = op.getNumResults(); - int numVariadicResults = op.getNumVariadicResults(); + int numVariadicResults = op.getNumVariableLengthResults(); addSizeCountTrait(opClass, "Result", numResults, numVariadicResults); // Add successor size trait. @@ -1579,7 +1619,7 @@ // Add variadic size trait and normal op traits. int numOperands = op.getNumOperands(); - int numVariadicOperands = op.getNumVariadicOperands(); + int numVariadicOperands = op.getNumVariableLengthOperands(); // Add operand size trait. if (numVariadicOperands != 0) { diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -395,6 +395,17 @@ if (parser.parseOperandList({0}Operands)) return failure(); )"; +const char *const optionalOperandParserCode = R"( + { + OpAsmParser::OperandType operand; + OptionalParseResult parseResult = parser.parseOptionalOperand(operand); + if (parseResult.hasValue()) { + if (failed(*parseResult)) + return failure(); + {0}Operands.push_back(operand); + } + } +)"; const char *const operandParserCode = R"( if (parser.parseOperand({0}RawOperands[0])) return failure(); @@ -407,6 +418,17 @@ if (parser.parseTypeList({0}Types)) return failure(); )"; +const char *const optionalTypeParserCode = R"( + { + Type optionalType; + OptionalParseResult parseResult = parser.parseOptionalType(optionalType); + if (parseResult.hasValue()) { + if (failed(*parseResult)) + return failure(); + {0}Types.push_back(optionalType); + } + } +)"; const char *const typeParserCode = R"( if (parser.parseType({0}RawTypes[0])) return failure(); @@ -456,18 +478,40 @@ return failure(); )"; +namespace { +/// The type of length for a given parse argument. +enum class ArgumentLengthKind { + /// The argument is variadic, and may contain 0->N elements. + Variadic, + /// The argument is optional, and may contain 0 or 1 elements. + Optional, + /// The argument is a single element, i.e. always represents 1 element. + Single +}; +} // end anonymous namespace + +/// Get the length kind for the given constraint. +static ArgumentLengthKind +getArgumentLengthKind(const NamedTypeConstraint *var) { + if (var->isOptional()) + return ArgumentLengthKind::Optional; + if (var->isVariadic()) + return ArgumentLengthKind::Variadic; + return ArgumentLengthKind::Single; +} + /// Get the name used for the type list for the given type directive operand. -/// 'isVariadic' is set to true if the operand has variadic types. -static StringRef getTypeListName(Element *arg, bool &isVariadic) { +/// 'lengthKind' to the corresponding kind for the given argument. +static StringRef getTypeListName(Element *arg, ArgumentLengthKind &lengthKind) { if (auto *operand = dyn_cast(arg)) { - isVariadic = operand->getVar()->isVariadic(); + lengthKind = getArgumentLengthKind(operand->getVar()); return operand->getVar()->name; } if (auto *result = dyn_cast(arg)) { - isVariadic = result->getVar()->isVariadic(); + lengthKind = getArgumentLengthKind(result->getVar()); return result->getVar()->name; } - isVariadic = true; + lengthKind = ArgumentLengthKind::Variadic; if (isa(arg)) return "allOperand"; if (isa(arg)) @@ -502,7 +546,7 @@ genElementParserStorage(&childElement, body); } else if (auto *operand = dyn_cast(element)) { StringRef name = operand->getVar()->name; - if (operand->getVar()->isVariadic()) { + if (operand->getVar()->isVariableLength()) { body << " SmallVector " << name << "Operands;\n"; } else { @@ -515,15 +559,15 @@ " (void){0}OperandsLoc;\n", name); } else if (auto *dir = dyn_cast(element)) { - bool variadic = false; - StringRef name = getTypeListName(dir->getOperand(), variadic); - if (variadic) + ArgumentLengthKind lengthKind; + StringRef name = getTypeListName(dir->getOperand(), lengthKind); + if (lengthKind != ArgumentLengthKind::Single) body << " SmallVector " << name << "Types;\n"; else body << llvm::formatv(" Type {0}RawTypes[1];\n", name) << llvm::formatv(" ArrayRef {0}Types({0}RawTypes);\n", name); } else if (auto *dir = dyn_cast(element)) { - bool ignored = false; + ArgumentLengthKind ignored; body << " ArrayRef " << getTypeListName(dir->getInputs(), ignored) << "Types;\n"; body << " ArrayRef " << getTypeListName(dir->getResults(), ignored) @@ -592,9 +636,14 @@ body << formatv(attrParserCode, var->attr.getStorageType(), var->name, attrTypeStr); } else if (auto *operand = dyn_cast(element)) { - bool isVariadic = operand->getVar()->isVariadic(); - body << formatv(isVariadic ? variadicOperandParserCode : operandParserCode, - operand->getVar()->name); + ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar()); + StringRef name = operand->getVar()->name; + if (lengthKind == ArgumentLengthKind::Variadic) + body << llvm::formatv(variadicOperandParserCode, name); + else if (lengthKind == ArgumentLengthKind::Optional) + body << llvm::formatv(optionalOperandParserCode, name); + else + body << formatv(operandParserCode, name); } else if (auto *successor = dyn_cast(element)) { bool isVariadic = successor->getVar()->isVariadic(); body << formatv(isVariadic ? successorListParserCode : successorParserCode, @@ -614,12 +663,16 @@ } else if (isa(element)) { body << llvm::formatv(successorListParserCode, "full"); } else if (auto *dir = dyn_cast(element)) { - bool isVariadic = false; - StringRef listName = getTypeListName(dir->getOperand(), isVariadic); - body << formatv(isVariadic ? variadicTypeParserCode : typeParserCode, - listName); + ArgumentLengthKind lengthKind; + StringRef listName = getTypeListName(dir->getOperand(), lengthKind); + if (lengthKind == ArgumentLengthKind::Variadic) + body << llvm::formatv(variadicTypeParserCode, listName); + else if (lengthKind == ArgumentLengthKind::Optional) + body << llvm::formatv(optionalTypeParserCode, listName); + else + body << formatv(typeParserCode, listName); } else if (auto *dir = dyn_cast(element)) { - bool ignored = false; + ArgumentLengthKind ignored; body << formatv(functionalTypeParserCode, getTypeListName(dir->getInputs(), ignored), getTypeListName(dir->getResults(), ignored)); @@ -817,7 +870,7 @@ << "builder.getI32VectorAttr({"; auto interleaveFn = [&](const NamedTypeConstraint &operand) { // If the operand is variadic emit the parsed size. - if (operand.isVariadic()) + if (operand.isVariableLength()) body << "static_cast(" << operand.name << "Operands.size())"; else body << "1"; @@ -885,6 +938,10 @@ auto *var = operand ? operand->getVar() : cast(arg)->getVar(); if (var->isVariadic()) return body << var->name << "().getTypes()"; + if (var->isOptional()) + return body << llvm::formatv( + "({0}() ? ArrayRef({0}().getType()) : ArrayRef())", + var->name); return body << "ArrayRef(" << var->name << "().getType())"; } @@ -900,11 +957,16 @@ if (OptionalElement *optional = dyn_cast(element)) { // Emit the check for the presence of the anchor element. Element *anchor = optional->getAnchor(); - if (AttributeVariable *attrVar = dyn_cast(anchor)) - body << " if (getAttr(\"" << attrVar->getVar()->name << "\")) {\n"; - else - body << " if (!" << cast(anchor)->getVar()->name - << "().empty()) {\n"; + if (auto *operand = dyn_cast(anchor)) { + const NamedTypeConstraint *var = operand->getVar(); + if (var->isOptional()) + body << " if (" << var->name << "()) {\n"; + else if (var->isVariadic()) + body << " if (!" << var->name << "().empty()) {\n"; + } else { + body << " if (getAttr(\"" + << cast(anchor)->getVar()->name << "\")) {\n"; + } // Emit each of the elements. for (Element &childElement : optional->getElements()) @@ -945,7 +1007,12 @@ else body << " p.printAttribute(" << var->name << "Attr());\n"; } else if (auto *operand = dyn_cast(element)) { - body << " p << " << operand->getVar()->name << "();\n"; + if (operand->getVar()->isOptional()) { + body << " if (Value value = " << operand->getVar()->name << "())\n" + << " p << value;\n"; + } else { + body << " p << " << operand->getVar()->name << "();\n"; + } } else if (auto *successor = dyn_cast(element)) { const NamedSuccessor *var = successor->getVar(); if (var->isVariadic()) @@ -1521,14 +1588,12 @@ // Similarly to results, allow a custom builder for resolving the type if // we aren't using the 'operands' directive. Optional builder = operand.constraint.getBuilderCall(); - if (!builder || (hasAllOperands && operand.isVariadic())) { + if (!builder || (hasAllOperands && operand.isVariableLength())) { return emitErrorAndNote( loc, "type of operand #" + Twine(i) + ", named '" + operand.name + - "', is not buildable and a buildable " + - "type cannot be inferred", - "suggest adding a type constraint " - "to the operation or adding a " + "', is not buildable and a buildable type cannot be inferred", + "suggest adding a type constraint to the operation or adding a " "'type($" + operand.name + ")' directive to the " + "custom assembly format"); } @@ -1559,18 +1624,16 @@ continue; } - // If the result is not variadic, allow for the case where the type has a - // builder that we can use. + // If the result is not variable length, allow for the case where the type + // has a builder that we can use. NamedTypeConstraint &result = op.getResult(i); Optional builder = result.constraint.getBuilderCall(); - if (!builder || result.constraint.isVariadic()) { + if (!builder || result.isVariableLength()) { return emitErrorAndNote( loc, "type of result #" + Twine(i) + ", named '" + result.name + - "', is not buildable and a buildable " + - "type cannot be inferred", - "suggest adding a type constraint " - "to the operation or adding a " + "', is not buildable and a buildable type cannot be inferred", + "suggest adding a type constraint to the operation or adding a " "'type($" + result.name + ")' directive to the " + "custom assembly format"); } @@ -1842,9 +1905,9 @@ // Only optional-like(i.e. variadic) operands can be within an optional // group. .Case([&](OperandVariable *ele) { - if (!ele->getVar()->isVariadic()) - return emitError(childLoc, "only variadic operands can be used within" - " an optional group"); + if (!ele->getVar()->isVariableLength()) + return emitError(childLoc, "only variable length operands can be " + "used within an optional group"); seenVariables.insert(ele->getVar()); return success(); }) diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -243,7 +243,7 @@ // Handle nested DAG construct first if (DagNode argTree = tree.getArgAsNestedDag(i)) { if (auto *operand = opArg.dyn_cast()) { - if (operand->isVariadic()) { + if (operand->isVariableLength()) { auto error = formatv("use nested DAG construct to match op {0}'s " "variadic operand #{1} unsupported now", op.getOperationName(), i); @@ -296,7 +296,7 @@ // of op definition. Constraint constraint = matcher.getAsConstraint(); if (operand->constraint != constraint) { - if (operand->isVariadic()) { + if (operand->isVariableLength()) { auto error = formatv( "further constrain op {0}'s variadic operand #{1} unsupported now", op.getOperationName(), argIndex); diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp --- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -807,11 +807,11 @@ for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) { auto argument = op.getArg(i); if (auto valueArg = argument.dyn_cast()) { - if (valueArg->isVariadic()) { + if (valueArg->isVariableLength()) { if (i != e - 1) { - PrintFatalError(loc, - "SPIR-V ops can have Variadic<..> argument only if " - "it's the last argument"); + PrintFatalError(loc, "SPIR-V ops can have Variadic<..> or " + "Optional<...> arguments only if " + "it's the last argument"); } os << tabs << formatv("for (; {0} < {1}.size(); ++{0})", wordIndex, words); @@ -829,7 +829,7 @@ words, wordIndex); os << tabs << " }\n"; os << tabs << formatv(" {0}.push_back(arg);\n", operands); - if (!valueArg->isVariadic()) { + if (!valueArg->isVariableLength()) { os << tabs << formatv(" {0}++;\n", wordIndex); } operandNum++;