diff --git a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td --- a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td +++ b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td @@ -31,6 +31,7 @@ vector operations, including a scalable vector type and intrinsics for some Arm SVE instructions. }]; + let useDefaultTypePrinterParser = 1; } //===----------------------------------------------------------------------===// @@ -66,20 +67,6 @@ "Type":$elementType ); - let printer = [{ - $_printer << "<"; - for (int64_t dim : getShape()) - $_printer << dim << 'x'; - $_printer << getElementType() << '>'; - }]; - - let parser = [{ - VectorType vector; - if ($_parser.parseType(vector)) - return Type(); - return get($_ctxt, vector.getShape(), vector.getElementType()); - }]; - let extraClassDeclaration = [{ bool hasStaticShape() const { return llvm::none_of(getShape(), ShapedType::isDynamic); diff --git a/mlir/include/mlir/IR/DialectImplementation.h b/mlir/include/mlir/IR/DialectImplementation.h --- a/mlir/include/mlir/IR/DialectImplementation.h +++ b/mlir/include/mlir/IR/DialectImplementation.h @@ -64,19 +64,19 @@ AttributeT>> { static FailureOr parse(AsmParser &parser) { AttributeT value; - if (parser.parseAttribute(value)) + if (parser.parseCustomAttributeWithFallback(value)) return failure(); return value; } }; -/// Parse a type. +/// Parse an attribute. template struct FieldParser< TypeT, std::enable_if_t::value, TypeT>> { static FailureOr parse(AsmParser &parser) { TypeT value; - if (parser.parseType(value)) + if (parser.parseCustomTypeWithFallback(value)) return failure(); return value; } diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2984,6 +2984,9 @@ string baseCppClass = "::mlir::Type"> : DialectType, /*descr*/"", name # "Type">, AttrOrTypeDef<"Type", name, traits, baseCppClass> { + // Make it possible to use such type as parameters for other types. + string cppType = dialect.cppNamespace # "::" # cppClassName; + // A constant builder provided when the type has no parameters. let builderCall = !if(!empty(parameters), "$_builder.getType<" # dialect.cppNamespace # diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -50,6 +50,36 @@ virtual void printType(Type type); virtual void printAttribute(Attribute attr); + /// Trait to check if `AttrType` provides a `print` method. + template + using has_print_method = + decltype(std::declval().print(std::declval())); + template + using detect_has_print_method = + llvm::is_detected; + + /// Print the provided attribute in the context of an operation custom + /// printer/parser: this will invoke directly the print method on the + /// attribute class and skip the `#dialect.mnemonic` prefix in most cases. + template ::value> + *sfinae = nullptr> + void printStrippedAttrOrType(AttrOrType attrOrType) { + if (succeeded(printAlias(attrOrType))) + return; + attrOrType.print(*this); + } + + /// SFINAE for printing the provided attribute in the context of an operation + /// custom printer in the case where the attribute does not define a print + /// method. + template ::value> + *sfinae = nullptr> + void printStrippedAttrOrType(AttrOrType attrOrType) { + *this << attrOrType; + } + /// Print the given attribute without its type. The corresponding parser must /// provide a valid type for the attribute. virtual void printAttributeWithoutType(Attribute attr); @@ -102,6 +132,14 @@ AsmPrinter(const AsmPrinter &) = delete; void operator=(const AsmPrinter &) = delete; + /// Print the alias for the given attribute, return failure if no alias could + /// be printed. + virtual LogicalResult printAlias(Attribute attr); + + /// Print the alias for the given type, return failure if no alias could + /// be printed. + virtual LogicalResult printAlias(Type type); + /// The internal implementation of the printer. Impl *impl; }; @@ -608,6 +646,13 @@ /// Parse an arbitrary attribute of a given type and return it in result. virtual ParseResult parseAttribute(Attribute &result, Type type = {}) = 0; + /// Parse a custom attribute with the provided callback, unless the next + /// token is `#`, in which case the generic parser is invoked. + virtual ParseResult parseCustomAttributeWithFallback( + Attribute &result, Type type, + function_ref + parseAttribute) = 0; + /// Parse an attribute of a specific kind and type. template ParseResult parseAttribute(AttrType &result, Type type = {}) { @@ -639,9 +684,9 @@ return parseAttribute(result, Type(), attrName, attrs); } - /// Parse an arbitrary attribute of a given type and return it in result. This - /// also adds the attribute to the specified attribute list with the specified - /// name. + /// Parse an arbitrary attribute of a given type and populate it in `result`. + /// This also adds the attribute to the specified attribute list with the + /// specified name. template ParseResult parseAttribute(AttrType &result, Type type, StringRef attrName, NamedAttrList &attrs) { @@ -661,6 +706,82 @@ return success(); } + /// Trait to check if `AttrType` provides a `parse` method. + template + using has_parse_method = decltype(AttrType::parse(std::declval(), + std::declval())); + template + using detect_has_parse_method = llvm::is_detected; + + /// Parse a custom attribute of a given type unless the next token is `#`, in + /// which case the generic parser is invoked. The parsed attribute is + /// populated in `result` and also added to the specified attribute list with + /// the specified name. + template + std::enable_if_t::value, ParseResult> + parseCustomAttributeWithFallback(AttrType &result, Type type, + StringRef attrName, NamedAttrList &attrs) { + llvm::SMLoc loc = getCurrentLocation(); + + // Parse any kind of attribute. + Attribute attr; + if (parseCustomAttributeWithFallback( + attr, type, [&](Attribute &result, Type type) -> ParseResult { + result = AttrType::parse(*this, type); + if (!result) + return failure(); + return success(); + })) + return failure(); + + // Check for the right kind of attribute. + result = attr.dyn_cast(); + if (!result) + return emitError(loc, "invalid kind of attribute specified"); + + attrs.append(attrName, result); + return success(); + } + + /// SFINAE parsing method for Attribute that don't implement a parse method. + template + std::enable_if_t::value, ParseResult> + parseCustomAttributeWithFallback(AttrType &result, Type type, + StringRef attrName, NamedAttrList &attrs) { + return parseAttribute(result, type, attrName, attrs); + } + + /// Parse a custom attribute of a given type unless the next token is `#`, in + /// which case the generic parser is invoked. The parsed attribute is + /// populated in `result`. + template + std::enable_if_t::value, ParseResult> + parseCustomAttributeWithFallback(AttrType &result) { + llvm::SMLoc loc = getCurrentLocation(); + + // Parse any kind of attribute. + Attribute attr; + if (parseCustomAttributeWithFallback( + attr, {}, [&](Attribute &result, Type type) -> ParseResult { + result = AttrType::parse(*this, type); + return success(!!result); + })) + return failure(); + + // Check for the right kind of attribute. + result = attr.dyn_cast(); + if (!result) + return emitError(loc, "invalid kind of attribute specified"); + return success(); + } + + /// SFINAE parsing method for Attribute that don't implement a parse method. + template + std::enable_if_t::value, ParseResult> + parseCustomAttributeWithFallback(AttrType &result) { + return parseAttribute(result); + } + /// Parse an arbitrary optional attribute of a given type and return it in /// result. virtual OptionalParseResult parseOptionalAttribute(Attribute &result, @@ -740,6 +861,11 @@ /// Parse a type. virtual ParseResult parseType(Type &result) = 0; + /// Parse a custom type with the provided callback, unless the next + /// token is `#`, in which case the generic parser is invoked. + virtual ParseResult parseCustomTypeWithFallback( + Type &result, function_ref parseType) = 0; + /// Parse an optional type. virtual OptionalParseResult parseOptionalType(Type &result) = 0; @@ -753,7 +879,7 @@ if (parseType(type)) return failure(); - // Check for the right kind of attribute. + // Check for the right kind of type. result = type.dyn_cast(); if (!result) return emitError(loc, "invalid kind of type specified"); @@ -761,6 +887,44 @@ return success(); } + /// Trait to check if `TypeT` provides a `parse` method. + template + using type_has_parse_method = + decltype(TypeT::parse(std::declval())); + template + using detect_type_has_parse_method = + llvm::is_detected; + + /// Parse a custom Type of a given type unless the next token is `#`, in + /// which case the generic parser is invoked. The parsed Type is + /// populated in `result`. + template + std::enable_if_t::value, ParseResult> + parseCustomTypeWithFallback(TypeT &result) { + llvm::SMLoc loc = getCurrentLocation(); + + // Parse any kind of Type. + Type type; + if (parseCustomTypeWithFallback(type, [&](Type &result) -> ParseResult { + result = TypeT::parse(*this); + return success(!!result); + })) + return failure(); + + // Check for the right kind of Type. + result = type.dyn_cast(); + if (!result) + return emitError(loc, "invalid kind of Type specified"); + return success(); + } + + /// SFINAE parsing method for Type that don't implement a parse method. + template + std::enable_if_t::value, ParseResult> + parseCustomTypeWithFallback(TypeT &result) { + return parseType(result); + } + /// Parse a type list. ParseResult parseTypeList(SmallVectorImpl &result) { do { @@ -792,7 +956,7 @@ if (parseColonType(type)) return failure(); - // Check for the right kind of attribute. + // Check for the right kind of type. result = type.dyn_cast(); if (!result) return emitError(loc, "invalid kind of type specified"); diff --git a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp --- a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp +++ b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp @@ -53,21 +53,21 @@ // ScalableVectorType //===----------------------------------------------------------------------===// -Type ArmSVEDialect::parseType(DialectAsmParser &parser) const { - llvm::SMLoc typeLoc = parser.getCurrentLocation(); - { - Type genType; - auto parseResult = generatedTypeParser(parser, "vector", genType); - if (parseResult.hasValue()) - return genType; - } - parser.emitError(typeLoc, "unknown type in ArmSVE dialect"); - return Type(); +void ScalableVectorType::print(AsmPrinter &printer) const { + printer << "<"; + for (int64_t dim : getShape()) + printer << dim << 'x'; + printer << getElementType() << '>'; } -void ArmSVEDialect::printType(Type type, DialectAsmPrinter &os) const { - if (failed(generatedTypePrinter(type, os))) - llvm_unreachable("unexpected 'arm_sve' type kind"); +Type ScalableVectorType::parse(AsmParser &parser) { + SmallVector dims; + Type eltType; + if (parser.parseLess() || + parser.parseDimensionList(dims, /*allowDynamic=*/false) || + parser.parseType(eltType) || parser.parseGreater()) + return {}; + return ScalableVectorType::get(eltType.getContext(), dims, eltType); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -170,7 +170,7 @@ }; void CombiningKindAttr::print(AsmPrinter &printer) const { - printer << "kind<"; + printer << "<"; auto kinds = llvm::make_filter_range(combiningKindsList, [&](auto kind) { return bitEnumContains(this->getKind(), kind); }); @@ -215,10 +215,12 @@ void VectorDialect::printAttribute(Attribute attr, DialectAsmPrinter &os) const { - if (auto ck = attr.dyn_cast()) + if (auto ck = attr.dyn_cast()) { + os << "kind"; ck.print(os); - else - llvm_unreachable("Unknown attribute type"); + return; + } + llvm_unreachable("Unknown attribute type"); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -1175,7 +1175,7 @@ /// Ex: /// ``` /// %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32> -/// %1 = vector.multi_reduction #vector.kind, %0 [1] +/// %1 = vector.multi_reduction add, %0 [1] /// : vector<8x32x16xf32> to vector<8x16xf32> /// ``` /// Gets converted to: @@ -1185,7 +1185,7 @@ /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, /// affine_map<(d0, d1, d2) -> (d0, d1)>], /// iterator_types = ["parallel", "parallel", "reduction"], -/// kind = #vector.kind} %0, %arg1, %cst_f0 +/// kind = add} %0, %arg1, %cst_f0 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> /// ``` struct MultiReduceToContract @@ -1234,7 +1234,7 @@ /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, /// affine_map<(d0, d1, d2) -> (d0, d1)>], /// iterator_types = ["parallel", "parallel", "reduction"], -/// kind = #vector.kind} %0, %arg1, %cst_f0 +/// kind = add} %0, %arg1, %cst_f0 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> /// ``` /// Gets converted to: @@ -1244,7 +1244,7 @@ /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, /// affine_map<(d0, d1, d2) -> (d0, d1)>], /// iterator_types = ["parallel", "parallel", "reduction"], -/// kind = #vector.kind} %arg0, %arg1, %cst_f0 +/// kind = add} %arg0, %arg1, %cst_f0 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> /// ``` struct CombineContractTranspose @@ -1291,7 +1291,7 @@ /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, /// affine_map<(d0, d1, d2) -> (d0, d1)>], /// iterator_types = ["parallel", "parallel", "reduction"], -/// kind = #vector.kind} %0, %arg1, %cst_f0 +/// kind = add} %0, %arg1, %cst_f0 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> /// ``` /// Gets converted to: @@ -1301,7 +1301,7 @@ /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, /// affine_map<(d0, d1, d2) -> (d0, d1)>], /// iterator_types = ["parallel", "parallel", "reduction"], -/// kind = #vector.kind} %arg0, %arg1, %cst_f0 +/// kind = add} %arg0, %arg1, %cst_f0 /// : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> /// ``` struct CombineContractBroadcast diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -474,6 +474,14 @@ void printAttributeWithoutType(Attribute attr) override { printAttribute(attr); } + LogicalResult printAlias(Attribute attr) override { + initializer.visit(attr); + return success(); + } + LogicalResult printAlias(Type type) override { + initializer.visit(type); + return success(); + } /// Print the given set of attributes with names not included within /// 'elidedAttrs'. @@ -1252,8 +1260,16 @@ void printAttribute(Attribute attr, AttrTypeElision typeElision = AttrTypeElision::Never); + /// Print the alias for the given attribute, return failure if no alias could + /// be printed. + LogicalResult printAlias(Attribute attr); + void printType(Type type); + /// Print the alias for the given type, return failure if no alias could + /// be printed. + LogicalResult printAlias(Type type); + /// Print the given location to the stream. If `allowAlias` is true, this /// allows for the internal location to use an attribute alias. void printLocation(LocationAttr loc, bool allowAlias = false); @@ -1594,6 +1610,14 @@ os << R"(opaque<"_", "0xDEADBEEF">)"; } +LogicalResult AsmPrinter::Impl::printAlias(Attribute attr) { + return success(state && succeeded(state->getAliasState().getAlias(attr, os))); +} + +LogicalResult AsmPrinter::Impl::printAlias(Type type) { + return success(state && succeeded(state->getAliasState().getAlias(type, os))); +} + void AsmPrinter::Impl::printAttribute(Attribute attr, AttrTypeElision typeElision) { if (!attr) { @@ -1602,7 +1626,7 @@ } // Try to print an alias for this attribute. - if (state && succeeded(state->getAliasState().getAlias(attr, os))) + if (succeeded(printAlias(attr))) return; if (!isa(attr.getDialect())) @@ -2104,6 +2128,16 @@ impl->printAttribute(attr); } +LogicalResult AsmPrinter::printAlias(Attribute attr) { + assert(impl && "expected AsmPrinter::printAlias to be overriden"); + return impl->printAlias(attr); +} + +LogicalResult AsmPrinter::printAlias(Type type) { + assert(impl && "expected AsmPrinter::printAlias to be overriden"); + return impl->printAlias(type); +} + void AsmPrinter::printAttributeWithoutType(Attribute attr) { assert(impl && "expected AsmPrinter::printAttributeWithoutType to be overriden"); diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -12,6 +12,7 @@ #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/IntegerSet.h" +#include "mlir/IR/OpImplementation.h" #include "mlir/IR/Operation.h" #include "mlir/IR/SymbolTable.h" #include "mlir/IR/Types.h" @@ -374,6 +375,7 @@ //===----------------------------------------------------------------------===// // BoolAttr +//===----------------------------------------------------------------------===// bool BoolAttr::getValue() const { auto *storage = reinterpret_cast(impl); diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -14,6 +14,7 @@ #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" +#include "mlir/IR/OpImplementation.h" #include "mlir/IR/TensorEncoding.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/BitVector.h" @@ -633,7 +634,7 @@ return true; // Allow custom dialect attributes. - if (!::mlir::isa(memorySpace.getDialect())) + if (!isa(memorySpace.getDialect())) return true; return false; diff --git a/mlir/lib/Parser/AsmParserImpl.h b/mlir/lib/Parser/AsmParserImpl.h --- a/mlir/lib/Parser/AsmParserImpl.h +++ b/mlir/lib/Parser/AsmParserImpl.h @@ -343,6 +343,29 @@ return success(static_cast(result)); } + /// Parse a custom attribute with the provided callback, unless the next + /// token is `#`, in which case the generic parser is invoked. + ParseResult parseCustomAttributeWithFallback( + Attribute &result, Type type, + function_ref parseAttribute) + override { + if (parser.getToken().isNot(Token::hash_identifier)) + return parseAttribute(result, type); + result = parser.parseAttribute(type); + return success(static_cast(result)); + } + + /// Parse a custom attribute with the provided callback, unless the next + /// token is `#`, in which case the generic parser is invoked. + ParseResult parseCustomTypeWithFallback( + Type &result, + function_ref parseType) override { + if (parser.getToken().isNot(Token::exclamation_identifier)) + return parseType(result); + result = parser.parseType(); + return success(static_cast(result)); + } + OptionalParseResult parseOptionalAttribute(Attribute &result, Type type) override { return parser.parseOptionalAttribute(result, type); diff --git a/mlir/test/Dialect/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/ArmSVE/roundtrip.mlir --- a/mlir/test/Dialect/ArmSVE/roundtrip.mlir +++ b/mlir/test/Dialect/ArmSVE/roundtrip.mlir @@ -3,7 +3,7 @@ func @arm_sve_sdot(%a: !arm_sve.vector<16xi8>, %b: !arm_sve.vector<16xi8>, %c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> { - // CHECK: arm_sve.sdot {{.*}}: !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32 + // CHECK: arm_sve.sdot {{.*}}: <16xi8> to <4xi32 %0 = arm_sve.sdot %c, %a, %b : !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32> return %0 : !arm_sve.vector<4xi32> @@ -12,7 +12,7 @@ func @arm_sve_smmla(%a: !arm_sve.vector<16xi8>, %b: !arm_sve.vector<16xi8>, %c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> { - // CHECK: arm_sve.smmla {{.*}}: !arm_sve.vector<16xi8> to !arm_sve.vector<4xi3 + // CHECK: arm_sve.smmla {{.*}}: <16xi8> to <4xi3 %0 = arm_sve.smmla %c, %a, %b : !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32> return %0 : !arm_sve.vector<4xi32> @@ -21,7 +21,7 @@ func @arm_sve_udot(%a: !arm_sve.vector<16xi8>, %b: !arm_sve.vector<16xi8>, %c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> { - // CHECK: arm_sve.udot {{.*}}: !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32 + // CHECK: arm_sve.udot {{.*}}: <16xi8> to <4xi32 %0 = arm_sve.udot %c, %a, %b : !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32> return %0 : !arm_sve.vector<4xi32> @@ -30,7 +30,7 @@ func @arm_sve_ummla(%a: !arm_sve.vector<16xi8>, %b: !arm_sve.vector<16xi8>, %c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> { - // CHECK: arm_sve.ummla {{.*}}: !arm_sve.vector<16xi8> to !arm_sve.vector<4xi3 + // CHECK: arm_sve.ummla {{.*}}: <16xi8> to <4xi3 %0 = arm_sve.ummla %c, %a, %b : !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32> return %0 : !arm_sve.vector<4xi32> diff --git a/mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir b/mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir --- a/mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir +++ b/mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir @@ -16,7 +16,7 @@ %0 = arith.addf %arg0, %arg0 : f32 // CHECK: %[[VAL_STORAGE:.*]] = async.runtime.create : !async.value %1 = async.runtime.create: !async.value -// CHECK: async.runtime.store %[[VAL]], %[[VAL_STORAGE]] : !async.value +// CHECK: async.runtime.store %[[VAL]], %[[VAL_STORAGE]] : async.runtime.store %0, %1: !async.value // CHECK: async.runtime.set_available %[[VAL_STORAGE]] : !async.value async.runtime.set_available %1: !async.value @@ -32,9 +32,9 @@ // CHECK: cond_br %[[IS_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_OK:.*]] // CHECK: ^[[BRANCH_OK]]: -// CHECK: %[[LOADED:.*]] = async.runtime.load %[[VAL_STORAGE]] : !async.value +// CHECK: %[[LOADED:.*]] = async.runtime.load %[[VAL_STORAGE]] : // CHECK: %[[RETURNED:.*]] = arith.mulf %[[ARG]], %[[LOADED]] : f32 -// CHECK: async.runtime.store %[[RETURNED]], %[[RETURNED_STORAGE]] : !async.value +// CHECK: async.runtime.store %[[RETURNED]], %[[RETURNED_STORAGE]] : // CHECK: async.runtime.set_available %[[RETURNED_STORAGE]] // CHECK: async.runtime.set_available %[[TOKEN]] // CHECK: br ^[[CLEANUP]] @@ -84,8 +84,8 @@ // CHECK: cond_br %[[IS_VALUE_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_VALUE_OK:.*]] // CHECK: ^[[BRANCH_VALUE_OK]]: -// CHECK: %[[LOADED:.*]] = async.runtime.load %[[RETURNED_TO_CALLER]]#1 : !async.value -// CHECK: async.runtime.store %[[LOADED]], %[[RETURNED_STORAGE]] : !async.value +// CHECK: %[[LOADED:.*]] = async.runtime.load %[[RETURNED_TO_CALLER]]#1 : +// CHECK: async.runtime.store %[[LOADED]], %[[RETURNED_STORAGE]] : // CHECK: async.runtime.set_available %[[RETURNED_STORAGE]] // CHECK: async.runtime.set_available %[[TOKEN]] // CHECK: br ^[[CLEANUP]] @@ -133,7 +133,7 @@ // CHECK: cond_br %[[IS_VALUE_ERROR_1]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_VALUE_OK_1:.*]] // CHECK: ^[[BRANCH_VALUE_OK_1]]: -// CHECK: %[[LOADED_1:.*]] = async.runtime.load %[[RETURNED_TO_CALLER_1]]#1 : !async.value +// CHECK: %[[LOADED_1:.*]] = async.runtime.load %[[RETURNED_TO_CALLER_1]]#1 : // CHECK: %[[RETURNED_TO_CALLER_2:.*]]:2 = call @simple_callee(%[[LOADED_1]]) : (f32) -> (!async.token, !async.value) // CHECK: %[[SAVED_2:.*]] = async.coro.save %[[HDL]] // CHECK: async.runtime.await_and_resume %[[RETURNED_TO_CALLER_2]]#0, %[[HDL]] @@ -150,8 +150,8 @@ // CHECK: cond_br %[[IS_VALUE_ERROR_2]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_VALUE_OK_2:.*]] // CHECK: ^[[BRANCH_VALUE_OK_2]]: -// CHECK: %[[LOADED_2:.*]] = async.runtime.load %[[RETURNED_TO_CALLER_2]]#1 : !async.value -// CHECK: async.runtime.store %[[LOADED_2]], %[[RETURNED_STORAGE]] : !async.value +// CHECK: %[[LOADED_2:.*]] = async.runtime.load %[[RETURNED_TO_CALLER_2]]#1 : +// CHECK: async.runtime.store %[[LOADED_2]], %[[RETURNED_STORAGE]] : // CHECK: async.runtime.set_available %[[RETURNED_STORAGE]] // CHECK: async.runtime.set_available %[[TOKEN]] // CHECK: br ^[[CLEANUP]] diff --git a/mlir/test/Dialect/Async/async-to-async-runtime.mlir b/mlir/test/Dialect/Async/async-to-async-runtime.mlir --- a/mlir/test/Dialect/Async/async-to-async-runtime.mlir +++ b/mlir/test/Dialect/Async/async-to-async-runtime.mlir @@ -245,7 +245,7 @@ } // CHECK: async.runtime.await %[[RET]]#1 : !async.value - // CHECK: %[[VALUE:.*]] = async.runtime.load %[[RET]]#1 : !async.value + // CHECK: %[[VALUE:.*]] = async.runtime.load %[[RET]]#1 : %0 = async.await %result : !async.value // CHECK: return %[[VALUE]] @@ -323,7 +323,7 @@ // // Load from the async.value argument after error checking. // CHECK: ^[[CONTINUATION:.*]]: -// CHECK: %[[LOADED:.*]] = async.runtime.load %[[ARG]] : !async.value) { - // CHECK: async.runtime.store %arg0, %arg1 : !async.value - async.runtime.store %arg0, %arg1 : !async.value + // CHECK: async.runtime.store %arg0, %arg1 : + async.runtime.store %arg0, %arg1 : return } // CHECK-LABEL: @load func @load(%arg0: !async.value) -> f32 { - // CHECK: %0 = async.runtime.load %arg0 : !async.value + // CHECK: %0 = async.runtime.load %arg0 : // CHECK: return %0 : f32 - %0 = async.runtime.load %arg0 : !async.value + %0 = async.runtime.load %arg0 : return %0 : f32 } diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -6,7 +6,7 @@ func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memref) { // CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584xf32> -// CHECK: vector.multi_reduction #vector.kind, %{{.*}} [0] : vector<1584xf32> to f32 +// CHECK: vector.multi_reduction , %{{.*}} [0] : vector<1584xf32> to f32 // CHECK: arith.addf %{{.*}}, %{{.*}} : f32 linalg.dot ins(%A, %B: memref<1584xf32>, memref<1584xf32>) outs(%C: memref) @@ -19,7 +19,7 @@ func @contraction_matvec(%A: memref<1584x1584xf32>, %B: memref<1584xf32>, %C: memref<1584xf32>) { // CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584xf32> -// CHECK: vector.multi_reduction #vector.kind, %{{.*}} [1] : vector<1584x1584xf32> to vector<1584xf32> +// CHECK: vector.multi_reduction , %{{.*}} [1] : vector<1584x1584xf32> to vector<1584xf32> // CHECK: arith.addf %{{.*}}, %{{.*}} : vector<1584xf32> linalg.matvec ins(%A, %B: memref<1584x1584xf32>, memref<1584xf32>) outs(%C: memref<1584xf32>) @@ -31,7 +31,7 @@ // CHECK-LABEL: contraction_matmul func @contraction_matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %C: memref<1584x1584xf32>) { // CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584x1584xf32> -// CHECK: vector.multi_reduction #vector.kind, %{{.*}} [2] : vector<1584x1584x1584xf32> to vector<1584x1584xf32> +// CHECK: vector.multi_reduction , %{{.*}} [2] : vector<1584x1584x1584xf32> to vector<1584x1584xf32> // CHECK: arith.addf %{{.*}}, %{{.*}} : vector<1584x1584xf32> linalg.matmul ins(%A, %B: memref<1584x1584xf32>, memref<1584x1584xf32>) outs(%C: memref<1584x1584xf32>) @@ -43,7 +43,7 @@ // CHECK-LABEL: contraction_batch_matmul func @contraction_batch_matmul(%A: memref<1584x1584x1584xf32>, %B: memref<1584x1584x1584xf32>, %C: memref<1584x1584x1584xf32>) { // CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584x1584x1584xf32> -// CHECK: vector.multi_reduction #vector.kind, %{{.*}} [3] : vector<1584x1584x1584x1584xf32> to vector<1584x1584x1584xf32> +// CHECK: vector.multi_reduction , %{{.*}} [3] : vector<1584x1584x1584x1584xf32> to vector<1584x1584x1584xf32> // CHECK: arith.addf %{{.*}}, %{{.*}} : vector<1584x1584x1584xf32> linalg.batch_matmul ins(%A, %B: memref<1584x1584x1584xf32>, memref<1584x1584x1584xf32>) @@ -71,7 +71,7 @@ // CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<8x32x16xf32> // CHECK: vector.transfer_read %{{.*}} : memref<8x32xf32>, vector<8x32xf32> // CHECK: %[[MUL:.*]] = arith.mulf %{{.*}}, %{{.*}} : vector<8x32x16xf32> - // CHECK: %[[R:.*]] = vector.multi_reduction #vector.kind, %[[MUL]] [2] : vector<8x32x16xf32> to vector<8x32xf32> + // CHECK: %[[R:.*]] = vector.multi_reduction , %[[MUL]] [2] : vector<8x32x16xf32> to vector<8x32xf32> // CHECK: arith.addf %[[R]], %{{.*}} : vector<8x32xf32> // CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<8x32xf32> linalg.generic #matmul_trait @@ -105,7 +105,7 @@ // CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<8x32x16xf32> // CHECK: vector.transfer_read %{{.*}} : memref<32x8xf32>, vector<8x32xf32> // CHECK: %[[MUL:.*]] = arith.mulf %{{.*}}, %{{.*}} : vector<8x32x16xf32> - // CHECK: %[[R:.*]] = vector.multi_reduction #vector.kind, %[[MUL]] [2] : vector<8x32x16xf32> to vector<8x32xf32> + // CHECK: %[[R:.*]] = vector.multi_reduction , %[[MUL]] [2] : vector<8x32x16xf32> to vector<8x32xf32> // CHECK: arith.addf %[[R]], %{{.*}} : vector<8x32xf32> // CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<32x8xf32> linalg.generic #matmul_transpose_out_trait @@ -139,7 +139,7 @@ // CHECK: vector.transfer_read %{{.*}} : memref<16x32xi32>, vector<8x32x16xi32> // CHECK: vector.transfer_read %{{.*}} : memref<8x32xi32>, vector<8x32xi32> // CHECK: %[[MUL:.*]] = arith.muli %{{.*}}, %{{.*}} : vector<8x32x16xi32> - // CHECK: %[[R:.*]] = vector.multi_reduction #vector.kind, %[[MUL]] [2] : vector<8x32x16xi32> to vector<8x32xi32> + // CHECK: %[[R:.*]] = vector.multi_reduction , %[[MUL]] [2] : vector<8x32x16xi32> to vector<8x32xi32> // CHECK: arith.addi %[[R]], %{{.*}} : vector<8x32xi32> // CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xi32>, memref<8x32xi32> @@ -160,7 +160,7 @@ func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>, %C: memref<8x32xf32>) { // CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<8x32x16xf32> - // CHECK: vector.multi_reduction #vector.kind, %{{.*}} [2] : vector<8x32x16xf32> to vector<8x32xf32> + // CHECK: vector.multi_reduction , %{{.*}} [2] : vector<8x32x16xf32> to vector<8x32xf32> // CHECK: arith.addf %{{.*}}, %{{.*}} : vector<8x32xf32> linalg.matmul ins(%A, %B: memref<8x16xf32>, memref<16x32xf32>) @@ -523,7 +523,7 @@ // linalg matmul lowers gets expanded to a 3D reduction, canonicalization later // convert it to a 2D contract. // CHECK: %[[MUL:.*]] = arith.mulf %[[V0]], %[[V1]] : vector<8x12x4xf32> - // CHECK: %[[R:.*]] = vector.multi_reduction #vector.kind, %[[MUL]] [2] : vector<8x12x4xf32> to vector<8x12xf32> + // CHECK: %[[R:.*]] = vector.multi_reduction , %[[MUL]] [2] : vector<8x12x4xf32> to vector<8x12xf32> // CHECK: %[[ADD:.*]] = arith.addf %[[R]], %[[V2]] : vector<8x12xf32> // CHECK: %[[W:.*]] = vector.transfer_write %[[ADD]], %[[ARG2]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x12xf32>, tensor<8x12xf32> %0 = linalg.matmul ins(%arg0, %arg1: tensor<8x4xf32>, tensor<4x12xf32>) @@ -744,7 +744,7 @@ // CHECK: vector.transfer_read {{.*}} : tensor<4x16x8xf32>, vector<4x16x8xf32> // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true]} : tensor<4x16xf32>, vector<4x16xf32> // CHECK: math.exp {{.*}} : vector<4x16x8xf32> - // CHECK: vector.multi_reduction #vector.kind, %{{.*}} [2] : vector<4x16x8xf32> to vector<4x16xf32> + // CHECK: vector.multi_reduction , %{{.*}} [2] : vector<4x16x8xf32> to vector<4x16xf32> // CHECK: addf {{.*}} : vector<4x16xf32> // CHECK: vector.transfer_write {{.*}} : vector<4x16xf32>, tensor<4x16xf32> // CHECK: return {{.*}} : tensor<4x16xf32> @@ -779,7 +779,7 @@ // CHECK: math.exp {{.*}} : vector<2x3x4x5xf32> // CHECK: math.exp {{.*}} : vector<2x3x4x5xf32> // CHECK: addf {{.*}} : vector<2x3x4x5xf32> - // CHECK: vector.multi_reduction #vector.kind, {{.*}} [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32> + // CHECK: vector.multi_reduction , {{.*}} [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32> // CHECK: addf {{.*}} : vector<2x5xf32> // CHECK: vector.transfer_write {{.*}} {in_bounds = [true, true], permutation_map = #[[$M3]]} : vector<2x5xf32>, tensor<5x2xf32> // CHECK: return {{.*}} : tensor<5x2xf32> @@ -808,7 +808,7 @@ // CHECK: %[[CMINF:.+]] = arith.constant dense<-3.402820e+38> : vector<4xf32> // CHECK: linalg.init_tensor [4] : tensor<4xf32> // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> - // CHECK: %[[R:.+]] = vector.multi_reduction #vector.kind, {{.*}} [1] : vector<4x4xf32> to vector<4xf32> + // CHECK: %[[R:.+]] = vector.multi_reduction , {{.*}} [1] : vector<4x4xf32> to vector<4xf32> // CHECK: maxf %[[R]], %[[CMINF]] : vector<4xf32> // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> %ident = arith.constant -3.40282e+38 : f32 @@ -833,7 +833,7 @@ // CHECK: linalg.init_tensor [4] : tensor<4xf32> // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> // CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32> - // CHECK: %[[R:.+]] = vector.multi_reduction #vector.kind, {{.*}} [1] : vector<4x4xf32> to vector<4xf32> + // CHECK: %[[R:.+]] = vector.multi_reduction , {{.*}} [1] : vector<4x4xf32> to vector<4xf32> // CHECK: arith.minf %[[R]], %[[CMAXF]] : vector<4xf32> // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> %maxf32 = arith.constant 3.40282e+38 : f32 @@ -857,7 +857,7 @@ // CHECK: linalg.init_tensor [4] : tensor<4xf32> // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> // CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32> - // CHECK: vector.multi_reduction #vector.kind, {{.*}} [1] : vector<4x4xf32> to vector<4xf32> + // CHECK: vector.multi_reduction , {{.*}} [1] : vector<4x4xf32> to vector<4xf32> // CHECK: mulf {{.*}} : vector<4xf32> // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> %ident = arith.constant 1.0 : f32 @@ -881,7 +881,7 @@ // CHECK: linalg.init_tensor [4] : tensor<4xi1> // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1> // CHECK: vector.transfer_read {{.*}} : tensor<4x4xi1>, vector<4x4xi1> - // CHECK: vector.multi_reduction #vector.kind, {{.*}} [1] : vector<4x4xi1> to vector<4xi1> + // CHECK: vector.multi_reduction , {{.*}} [1] : vector<4x4xi1> to vector<4xi1> // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1> %ident = arith.constant false %init = linalg.init_tensor [4] : tensor<4xi1> @@ -904,7 +904,7 @@ // CHECK: linalg.init_tensor [4] : tensor<4xi1> // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1> // CHECK: vector.transfer_read {{.*}} : tensor<4x4xi1>, vector<4x4xi1> - // CHECK: vector.multi_reduction #vector.kind, {{.*}} [1] : vector<4x4xi1> to vector<4xi1> + // CHECK: vector.multi_reduction , {{.*}} [1] : vector<4x4xi1> to vector<4xi1> // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1> %ident = arith.constant true %init = linalg.init_tensor [4] : tensor<4xi1> @@ -927,7 +927,7 @@ // CHECK: linalg.init_tensor [4] : tensor<4xi1> // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1> // CHECK: vector.transfer_read {{.*}} : tensor<4x4xi1>, vector<4x4xi1> - // CHECK: vector.multi_reduction #vector.kind, {{.*}} [1] : vector<4x4xi1> to vector<4xi1> + // CHECK: vector.multi_reduction , {{.*}} [1] : vector<4x4xi1> to vector<4xi1> // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1> %ident = arith.constant false %init = linalg.init_tensor [4] : tensor<4xi1> @@ -979,7 +979,7 @@ // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true], permutation_map = #[[$M6]]} : tensor<4x1xf32>, vector<4x4xf32> // CHECK: subf {{.*}} : vector<4x4xf32> // CHECK: math.exp {{.*}} : vector<4x4xf32> - // CHECK: vector.multi_reduction #vector.kind, {{.*}} : vector<4x4xf32> to vector<4xf32> + // CHECK: vector.multi_reduction , {{.*}} : vector<4x4xf32> to vector<4xf32> // CHECK: addf {{.*}} : vector<4xf32> // CHECK: vector.transfer_write {{.*}} {in_bounds = [true]} : vector<4xf32>, tensor<4xf32> %c0 = arith.constant 0.0 : f32 @@ -1019,7 +1019,7 @@ // CHECK: %[[r:.*]] = vector.transfer_read %[[A]][%[[C0]]] // CHECK-SAME: : tensor<32xf32>, vector<32xf32> // CHECK: %[[f0:.*]] = vector.extractelement %[[vF0]][] : vector - // CHECK: %[[red:.*]] = vector.multi_reduction #vector.kind, %[[r]] [0] + // CHECK: %[[red:.*]] = vector.multi_reduction , %[[r]] [0] // CHECK-SAME: : vector<32xf32> to f32 // CHECK: %[[a:.*]] = arith.addf %[[red]], %[[f0]] : f32 // CHECK: %[[red_v1:.*]] = vector.broadcast %[[a]] : f32 to vector diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1027,7 +1027,7 @@ // CHECK-LABEL: func @vector_multi_reduction_single_parallel( // CHECK-SAME: %[[v:.*]]: vector<2xf32> func @vector_multi_reduction_single_parallel(%arg0: vector<2xf32>) -> vector<2xf32> { - %0 = vector.multi_reduction #vector.kind, %arg0 [] : vector<2xf32> to vector<2xf32> + %0 = vector.multi_reduction , %arg0 [] : vector<2xf32> to vector<2xf32> // CHECK: return %[[v]] : vector<2xf32> return %0 : vector<2xf32> diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -3,7 +3,7 @@ // ----- func @broadcast_to_scalar(%arg0: f32) -> f32 { - // expected-error@+1 {{'vector.broadcast' op result #0 must be vector of any type values, but got 'f32'}} + // expected-error@+1 {{custom op 'vector.broadcast' invalid kind of type specified}} %0 = vector.broadcast %arg0 : f32 to f32 } @@ -1008,7 +1008,7 @@ // ----- func @bitcast_not_vector(%arg0 : vector<5x1x3x2xf32>) { - // expected-error@+1 {{must be vector of any type values}} + // expected-error@+1 {{'vector.bitcast' invalid kind of type specified}} %0 = vector.bitcast %arg0 : vector<5x1x3x2xf32> to f32 } diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -676,9 +676,9 @@ // CHECK-LABEL: @multi_reduction func @multi_reduction(%0: vector<4x8x16x32xf32>) -> f32 { - %1 = vector.multi_reduction #vector.kind, %0 [1, 3] : + %1 = vector.multi_reduction , %0 [1, 3] : vector<4x8x16x32xf32> to vector<4x16xf32> - %2 = vector.multi_reduction #vector.kind, %1 [0, 1] : + %2 = vector.multi_reduction , %1 [0, 1] : vector<4x16xf32> to f32 return %2 : f32 } diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir --- a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir @@ -1,7 +1,7 @@ // RUN: mlir-opt %s -test-vector-multi-reduction-lowering-patterns | FileCheck %s func @vector_multi_reduction(%arg0: vector<2x4xf32>) -> vector<2xf32> { - %0 = vector.multi_reduction #vector.kind, %arg0 [1] : vector<2x4xf32> to vector<2xf32> + %0 = vector.multi_reduction , %arg0 [1] : vector<2x4xf32> to vector<2xf32> return %0 : vector<2xf32> } // CHECK-LABEL: func @vector_multi_reduction @@ -18,7 +18,7 @@ // CHECK: return %[[RESULT_VEC]] func @vector_multi_reduction_to_scalar(%arg0: vector<2x4xf32>) -> f32 { - %0 = vector.multi_reduction #vector.kind, %arg0 [0, 1] : vector<2x4xf32> to f32 + %0 = vector.multi_reduction , %arg0 [0, 1] : vector<2x4xf32> to f32 return %0 : f32 } // CHECK-LABEL: func @vector_multi_reduction_to_scalar @@ -30,7 +30,7 @@ // CHECK: return %[[RES]] func @vector_reduction_inner(%arg0: vector<2x3x4x5xi32>) -> vector<2x3xi32> { - %0 = vector.multi_reduction #vector.kind, %arg0 [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32> + %0 = vector.multi_reduction , %arg0 [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32> return %0 : vector<2x3xi32> } // CHECK-LABEL: func @vector_reduction_inner @@ -66,7 +66,7 @@ func @vector_multi_reduction_transposed(%arg0: vector<2x3x4x5xf32>) -> vector<2x5xf32> { - %0 = vector.multi_reduction #vector.kind, %arg0 [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32> + %0 = vector.multi_reduction , %arg0 [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32> return %0 : vector<2x5xf32> } @@ -78,7 +78,7 @@ // CHECK: return %[[RESULT]] func @vector_multi_reduction_ordering(%arg0: vector<3x2x4xf32>) -> vector<2x4xf32> { - %0 = vector.multi_reduction #vector.kind, %arg0 [0] : vector<3x2x4xf32> to vector<2x4xf32> + %0 = vector.multi_reduction , %arg0 [0] : vector<3x2x4xf32> to vector<2x4xf32> return %0 : vector<2x4xf32> } // CHECK-LABEL: func @vector_multi_reduction_ordering diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir --- a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir @@ -1,7 +1,7 @@ // RUN: mlir-opt %s -test-vector-multi-reduction-lowering-patterns="use-outer-reductions" | FileCheck %s func @vector_multi_reduction(%arg0: vector<2x4xf32>) -> vector<2xf32> { - %0 = vector.multi_reduction #vector.kind, %arg0 [1] : vector<2x4xf32> to vector<2xf32> + %0 = vector.multi_reduction , %arg0 [1] : vector<2x4xf32> to vector<2xf32> return %0 : vector<2xf32> } @@ -18,7 +18,7 @@ // CHECK: return %[[RESULT_VEC]] : vector<2xf32> func @vector_multi_reduction_min(%arg0: vector<2x4xf32>) -> vector<2xf32> { - %0 = vector.multi_reduction #vector.kind, %arg0 [1] : vector<2x4xf32> to vector<2xf32> + %0 = vector.multi_reduction , %arg0 [1] : vector<2x4xf32> to vector<2xf32> return %0 : vector<2xf32> } @@ -35,7 +35,7 @@ // CHECK: return %[[RESULT_VEC]] : vector<2xf32> func @vector_multi_reduction_max(%arg0: vector<2x4xf32>) -> vector<2xf32> { - %0 = vector.multi_reduction #vector.kind, %arg0 [1] : vector<2x4xf32> to vector<2xf32> + %0 = vector.multi_reduction , %arg0 [1] : vector<2x4xf32> to vector<2xf32> return %0 : vector<2xf32> } @@ -52,7 +52,7 @@ // CHECK: return %[[RESULT_VEC]] : vector<2xf32> func @vector_multi_reduction_and(%arg0: vector<2x4xi32>) -> vector<2xi32> { - %0 = vector.multi_reduction #vector.kind, %arg0 [1] : vector<2x4xi32> to vector<2xi32> + %0 = vector.multi_reduction , %arg0 [1] : vector<2x4xi32> to vector<2xi32> return %0 : vector<2xi32> } @@ -69,7 +69,7 @@ // CHECK: return %[[RESULT_VEC]] : vector<2xi32> func @vector_multi_reduction_or(%arg0: vector<2x4xi32>) -> vector<2xi32> { - %0 = vector.multi_reduction #vector.kind, %arg0 [1] : vector<2x4xi32> to vector<2xi32> + %0 = vector.multi_reduction , %arg0 [1] : vector<2x4xi32> to vector<2xi32> return %0 : vector<2xi32> } @@ -86,7 +86,7 @@ // CHECK: return %[[RESULT_VEC]] : vector<2xi32> func @vector_multi_reduction_xor(%arg0: vector<2x4xi32>) -> vector<2xi32> { - %0 = vector.multi_reduction #vector.kind, %arg0 [1] : vector<2x4xi32> to vector<2xi32> + %0 = vector.multi_reduction , %arg0 [1] : vector<2x4xi32> to vector<2xi32> return %0 : vector<2xi32> } @@ -104,7 +104,7 @@ func @vector_reduction_outer(%arg0: vector<2x3x4x5xi32>) -> vector<2x3xi32> { - %0 = vector.multi_reduction #vector.kind, %arg0 [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32> + %0 = vector.multi_reduction , %arg0 [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32> return %0 : vector<2x3xi32> } diff --git a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir --- a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir +++ b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir @@ -12,7 +12,7 @@ func @multidimreduction_contract( %arg0: vector<8x32x16xf32>,%arg1: vector<8x32x16xf32>) -> vector<8x16xf32> { %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32> - %1 = vector.multi_reduction #vector.kind, %0 [1] : vector<8x32x16xf32> to vector<8x16xf32> + %1 = vector.multi_reduction , %0 [1] : vector<8x32x16xf32> to vector<8x16xf32> return %1 : vector<8x16xf32> } @@ -30,7 +30,7 @@ func @multidimreduction_contract_int( %arg0: vector<8x32x16xi32>,%arg1: vector<8x32x16xi32>) -> vector<8x16xi32> { %0 = arith.muli %arg0, %arg1 : vector<8x32x16xi32> - %1 = vector.multi_reduction #vector.kind, %0 [1] : vector<8x32x16xi32> to vector<8x16xi32> + %1 = vector.multi_reduction , %0 [1] : vector<8x32x16xi32> to vector<8x16xi32> return %1 : vector<8x16xi32> } diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td --- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td +++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td @@ -14,7 +14,7 @@ #define TEST_ATTRDEFS // To get the test dialect definition. -include "TestOps.td" +include "TestDialect.td" include "mlir/IR/BuiltinAttributeInterfaces.td" include "mlir/IR/SubElementInterfaces.td" @@ -121,6 +121,29 @@ ); } +// A more complex parameterized attribute with multiple level of nesting. +def CompoundNestedInner : Test_Attr<"CompoundNestedInner"> { + let mnemonic = "cmpnd_nested_inner"; + // List of type parameters. + let parameters = ( + ins + "int":$some_int, + CompoundAttrA:$cmpdA + ); + let assemblyFormat = "`<` $some_int $cmpdA `>`"; +} + +def CompoundNestedOuter : Test_Attr<"CompoundNestedOuter"> { + let mnemonic = "cmpnd_nested_outer"; + + // List of type parameters. + let parameters = ( + ins + CompoundNestedInner:$inner + ); + let assemblyFormat = "`<` `i` $inner `>`"; +} + def TestParamOne : AttrParameter<"int64_t", ""> {} def TestParamTwo : AttrParameter<"std::string", "", "llvm::StringRef"> { diff --git a/mlir/test/lib/Dialect/Test/TestDialect.td b/mlir/test/lib/Dialect/Test/TestDialect.td new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Test/TestDialect.td @@ -0,0 +1,46 @@ +//===-- TestDialect.td - Test dialect definition -----------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef TEST_DIALECT +#define TEST_DIALECT + +include "mlir/IR/OpBase.td" + +def Test_Dialect : Dialect { + let name = "test"; + let cppNamespace = "::test"; + let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed; + let hasCanonicalizer = 1; + let hasConstantMaterializer = 1; + let hasOperationAttrVerify = 1; + let hasRegionArgAttrVerify = 1; + let hasRegionResultAttrVerify = 1; + let hasOperationInterfaceFallback = 1; + let hasNonDefaultDestructor = 1; + let useDefaultAttributePrinterParser = 1; + let dependentDialects = ["::mlir::DLTIDialect"]; + + let extraClassDeclaration = [{ + void registerAttributes(); + void registerTypes(); + + // Provides a custom printing/parsing for some operations. + ::llvm::Optional + getParseOperationHook(::llvm::StringRef opName) const override; + ::llvm::unique_function + getOperationPrinter(::mlir::Operation *op) const override; + + private: + // Storage for a custom fallback interface. + void *fallbackEffectOpInterfaces; + + }]; +} + +#endif // TEST_DIALECT diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -9,6 +9,7 @@ #ifndef TEST_OPS #define TEST_OPS +include "TestDialect.td" include "mlir/Dialect/DLTI/DLTIBase.td" include "mlir/IR/OpBase.td" include "mlir/IR/OpAsmInterface.td" @@ -23,40 +24,11 @@ include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td" include "TestInterfaces.td" -def Test_Dialect : Dialect { - let name = "test"; - let cppNamespace = "::test"; - let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed; - let hasCanonicalizer = 1; - let hasConstantMaterializer = 1; - let hasOperationAttrVerify = 1; - let hasRegionArgAttrVerify = 1; - let hasRegionResultAttrVerify = 1; - let hasOperationInterfaceFallback = 1; - let hasNonDefaultDestructor = 1; - let useDefaultAttributePrinterParser = 1; - let dependentDialects = ["::mlir::DLTIDialect"]; - - let extraClassDeclaration = [{ - void registerAttributes(); - void registerTypes(); - - // Provides a custom printing/parsing for some operations. - ::llvm::Optional - getParseOperationHook(::llvm::StringRef opName) const override; - ::llvm::unique_function - getOperationPrinter(::mlir::Operation *op) const override; - - private: - // Storage for a custom fallback interface. - void *fallbackEffectOpInterfaces; - - }]; -} // Include the attribute definitions. include "TestAttrDefs.td" +// Include the type definitions. +include "TestTypeDefs.td" class TEST_Op traits = []> : @@ -1933,6 +1905,16 @@ let assemblyFormat = "$nested attr-dict-with-keyword"; } +def FormatNestedCompoundAttr : TEST_Op<"format_cpmd_nested_attr"> { + let arguments = (ins CompoundNestedOuter:$nested); + let assemblyFormat = "`nested` $nested attr-dict-with-keyword"; +} + +def FormatNestedType : TEST_Op<"format_cpmd_nested_type"> { + let arguments = (ins CompoundNestedOuterType:$nested); + let assemblyFormat = "$nested `nested` type($nested) attr-dict-with-keyword"; +} + //===----------------------------------------------------------------------===// // Custom Directives diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td --- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td +++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td @@ -14,8 +14,9 @@ #define TEST_TYPEDEFS // To get the test dialect def. -include "TestOps.td" +include "TestDialect.td" include "TestAttrDefs.td" +include "TestInterfaces.td" include "mlir/IR/BuiltinTypes.td" include "mlir/Interfaces/DataLayoutInterfaces.td" @@ -49,6 +50,29 @@ }]; } +// A more complex and nested parameterized type. +def CompoundNestedInnerType : Test_Type<"CompoundNestedInner"> { + let mnemonic = "cmpnd_inner"; + // List of type parameters. + let parameters = ( + ins + "int":$some_int, + CompoundTypeA:$cmpdA + ); + let assemblyFormat = "`<` $some_int $cmpdA `>`"; +} + +def CompoundNestedOuterType : Test_Type<"CompoundNestedOuter"> { + let mnemonic = "cmpnd_nested_outer"; + + // List of type parameters. + let parameters = ( + ins + CompoundNestedInnerType:$inner + ); + let assemblyFormat = "`<` `i` $inner `>`"; +} + // An example of how one could implement a standard integer. def IntegerType : Test_Type<"TestInteger"> { let mnemonic = "int"; diff --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h --- a/mlir/test/lib/Dialect/Test/TestTypes.h +++ b/mlir/test/lib/Dialect/Test/TestTypes.h @@ -25,6 +25,7 @@ #include "mlir/Interfaces/DataLayoutInterfaces.h" namespace test { +class TestAttrWithFormatAttr; /// FieldInfo represents a field in the StructType data type. It is used as a /// parameter in TestTypeDefs.td. @@ -63,13 +64,13 @@ return test::CustomParam{value.getValue()}; } }; -} // end namespace mlir - inline mlir::AsmPrinter &operator<<(mlir::AsmPrinter &printer, - const test::CustomParam ¶m) { + test::CustomParam param) { return printer << param.value; } +} // end namespace mlir + #include "TestTypeInterfaces.h.inc" #define GET_TYPEDEF_CLASSES diff --git a/mlir/test/mlir-tblgen/attr-or-type-format.td b/mlir/test/mlir-tblgen/attr-or-type-format.td --- a/mlir/test/mlir-tblgen/attr-or-type-format.td +++ b/mlir/test/mlir-tblgen/attr-or-type-format.td @@ -61,7 +61,7 @@ // ATTR: printer << ' ' << "hello"; // ATTR: printer << ' ' << "="; // ATTR: printer << ' '; -// ATTR: printer << getValue(); +// ATTR: printer.printStrippedAttrOrType(getValue()); // ATTR: printer << ","; // ATTR: printer << ' '; // ATTR: ::printAttrParamA(printer, getComplex()); @@ -154,10 +154,10 @@ // ATTR: void TestFAttr::print(::mlir::AsmPrinter &printer) const { // ATTR: printer << ' '; -// ATTR: printer << getV0(); +// ATTR: printer.printStrippedAttrOrType(getV0()); // ATTR: printer << ","; // ATTR: printer << ' '; -// ATTR: printer << getV1(); +// ATTR: printer.printStrippedAttrOrType(getV1()); // ATTR: } def AttrC : TestAttr<"TestF"> { @@ -213,7 +213,7 @@ // TYPE: printer << ' ' << "bob"; // TYPE: printer << ' ' << "bar"; // TYPE: printer << ' '; -// TYPE: printer << getValue(); +// TYPE: printer.printStrippedAttrOrType(getValue()); // TYPE: printer << ' ' << "complex"; // TYPE: printer << ' ' << "="; // TYPE: printer << ' '; @@ -361,21 +361,21 @@ // TYPE: printer << "v0"; // TYPE: printer << ' ' << "="; // TYPE: printer << ' '; -// TYPE: printer << getV0(); +// TYPE: printer.printStrippedAttrOrType(getV0()); // TYPE: printer << ","; // TYPE: printer << ' ' << "v2"; // TYPE: printer << ' ' << "="; // TYPE: printer << ' '; -// TYPE: printer << getV2(); +// TYPE: printer.printStrippedAttrOrType(getV2()); // TYPE: printer << "v1"; // TYPE: printer << ' ' << "="; // TYPE: printer << ' '; -// TYPE: printer << getV1(); +// TYPE: printer.printStrippedAttrOrType(getV1()); // TYPE: printer << ","; // TYPE: printer << ' ' << "v3"; // TYPE: printer << ' ' << "="; // TYPE: printer << ' '; -// TYPE: printer << getV3(); +// TYPE: printer.printStrippedAttrOrType(getV3()); // TYPE: } def TypeC : TestType<"TestE"> { diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir --- a/mlir/test/mlir-tblgen/op-format.mlir +++ b/mlir/test/mlir-tblgen/op-format.mlir @@ -256,16 +256,50 @@ // Format a custom attribute //===----------------------------------------------------------------------===// -// CHECK: test.format_compound_attr #test.cmpnd_a<1, !test.smpla, [5, 6]> -test.format_compound_attr #test.cmpnd_a<1, !test.smpla, [5, 6]> +// CHECK: test.format_compound_attr <1, !test.smpla, [5, 6]> +test.format_compound_attr <1, !test.smpla, [5, 6]> -// CHECK: module attributes {test.nested = #test.cmpnd_nested>} { +//----- + + +// CHECK: module attributes {test.nested = #test.cmpnd_nested>} { +module attributes {test.nested = #test.cmpnd_nested>} { +} + +//----- + +// Same as above, but fully spelling the inner attribute prefix `#test.cmpnd_a`. +// CHECK: module attributes {test.nested = #test.cmpnd_nested>} { module attributes {test.nested = #test.cmpnd_nested>} { } -// CHECK: test.format_nested_attr #test.cmpnd_nested> +// CHECK: test.format_nested_attr > +test.format_nested_attr #test.cmpnd_nested> + +//----- + +// Same as above, but fully spelling the inner attribute prefix `#test.cmpnd_a`. +// CHECK: test.format_nested_attr > test.format_nested_attr #test.cmpnd_nested> +//----- + +// CHECK: module attributes {test.someAttr = #test.cmpnd_nested_inner<42 <1, !test.smpla, [5, 6]>>} +module attributes {test.someAttr = #test.cmpnd_nested_inner<42 <1, !test.smpla, [5, 6]>>} +{ +} + +//----- + +// CHECK: module attributes {test.someAttr = #test.cmpnd_nested_outer>>} +module attributes {test.someAttr = #test.cmpnd_nested_outer>>} +{ +} + +//----- + +// CHECK: test.format_cpmd_nested_attr nested >> +test.format_cpmd_nested_attr nested >> //===----------------------------------------------------------------------===// // Format custom directives diff --git a/mlir/test/mlir-tblgen/testdialect-typedefs.mlir b/mlir/test/mlir-tblgen/testdialect-typedefs.mlir --- a/mlir/test/mlir-tblgen/testdialect-typedefs.mlir +++ b/mlir/test/mlir-tblgen/testdialect-typedefs.mlir @@ -13,6 +13,22 @@ return } +// CHECK: @compoundNested(%arg0: !test.cmpnd_nested_outer>>) +func @compoundNested(%arg0: !test.cmpnd_nested_outer>>) -> () { + return +} + +// Same as above, but we're parsing the complete spec for the inner type +// CHECK: @compoundNestedExplicit(%arg0: !test.cmpnd_nested_outer>>) +func @compoundNestedExplicit(%arg0: !test.cmpnd_nested_outer>>) -> () { +// Verify that the type prefix is elided and optional +// CHECK: format_cpmd_nested_type %arg0 nested >> +// CHECK: format_cpmd_nested_type %arg0 nested >> + test.format_cpmd_nested_type %arg0 nested !test.cmpnd_nested_outer>> + test.format_cpmd_nested_type %arg0 nested >> + return +} + // CHECK: @testInt(%arg0: !test.int, %arg1: !test.int, %arg2: !test.int) func @testInt(%A : !test.int, %B : !test.int, %C : !test.int) { return diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp @@ -163,7 +163,8 @@ "::mlir::FieldParser<$0>::parse($_parser)"; /// Default printer for attribute or type parameters. -static const char *const defaultParameterPrinter = "$_printer << $_self"; +static const char *const defaultParameterPrinter = + "$_printer.printStrippedAttrOrType($_self)"; /// Print an error when failing to parse an element. /// diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -496,13 +496,25 @@ /// {0}: The name of the attribute. /// {1}: The type for the attribute. const char *const attrParserCode = R"( - if (parser.parseAttribute({0}Attr{1}, "{0}", result.attributes)) + if (parser.parseCustomAttributeWithFallback({0}Attr, {1}, "{0}", + result.attributes)) {{ + return ::mlir::failure(); + } +)"; + +/// The code snippet used to generate a parser call for an attribute. +/// +/// {0}: The name of the attribute. +/// {1}: The type for the attribute. +const char *const genericAttrParserCode = R"( + if (parser.parseAttribute({0}Attr, {1}, "{0}", result.attributes)) return ::mlir::failure(); )"; + const char *const optionalAttrParserCode = R"( { ::mlir::OptionalParseResult parseResult = - parser.parseOptionalAttribute({0}Attr{1}, "{0}", result.attributes); + parser.parseOptionalAttribute({0}Attr, {1}, "{0}", result.attributes); if (parseResult.hasValue() && failed(*parseResult)) return ::mlir::failure(); } @@ -635,8 +647,12 @@ } )"; const char *const typeParserCode = R"( - if (parser.parseType({0}RawTypes[0])) - return ::mlir::failure(); + { + {0} type; + if (parser.parseCustomTypeWithFallback(type)) + return ::mlir::failure(); + {1}RawTypes[0] = type; + } )"; /// The code snippet used to generate a parser call for a functional type. @@ -1269,12 +1285,19 @@ std::string attrTypeStr; if (Optional typeBuilder = attr->getTypeBuilder()) { llvm::raw_string_ostream os(attrTypeStr); - os << ", " << tgfmt(*typeBuilder, &attrTypeCtx); + os << tgfmt(*typeBuilder, &attrTypeCtx); + } else { + attrTypeStr = "Type{}"; + } + if (var->attr.isOptional()) { + body << formatv(optionalAttrParserCode, var->name, attrTypeStr); + } else { + if (var->attr.getStorageType() == "::mlir::Attribute") + body << formatv(genericAttrParserCode, var->name, attrTypeStr); + else + body << formatv(attrParserCode, var->name, attrTypeStr); } - body << formatv(var->attr.isOptional() ? optionalAttrParserCode - : attrParserCode, - var->name, attrTypeStr); } else if (auto *operand = dyn_cast(element)) { ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar()); StringRef name = operand->getVar()->name; @@ -1334,14 +1357,23 @@ } else if (auto *dir = dyn_cast(element)) { ArgumentLengthKind lengthKind; StringRef listName = getTypeListName(dir->getOperand(), lengthKind); - if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) + if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) { body << llvm::formatv(variadicOfVariadicTypeParserCode, listName); - else if (lengthKind == ArgumentLengthKind::Variadic) + } else if (lengthKind == ArgumentLengthKind::Variadic) { body << llvm::formatv(variadicTypeParserCode, listName); - else if (lengthKind == ArgumentLengthKind::Optional) + } else if (lengthKind == ArgumentLengthKind::Optional) { body << llvm::formatv(optionalTypeParserCode, listName); - else - body << formatv(typeParserCode, listName); + } else { + TypeSwitch(dir->getOperand()) + .Case([&](auto operand) { + body << formatv(typeParserCode, + operand->getVar()->constraint.getCPPClassName(), + listName); + }) + .Default([&](auto operand) { + body << formatv(typeParserCode, "Type", listName); + }); + } } else if (auto *dir = dyn_cast(element)) { ArgumentLengthKind ignored; body << formatv(functionalTypeParserCode, @@ -1761,7 +1793,8 @@ /// Generate the C++ for an operand to a (*-)type directive. static MethodBody &genTypeOperandPrinter(Element *arg, const Operator &op, - MethodBody &body) { + MethodBody &body, + bool useArrayRef = true) { if (isa(arg)) return body << "getOperation()->getOperandTypes()"; if (isa(arg)) @@ -1778,8 +1811,10 @@ "({0}() ? ::llvm::ArrayRef<::mlir::Type>({0}().getType()) : " "::llvm::ArrayRef<::mlir::Type>())", op.getGetterName(var->name)); - return body << "::llvm::ArrayRef<::mlir::Type>(" - << op.getGetterName(var->name) << "().getType())"; + if (useArrayRef) + return body << "::llvm::ArrayRef<::mlir::Type>(" + << op.getGetterName(var->name) << "().getType())"; + return body << op.getGetterName(var->name) << "().getType()"; } /// Generate the printer for an enum attribute. @@ -1978,9 +2013,15 @@ if (attr->getTypeBuilder()) body << " _odsPrinter.printAttributeWithoutType(" << op.getGetterName(var->name) << "Attr());\n"; - else + else if (var->attr.isOptional()) + body << "_odsPrinter.printAttribute(" << op.getGetterName(var->name) + << "Attr());\n"; + else if (var->attr.getStorageType() == "::mlir::Attribute") body << " _odsPrinter.printAttribute(" << op.getGetterName(var->name) << "Attr());\n"; + else + body << "_odsPrinter.printStrippedAttrOrType(" + << op.getGetterName(var->name) << "Attr());\n"; } else if (auto *operand = dyn_cast(element)) { if (operand->getVar()->isVariadicOfVariadic()) { body << " ::llvm::interleaveComma(" @@ -2033,8 +2074,29 @@ return; } } + const NamedTypeConstraint *var = nullptr; + { + if (auto *operand = dyn_cast(dir->getOperand())) + var = operand->getVar(); + else if (auto *operand = dyn_cast(dir->getOperand())) + var = operand->getVar(); + } + if (var && !var->isVariadicOfVariadic() && !var->isVariadic() && + !var->isOptional()) { + std::string cppClass = var->constraint.getCPPClassName(); + body << " {\n" + << " auto type = " << op.getGetterName(var->name) + << "().getType();\n" + << " if (auto validType = type.dyn_cast<" << cppClass << ">())\n" + << " _odsPrinter.printStrippedAttrOrType(validType);\n" + << " else\n" + << " _odsPrinter << type;\n" + << " }\n"; + return; + } body << " _odsPrinter << "; - genTypeOperandPrinter(dir->getOperand(), op, body) << ";\n"; + genTypeOperandPrinter(dir->getOperand(), op, body, /*useArrayRef=*/false) + << ";\n"; } else if (auto *dir = dyn_cast(element)) { body << " _odsPrinter.printFunctionalType("; genTypeOperandPrinter(dir->getInputs(), op, body) << ", ";