diff --git a/mlir/docs/Tutorials/Toy/Ch-2.md b/mlir/docs/Tutorials/Toy/Ch-2.md --- a/mlir/docs/Tutorials/Toy/Ch-2.md +++ b/mlir/docs/Tutorials/Toy/Ch-2.md @@ -210,7 +210,9 @@ /// The ConstantOp takes no inputs. mlir::OpTrait::ZeroOperands, /// The ConstantOp returns a single result. - mlir::OpTrait::OneResult> { + mlir::OpTrait::OneResult, + /// The result of getType is `Type`. + mlir::OpTraits::OneTypedResult<Type>::Impl> { public: /// Inherit the constructors from the base Op class. diff --git a/mlir/examples/standalone/include/Standalone/StandaloneOps.h b/mlir/examples/standalone/include/Standalone/StandaloneOps.h --- a/mlir/examples/standalone/include/Standalone/StandaloneOps.h +++ b/mlir/examples/standalone/include/Standalone/StandaloneOps.h @@ -9,6 +9,7 @@ #ifndef STANDALONE_STANDALONEOPS_H #define STANDALONE_STANDALONEOPS_H +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" #include "mlir/Interfaces/SideEffectInterfaces.h" diff --git a/mlir/include/mlir/Dialect/AVX512/AVX512Dialect.h b/mlir/include/mlir/Dialect/AVX512/AVX512Dialect.h --- a/mlir/include/mlir/Dialect/AVX512/AVX512Dialect.h +++ b/mlir/include/mlir/Dialect/AVX512/AVX512Dialect.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_AVX512_AVX512DIALECT_H_ #define MLIR_DIALECT_AVX512_AVX512DIALECT_H_ +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" #include "mlir/Interfaces/SideEffectInterfaces.h" diff --git a/mlir/include/mlir/Dialect/ArmNeon/ArmNeonDialect.h b/mlir/include/mlir/Dialect/ArmNeon/ArmNeonDialect.h --- a/mlir/include/mlir/Dialect/ArmNeon/ArmNeonDialect.h +++ b/mlir/include/mlir/Dialect/ArmNeon/ArmNeonDialect.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_ARMNEON_ARMNEONDIALECT_H_ #define MLIR_DIALECT_ARMNEON_ARMNEONDIALECT_H_ +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" #include "mlir/Interfaces/SideEffectInterfaces.h" diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -175,8 +175,12 @@ // are considered as uncategorized constraints. // Subclass for constraints on a type. -class TypeConstraint<Pred predicate, string description = ""> : - Constraint<predicate, description>; +class TypeConstraint<Pred predicate, string description = "", + string cppClassNameParam = "::mlir::Type"> : + Constraint<predicate, description> { + // The name of the C++ Type class if known, or Type if not. + string cppClassName = cppClassNameParam; +} // Subclass for constraints on an attribute. class AttrConstraint<Pred predicate, string description = ""> : @@ -285,8 +289,9 @@ //===----------------------------------------------------------------------===// // A type, carries type constraints. -class Type<Pred condition, string descr = ""> : - TypeConstraint<condition, descr> { +class Type<Pred condition, string descr = "", + string cppClassName = "::mlir::Type"> : + TypeConstraint<condition, descr, cppClassName> { string typeDescription = ""; string builderCall = ""; } @@ -299,8 +304,9 @@ } // A type of a specific dialect. -class DialectType<Dialect d, Pred condition, string descr = ""> : - Type<condition, descr> { +class DialectType<Dialect d, Pred condition, string descr = "", + string cppClassName = "::mlir::Type"> : + Type<condition, descr, cppClassName> { Dialect dialect = d; } @@ -331,11 +337,13 @@ def AnyType : Type<CPred<"true">, "any type">; // None type -def NoneType : Type<CPred<"$_self.isa<::mlir::NoneType>()">, "none type">, +def NoneType : Type<CPred<"$_self.isa<::mlir::NoneType>()">, "none type", + "::mlir::NoneType">, BuildableType<"$_builder.getType<::mlir::NoneType>()">; // Any type from the given list -class AnyTypeOf<list<Type> allowedTypes, string description = ""> : Type< +class AnyTypeOf<list<Type> allowedTypes, string description = "", + string cppClassName = "::mlir::Type"> : Type< // Satisfy any of the allowed type's condition Or<!foreach(allowedtype, allowedTypes, allowedtype.predicate)>, !if(!eq(description, ""), @@ -345,7 +353,8 @@ // Integer types. // Any integer type irrespective of its width and signedness semantics. -def AnyInteger : Type<CPred<"$_self.isa<::mlir::IntegerType>()">, "integer">; +def AnyInteger : Type<CPred<"$_self.isa<::mlir::IntegerType>()">, "integer", + "::mlir::IntegerType">; // Any integer type (regardless of signedness semantics) of a specific width. class AnyI<int width> @@ -355,7 +364,8 @@ class AnyIntOfWidths<list<int> widths> : AnyTypeOf<!foreach(w, widths, AnyI<w>), - StrJoinInt<widths, "/">.result # "-bit integer">; + StrJoinInt<widths, "/">.result # "-bit integer", + "::mlir::IntegerType">; def AnyI1 : AnyI<1>; def AnyI8 : AnyI<8>; @@ -365,12 +375,13 @@ // Any signless integer type irrespective of its width. def AnySignlessInteger : Type< - CPred<"$_self.isSignlessInteger()">, "signless integer">; + CPred<"$_self.isSignlessInteger()">, "signless integer", + "::mlir::IntegerType">; // Signless integer type of a specific width. class I<int width> : Type<CPred<"$_self.isSignlessInteger(" # width # ")">, - width # "-bit signless integer">, + width # "-bit signless integer", "::mlir::IntegerType">, BuildableType<"$_builder.getIntegerType(" # width # ")"> { int bitwidth = width; } @@ -392,7 +403,7 @@ // Signed integer type of a specific width. class SI<int width> : Type<CPred<"$_self.isSignedInteger(" # width # ")">, - width # "-bit signed integer">, + width # "-bit signed integer", "::mlir::IntegerType">, BuildableType< "$_builder.getIntegerType(" # width # ", /*isSigned=*/true)"> { int bitwidth = width; @@ -415,7 +426,7 @@ // Unsigned integer type of a specific width. class UI<int width> : Type<CPred<"$_self.isUnsignedInteger(" # width # ")">, - width # "-bit unsigned integer">, + width # "-bit unsigned integer", "::mlir::IntegerType">, BuildableType< "$_builder.getIntegerType(" # width # ", /*isSigned=*/false)"> { int bitwidth = width; @@ -432,18 +443,20 @@ def UI64 : UI<64>; // Index type. -def Index : Type<CPred<"$_self.isa<::mlir::IndexType>()">, "index">, +def Index : Type<CPred<"$_self.isa<::mlir::IndexType>()">, "index", + "::mlir::IndexType">, BuildableType<"$_builder.getIndexType()">; // Floating point types. // Any float type irrespective of its width. -def AnyFloat : Type<CPred<"$_self.isa<::mlir::FloatType>()">, "floating-point">; +def AnyFloat : Type<CPred<"$_self.isa<::mlir::FloatType>()">, "floating-point", + "::mlir::FloatType">; // Float type of a specific width. class F<int width> : Type<CPred<"$_self.isF" # width # "()">, - width # "-bit float">, + width # "-bit float", "::mlir::FloatType">, BuildableType<"$_builder.getF" # width # "Type()"> { int bitwidth = width; } @@ -465,16 +478,17 @@ SubstLeaves<"$_self", "$_self.cast<::mlir::ComplexType>().getElementType()", type.predicate>]>, - "complex type with " # type.description # " elements"> { + "complex type with " # type.description # " elements", + "::mlir::ComplexType"> { Type elementType = type; } def AnyComplex : Type<CPred<"$_self.isa<::mlir::ComplexType>()">, - "complex-type">; + "complex-type", "::mlir::ComplexType">; class OpaqueType<string dialect, string name, string description> : Type<CPred<"isOpaqueTypeWithName($_self, \""#dialect#"\", \""#name#"\")">, - description>, + description, "::mlir::OpaqueType">, BuildableType<"::mlir::OpaqueType::get($_builder.getContext(), " "$_builder.getIdentifier(\"" # dialect # "\"), \"" # name # "\")">; @@ -483,17 +497,17 @@ // Any function type. def FunctionType : Type<CPred<"$_self.isa<::mlir::FunctionType>()">, - "function type">; + "function type", "::mlir::FunctionType">; // A container type is a type that has another type embedded within it. class ContainerType<Type etype, Pred containerPred, code elementTypeCall, - string descr> : + string descr, string cppClassName = "::mlir::Type"> : // First, check the container predicate. Then, substitute the extracted // element into the element type checker. Type<And<[containerPred, SubstLeaves<"$_self", !cast<string>(elementTypeCall), etype.predicate>]>, - descr # " of " # etype.description # " values"> { + descr # " of " # etype.description # " values", cppClassName> { // The type of elements in the container. Type elementType = etype; @@ -502,9 +516,11 @@ } class ShapedContainerType<list<Type> allowedTypes, - Pred containerPred, string descr> : + Pred containerPred, string descr, + string cppClassName = "::mlir::Type"> : ContainerType<AnyTypeOf<allowedTypes>, containerPred, - "$_self.cast<::mlir::ShapedType>().getElementType()", descr>; + "$_self.cast<::mlir::ShapedType>().getElementType()", descr, + cppClassName>; // Whether a shaped type is ranked. def HasRankPred : CPred<"$_self.cast<::mlir::ShapedType>().hasRank()">; @@ -520,7 +536,8 @@ // Vector types. class VectorOf<list<Type> allowedTypes> : - ShapedContainerType<allowedTypes, IsVectorTypePred, "vector">; + ShapedContainerType<allowedTypes, IsVectorTypePred, "vector", + "::mlir::VectorType">; // Whether the number of elements of a vector is from the given // `allowedRanks` list @@ -534,7 +551,7 @@ // Any vector where the rank is from the given `allowedRanks` list class VectorOfRank<list<int> allowedRanks> : Type< IsVectorOfRankPred<allowedRanks>, - " of ranks " # StrJoinInt<allowedRanks, "/">.result>; + " of ranks " # StrJoinInt<allowedRanks, "/">.result, "::mlir::VectorType">; // Any vector where the rank is from the given `allowedRanks` list and the type // is from the given `allowedTypes` list @@ -543,7 +560,8 @@ And<[VectorOf<allowedTypes>.predicate, VectorOfRank<allowedRanks>.predicate]>, VectorOf<allowedTypes>.description # - VectorOfRank<allowedRanks>.description>; + VectorOfRank<allowedRanks>.description, + "::mlir::VectorType">; // Whether the number of elements of a vector is from the given // `allowedLengths` list @@ -558,7 +576,8 @@ // `allowedLengths` list class VectorOfLength<list<int> allowedLengths> : Type< IsVectorOfLengthPred<allowedLengths>, - " of length " # StrJoinInt<allowedLengths, "/">.result>; + " of length " # StrJoinInt<allowedLengths, "/">.result, + "::mlir::VectorType">; // Any vector where the number of elements is from the given @@ -569,30 +588,34 @@ And<[VectorOf<allowedTypes>.predicate, VectorOfLength<allowedLengths>.predicate]>, VectorOf<allowedTypes>.description # - VectorOfLength<allowedLengths>.description>; + VectorOfLength<allowedLengths>.description, + "::mlir::VectorType">; def AnyVector : VectorOf<[AnyType]>; // Shaped types. -def AnyShaped: ShapedContainerType<[AnyType], IsShapedTypePred, "shaped">; +def AnyShaped: ShapedContainerType<[AnyType], IsShapedTypePred, "shaped", + "::mlir::ShapedType">; // Tensor types. // Any tensor type whose element type is from the given `allowedTypes` list class TensorOf<list<Type> allowedTypes> : - ShapedContainerType<allowedTypes, IsTensorTypePred, "tensor">; + ShapedContainerType<allowedTypes, IsTensorTypePred, "tensor", + "::mlir::TensorType">; def AnyTensor : TensorOf<[AnyType]>; def AnyRankedTensor : ShapedContainerType<[AnyType], And<[IsTensorTypePred, HasRankPred]>, - "ranked tensor">; + "ranked tensor", "::mlir::TensorType">; // TODO: Have an easy way to add another constraint to a type. class StaticShapeTensorOf<list<Type> allowedTypes> : Type<And<[TensorOf<allowedTypes>.predicate, HasStaticShapePred]>, - "statically shaped " # TensorOf<allowedTypes>.description>; + "statically shaped " # TensorOf<allowedTypes>.description, + "::mlir::TensorType">; def AnyStaticShapeTensor : StaticShapeTensorOf<[AnyType]>; @@ -612,7 +635,7 @@ class TensorRankOf<list<Type> allowedTypes, list<int> ranks> : Type<And<[TensorOf<allowedTypes>.predicate, HasAnyRankOfPred<ranks>]>, StrJoin<!foreach(rank, ranks, rank # "D"), "/">.result # " " # - TensorOf<allowedTypes>.description>; + TensorOf<allowedTypes>.description, "::mlir::TensorType">; class 0DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [0]>; class 1DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [1]>; @@ -623,12 +646,14 @@ // Unranked Memref type def AnyUnrankedMemRef : ShapedContainerType<[AnyType], - IsUnrankedMemRefTypePred, "unranked.memref">; + IsUnrankedMemRefTypePred, "unranked.memref", + "::mlir::MemRefType">; // Memref type. // Memrefs are blocks of data with fixed type and rank. class MemRefOf<list<Type> allowedTypes> : - ShapedContainerType<allowedTypes, IsMemRefTypePred, "memref">; + ShapedContainerType<allowedTypes, IsMemRefTypePred, "memref", + "::mlir::MemRefType">; def AnyMemRef : MemRefOf<[AnyType]>; @@ -679,7 +704,7 @@ MemRefOf<allowedTypes>.description>; // This represents a generic tuple without any constraints on element type. -def AnyTuple : Type<IsTupleTypePred, "tuple">; +def AnyTuple : Type<IsTupleTypePred, "tuple", "::mlir::TupleType">; // A container type that has other types embedded in it, but (unlike // ContainerType) can hold elements with a mix of types. Requires a call that @@ -2414,9 +2439,7 @@ // the given C++ base class. class TypeDef<Dialect dialect, string name, string baseCppClass = "::mlir::Type"> - : DialectType<dialect, CPred<"">> { - // The name of the C++ Type class. - string cppClassName = name # "Type"; + : DialectType<dialect, CPred<"">, /*descr*/"", name # "Type"> { // The name of the C++ base class to use for this Type. string cppBaseClassName = baseCppClass; diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -28,10 +28,6 @@ class Builder; class OpBuilder; -namespace OpTrait { -template <typename ConcreteType> class OneResult; -} - /// This class represents success/failure for operation parsing. It is /// essentially a simple wrapper class around LogicalResult that allows for /// explicit conversion to bool. This allows for the parser to chain together @@ -188,7 +184,8 @@ void setAttrs(DictionaryAttr newAttrs) { state->setAttrs(newAttrs); } /// Set the dialect attributes for this operation, and preserve all dependent. - template <typename DialectAttrs> void setDialectAttrs(DialectAttrs &&attrs) { + template <typename DialectAttrs> + void setDialectAttrs(DialectAttrs &&attrs) { state->setDialectAttrs(std::forward<DialectAttrs>(attrs)); } @@ -424,7 +421,8 @@ /// /// class FooOp : public Op<FooOp, OpTrait::NOperands<2>::Impl> { /// -template <unsigned N> class NOperands { +template <unsigned N> +class NOperands { public: static_assert(N > 1, "use ZeroOperands/OneOperand for N < 2"); @@ -443,7 +441,8 @@ /// /// class FooOp : public Op<FooOp, OpTrait::AtLeastNOperands<2>::Impl> { /// -template <unsigned N> class AtLeastNOperands { +template <unsigned N> +class AtLeastNOperands { public: template <typename ConcreteType> class Impl : public detail::MultiOperandTraitBase<ConcreteType, @@ -517,7 +516,8 @@ /// This class provides the API for ops that are known to have a specified /// number of regions. -template <unsigned N> class NRegions { +template <unsigned N> +class NRegions { public: static_assert(N > 1, "use ZeroRegion/OneRegion for N < 2"); @@ -533,7 +533,8 @@ /// This class provides APIs for ops that are known to have at least a specified /// number of regions. -template <unsigned N> class AtLeastNRegions { +template <unsigned N> +class AtLeastNRegions { public: template <typename ConcreteType> class Impl : public detail::MultiRegionTraitBase<ConcreteType, @@ -582,7 +583,8 @@ /// Replace all uses of results of this operation with the provided 'values'. /// 'values' may correspond to an existing operation, or a range of 'Value'. - template <typename ValuesT> void replaceAllUsesWith(ValuesT &&values) { + template <typename ValuesT> + void replaceAllUsesWith(ValuesT &&values) { this->getOperation()->replaceAllUsesWith(std::forward<ValuesT>(values)); } @@ -610,20 +612,19 @@ } // end namespace detail /// This class provides return value APIs for ops that are known to have a -/// single result. +/// single result. ResultType is the concrete type returned by getType(). template <typename ConcreteType> class OneResult : public TraitBase<ConcreteType, OneResult> { public: Value getResult() { return this->getOperation()->getResult(0); } - Type getType() { return getResult().getType(); } /// If the operation returns a single value, then the Op can be implicitly /// converted to an Value. This yields the value of the only result. operator Value() { return getResult(); } - /// Replace all uses of 'this' value with the new value, updating anything in - /// the IR that uses 'this' to use the other value instead. When this returns - /// there are zero uses of 'this'. + /// Replace all uses of 'this' value with the new value, updating anything + /// in the IR that uses 'this' to use the other value instead. When this + /// returns there are zero uses of 'this'. void replaceAllUsesWith(Value newValue) { getResult().replaceAllUsesWith(newValue); } @@ -638,12 +639,33 @@ } }; +/// This trait is used for return value APIs for ops that are known to have a +/// specific type other than `Type`. This allows the "getType()" member to be +/// more specific for an op. This should be used in conjunction with OneResult, +/// and occur in the trait list before OneResult. +template <typename ResultType> +class OneTypedResult { +public: + /// This class provides return value APIs for ops that are known to have a + /// single result. ResultType is the concrete type returned by getType(). + template <typename ConcreteType> + class Impl + : public TraitBase<ConcreteType, OneTypedResult<ResultType>::Impl> { + public: + ResultType getType() { + auto resultTy = this->getOperation()->getResult(0).getType(); + return resultTy.template cast<ResultType>(); + } + }; +}; + /// This class provides the API for ops that are known to have a specified /// number of results. This is used as a trait like this: /// /// class FooOp : public Op<FooOp, OpTrait::NResults<2>::Impl> { /// -template <unsigned N> class NResults { +template <unsigned N> +class NResults { public: static_assert(N > 1, "use ZeroResult/OneResult for N < 2"); @@ -662,7 +684,8 @@ /// /// class FooOp : public Op<FooOp, OpTrait::AtLeastNResults<2>::Impl> { /// -template <unsigned N> class AtLeastNResults { +template <unsigned N> +class AtLeastNResults { public: template <typename ConcreteType> class Impl : public detail::MultiResultTraitBase<ConcreteType, @@ -1573,7 +1596,8 @@ using has_fold = decltype( std::declval<T>().fold(std::declval<ArrayRef<Attribute>>(), std::declval<SmallVectorImpl<OpFoldResult> &>())); - template <typename T> using detect_has_fold = llvm::is_detected<has_fold, T>; + template <typename T> + using detect_has_fold = llvm::is_detected<has_fold, T>; /// Trait to check if T provides a 'print' method. template <typename T, typename... Args> using has_print = diff --git a/mlir/include/mlir/TableGen/Type.h b/mlir/include/mlir/TableGen/Type.h --- a/mlir/include/mlir/TableGen/Type.h +++ b/mlir/include/mlir/TableGen/Type.h @@ -47,6 +47,9 @@ // Returns the builder call for this constraint if this is a buildable type, // returns None otherwise. Optional<StringRef> getBuilderCall() const; + + // Return the C++ class name for this type (which may just be ::mlir::Type). + StringRef getCPPClassName() const; }; // Wrapper class with helper methods for accessing Types defined in TableGen. diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -612,7 +612,7 @@ LogicalResult matchAndRewrite(vector::CreateMaskOp op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const override { - auto dstType = op->getResult(0).getType().cast<VectorType>(); + auto dstType = op.getType(); int64_t rank = dstType.getRank(); if (rank == 1) { rewriter.replaceOp( @@ -1091,8 +1091,7 @@ auto loc = castOp->getLoc(); MemRefType sourceMemRefType = castOp.getOperand().getType().cast<MemRefType>(); - MemRefType targetMemRefType = - castOp.getResult().getType().cast<MemRefType>(); + MemRefType targetMemRefType = castOp.getType(); // Only static shape casts supported atm. if (!sourceMemRefType.hasStaticShape() || @@ -1459,7 +1458,7 @@ LogicalResult matchAndRewrite(ExtractStridedSliceOp op, PatternRewriter &rewriter) const override { - auto dstType = op.getResult().getType().cast<VectorType>(); + auto dstType = op.getType(); assert(!op.offsets().getValue().empty() && "Unexpected empty offsets"); diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -67,7 +67,7 @@ ArrayAttr masks = m.mask_dim_sizes(); assert(masks.size() == 1); int64_t i = masks[0].cast<IntegerAttr>().getInt(); - int64_t u = m.getType().cast<VectorType>().getDimSize(0); + int64_t u = m.getType().getDimSize(0); if (i >= u) return MaskFormat::AllTrue; if (i <= 0) @@ -849,7 +849,7 @@ return Value(); // Get the nth dimension size starting from lowest dimension. auto getDimReverse = [](VectorType type, int64_t n) { - return type.getShape().take_back(n+1).front(); + return type.getShape().take_back(n + 1).front(); }; int64_t destinationRank = extractOp.getType().isa<VectorType>() @@ -1870,9 +1870,8 @@ auto dense = constantOp.value().dyn_cast<SplatElementsAttr>(); if (!dense) return failure(); - auto newAttr = DenseElementsAttr::get( - extractStridedSliceOp.getType().cast<VectorType>(), - dense.getSplatValue()); + auto newAttr = DenseElementsAttr::get(extractStridedSliceOp.getType(), + dense.getSplatValue()); rewriter.replaceOpWithNewOp<ConstantOp>(extractStridedSliceOp, newAttr); return success(); } diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -999,8 +999,7 @@ return failure(); auto operandSourceVectorType = sourceShapeCastOp.source().getType().cast<VectorType>(); - auto operandResultVectorType = - sourceShapeCastOp.result().getType().cast<VectorType>(); + auto operandResultVectorType = sourceShapeCastOp.getType(); // Check if shape cast operations invert each other. if (operandSourceVectorType != resultVectorType || @@ -1397,7 +1396,7 @@ LogicalResult matchAndRewrite(vector::ConstantMaskOp op, PatternRewriter &rewriter) const override { auto loc = op.getLoc(); - auto dstType = op.getResult().getType().cast<VectorType>(); + auto dstType = op.getType(); auto eltType = dstType.getElementType(); auto dimSizes = op.mask_dim_sizes(); int64_t rank = dimSizes.size(); diff --git a/mlir/lib/TableGen/Type.cpp b/mlir/lib/TableGen/Type.cpp --- a/mlir/lib/TableGen/Type.cpp +++ b/mlir/lib/TableGen/Type.cpp @@ -53,6 +53,11 @@ .Default([](auto *) { return llvm::None; }); } +// Return the C++ class name for this type (which may just be ::mlir::Type). +StringRef TypeConstraint::getCPPClassName() const { + return def->getValueAsString("cppClassName"); +} + Type::Type(const llvm::Record *record) : TypeConstraint(record) {} StringRef Type::getTypeDescription() const { diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -2137,11 +2137,18 @@ unsigned numVariadicRegions = op.getNumVariadicRegions(); addSizeCountTrait(opClass, "Region", numRegions, numVariadicRegions); - // Add result size trait. + // Add result size traits. int numResults = op.getNumResults(); int numVariadicResults = op.getNumVariableLengthResults(); addSizeCountTrait(opClass, "Result", numResults, numVariadicResults); + // For single result ops with a known specific type, generate a OneTypedResult + // trait. + if (numResults == 1 && numVariadicResults == 0) { + auto cppName = op.getResults().begin()->constraint.getCPPClassName(); + opClass.addTrait("::mlir::OpTrait::OneTypedResult<" + cppName + ">::Impl"); + } + // Add successor size trait. unsigned numSuccessors = op.getNumSuccessors(); unsigned numVariadicSuccessors = op.getNumVariadicSuccessors();