diff --git a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td --- a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td +++ b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td @@ -31,6 +31,7 @@ vector operations, including a scalable vector type and intrinsics for some Arm SVE instructions. }]; + let useDefaultTypePrinterParser = 1; } //===----------------------------------------------------------------------===// @@ -66,20 +67,6 @@ "Type":$elementType ); - let printer = [{ - $_printer << "<"; - for (int64_t dim : getShape()) - $_printer << dim << 'x'; - $_printer << getElementType() << '>'; - }]; - - let parser = [{ - VectorType vector; - if ($_parser.parseType(vector)) - return Type(); - return get($_ctxt, vector.getShape(), vector.getElementType()); - }]; - let extraClassDeclaration = [{ bool hasStaticShape() const { return llvm::none_of(getShape(), ShapedType::isDynamic); diff --git a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h --- a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h +++ b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h @@ -19,6 +19,8 @@ namespace mlir { class ShapedType; +class AsmParser; +class AsmPrinter; //===----------------------------------------------------------------------===// // ElementsAttr diff --git a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td --- a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td +++ b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td @@ -417,6 +417,7 @@ return values->begin(); return llvm::None; } + }] # ElementsAttrInterfaceAccessors; } diff --git a/mlir/include/mlir/IR/DialectImplementation.h b/mlir/include/mlir/IR/DialectImplementation.h --- a/mlir/include/mlir/IR/DialectImplementation.h +++ b/mlir/include/mlir/IR/DialectImplementation.h @@ -64,7 +64,19 @@ AttributeT>> { static FailureOr parse(AsmParser &parser) { AttributeT value; - if (parser.parseAttribute(value)) + if (parser.parseCustomAttributeWithFallback(value)) + return failure(); + return value; + } +}; + +/// Parse an attribute. +template +struct FieldParser< + TypeT, std::enable_if_t::value, TypeT>> { + static FailureOr parse(AsmParser &parser) { + TypeT value; + if (parser.parseCustomTypeWithFallback(value)) return failure(); return value; } diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2971,6 +2971,9 @@ string baseCppClass = "::mlir::Type"> : DialectType, /*descr*/"", name # "Type">, AttrOrTypeDef<"Type", name, traits, baseCppClass> { + // Make it possible to use such type as parameters for other types. + string cppType = dialect.cppNamespace # "::" # cppClassName; + // A constant builder provided when the type has no parameters. let builderCall = !if(!empty(parameters), "$_builder.getType<" # dialect.cppNamespace # diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -22,6 +22,7 @@ namespace mlir { class Builder; +class AsmParser; //===----------------------------------------------------------------------===// // AsmPrinter @@ -50,6 +51,46 @@ virtual void printType(Type type); virtual void printAttribute(Attribute attr); + /// Trait to check if `AttrType` provides a `print` method. + template + using has_print_method = + decltype(std::declval().print(std::declval())); + template + using detect_has_print_method = + llvm::is_detected; + + /// Print the provided attribute in the context of an operation custom + /// printer/parser: this will invoke directly the print method on the + /// attribute class and skip the `#dialect.mnemonic` prefix in most cases. + template + void printStrippedAttrOrType( + AttrOrType attrOrType, + std::enable_if_t::value> *sfinae = + nullptr) { + if (succeeded(printAlias(attrOrType))) + return; + attrOrType.print(*this); + } + + /// SFINAE for printing the provided attribute in the context of an operation + /// custom printer in the case where the attribute does not define a print + /// method. + template + void printStrippedAttrOrType( + AttrOrType attrOrType, + std::enable_if_t::value> *sfinae = + nullptr) { + *this << attrOrType; + } + + /// Print the alias for the given attribute, return failure if no alias could + /// be printed. + virtual LogicalResult printAlias(Attribute attr); + + /// Print the alias for the given type, return failure if no alias could + /// be printed. + virtual LogicalResult printAlias(Type type); + /// Print the given attribute without its type. The corresponding parser must /// provide a valid type for the attribute. virtual void printAttributeWithoutType(Attribute attr); @@ -608,6 +649,13 @@ /// Parse an arbitrary attribute of a given type and return it in result. virtual ParseResult parseAttribute(Attribute &result, Type type = {}) = 0; + /// Parse a custom attribute with the provided callback, unless the next + /// token is `#`, in which case the generic parser is invoked. + virtual ParseResult parseCustomAttributeWithFallback( + Attribute &result, Type type, + llvm::function_ref + parseAttribute) = 0; + /// Parse an attribute of a specific kind and type. template ParseResult parseAttribute(AttrType &result, Type type = {}) { @@ -639,9 +687,9 @@ return parseAttribute(result, Type(), attrName, attrs); } - /// Parse an arbitrary attribute of a given type and return it in result. This - /// also adds the attribute to the specified attribute list with the specified - /// name. + /// Parse an arbitrary attribute of a given type and populate it in `result`. + /// This also adds the attribute to the specified attribute list with the + /// specified name. template ParseResult parseAttribute(AttrType &result, Type type, StringRef attrName, NamedAttrList &attrs) { @@ -661,6 +709,88 @@ return success(); } + /// Trait to check if `AttrType` provides a `parse` method. + template + using has_parse_method = decltype(AttrType::parse(std::declval(), + std::declval())); + template + using detect_has_parse_method = llvm::is_detected; + + /// Parse a custom attribute of a given type unless the next token is `#`, in + /// which case the generic parser is invoked. The parsed attribute is + /// populated in `result` and also added to the specified attribute list with + /// the specified name. + template + std::enable_if_t::value, ParseResult> + parseCustomAttributeWithFallback(AttrType &result, Type type, + StringRef attrName, NamedAttrList &attrs) { + llvm::SMLoc loc = getCurrentLocation(); + + // Parse any kind of attribute. + Attribute attr; + if (parseCustomAttributeWithFallback( + attr, type, + [&](::mlir::Attribute &result, + ::mlir::Type type) -> ::mlir::ParseResult { + result = AttrType::parse(*this, type); + if (!result) + return ::mlir::failure(); + return ::mlir::success(); + })) + return failure(); + + // Check for the right kind of attribute. + result = attr.dyn_cast(); + if (!result) + return emitError(loc, "invalid kind of attribute specified"); + + attrs.append(attrName, result); + return success(); + } + + /// SFINAE parsing method for Attribute that don't implement a parse method. + template + std::enable_if_t::value, ParseResult> + parseCustomAttributeWithFallback(AttrType &result, Type type, + StringRef attrName, NamedAttrList &attrs) { + return parseAttribute(result, type, attrName, attrs); + } + + /// Parse a custom attribute of a given type unless the next token is `#`, in + /// which case the generic parser is invoked. The parsed attribute is + /// populated in `result`. + template + std::enable_if_t::value, ParseResult> + parseCustomAttributeWithFallback(AttrType &result) { + llvm::SMLoc loc = getCurrentLocation(); + + // Parse any kind of attribute. + Attribute attr; + if (parseCustomAttributeWithFallback( + attr, {}, + [&](::mlir::Attribute &result, + ::mlir::Type type) -> ::mlir::ParseResult { + result = AttrType::parse(*this, type); + if (!result) + return ::mlir::failure(); + return ::mlir::success(); + })) + return failure(); + + // Check for the right kind of attribute. + result = attr.dyn_cast(); + if (!result) + return emitError(loc, "invalid kind of attribute specified"); + return success(); + } + + /// SFINAE parsing method for Attribute that don't implement a parse method. + template + std::enable_if_t::value, ParseResult> + parseCustomAttributeWithFallback(AttrType &result) { + return parseAttribute(result); + } + /// Parse an arbitrary optional attribute of a given type and return it in /// result. virtual OptionalParseResult parseOptionalAttribute(Attribute &result, @@ -740,6 +870,12 @@ /// Parse a type. virtual ParseResult parseType(Type &result) = 0; + /// Parse a custom type with the provided callback, unless the next + /// token is `#`, in which case the generic parser is invoked. + virtual ParseResult parseCustomTypeWithFallback( + Type &result, + llvm::function_ref parseType) = 0; + /// Parse an optional type. virtual OptionalParseResult parseOptionalType(Type &result) = 0; @@ -753,7 +889,7 @@ if (parseType(type)) return failure(); - // Check for the right kind of attribute. + // Check for the right kind of type. result = type.dyn_cast(); if (!result) return emitError(loc, "invalid kind of type specified"); @@ -761,6 +897,47 @@ return success(); } + /// Trait to check if `TypeT` provides a `parse` method. + template + using type_has_parse_method = + decltype(TypeT::parse(std::declval())); + template + using detect_type_has_parse_method = + llvm::is_detected; + + /// Parse a custom Type of a given type unless the next token is `#`, in + /// which case the generic parser is invoked. The parsed Type is + /// populated in `result`. + template + std::enable_if_t::value, ParseResult> + parseCustomTypeWithFallback(TypeT &result) { + llvm::SMLoc loc = getCurrentLocation(); + + // Parse any kind of Type. + Type type; + if (parseCustomTypeWithFallback( + type, [&](::mlir::Type &result) -> ::mlir::ParseResult { + result = TypeT::parse(*this); + if (!result) + return ::mlir::failure(); + return ::mlir::success(); + })) + return failure(); + + // Check for the right kind of Type. + result = type.dyn_cast(); + if (!result) + return emitError(loc, "invalid kind of Type specified"); + return success(); + } + + /// SFINAE parsing method for Type that don't implement a parse method. + template + std::enable_if_t::value, ParseResult> + parseCustomTypeWithFallback(TypeT &result) { + return parseType(result); + } + /// Parse a type list. ParseResult parseTypeList(SmallVectorImpl &result) { do { @@ -792,7 +969,7 @@ if (parseColonType(type)) return failure(); - // Check for the right kind of attribute. + // Check for the right kind of type. result = type.dyn_cast(); if (!result) return emitError(loc, "invalid kind of type specified"); diff --git a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp --- a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp +++ b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp @@ -53,21 +53,21 @@ // ScalableVectorType //===----------------------------------------------------------------------===// -Type ArmSVEDialect::parseType(DialectAsmParser &parser) const { - llvm::SMLoc typeLoc = parser.getCurrentLocation(); - { - Type genType; - auto parseResult = generatedTypeParser(parser, "vector", genType); - if (parseResult.hasValue()) - return genType; - } - parser.emitError(typeLoc, "unknown type in ArmSVE dialect"); - return Type(); +void ScalableVectorType::print(AsmPrinter &printer) const { + printer << "<"; + for (int64_t dim : getShape()) + printer << dim << 'x'; + printer << getElementType() << '>'; } -void ArmSVEDialect::printType(Type type, DialectAsmPrinter &os) const { - if (failed(generatedTypePrinter(type, os))) - llvm_unreachable("unexpected 'arm_sve' type kind"); +Type ScalableVectorType::parse(AsmParser &parser) { + SmallVector dims; + Type eltType; + if (parser.parseLess() || + parser.parseDimensionList(dims, /*allowDynamic=*/false) || + parser.parseType(eltType) || parser.parseGreater()) + return {}; + return ScalableVectorType::get(eltType.getContext(), dims, eltType); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -170,7 +170,7 @@ }; void CombiningKindAttr::print(AsmPrinter &printer) const { - printer << "kind<"; + printer << "<"; auto kinds = llvm::make_filter_range(combiningKindsList, [&](auto kind) { return bitEnumContains(this->getKind(), kind); }); @@ -215,10 +215,12 @@ void VectorDialect::printAttribute(Attribute attr, DialectAsmPrinter &os) const { - if (auto ck = attr.dyn_cast()) + if (auto ck = attr.dyn_cast()) { + os << "kind"; ck.print(os); - else - llvm_unreachable("Unknown attribute type"); + return; + } + llvm_unreachable("Unknown attribute type"); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -1177,7 +1177,7 @@ /// Ex: /// ``` /// %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32> -/// %1 = vector.multi_reduction #vector.kind, %0 [1] +/// %1 = vector.multi_reduction add, %0 [1] /// : vector<8x32x16xf32> to vector<8x16xf32> /// ``` /// Gets converted to: @@ -1187,7 +1187,7 @@ /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, /// affine_map<(d0, d1, d2) -> (d0, d1)>], /// iterator_types = ["parallel", "parallel", "reduction"], -/// kind = #vector.kind} %0, %arg1, %cst_f0 +/// kind = add} %0, %arg1, %cst_f0 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> /// ``` struct MultiReduceToContract @@ -1236,7 +1236,7 @@ /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, /// affine_map<(d0, d1, d2) -> (d0, d1)>], /// iterator_types = ["parallel", "parallel", "reduction"], -/// kind = #vector.kind} %0, %arg1, %cst_f0 +/// kind = add} %0, %arg1, %cst_f0 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> /// ``` /// Gets converted to: @@ -1246,7 +1246,7 @@ /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, /// affine_map<(d0, d1, d2) -> (d0, d1)>], /// iterator_types = ["parallel", "parallel", "reduction"], -/// kind = #vector.kind} %arg0, %arg1, %cst_f0 +/// kind = add} %arg0, %arg1, %cst_f0 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> /// ``` struct CombineContractTranspose @@ -1293,7 +1293,7 @@ /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, /// affine_map<(d0, d1, d2) -> (d0, d1)>], /// iterator_types = ["parallel", "parallel", "reduction"], -/// kind = #vector.kind} %0, %arg1, %cst_f0 +/// kind = add} %0, %arg1, %cst_f0 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> /// ``` /// Gets converted to: @@ -1303,7 +1303,7 @@ /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, /// affine_map<(d0, d1, d2) -> (d0, d1)>], /// iterator_types = ["parallel", "parallel", "reduction"], -/// kind = #vector.kind} %arg0, %arg1, %cst_f0 +/// kind = add} %arg0, %arg1, %cst_f0 /// : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> /// ``` struct CombineContractBroadcast diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -474,6 +474,14 @@ void printAttributeWithoutType(Attribute attr) override { printAttribute(attr); } + LogicalResult printAlias(Attribute attr) override { + initializer.visit(attr); + return success(); + } + LogicalResult printAlias(Type type) override { + initializer.visit(type); + return success(); + } /// Print the given set of attributes with names not included within /// 'elidedAttrs'. @@ -1252,8 +1260,16 @@ void printAttribute(Attribute attr, AttrTypeElision typeElision = AttrTypeElision::Never); + /// Print the alias for the given attribute, return failure if no alias could + /// be printed. + LogicalResult printAlias(Attribute attr); + void printType(Type type); + /// Print the alias for the given type, return failure if no alias could + /// be printed. + LogicalResult printAlias(Type attr); + /// Print the given location to the stream. If `allowAlias` is true, this /// allows for the internal location to use an attribute alias. void printLocation(LocationAttr loc, bool allowAlias = false); @@ -1594,6 +1610,14 @@ os << R"(opaque<"_", "0xDEADBEEF">)"; } +LogicalResult AsmPrinter::Impl::printAlias(Attribute attr) { + return success(state && succeeded(state->getAliasState().getAlias(attr, os))); +} + +LogicalResult AsmPrinter::Impl::printAlias(Type type) { + return success(state && succeeded(state->getAliasState().getAlias(type, os))); +} + void AsmPrinter::Impl::printAttribute(Attribute attr, AttrTypeElision typeElision) { if (!attr) { @@ -1602,7 +1626,7 @@ } // Try to print an alias for this attribute. - if (state && succeeded(state->getAliasState().getAlias(attr, os))) + if (succeeded(printAlias(attr))) return; if (!isa(attr.getDialect())) @@ -2106,6 +2130,16 @@ impl->printAttribute(attr); } +LogicalResult AsmPrinter::printAlias(Attribute attr) { + assert(impl && "expected AsmPrinter::printAlias to be overriden"); + return impl->printAlias(attr); +} + +LogicalResult AsmPrinter::printAlias(Type type) { + assert(impl && "expected AsmPrinter::printAlias to be overriden"); + return impl->printAlias(type); +} + void AsmPrinter::printAttributeWithoutType(Attribute attr) { assert(impl && "expected AsmPrinter::printAttributeWithoutType to be overriden"); diff --git a/mlir/lib/IR/BuiltinAttributeInterfaces.cpp b/mlir/lib/IR/BuiltinAttributeInterfaces.cpp --- a/mlir/lib/IR/BuiltinAttributeInterfaces.cpp +++ b/mlir/lib/IR/BuiltinAttributeInterfaces.cpp @@ -9,6 +9,7 @@ #include "mlir/IR/BuiltinAttributeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" +#include "mlir/IR/OpImplementation.h" #include "llvm/ADT/Sequence.h" using namespace mlir; diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -12,6 +12,7 @@ #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/IntegerSet.h" +#include "mlir/IR/OpImplementation.h" #include "mlir/IR/Operation.h" #include "mlir/IR/SymbolTable.h" #include "mlir/IR/Types.h" @@ -62,6 +63,14 @@ return get(getContext(), vector); } +Attribute ArrayAttr::parse(AsmParser &parser, Type type) { + Attribute attr; + parser.parseAttribute(attr, type); + return attr; +} + +void ArrayAttr::print(AsmPrinter &printer) const { printer << *this; } + //===----------------------------------------------------------------------===// // DictionaryAttr //===----------------------------------------------------------------------===// @@ -241,6 +250,12 @@ return getWithSorted(getContext(), vec); } +Attribute DictionaryAttr::parse(AsmParser &parser, Type type) { + Attribute attr; + parser.parseAttribute(attr, type); + return attr; +} + //===----------------------------------------------------------------------===// // StringAttr //===----------------------------------------------------------------------===// @@ -270,6 +285,16 @@ return getImpl()->referencedDialect; } +void StringAttr::print(AsmPrinter &printer) const { + printer.printAttribute(*this); +} + +Attribute StringAttr::parse(AsmParser &parser, Type type) { + Attribute attr; + parser.parseAttribute(attr, type); + return attr; +} + //===----------------------------------------------------------------------===// // FloatAttr //===----------------------------------------------------------------------===// @@ -300,6 +325,12 @@ return success(); } +Attribute FloatAttr::parse(AsmParser &parser, Type type) { + Attribute attr; + parser.parseAttribute(attr, type); + return attr; +} + //===----------------------------------------------------------------------===// // SymbolRefAttr //===----------------------------------------------------------------------===// @@ -329,6 +360,12 @@ return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getAttr(); } +Attribute SymbolRefAttr::parse(AsmParser &parser, Type type) { + Attribute attr; + parser.parseAttribute(attr, type); + return attr; +} + //===----------------------------------------------------------------------===// // IntegerAttr //===----------------------------------------------------------------------===// @@ -376,8 +413,19 @@ return attr.cast(); } +Attribute IntegerAttr::parse(AsmParser &parser, Type type) { + Attribute attr; + parser.parseAttribute(attr, type); + return attr; +} + +void IntegerAttr::print(AsmPrinter &printer) const { + printer.printAttribute(*this); +} + //===----------------------------------------------------------------------===// // BoolAttr +//===----------------------------------------------------------------------===// bool BoolAttr::getValue() const { auto *storage = reinterpret_cast(impl); @@ -414,6 +462,12 @@ return success(); } +Attribute OpaqueAttr::parse(AsmParser &parser, Type type) { + Attribute attr; + parser.parseAttribute(attr, type); + return attr; +} + //===----------------------------------------------------------------------===// // DenseElementsAttr Utilities //===----------------------------------------------------------------------===// @@ -1152,6 +1206,16 @@ elementBitWidth, numElements); } +void DenseIntOrFPElementsAttr::print(AsmPrinter &printer) const { + printer.printAttribute(*this); +} + +Attribute DenseIntOrFPElementsAttr::parse(AsmParser &parser, Type type) { + Attribute attr; + parser.parseAttribute(attr, type); + return attr; +} + //===----------------------------------------------------------------------===// // DenseFPElementsAttr //===----------------------------------------------------------------------===// @@ -1378,3 +1442,9 @@ function_ref walkTypesFn) const { walkTypesFn(getValue()); } + +Attribute TypeAttr::parse(AsmParser &parser, Type type) { + Attribute attr; + parser.parseAttribute(attr, type); + return attr; +} diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -14,6 +14,7 @@ #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" +#include "mlir/IR/OpImplementation.h" #include "mlir/IR/TensorEncoding.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/BitVector.h" @@ -60,6 +61,15 @@ return success(); } +void ComplexType::print(mlir::AsmPrinter &printer) const { printer << *this; } + +Type ComplexType::parse(mlir::AsmParser &parser) { + ComplexType type; + if (failed(parser.parseType(type))) + return {}; + return type; +} + //===----------------------------------------------------------------------===// // Integer Type //===----------------------------------------------------------------------===// @@ -90,6 +100,15 @@ return IntegerType::get(getContext(), scale * getWidth(), getSignedness()); } +void IntegerType::print(mlir::AsmPrinter &printer) const { printer << *this; } + +Type IntegerType::parse(mlir::AsmParser &parser) { + IntegerType type; + if (failed(parser.parseType(type))) + return {}; + return type; +} + //===----------------------------------------------------------------------===// // Float Type //===----------------------------------------------------------------------===// @@ -245,6 +264,15 @@ walkTypesFn(type); } +void FunctionType::print(mlir::AsmPrinter &printer) const { printer << *this; } + +Type FunctionType::parse(mlir::AsmParser &parser) { + FunctionType type; + if (failed(parser.parseType(type))) + return {}; + return type; +} + //===----------------------------------------------------------------------===// // OpaqueType //===----------------------------------------------------------------------===// @@ -475,6 +503,15 @@ walkTypesFn(getElementType()); } +void VectorType::print(mlir::AsmPrinter &printer) const { printer << *this; } + +Type VectorType::parse(mlir::AsmParser &parser) { + VectorType type; + if (failed(parser.parseType(type))) + return {}; + return type; +} + //===----------------------------------------------------------------------===// // TensorType //===----------------------------------------------------------------------===// @@ -786,6 +823,15 @@ walkAttrsFn(getMemorySpace()); } +void MemRefType::print(mlir::AsmPrinter &printer) const { printer << *this; } + +Type MemRefType::parse(mlir::AsmParser &parser) { + MemRefType type; + if (failed(parser.parseType(type))) + return {}; + return type; +} + //===----------------------------------------------------------------------===// // UnrankedMemRefType //===----------------------------------------------------------------------===// @@ -952,6 +998,17 @@ walkAttrsFn(getMemorySpace()); } +void UnrankedMemRefType::print(mlir::AsmPrinter &printer) const { + printer << *this; +} + +Type UnrankedMemRefType::parse(mlir::AsmParser &parser) { + UnrankedMemRefType type; + if (failed(parser.parseType(type))) + return {}; + return type; +} + //===----------------------------------------------------------------------===// /// TupleType //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Parser/AsmParserImpl.h b/mlir/lib/Parser/AsmParserImpl.h --- a/mlir/lib/Parser/AsmParserImpl.h +++ b/mlir/lib/Parser/AsmParserImpl.h @@ -343,6 +343,29 @@ return success(static_cast(result)); } + /// Parse a custom attribute with the provided callback, unless the next + /// token is `#`, in which case the generic parser is invoked. + ParseResult parseCustomAttributeWithFallback( + Attribute &result, Type type, + llvm::function_ref + parseAttribute) override { + if (parser.getToken().isNot(Token::hash_identifier)) + return parseAttribute(result, type); + result = parser.parseAttribute(type); + return success(static_cast(result)); + } + + /// Parse a custom attribute with the provided callback, unless the next + /// token is `#`, in which case the generic parser is invoked. + ParseResult parseCustomTypeWithFallback( + Type &result, + llvm::function_ref parseType) override { + if (parser.getToken().isNot(Token::exclamation_identifier)) + return parseType(result); + result = parser.parseType(); + return success(static_cast(result)); + } + OptionalParseResult parseOptionalAttribute(Attribute &result, Type type) override { return parser.parseOptionalAttribute(result, type); diff --git a/mlir/test/Dialect/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/ArmSVE/roundtrip.mlir --- a/mlir/test/Dialect/ArmSVE/roundtrip.mlir +++ b/mlir/test/Dialect/ArmSVE/roundtrip.mlir @@ -3,7 +3,7 @@ func @arm_sve_sdot(%a: !arm_sve.vector<16xi8>, %b: !arm_sve.vector<16xi8>, %c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> { - // CHECK: arm_sve.sdot {{.*}}: !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32 + // CHECK: arm_sve.sdot {{.*}}: <16xi8> to <4xi32 %0 = arm_sve.sdot %c, %a, %b : !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32> return %0 : !arm_sve.vector<4xi32> @@ -12,7 +12,7 @@ func @arm_sve_smmla(%a: !arm_sve.vector<16xi8>, %b: !arm_sve.vector<16xi8>, %c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> { - // CHECK: arm_sve.smmla {{.*}}: !arm_sve.vector<16xi8> to !arm_sve.vector<4xi3 + // CHECK: arm_sve.smmla {{.*}}: <16xi8> to <4xi3 %0 = arm_sve.smmla %c, %a, %b : !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32> return %0 : !arm_sve.vector<4xi32> @@ -21,7 +21,7 @@ func @arm_sve_udot(%a: !arm_sve.vector<16xi8>, %b: !arm_sve.vector<16xi8>, %c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> { - // CHECK: arm_sve.udot {{.*}}: !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32 + // CHECK: arm_sve.udot {{.*}}: <16xi8> to <4xi32 %0 = arm_sve.udot %c, %a, %b : !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32> return %0 : !arm_sve.vector<4xi32> @@ -30,7 +30,7 @@ func @arm_sve_ummla(%a: !arm_sve.vector<16xi8>, %b: !arm_sve.vector<16xi8>, %c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> { - // CHECK: arm_sve.ummla {{.*}}: !arm_sve.vector<16xi8> to !arm_sve.vector<4xi3 + // CHECK: arm_sve.ummla {{.*}}: <16xi8> to <4xi3 %0 = arm_sve.ummla %c, %a, %b : !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32> return %0 : !arm_sve.vector<4xi32> diff --git a/mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir b/mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir --- a/mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir +++ b/mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir @@ -16,7 +16,7 @@ %0 = arith.addf %arg0, %arg0 : f32 // CHECK: %[[VAL_STORAGE:.*]] = async.runtime.create : !async.value %1 = async.runtime.create: !async.value -// CHECK: async.runtime.store %[[VAL]], %[[VAL_STORAGE]] : !async.value +// CHECK: async.runtime.store %[[VAL]], %[[VAL_STORAGE]] : async.runtime.store %0, %1: !async.value // CHECK: async.runtime.set_available %[[VAL_STORAGE]] : !async.value async.runtime.set_available %1: !async.value @@ -32,9 +32,9 @@ // CHECK: cond_br %[[IS_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_OK:.*]] // CHECK: ^[[BRANCH_OK]]: -// CHECK: %[[LOADED:.*]] = async.runtime.load %[[VAL_STORAGE]] : !async.value +// CHECK: %[[LOADED:.*]] = async.runtime.load %[[VAL_STORAGE]] : // CHECK: %[[RETURNED:.*]] = arith.mulf %[[ARG]], %[[LOADED]] : f32 -// CHECK: async.runtime.store %[[RETURNED]], %[[RETURNED_STORAGE]] : !async.value +// CHECK: async.runtime.store %[[RETURNED]], %[[RETURNED_STORAGE]] : // CHECK: async.runtime.set_available %[[RETURNED_STORAGE]] // CHECK: async.runtime.set_available %[[TOKEN]] // CHECK: br ^[[CLEANUP]] @@ -84,8 +84,8 @@ // CHECK: cond_br %[[IS_VALUE_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_VALUE_OK:.*]] // CHECK: ^[[BRANCH_VALUE_OK]]: -// CHECK: %[[LOADED:.*]] = async.runtime.load %[[RETURNED_TO_CALLER]]#1 : !async.value -// CHECK: async.runtime.store %[[LOADED]], %[[RETURNED_STORAGE]] : !async.value +// CHECK: %[[LOADED:.*]] = async.runtime.load %[[RETURNED_TO_CALLER]]#1 : +// CHECK: async.runtime.store %[[LOADED]], %[[RETURNED_STORAGE]] : // CHECK: async.runtime.set_available %[[RETURNED_STORAGE]] // CHECK: async.runtime.set_available %[[TOKEN]] // CHECK: br ^[[CLEANUP]] @@ -133,7 +133,7 @@ // CHECK: cond_br %[[IS_VALUE_ERROR_1]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_VALUE_OK_1:.*]] // CHECK: ^[[BRANCH_VALUE_OK_1]]: -// CHECK: %[[LOADED_1:.*]] = async.runtime.load %[[RETURNED_TO_CALLER_1]]#1 : !async.value +// CHECK: %[[LOADED_1:.*]] = async.runtime.load %[[RETURNED_TO_CALLER_1]]#1 : // CHECK: %[[RETURNED_TO_CALLER_2:.*]]:2 = call @simple_callee(%[[LOADED_1]]) : (f32) -> (!async.token, !async.value) // CHECK: %[[SAVED_2:.*]] = async.coro.save %[[HDL]] // CHECK: async.runtime.await_and_resume %[[RETURNED_TO_CALLER_2]]#0, %[[HDL]] @@ -150,8 +150,8 @@ // CHECK: cond_br %[[IS_VALUE_ERROR_2]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_VALUE_OK_2:.*]] // CHECK: ^[[BRANCH_VALUE_OK_2]]: -// CHECK: %[[LOADED_2:.*]] = async.runtime.load %[[RETURNED_TO_CALLER_2]]#1 : !async.value -// CHECK: async.runtime.store %[[LOADED_2]], %[[RETURNED_STORAGE]] : !async.value +// CHECK: %[[LOADED_2:.*]] = async.runtime.load %[[RETURNED_TO_CALLER_2]]#1 : +// CHECK: async.runtime.store %[[LOADED_2]], %[[RETURNED_STORAGE]] : // CHECK: async.runtime.set_available %[[RETURNED_STORAGE]] // CHECK: async.runtime.set_available %[[TOKEN]] // CHECK: br ^[[CLEANUP]] diff --git a/mlir/test/Dialect/Async/async-to-async-runtime.mlir b/mlir/test/Dialect/Async/async-to-async-runtime.mlir --- a/mlir/test/Dialect/Async/async-to-async-runtime.mlir +++ b/mlir/test/Dialect/Async/async-to-async-runtime.mlir @@ -245,7 +245,7 @@ } // CHECK: async.runtime.await %[[RET]]#1 : !async.value - // CHECK: %[[VALUE:.*]] = async.runtime.load %[[RET]]#1 : !async.value + // CHECK: %[[VALUE:.*]] = async.runtime.load %[[RET]]#1 : %0 = async.await %result : !async.value // CHECK: return %[[VALUE]] @@ -323,7 +323,7 @@ // // Load from the async.value argument after error checking. // CHECK: ^[[CONTINUATION:.*]]: -// CHECK: %[[LOADED:.*]] = async.runtime.load %[[ARG]] : !async.value) { - // CHECK: async.runtime.store %arg0, %arg1 : !async.value - async.runtime.store %arg0, %arg1 : !async.value + // CHECK: async.runtime.store %arg0, %arg1 : + async.runtime.store %arg0, %arg1 : return } // CHECK-LABEL: @load func @load(%arg0: !async.value) -> f32 { - // CHECK: %0 = async.runtime.load %arg0 : !async.value + // CHECK: %0 = async.runtime.load %arg0 : // CHECK: return %0 : f32 - %0 = async.runtime.load %arg0 : !async.value + %0 = async.runtime.load %arg0 : return %0 : f32 } diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -6,7 +6,7 @@ func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memref) { // CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584xf32> -// CHECK: vector.multi_reduction #vector.kind, %{{.*}} [0] : vector<1584xf32> to f32 +// CHECK: vector.multi_reduction , %{{.*}} [0] : vector<1584xf32> to f32 // CHECK: arith.addf %{{.*}}, %{{.*}} : f32 linalg.dot ins(%A, %B: memref<1584xf32>, memref<1584xf32>) outs(%C: memref) @@ -19,7 +19,7 @@ func @contraction_matvec(%A: memref<1584x1584xf32>, %B: memref<1584xf32>, %C: memref<1584xf32>) { // CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584xf32> -// CHECK: vector.multi_reduction #vector.kind, %{{.*}} [1] : vector<1584x1584xf32> to vector<1584xf32> +// CHECK: vector.multi_reduction , %{{.*}} [1] : vector<1584x1584xf32> to vector<1584xf32> // CHECK: arith.addf %{{.*}}, %{{.*}} : vector<1584xf32> linalg.matvec ins(%A, %B: memref<1584x1584xf32>, memref<1584xf32>) outs(%C: memref<1584xf32>) @@ -31,7 +31,7 @@ // CHECK-LABEL: contraction_matmul func @contraction_matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %C: memref<1584x1584xf32>) { // CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584x1584xf32> -// CHECK: vector.multi_reduction #vector.kind, %{{.*}} [2] : vector<1584x1584x1584xf32> to vector<1584x1584xf32> +// CHECK: vector.multi_reduction , %{{.*}} [2] : vector<1584x1584x1584xf32> to vector<1584x1584xf32> // CHECK: arith.addf %{{.*}}, %{{.*}} : vector<1584x1584xf32> linalg.matmul ins(%A, %B: memref<1584x1584xf32>, memref<1584x1584xf32>) outs(%C: memref<1584x1584xf32>) @@ -43,7 +43,7 @@ // CHECK-LABEL: contraction_batch_matmul func @contraction_batch_matmul(%A: memref<1584x1584x1584xf32>, %B: memref<1584x1584x1584xf32>, %C: memref<1584x1584x1584xf32>) { // CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584x1584x1584xf32> -// CHECK: vector.multi_reduction #vector.kind, %{{.*}} [3] : vector<1584x1584x1584x1584xf32> to vector<1584x1584x1584xf32> +// CHECK: vector.multi_reduction , %{{.*}} [3] : vector<1584x1584x1584x1584xf32> to vector<1584x1584x1584xf32> // CHECK: arith.addf %{{.*}}, %{{.*}} : vector<1584x1584x1584xf32> linalg.batch_matmul ins(%A, %B: memref<1584x1584x1584xf32>, memref<1584x1584x1584xf32>) @@ -71,7 +71,7 @@ // CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<8x32x16xf32> // CHECK: vector.transfer_read %{{.*}} : memref<8x32xf32>, vector<8x32xf32> // CHECK: %[[MUL:.*]] = arith.mulf %{{.*}}, %{{.*}} : vector<8x32x16xf32> - // CHECK: %[[R:.*]] = vector.multi_reduction #vector.kind, %[[MUL]] [2] : vector<8x32x16xf32> to vector<8x32xf32> + // CHECK: %[[R:.*]] = vector.multi_reduction , %[[MUL]] [2] : vector<8x32x16xf32> to vector<8x32xf32> // CHECK: arith.addf %[[R]], %{{.*}} : vector<8x32xf32> // CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<8x32xf32> linalg.generic #matmul_trait @@ -105,7 +105,7 @@ // CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<8x32x16xf32> // CHECK: vector.transfer_read %{{.*}} : memref<32x8xf32>, vector<8x32xf32> // CHECK: %[[MUL:.*]] = arith.mulf %{{.*}}, %{{.*}} : vector<8x32x16xf32> - // CHECK: %[[R:.*]] = vector.multi_reduction #vector.kind, %[[MUL]] [2] : vector<8x32x16xf32> to vector<8x32xf32> + // CHECK: %[[R:.*]] = vector.multi_reduction , %[[MUL]] [2] : vector<8x32x16xf32> to vector<8x32xf32> // CHECK: arith.addf %[[R]], %{{.*}} : vector<8x32xf32> // CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<32x8xf32> linalg.generic #matmul_transpose_out_trait @@ -139,7 +139,7 @@ // CHECK: vector.transfer_read %{{.*}} : memref<16x32xi32>, vector<8x32x16xi32> // CHECK: vector.transfer_read %{{.*}} : memref<8x32xi32>, vector<8x32xi32> // CHECK: %[[MUL:.*]] = arith.muli %{{.*}}, %{{.*}} : vector<8x32x16xi32> - // CHECK: %[[R:.*]] = vector.multi_reduction #vector.kind, %[[MUL]] [2] : vector<8x32x16xi32> to vector<8x32xi32> + // CHECK: %[[R:.*]] = vector.multi_reduction , %[[MUL]] [2] : vector<8x32x16xi32> to vector<8x32xi32> // CHECK: arith.addi %[[R]], %{{.*}} : vector<8x32xi32> // CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xi32>, memref<8x32xi32> @@ -160,7 +160,7 @@ func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>, %C: memref<8x32xf32>) { // CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<8x32x16xf32> - // CHECK: vector.multi_reduction #vector.kind, %{{.*}} [2] : vector<8x32x16xf32> to vector<8x32xf32> + // CHECK: vector.multi_reduction , %{{.*}} [2] : vector<8x32x16xf32> to vector<8x32xf32> // CHECK: arith.addf %{{.*}}, %{{.*}} : vector<8x32xf32> linalg.matmul ins(%A, %B: memref<8x16xf32>, memref<16x32xf32>) @@ -523,7 +523,7 @@ // linalg matmul lowers gets expanded to a 3D reduction, canonicalization later // convert it to a 2D contract. // CHECK: %[[MUL:.*]] = arith.mulf %[[V0]], %[[V1]] : vector<8x12x4xf32> - // CHECK: %[[R:.*]] = vector.multi_reduction #vector.kind, %[[MUL]] [2] : vector<8x12x4xf32> to vector<8x12xf32> + // CHECK: %[[R:.*]] = vector.multi_reduction , %[[MUL]] [2] : vector<8x12x4xf32> to vector<8x12xf32> // CHECK: %[[ADD:.*]] = arith.addf %[[R]], %[[V2]] : vector<8x12xf32> // CHECK: %[[W:.*]] = vector.transfer_write %[[ADD]], %[[ARG2]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x12xf32>, tensor<8x12xf32> %0 = linalg.matmul ins(%arg0, %arg1: tensor<8x4xf32>, tensor<4x12xf32>) @@ -744,7 +744,7 @@ // CHECK: vector.transfer_read {{.*}} : tensor<4x16x8xf32>, vector<4x16x8xf32> // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true]} : tensor<4x16xf32>, vector<4x16xf32> // CHECK: math.exp {{.*}} : vector<4x16x8xf32> - // CHECK: vector.multi_reduction #vector.kind, %{{.*}} [2] : vector<4x16x8xf32> to vector<4x16xf32> + // CHECK: vector.multi_reduction , %{{.*}} [2] : vector<4x16x8xf32> to vector<4x16xf32> // CHECK: addf {{.*}} : vector<4x16xf32> // CHECK: vector.transfer_write {{.*}} : vector<4x16xf32>, tensor<4x16xf32> // CHECK: return {{.*}} : tensor<4x16xf32> @@ -779,7 +779,7 @@ // CHECK: math.exp {{.*}} : vector<2x3x4x5xf32> // CHECK: math.exp {{.*}} : vector<2x3x4x5xf32> // CHECK: addf {{.*}} : vector<2x3x4x5xf32> - // CHECK: vector.multi_reduction #vector.kind, {{.*}} [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32> + // CHECK: vector.multi_reduction , {{.*}} [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32> // CHECK: addf {{.*}} : vector<2x5xf32> // CHECK: vector.transfer_write {{.*}} {in_bounds = [true, true], permutation_map = #[[$M3]]} : vector<2x5xf32>, tensor<5x2xf32> // CHECK: return {{.*}} : tensor<5x2xf32> @@ -808,7 +808,7 @@ // CHECK: %[[CMINF:.+]] = arith.constant dense<-3.402820e+38> : vector<4xf32> // CHECK: linalg.init_tensor [4] : tensor<4xf32> // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> - // CHECK: %[[R:.+]] = vector.multi_reduction #vector.kind, {{.*}} [1] : vector<4x4xf32> to vector<4xf32> + // CHECK: %[[R:.+]] = vector.multi_reduction , {{.*}} [1] : vector<4x4xf32> to vector<4xf32> // CHECK: maxf %[[R]], %[[CMINF]] : vector<4xf32> // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> %ident = arith.constant -3.40282e+38 : f32 @@ -833,7 +833,7 @@ // CHECK: linalg.init_tensor [4] : tensor<4xf32> // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> // CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32> - // CHECK: %[[R:.+]] = vector.multi_reduction #vector.kind, {{.*}} [1] : vector<4x4xf32> to vector<4xf32> + // CHECK: %[[R:.+]] = vector.multi_reduction , {{.*}} [1] : vector<4x4xf32> to vector<4xf32> // CHECK: minf %[[R]], %[[CMAXF]] : vector<4xf32> // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> %maxf32 = arith.constant 3.40282e+38 : f32 @@ -857,7 +857,7 @@ // CHECK: linalg.init_tensor [4] : tensor<4xf32> // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> // CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32> - // CHECK: vector.multi_reduction #vector.kind, {{.*}} [1] : vector<4x4xf32> to vector<4xf32> + // CHECK: vector.multi_reduction , {{.*}} [1] : vector<4x4xf32> to vector<4xf32> // CHECK: mulf {{.*}} : vector<4xf32> // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> %ident = arith.constant 1.0 : f32 @@ -881,7 +881,7 @@ // CHECK: linalg.init_tensor [4] : tensor<4xi1> // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1> // CHECK: vector.transfer_read {{.*}} : tensor<4x4xi1>, vector<4x4xi1> - // CHECK: vector.multi_reduction #vector.kind, {{.*}} [1] : vector<4x4xi1> to vector<4xi1> + // CHECK: vector.multi_reduction , {{.*}} [1] : vector<4x4xi1> to vector<4xi1> // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1> %ident = arith.constant false %init = linalg.init_tensor [4] : tensor<4xi1> @@ -904,7 +904,7 @@ // CHECK: linalg.init_tensor [4] : tensor<4xi1> // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1> // CHECK: vector.transfer_read {{.*}} : tensor<4x4xi1>, vector<4x4xi1> - // CHECK: vector.multi_reduction #vector.kind, {{.*}} [1] : vector<4x4xi1> to vector<4xi1> + // CHECK: vector.multi_reduction , {{.*}} [1] : vector<4x4xi1> to vector<4xi1> // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1> %ident = arith.constant true %init = linalg.init_tensor [4] : tensor<4xi1> @@ -927,7 +927,7 @@ // CHECK: linalg.init_tensor [4] : tensor<4xi1> // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1> // CHECK: vector.transfer_read {{.*}} : tensor<4x4xi1>, vector<4x4xi1> - // CHECK: vector.multi_reduction #vector.kind, {{.*}} [1] : vector<4x4xi1> to vector<4xi1> + // CHECK: vector.multi_reduction , {{.*}} [1] : vector<4x4xi1> to vector<4xi1> // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1> %ident = arith.constant false %init = linalg.init_tensor [4] : tensor<4xi1> @@ -979,7 +979,7 @@ // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true], permutation_map = #[[$M6]]} : tensor<4x1xf32>, vector<4x4xf32> // CHECK: subf {{.*}} : vector<4x4xf32> // CHECK: math.exp {{.*}} : vector<4x4xf32> - // CHECK: vector.multi_reduction #vector.kind, {{.*}} : vector<4x4xf32> to vector<4xf32> + // CHECK: vector.multi_reduction , {{.*}} : vector<4x4xf32> to vector<4xf32> // CHECK: addf {{.*}} : vector<4xf32> // CHECK: vector.transfer_write {{.*}} {in_bounds = [true]} : vector<4xf32>, tensor<4xf32> %c0 = arith.constant 0.0 : f32 @@ -1018,7 +1018,7 @@ %1 = linalg.fill(%f0, %0) : f32, tensor -> tensor // CHECK: %[[r:.*]] = vector.transfer_read %[[A]][%[[C0]]] // CHECK-SAME: : tensor<32xf32>, vector<32xf32> - // CHECK: %[[red:.*]] = vector.multi_reduction #vector.kind, %[[r]] [0] + // CHECK: %[[red:.*]] = vector.multi_reduction , %[[r]] [0] // CHECK-SAME: : vector<32xf32> to f32 // CHECK: %[[a:.*]] = arith.addf %[[red]], %[[F0]] : f32 // CHECK: %[[red_v1:.*]] = vector.broadcast %[[a]] : f32 to vector<1xf32> diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1054,7 +1054,7 @@ // CHECK-LABEL: func @vector_multi_reduction_single_parallel( // CHECK-SAME: %[[v:.*]]: vector<2xf32> func @vector_multi_reduction_single_parallel(%arg0: vector<2xf32>) -> vector<2xf32> { - %0 = vector.multi_reduction #vector.kind, %arg0 [] : vector<2xf32> to vector<2xf32> + %0 = vector.multi_reduction , %arg0 [] : vector<2xf32> to vector<2xf32> // CHECK: return %[[v]] : vector<2xf32> return %0 : vector<2xf32> diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -3,7 +3,7 @@ // ----- func @broadcast_to_scalar(%arg0: f32) -> f32 { - // expected-error@+1 {{'vector.broadcast' op result #0 must be vector of any type values, but got 'f32'}} + // expected-error@+1 {{custom op 'vector.broadcast' invalid kind of type specified}} %0 = vector.broadcast %arg0 : f32 to f32 } @@ -969,7 +969,7 @@ // ----- func @bitcast_not_vector(%arg0 : vector<5x1x3x2xf32>) { - // expected-error@+1 {{must be vector of any type values}} + // expected-error@+1 {{'vector.bitcast' invalid kind of type specified}} %0 = vector.bitcast %arg0 : vector<5x1x3x2xf32> to f32 } diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -638,9 +638,9 @@ // CHECK-LABEL: @multi_reduction func @multi_reduction(%0: vector<4x8x16x32xf32>) -> f32 { - %1 = vector.multi_reduction #vector.kind, %0 [1, 3] : + %1 = vector.multi_reduction , %0 [1, 3] : vector<4x8x16x32xf32> to vector<4x16xf32> - %2 = vector.multi_reduction #vector.kind, %1 [0, 1] : + %2 = vector.multi_reduction , %1 [0, 1] : vector<4x16xf32> to f32 return %2 : f32 } diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir --- a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir @@ -1,7 +1,7 @@ // RUN: mlir-opt %s -test-vector-multi-reduction-lowering-patterns | FileCheck %s func @vector_multi_reduction(%arg0: vector<2x4xf32>) -> vector<2xf32> { - %0 = vector.multi_reduction #vector.kind, %arg0 [1] : vector<2x4xf32> to vector<2xf32> + %0 = vector.multi_reduction , %arg0 [1] : vector<2x4xf32> to vector<2xf32> return %0 : vector<2xf32> } // CHECK-LABEL: func @vector_multi_reduction @@ -18,7 +18,7 @@ // CHECK: return %[[RESULT_VEC]] func @vector_multi_reduction_to_scalar(%arg0: vector<2x4xf32>) -> f32 { - %0 = vector.multi_reduction #vector.kind, %arg0 [0, 1] : vector<2x4xf32> to f32 + %0 = vector.multi_reduction , %arg0 [0, 1] : vector<2x4xf32> to f32 return %0 : f32 } // CHECK-LABEL: func @vector_multi_reduction_to_scalar @@ -30,7 +30,7 @@ // CHECK: return %[[RES]] func @vector_reduction_inner(%arg0: vector<2x3x4x5xi32>) -> vector<2x3xi32> { - %0 = vector.multi_reduction #vector.kind, %arg0 [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32> + %0 = vector.multi_reduction , %arg0 [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32> return %0 : vector<2x3xi32> } // CHECK-LABEL: func @vector_reduction_inner @@ -66,7 +66,7 @@ func @vector_multi_reduction_transposed(%arg0: vector<2x3x4x5xf32>) -> vector<2x5xf32> { - %0 = vector.multi_reduction #vector.kind, %arg0 [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32> + %0 = vector.multi_reduction , %arg0 [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32> return %0 : vector<2x5xf32> } @@ -78,7 +78,7 @@ // CHECK: return %[[RESULT]] func @vector_multi_reduction_ordering(%arg0: vector<3x2x4xf32>) -> vector<2x4xf32> { - %0 = vector.multi_reduction #vector.kind, %arg0 [0] : vector<3x2x4xf32> to vector<2x4xf32> + %0 = vector.multi_reduction , %arg0 [0] : vector<3x2x4xf32> to vector<2x4xf32> return %0 : vector<2x4xf32> } // CHECK-LABEL: func @vector_multi_reduction_ordering diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir --- a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir @@ -1,7 +1,7 @@ // RUN: mlir-opt %s -test-vector-multi-reduction-lowering-patterns="use-outer-reductions" | FileCheck %s func @vector_multi_reduction(%arg0: vector<2x4xf32>) -> vector<2xf32> { - %0 = vector.multi_reduction #vector.kind, %arg0 [1] : vector<2x4xf32> to vector<2xf32> + %0 = vector.multi_reduction , %arg0 [1] : vector<2x4xf32> to vector<2xf32> return %0 : vector<2xf32> } @@ -18,7 +18,7 @@ // CHECK: return %[[RESULT_VEC]] : vector<2xf32> func @vector_multi_reduction_min(%arg0: vector<2x4xf32>) -> vector<2xf32> { - %0 = vector.multi_reduction #vector.kind, %arg0 [1] : vector<2x4xf32> to vector<2xf32> + %0 = vector.multi_reduction , %arg0 [1] : vector<2x4xf32> to vector<2xf32> return %0 : vector<2xf32> } @@ -35,7 +35,7 @@ // CHECK: return %[[RESULT_VEC]] : vector<2xf32> func @vector_multi_reduction_max(%arg0: vector<2x4xf32>) -> vector<2xf32> { - %0 = vector.multi_reduction #vector.kind, %arg0 [1] : vector<2x4xf32> to vector<2xf32> + %0 = vector.multi_reduction , %arg0 [1] : vector<2x4xf32> to vector<2xf32> return %0 : vector<2xf32> } @@ -52,7 +52,7 @@ // CHECK: return %[[RESULT_VEC]] : vector<2xf32> func @vector_multi_reduction_and(%arg0: vector<2x4xi32>) -> vector<2xi32> { - %0 = vector.multi_reduction #vector.kind, %arg0 [1] : vector<2x4xi32> to vector<2xi32> + %0 = vector.multi_reduction , %arg0 [1] : vector<2x4xi32> to vector<2xi32> return %0 : vector<2xi32> } @@ -69,7 +69,7 @@ // CHECK: return %[[RESULT_VEC]] : vector<2xi32> func @vector_multi_reduction_or(%arg0: vector<2x4xi32>) -> vector<2xi32> { - %0 = vector.multi_reduction #vector.kind, %arg0 [1] : vector<2x4xi32> to vector<2xi32> + %0 = vector.multi_reduction , %arg0 [1] : vector<2x4xi32> to vector<2xi32> return %0 : vector<2xi32> } @@ -86,7 +86,7 @@ // CHECK: return %[[RESULT_VEC]] : vector<2xi32> func @vector_multi_reduction_xor(%arg0: vector<2x4xi32>) -> vector<2xi32> { - %0 = vector.multi_reduction #vector.kind, %arg0 [1] : vector<2x4xi32> to vector<2xi32> + %0 = vector.multi_reduction , %arg0 [1] : vector<2x4xi32> to vector<2xi32> return %0 : vector<2xi32> } @@ -104,7 +104,7 @@ func @vector_reduction_outer(%arg0: vector<2x3x4x5xi32>) -> vector<2x3xi32> { - %0 = vector.multi_reduction #vector.kind, %arg0 [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32> + %0 = vector.multi_reduction , %arg0 [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32> return %0 : vector<2x3xi32> } diff --git a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir --- a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir +++ b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir @@ -12,7 +12,7 @@ func @multidimreduction_contract( %arg0: vector<8x32x16xf32>,%arg1: vector<8x32x16xf32>) -> vector<8x16xf32> { %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32> - %1 = vector.multi_reduction #vector.kind, %0 [1] : vector<8x32x16xf32> to vector<8x16xf32> + %1 = vector.multi_reduction , %0 [1] : vector<8x32x16xf32> to vector<8x16xf32> return %1 : vector<8x16xf32> } @@ -30,7 +30,7 @@ func @multidimreduction_contract_int( %arg0: vector<8x32x16xi32>,%arg1: vector<8x32x16xi32>) -> vector<8x16xi32> { %0 = arith.muli %arg0, %arg1 : vector<8x32x16xi32> - %1 = vector.multi_reduction #vector.kind, %0 [1] : vector<8x32x16xi32> to vector<8x16xi32> + %1 = vector.multi_reduction , %0 [1] : vector<8x32x16xi32> to vector<8x16xi32> return %1 : vector<8x16xi32> } diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td --- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td +++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td @@ -14,7 +14,7 @@ #define TEST_ATTRDEFS // To get the test dialect definition. -include "TestOps.td" +include "TestDialect.td" include "mlir/IR/BuiltinAttributeInterfaces.td" include "mlir/IR/SubElementInterfaces.td" @@ -121,6 +121,29 @@ ); } +// A more complex parameterized attribute with multiple level of nesting. +def CompoundNestedInner : Test_Attr<"CompoundNestedInner"> { + let mnemonic = "cmpnd_nested_inner"; + // List of type parameters. + let parameters = ( + ins + "int":$some_int, + CompoundAttrA:$cmpdA + ); + let assemblyFormat = "`<` $some_int $cmpdA `>`"; +} + +def CompoundNestedOuter : Test_Attr<"CompoundNestedOuter"> { + let mnemonic = "cmpnd_nested_outer"; + + // List of type parameters. + let parameters = ( + ins + CompoundNestedInner:$inner + ); + let assemblyFormat = "`<` `i` $inner `>`"; +} + def TestParamOne : AttrParameter<"int64_t", ""> {} def TestParamTwo : AttrParameter<"std::string", "", "llvm::StringRef"> { diff --git a/mlir/test/lib/Dialect/Test/TestDialect.td b/mlir/test/lib/Dialect/Test/TestDialect.td new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Test/TestDialect.td @@ -0,0 +1,46 @@ +//===-- TestDialect.td - Test dialect definition -----------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef TEST_DIALECT +#define TEST_DIALECT + +include "mlir/IR/OpBase.td" + +def Test_Dialect : Dialect { + let name = "test"; + let cppNamespace = "::test"; + let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed; + let hasCanonicalizer = 1; + let hasConstantMaterializer = 1; + let hasOperationAttrVerify = 1; + let hasRegionArgAttrVerify = 1; + let hasRegionResultAttrVerify = 1; + let hasOperationInterfaceFallback = 1; + let hasNonDefaultDestructor = 1; + let useDefaultAttributePrinterParser = 1; + let dependentDialects = ["::mlir::DLTIDialect"]; + + let extraClassDeclaration = [{ + void registerAttributes(); + void registerTypes(); + + // Provides a custom printing/parsing for some operations. + ::llvm::Optional + getParseOperationHook(::llvm::StringRef opName) const override; + ::llvm::unique_function + getOperationPrinter(::mlir::Operation *op) const override; + + private: + // Storage for a custom fallback interface. + void *fallbackEffectOpInterfaces; + + }]; +} + +#endif // TEST_DIALECT diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -9,6 +9,7 @@ #ifndef TEST_OPS #define TEST_OPS +include "TestDialect.td" include "mlir/Dialect/DLTI/DLTIBase.td" include "mlir/IR/OpBase.td" include "mlir/IR/OpAsmInterface.td" @@ -23,40 +24,11 @@ include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td" include "TestInterfaces.td" -def Test_Dialect : Dialect { - let name = "test"; - let cppNamespace = "::test"; - let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed; - let hasCanonicalizer = 1; - let hasConstantMaterializer = 1; - let hasOperationAttrVerify = 1; - let hasRegionArgAttrVerify = 1; - let hasRegionResultAttrVerify = 1; - let hasOperationInterfaceFallback = 1; - let hasNonDefaultDestructor = 1; - let useDefaultAttributePrinterParser = 1; - let dependentDialects = ["::mlir::DLTIDialect"]; - - let extraClassDeclaration = [{ - void registerAttributes(); - void registerTypes(); - - // Provides a custom printing/parsing for some operations. - ::llvm::Optional - getParseOperationHook(::llvm::StringRef opName) const override; - ::llvm::unique_function - getOperationPrinter(::mlir::Operation *op) const override; - - private: - // Storage for a custom fallback interface. - void *fallbackEffectOpInterfaces; - - }]; -} // Include the attribute definitions. include "TestAttrDefs.td" +// Include the type definitions. +include "TestTypeDefs.td" class TEST_Op traits = []> : @@ -1907,6 +1879,16 @@ let assemblyFormat = "$nested attr-dict-with-keyword"; } +def FormatNestedCompoundAttr : TEST_Op<"format_cpmd_nested_attr"> { + let arguments = (ins CompoundNestedOuter:$nested); + let assemblyFormat = "`nested` $nested attr-dict-with-keyword"; +} + +def FormatNestedType : TEST_Op<"format_cpmd_nested_type"> { + let arguments = (ins CompoundNestedOuterType:$nested); + let assemblyFormat = "$nested `nested` type($nested) attr-dict-with-keyword"; +} + //===----------------------------------------------------------------------===// // Custom Directives diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td --- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td +++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td @@ -14,8 +14,9 @@ #define TEST_TYPEDEFS // To get the test dialect def. -include "TestOps.td" +include "TestDialect.td" include "TestAttrDefs.td" +include "TestInterfaces.td" include "mlir/IR/BuiltinTypes.td" include "mlir/Interfaces/DataLayoutInterfaces.td" @@ -49,6 +50,29 @@ }]; } +// A more complex and nested parameterized type. +def CompoundNestedInnerType : Test_Type<"CompoundNestedInner"> { + let mnemonic = "cmpnd_inner"; + // List of type parameters. + let parameters = ( + ins + "int":$some_int, + CompoundTypeA:$cmpdA + ); + let assemblyFormat = "`<` $some_int $cmpdA `>`"; +} + +def CompoundNestedOuterType : Test_Type<"CompoundNestedOuter"> { + let mnemonic = "cmpnd_nested_outer"; + + // List of type parameters. + let parameters = ( + ins + CompoundNestedInnerType:$inner + ); + let assemblyFormat = "`<` `i` $inner `>`"; +} + // An example of how one could implement a standard integer. def IntegerType : Test_Type<"TestInteger"> { let mnemonic = "int"; diff --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h --- a/mlir/test/lib/Dialect/Test/TestTypes.h +++ b/mlir/test/lib/Dialect/Test/TestTypes.h @@ -25,6 +25,7 @@ #include "mlir/Interfaces/DataLayoutInterfaces.h" namespace test { +class TestAttrWithFormatAttr; /// FieldInfo represents a field in the StructType data type. It is used as a /// parameter in TestTypeDefs.td. @@ -63,13 +64,13 @@ return test::CustomParam{value.getValue()}; } }; -} // end namespace mlir - inline mlir::AsmPrinter &operator<<(mlir::AsmPrinter &printer, - const test::CustomParam ¶m) { + test::CustomParam param) { return printer << param.value; } +} // end namespace mlir + #include "TestTypeInterfaces.h.inc" #define GET_TYPEDEF_CLASSES diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp --- a/mlir/test/lib/Dialect/Test/TestTypes.cpp +++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp @@ -301,3 +301,6 @@ SetVector stack; printTestType(type, printer, stack); } + +static_assert( + AsmPrinter::detect_has_print_method::value, ""); diff --git a/mlir/test/mlir-tblgen/attr-or-type-format.td b/mlir/test/mlir-tblgen/attr-or-type-format.td --- a/mlir/test/mlir-tblgen/attr-or-type-format.td +++ b/mlir/test/mlir-tblgen/attr-or-type-format.td @@ -61,7 +61,7 @@ // ATTR: printer << ' ' << "hello"; // ATTR: printer << ' ' << "="; // ATTR: printer << ' '; -// ATTR: printer << getValue(); +// ATTR: printer.printStrippedAttrOrType(getValue()); // ATTR: printer << ","; // ATTR: printer << ' '; // ATTR: ::printAttrParamA(printer, getComplex()); @@ -154,10 +154,10 @@ // ATTR: void TestFAttr::print(::mlir::AsmPrinter &printer) const { // ATTR: printer << ' '; -// ATTR: printer << getV0(); +// ATTR: printer.printStrippedAttrOrType(getV0()); // ATTR: printer << ","; // ATTR: printer << ' '; -// ATTR: printer << getV1(); +// ATTR: printer.printStrippedAttrOrType(getV1()); // ATTR: } def AttrC : TestAttr<"TestF"> { @@ -213,7 +213,7 @@ // TYPE: printer << ' ' << "bob"; // TYPE: printer << ' ' << "bar"; // TYPE: printer << ' '; -// TYPE: printer << getValue(); +// TYPE: printer.printStrippedAttrOrType(getValue()); // TYPE: printer << ' ' << "complex"; // TYPE: printer << ' ' << "="; // TYPE: printer << ' '; @@ -361,21 +361,21 @@ // TYPE: printer << "v0"; // TYPE: printer << ' ' << "="; // TYPE: printer << ' '; -// TYPE: printer << getV0(); +// TYPE: printer.printStrippedAttrOrType(getV0()); // TYPE: printer << ","; // TYPE: printer << ' ' << "v2"; // TYPE: printer << ' ' << "="; // TYPE: printer << ' '; -// TYPE: printer << getV2(); +// TYPE: printer.printStrippedAttrOrType(getV2()); // TYPE: printer << "v1"; // TYPE: printer << ' ' << "="; // TYPE: printer << ' '; -// TYPE: printer << getV1(); +// TYPE: printer.printStrippedAttrOrType(getV1()); // TYPE: printer << ","; // TYPE: printer << ' ' << "v3"; // TYPE: printer << ' ' << "="; // TYPE: printer << ' '; -// TYPE: printer << getV3(); +// TYPE: printer.printStrippedAttrOrType(getV3()); // TYPE: } def TypeC : TestType<"TestE"> { diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir --- a/mlir/test/mlir-tblgen/op-format.mlir +++ b/mlir/test/mlir-tblgen/op-format.mlir @@ -256,16 +256,50 @@ // Format a custom attribute //===----------------------------------------------------------------------===// -// CHECK: test.format_compound_attr #test.cmpnd_a<1, !test.smpla, [5, 6]> -test.format_compound_attr #test.cmpnd_a<1, !test.smpla, [5, 6]> +// CHECK: test.format_compound_attr <1, !test.smpla, [5, 6]> +test.format_compound_attr <1, !test.smpla, [5, 6]> -// CHECK: module attributes {test.nested = #test.cmpnd_nested>} { +//----- + + +// CHECK: module attributes {test.nested = #test.cmpnd_nested>} { +module attributes {test.nested = #test.cmpnd_nested>} { +} + +//----- + +// Same as above, but fully spelling the inner attribute prefix `#test.cmpnd_a`. +// CHECK: module attributes {test.nested = #test.cmpnd_nested>} { module attributes {test.nested = #test.cmpnd_nested>} { } -// CHECK: test.format_nested_attr #test.cmpnd_nested> +// CHECK: test.format_nested_attr > +test.format_nested_attr #test.cmpnd_nested> + +//----- + +// Same as above, but fully spelling the inner attribute prefix `#test.cmpnd_a`. +// CHECK: test.format_nested_attr > test.format_nested_attr #test.cmpnd_nested> +//----- + +// CHECK: module attributes {test.someAttr = #test.cmpnd_nested_inner<42 <1, !test.smpla, [5, 6]>>} +module attributes {test.someAttr = #test.cmpnd_nested_inner<42 <1, !test.smpla, [5, 6]>>} +{ +} + +//----- + +// CHECK: module attributes {test.someAttr = #test.cmpnd_nested_outer>>} +module attributes {test.someAttr = #test.cmpnd_nested_outer>>} +{ +} + +//----- + +// CHECK: test.format_cpmd_nested_attr nested >> +test.format_cpmd_nested_attr nested >> //===----------------------------------------------------------------------===// // Format custom directives diff --git a/mlir/test/mlir-tblgen/testdialect-typedefs.mlir b/mlir/test/mlir-tblgen/testdialect-typedefs.mlir --- a/mlir/test/mlir-tblgen/testdialect-typedefs.mlir +++ b/mlir/test/mlir-tblgen/testdialect-typedefs.mlir @@ -13,6 +13,22 @@ return } +// CHECK: @compoundNested(%arg0: !test.cmpnd_nested_outer>>) +func @compoundNested(%arg0: !test.cmpnd_nested_outer>>) -> () { + return +} + +// Same as above, but we're parsing the complete spec for the inner type +// CHECK: @compoundNestedExplicit(%arg0: !test.cmpnd_nested_outer>>) +func @compoundNestedExplicit(%arg0: !test.cmpnd_nested_outer>>) -> () { +// Verify that the type prefix is elided and optional +// CHECK: format_cpmd_nested_type %arg0 nested >> +// CHECK: format_cpmd_nested_type %arg0 nested >> + test.format_cpmd_nested_type %arg0 nested !test.cmpnd_nested_outer>> + test.format_cpmd_nested_type %arg0 nested >> + return +} + // CHECK: @testInt(%arg0: !test.int, %arg1: !test.int, %arg2: !test.int) func @testInt(%A : !test.int, %B : !test.int, %C : !test.int) { return diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp @@ -259,6 +259,7 @@ /// {1}: Extra parser parameters. static const char *const defDeclParsePrintStr = R"( static ::mlir::{0} parse(::mlir::AsmParser &parser{1}); + using Base::print; void print(::mlir::AsmPrinter &printer) const; )"; @@ -411,13 +412,12 @@ os << " static constexpr ::llvm::StringLiteral getMnemonic() {\n" << " return ::llvm::StringLiteral(\"" << mnenomic << "\");\n" << " }\n"; - - // If mnemonic specified, emit print/parse declarations. - if (def.getParserCode() || def.getPrinterCode() || - def.getAssemblyFormat() || !params.empty()) { - os << llvm::formatv(defDeclParsePrintStr, valueType, - isAttrGenerator ? ", ::mlir::Type type" : ""); - } + } + // Emit print/parse declarations. + if (def.getParserCode() || def.getPrinterCode() || def.getAssemblyFormat() || + !params.empty()) { + os << llvm::formatv(defDeclParsePrintStr, valueType, + isAttrGenerator ? ", ::mlir::Type type" : ""); } if (def.genAccessors()) { diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp @@ -178,7 +178,8 @@ "::mlir::FieldParser<$0>::parse($_parser)"; /// Default printer for attribute or type parameters. -static const char *const defaultParameterPrinter = "$_printer << $_self"; +static const char *const defaultParameterPrinter = + "$_printer.printStrippedAttrOrType($_self)"; /// Print an error when failing to parse an element. /// diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -495,13 +495,25 @@ /// {0}: The name of the attribute. /// {1}: The type for the attribute. const char *const attrParserCode = R"( - if (parser.parseAttribute({0}Attr{1}, "{0}", result.attributes)) + if (parser.parseCustomAttributeWithFallback({0}Attr, {1}, "{0}", + result.attributes)) {{ + return ::mlir::failure(); + } +)"; + +/// The code snippet used to generate a parser call for an attribute. +/// +/// {0}: The name of the attribute. +/// {1}: The type for the attribute. +const char *const genericAttrParserCode = R"( + if (parser.parseAttribute({0}Attr, {1}, "{0}", result.attributes)) return ::mlir::failure(); )"; + const char *const optionalAttrParserCode = R"( { ::mlir::OptionalParseResult parseResult = - parser.parseOptionalAttribute({0}Attr{1}, "{0}", result.attributes); + parser.parseOptionalAttribute({0}Attr, {1}, "{0}", result.attributes); if (parseResult.hasValue() && failed(*parseResult)) return ::mlir::failure(); } @@ -634,8 +646,12 @@ } )"; const char *const typeParserCode = R"( - if (parser.parseType({0}RawTypes[0])) - return ::mlir::failure(); + { + {0} type; + if (parser.parseCustomTypeWithFallback(type)) + return ::mlir::failure(); + {1}RawTypes[0] = type; + } )"; /// The code snippet used to generate a parser call for a functional type. @@ -1268,12 +1284,19 @@ std::string attrTypeStr; if (Optional typeBuilder = attr->getTypeBuilder()) { llvm::raw_string_ostream os(attrTypeStr); - os << ", " << tgfmt(*typeBuilder, &attrTypeCtx); + os << tgfmt(*typeBuilder, &attrTypeCtx); + } else { + attrTypeStr = "Type{}"; + } + if (var->attr.isOptional()) { + body << formatv(optionalAttrParserCode, var->name, attrTypeStr); + } else { + if (var->attr.getStorageType() == "::mlir::Attribute") + body << formatv(genericAttrParserCode, var->name, attrTypeStr); + else + body << formatv(attrParserCode, var->name, attrTypeStr); } - body << formatv(var->attr.isOptional() ? optionalAttrParserCode - : attrParserCode, - var->name, attrTypeStr); } else if (auto *operand = dyn_cast(element)) { ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar()); StringRef name = operand->getVar()->name; @@ -1333,14 +1356,26 @@ } else if (auto *dir = dyn_cast(element)) { ArgumentLengthKind lengthKind; StringRef listName = getTypeListName(dir->getOperand(), lengthKind); - if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) + if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) { body << llvm::formatv(variadicOfVariadicTypeParserCode, listName); - else if (lengthKind == ArgumentLengthKind::Variadic) + } else if (lengthKind == ArgumentLengthKind::Variadic) { body << llvm::formatv(variadicTypeParserCode, listName); - else if (lengthKind == ArgumentLengthKind::Optional) + } else if (lengthKind == ArgumentLengthKind::Optional) { body << llvm::formatv(optionalTypeParserCode, listName); - else - body << formatv(typeParserCode, listName); + } else { + const NamedTypeConstraint *var = nullptr; + { + if (auto *operand = dyn_cast(dir->getOperand())) + var = operand->getVar(); + if (auto *operand = dyn_cast(dir->getOperand())) + var = operand->getVar(); + } + if (var) + body << formatv(typeParserCode, var->constraint.getCPPClassName(), + listName); + else + body << formatv(typeParserCode, "Type", listName); + } } else if (auto *dir = dyn_cast(element)) { ArgumentLengthKind ignored; body << formatv(functionalTypeParserCode, @@ -1760,7 +1795,8 @@ /// Generate the C++ for an operand to a (*-)type directive. static MethodBody &genTypeOperandPrinter(Element *arg, const Operator &op, - MethodBody &body) { + MethodBody &body, + bool useArrayRef = true) { if (isa(arg)) return body << "getOperation()->getOperandTypes()"; if (isa(arg)) @@ -1777,8 +1813,10 @@ "({0}() ? ::llvm::ArrayRef<::mlir::Type>({0}().getType()) : " "::llvm::ArrayRef<::mlir::Type>())", op.getGetterName(var->name)); - return body << "::llvm::ArrayRef<::mlir::Type>(" - << op.getGetterName(var->name) << "().getType())"; + if (useArrayRef) + return body << "::llvm::ArrayRef<::mlir::Type>(" + << op.getGetterName(var->name) << "().getType())"; + return body << op.getGetterName(var->name) << "().getType()"; } /// Generate the printer for an enum attribute. @@ -1977,9 +2015,15 @@ if (attr->getTypeBuilder()) body << " _odsPrinter.printAttributeWithoutType(" << op.getGetterName(var->name) << "Attr());\n"; - else + else if (var->attr.isOptional()) + body << "_odsPrinter.printAttribute(" << op.getGetterName(var->name) + << "Attr());\n"; + else if (var->attr.getStorageType() == "::mlir::Attribute") body << " _odsPrinter.printAttribute(" << op.getGetterName(var->name) << "Attr());\n"; + else + body << "_odsPrinter.printStrippedAttrOrType(" + << op.getGetterName(var->name) << "Attr());\n"; } else if (auto *operand = dyn_cast(element)) { if (operand->getVar()->isVariadicOfVariadic()) { body << " ::llvm::interleaveComma(" @@ -2032,8 +2076,29 @@ return; } } + const NamedTypeConstraint *var = nullptr; + { + if (auto *operand = dyn_cast(dir->getOperand())) + var = operand->getVar(); + if (auto *operand = dyn_cast(dir->getOperand())) + var = operand->getVar(); + } + if (var && !var->isVariadicOfVariadic() && !var->isVariadic() && + !var->isOptional()) { + std::string cppClass = var->constraint.getCPPClassName(); + body << " {\n" + << " auto type = " << op.getGetterName(var->name) + << "().getType();\n" + << " if (auto validType = type.dyn_cast<" << cppClass << ">())\n" + << " _odsPrinter.printStrippedAttrOrType(validType);\n" + << " else\n" + << " _odsPrinter << type;\n" + << " }\n"; + return; + } body << " _odsPrinter << "; - genTypeOperandPrinter(dir->getOperand(), op, body) << ";\n"; + genTypeOperandPrinter(dir->getOperand(), op, body, /*useArrayRef=*/false) + << ";\n"; } else if (auto *dir = dyn_cast(element)) { body << " _odsPrinter.printFunctionalType("; genTypeOperandPrinter(dir->getInputs(), op, body) << ", ";