diff --git a/mlir/docs/Tutorials/Toy/Ch-2.md b/mlir/docs/Tutorials/Toy/Ch-2.md --- a/mlir/docs/Tutorials/Toy/Ch-2.md +++ b/mlir/docs/Tutorials/Toy/Ch-2.md @@ -210,7 +210,9 @@ /// The ConstantOp takes no inputs. mlir::OpTrait::ZeroOperands, /// The ConstantOp returns a single result. - mlir::OpTrait::OneResult> { + mlir::OpTrait::OneResult, + /// The result of getType is `Type`. + mlir::OpTraits::OneTypedResult::Impl> { public: /// Inherit the constructors from the base Op class. diff --git a/mlir/examples/standalone/include/Standalone/StandaloneOps.h b/mlir/examples/standalone/include/Standalone/StandaloneOps.h --- a/mlir/examples/standalone/include/Standalone/StandaloneOps.h +++ b/mlir/examples/standalone/include/Standalone/StandaloneOps.h @@ -9,6 +9,7 @@ #ifndef STANDALONE_STANDALONEOPS_H #define STANDALONE_STANDALONEOPS_H +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" #include "mlir/Interfaces/SideEffectInterfaces.h" diff --git a/mlir/include/mlir/Dialect/AVX512/AVX512Dialect.h b/mlir/include/mlir/Dialect/AVX512/AVX512Dialect.h --- a/mlir/include/mlir/Dialect/AVX512/AVX512Dialect.h +++ b/mlir/include/mlir/Dialect/AVX512/AVX512Dialect.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_AVX512_AVX512DIALECT_H_ #define MLIR_DIALECT_AVX512_AVX512DIALECT_H_ +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" #include "mlir/Interfaces/SideEffectInterfaces.h" diff --git a/mlir/include/mlir/Dialect/ArmNeon/ArmNeonDialect.h b/mlir/include/mlir/Dialect/ArmNeon/ArmNeonDialect.h --- a/mlir/include/mlir/Dialect/ArmNeon/ArmNeonDialect.h +++ b/mlir/include/mlir/Dialect/ArmNeon/ArmNeonDialect.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_ARMNEON_ARMNEONDIALECT_H_ #define MLIR_DIALECT_ARMNEON_ARMNEONDIALECT_H_ +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" #include "mlir/Interfaces/SideEffectInterfaces.h" diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -175,8 +175,12 @@ // are considered as uncategorized constraints. // Subclass for constraints on a type. -class TypeConstraint : - Constraint; +class TypeConstraint : + Constraint { + // The name of the C++ Type class if known, or Type if not. + string cppClassName = cppClassNameParam; +} // Subclass for constraints on an attribute. class AttrConstraint : @@ -285,8 +289,9 @@ //===----------------------------------------------------------------------===// // A type, carries type constraints. -class Type : - TypeConstraint { +class Type : + TypeConstraint { string typeDescription = ""; string builderCall = ""; } @@ -299,8 +304,9 @@ } // A type of a specific dialect. -class DialectType : - Type { +class DialectType : + Type { Dialect dialect = d; } @@ -331,11 +337,13 @@ def AnyType : Type, "any type">; // None type -def NoneType : Type()">, "none type">, +def NoneType : Type()">, "none type", + "::mlir::NoneType">, BuildableType<"$_builder.getType<::mlir::NoneType>()">; // Any type from the given list -class AnyTypeOf allowedTypes, string description = ""> : Type< +class AnyTypeOf allowedTypes, string description = "", + string cppClassName = "::mlir::Type"> : Type< // Satisfy any of the allowed type's condition Or, !if(!eq(description, ""), @@ -345,7 +353,8 @@ // Integer types. // Any integer type irrespective of its width and signedness semantics. -def AnyInteger : Type()">, "integer">; +def AnyInteger : Type()">, "integer", + "::mlir::IntegerType">; // Any integer type (regardless of signedness semantics) of a specific width. class AnyI @@ -355,7 +364,8 @@ class AnyIntOfWidths widths> : AnyTypeOf), - StrJoinInt.result # "-bit integer">; + StrJoinInt.result # "-bit integer", + "::mlir::IntegerType">; def AnyI1 : AnyI<1>; def AnyI8 : AnyI<8>; @@ -365,12 +375,13 @@ // Any signless integer type irrespective of its width. def AnySignlessInteger : Type< - CPred<"$_self.isSignlessInteger()">, "signless integer">; + CPred<"$_self.isSignlessInteger()">, "signless integer", + "::mlir::IntegerType">; // Signless integer type of a specific width. class I : Type, - width # "-bit signless integer">, + width # "-bit signless integer", "::mlir::IntegerType">, BuildableType<"$_builder.getIntegerType(" # width # ")"> { int bitwidth = width; } @@ -392,7 +403,7 @@ // Signed integer type of a specific width. class SI : Type, - width # "-bit signed integer">, + width # "-bit signed integer", "::mlir::IntegerType">, BuildableType< "$_builder.getIntegerType(" # width # ", /*isSigned=*/true)"> { int bitwidth = width; @@ -415,7 +426,7 @@ // Unsigned integer type of a specific width. class UI : Type, - width # "-bit unsigned integer">, + width # "-bit unsigned integer", "::mlir::IntegerType">, BuildableType< "$_builder.getIntegerType(" # width # ", /*isSigned=*/false)"> { int bitwidth = width; @@ -432,18 +443,20 @@ def UI64 : UI<64>; // Index type. -def Index : Type()">, "index">, +def Index : Type()">, "index", + "::mlir::IndexType">, BuildableType<"$_builder.getIndexType()">; // Floating point types. // Any float type irrespective of its width. -def AnyFloat : Type()">, "floating-point">; +def AnyFloat : Type()">, "floating-point", + "::mlir::FloatType">; // Float type of a specific width. class F : Type, - width # "-bit float">, + width # "-bit float", "::mlir::FloatType">, BuildableType<"$_builder.getF" # width # "Type()"> { int bitwidth = width; } @@ -465,16 +478,17 @@ SubstLeaves<"$_self", "$_self.cast<::mlir::ComplexType>().getElementType()", type.predicate>]>, - "complex type with " # type.description # " elements"> { + "complex type with " # type.description # " elements", + "::mlir::ComplexType"> { Type elementType = type; } def AnyComplex : Type()">, - "complex-type">; + "complex-type", "::mlir::ComplexType">; class OpaqueType : Type, - description>, + description, "::mlir::OpaqueType">, BuildableType<"::mlir::OpaqueType::get($_builder.getContext(), " "$_builder.getIdentifier(\"" # dialect # "\"), \"" # name # "\")">; @@ -483,17 +497,17 @@ // Any function type. def FunctionType : Type()">, - "function type">; + "function type", "::mlir::FunctionType">; // A container type is a type that has another type embedded within it. class ContainerType : + string descr, string cppClassName = "::mlir::Type"> : // First, check the container predicate. Then, substitute the extracted // element into the element type checker. Type(elementTypeCall), etype.predicate>]>, - descr # " of " # etype.description # " values"> { + descr # " of " # etype.description # " values", cppClassName> { // The type of elements in the container. Type elementType = etype; @@ -502,9 +516,11 @@ } class ShapedContainerType allowedTypes, - Pred containerPred, string descr> : + Pred containerPred, string descr, + string cppClassName = "::mlir::Type"> : ContainerType, containerPred, - "$_self.cast<::mlir::ShapedType>().getElementType()", descr>; + "$_self.cast<::mlir::ShapedType>().getElementType()", descr, + cppClassName>; // Whether a shaped type is ranked. def HasRankPred : CPred<"$_self.cast<::mlir::ShapedType>().hasRank()">; @@ -520,7 +536,8 @@ // Vector types. class VectorOf allowedTypes> : - ShapedContainerType; + ShapedContainerType; // Whether the number of elements of a vector is from the given // `allowedRanks` list @@ -534,7 +551,7 @@ // Any vector where the rank is from the given `allowedRanks` list class VectorOfRank allowedRanks> : Type< IsVectorOfRankPred, - " of ranks " # StrJoinInt.result>; + " of ranks " # StrJoinInt.result, "::mlir::VectorType">; // Any vector where the rank is from the given `allowedRanks` list and the type // is from the given `allowedTypes` list @@ -543,7 +560,8 @@ And<[VectorOf.predicate, VectorOfRank.predicate]>, VectorOf.description # - VectorOfRank.description>; + VectorOfRank.description, + "::mlir::VectorType">; // Whether the number of elements of a vector is from the given // `allowedLengths` list @@ -558,7 +576,8 @@ // `allowedLengths` list class VectorOfLength allowedLengths> : Type< IsVectorOfLengthPred, - " of length " # StrJoinInt.result>; + " of length " # StrJoinInt.result, + "::mlir::VectorType">; // Any vector where the number of elements is from the given @@ -569,30 +588,34 @@ And<[VectorOf.predicate, VectorOfLength.predicate]>, VectorOf.description # - VectorOfLength.description>; + VectorOfLength.description, + "::mlir::VectorType">; def AnyVector : VectorOf<[AnyType]>; // Shaped types. -def AnyShaped: ShapedContainerType<[AnyType], IsShapedTypePred, "shaped">; +def AnyShaped: ShapedContainerType<[AnyType], IsShapedTypePred, "shaped", + "::mlir::ShapedType">; // Tensor types. // Any tensor type whose element type is from the given `allowedTypes` list class TensorOf allowedTypes> : - ShapedContainerType; + ShapedContainerType; def AnyTensor : TensorOf<[AnyType]>; def AnyRankedTensor : ShapedContainerType<[AnyType], And<[IsTensorTypePred, HasRankPred]>, - "ranked tensor">; + "ranked tensor", "::mlir::TensorType">; // TODO: Have an easy way to add another constraint to a type. class StaticShapeTensorOf allowedTypes> : Type.predicate, HasStaticShapePred]>, - "statically shaped " # TensorOf.description>; + "statically shaped " # TensorOf.description, + "::mlir::TensorType">; def AnyStaticShapeTensor : StaticShapeTensorOf<[AnyType]>; @@ -612,7 +635,7 @@ class TensorRankOf allowedTypes, list ranks> : Type.predicate, HasAnyRankOfPred]>, StrJoin.result # " " # - TensorOf.description>; + TensorOf.description, "::mlir::TensorType">; class 0DTensorOf allowedTypes> : TensorRankOf; class 1DTensorOf allowedTypes> : TensorRankOf; @@ -623,12 +646,14 @@ // Unranked Memref type def AnyUnrankedMemRef : ShapedContainerType<[AnyType], - IsUnrankedMemRefTypePred, "unranked.memref">; + IsUnrankedMemRefTypePred, "unranked.memref", + "::mlir::MemRefType">; // Memref type. // Memrefs are blocks of data with fixed type and rank. class MemRefOf allowedTypes> : - ShapedContainerType; + ShapedContainerType; def AnyMemRef : MemRefOf<[AnyType]>; @@ -679,7 +704,7 @@ MemRefOf.description>; // This represents a generic tuple without any constraints on element type. -def AnyTuple : Type; +def AnyTuple : Type; // A container type that has other types embedded in it, but (unlike // ContainerType) can hold elements with a mix of types. Requires a call that @@ -2414,9 +2439,7 @@ // the given C++ base class. class TypeDef - : DialectType> { - // The name of the C++ Type class. - string cppClassName = name # "Type"; + : DialectType, /*descr*/"", name # "Type"> { // The name of the C++ base class to use for this Type. string cppBaseClassName = baseCppClass; diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -28,10 +28,6 @@ class Builder; class OpBuilder; -namespace OpTrait { -template class OneResult; -} - /// This class represents success/failure for operation parsing. It is /// essentially a simple wrapper class around LogicalResult that allows for /// explicit conversion to bool. This allows for the parser to chain together @@ -188,7 +184,8 @@ void setAttrs(DictionaryAttr newAttrs) { state->setAttrs(newAttrs); } /// Set the dialect attributes for this operation, and preserve all dependent. - template void setDialectAttrs(DialectAttrs &&attrs) { + template + void setDialectAttrs(DialectAttrs &&attrs) { state->setDialectAttrs(std::forward(attrs)); } @@ -424,7 +421,8 @@ /// /// class FooOp : public Op::Impl> { /// -template class NOperands { +template +class NOperands { public: static_assert(N > 1, "use ZeroOperands/OneOperand for N < 2"); @@ -443,7 +441,8 @@ /// /// class FooOp : public Op::Impl> { /// -template class AtLeastNOperands { +template +class AtLeastNOperands { public: template class Impl : public detail::MultiOperandTraitBase class NRegions { +template +class NRegions { public: static_assert(N > 1, "use ZeroRegion/OneRegion for N < 2"); @@ -533,7 +533,8 @@ /// This class provides APIs for ops that are known to have at least a specified /// number of regions. -template class AtLeastNRegions { +template +class AtLeastNRegions { public: template class Impl : public detail::MultiRegionTraitBase void replaceAllUsesWith(ValuesT &&values) { + template + void replaceAllUsesWith(ValuesT &&values) { this->getOperation()->replaceAllUsesWith(std::forward(values)); } @@ -610,20 +612,19 @@ } // end namespace detail /// This class provides return value APIs for ops that are known to have a -/// single result. +/// single result. ResultType is the concrete type returned by getType(). template class OneResult : public TraitBase { public: Value getResult() { return this->getOperation()->getResult(0); } - Type getType() { return getResult().getType(); } /// If the operation returns a single value, then the Op can be implicitly /// converted to an Value. This yields the value of the only result. operator Value() { return getResult(); } - /// Replace all uses of 'this' value with the new value, updating anything in - /// the IR that uses 'this' to use the other value instead. When this returns - /// there are zero uses of 'this'. + /// Replace all uses of 'this' value with the new value, updating anything + /// in the IR that uses 'this' to use the other value instead. When this + /// returns there are zero uses of 'this'. void replaceAllUsesWith(Value newValue) { getResult().replaceAllUsesWith(newValue); } @@ -638,12 +639,33 @@ } }; +/// This trait is used for return value APIs for ops that are known to have a +/// specific type other than `Type`. This allows the "getType()" member to be +/// more specific for an op. This should be used in conjunction with OneResult, +/// and occur in the trait list before OneResult. +template +class OneTypedResult { +public: + /// This class provides return value APIs for ops that are known to have a + /// single result. ResultType is the concrete type returned by getType(). + template + class Impl + : public TraitBase::Impl> { + public: + ResultType getType() { + auto resultTy = this->getOperation()->getResult(0).getType(); + return resultTy.template cast(); + } + }; +}; + /// This class provides the API for ops that are known to have a specified /// number of results. This is used as a trait like this: /// /// class FooOp : public Op::Impl> { /// -template class NResults { +template +class NResults { public: static_assert(N > 1, "use ZeroResult/OneResult for N < 2"); @@ -662,7 +684,8 @@ /// /// class FooOp : public Op::Impl> { /// -template class AtLeastNResults { +template +class AtLeastNResults { public: template class Impl : public detail::MultiResultTraitBase().fold(std::declval>(), std::declval &>())); - template using detect_has_fold = llvm::is_detected; + template + using detect_has_fold = llvm::is_detected; /// Trait to check if T provides a 'print' method. template using has_print = diff --git a/mlir/include/mlir/TableGen/Type.h b/mlir/include/mlir/TableGen/Type.h --- a/mlir/include/mlir/TableGen/Type.h +++ b/mlir/include/mlir/TableGen/Type.h @@ -47,6 +47,9 @@ // Returns the builder call for this constraint if this is a buildable type, // returns None otherwise. Optional getBuilderCall() const; + + // Return the C++ class name for this type (which may just be ::mlir::Type). + StringRef getCPPClassName() const; }; // Wrapper class with helper methods for accessing Types defined in TableGen. diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -612,7 +612,7 @@ LogicalResult matchAndRewrite(vector::CreateMaskOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto dstType = op->getResult(0).getType().cast(); + auto dstType = op.getType(); int64_t rank = dstType.getRank(); if (rank == 1) { rewriter.replaceOp( @@ -1091,8 +1091,7 @@ auto loc = castOp->getLoc(); MemRefType sourceMemRefType = castOp.getOperand().getType().cast(); - MemRefType targetMemRefType = - castOp.getResult().getType().cast(); + MemRefType targetMemRefType = castOp.getType(); // Only static shape casts supported atm. if (!sourceMemRefType.hasStaticShape() || @@ -1461,7 +1460,7 @@ LogicalResult matchAndRewrite(ExtractStridedSliceOp op, PatternRewriter &rewriter) const override { - auto dstType = op.getResult().getType().cast(); + auto dstType = op.getType(); assert(!op.offsets().getValue().empty() && "Unexpected empty offsets"); diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -67,7 +67,7 @@ ArrayAttr masks = m.mask_dim_sizes(); assert(masks.size() == 1); int64_t i = masks[0].cast().getInt(); - int64_t u = m.getType().cast().getDimSize(0); + int64_t u = m.getType().getDimSize(0); if (i >= u) return MaskFormat::AllTrue; if (i <= 0) @@ -849,7 +849,7 @@ return Value(); // Get the nth dimension size starting from lowest dimension. auto getDimReverse = [](VectorType type, int64_t n) { - return type.getShape().take_back(n+1).front(); + return type.getShape().take_back(n + 1).front(); }; int64_t destinationRank = extractOp.getType().isa() @@ -1839,9 +1839,8 @@ auto dense = constantOp.value().dyn_cast(); if (!dense) return failure(); - auto newAttr = DenseElementsAttr::get( - extractStridedSliceOp.getType().cast(), - dense.getSplatValue()); + auto newAttr = DenseElementsAttr::get(extractStridedSliceOp.getType(), + dense.getSplatValue()); rewriter.replaceOpWithNewOp(extractStridedSliceOp, newAttr); return success(); } diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -999,8 +999,7 @@ return failure(); auto operandSourceVectorType = sourceShapeCastOp.source().getType().cast(); - auto operandResultVectorType = - sourceShapeCastOp.result().getType().cast(); + auto operandResultVectorType = sourceShapeCastOp.getType(); // Check if shape cast operations invert each other. if (operandSourceVectorType != resultVectorType || @@ -1397,7 +1396,7 @@ LogicalResult matchAndRewrite(vector::ConstantMaskOp op, PatternRewriter &rewriter) const override { auto loc = op.getLoc(); - auto dstType = op.getResult().getType().cast(); + auto dstType = op.getType(); auto eltType = dstType.getElementType(); auto dimSizes = op.mask_dim_sizes(); int64_t rank = dimSizes.size(); diff --git a/mlir/lib/TableGen/Type.cpp b/mlir/lib/TableGen/Type.cpp --- a/mlir/lib/TableGen/Type.cpp +++ b/mlir/lib/TableGen/Type.cpp @@ -53,6 +53,11 @@ .Default([](auto *) { return llvm::None; }); } +// Return the C++ class name for this type (which may just be ::mlir::Type). +StringRef TypeConstraint::getCPPClassName() const { + return def->getValueAsString("cppClassName"); +} + Type::Type(const llvm::Record *record) : TypeConstraint(record) {} StringRef Type::getTypeDescription() const { diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -2137,11 +2137,18 @@ unsigned numVariadicRegions = op.getNumVariadicRegions(); addSizeCountTrait(opClass, "Region", numRegions, numVariadicRegions); - // Add result size trait. + // Add result size traits. int numResults = op.getNumResults(); int numVariadicResults = op.getNumVariableLengthResults(); addSizeCountTrait(opClass, "Result", numResults, numVariadicResults); + // For single result ops with a known specific type, generate a OneTypedResult + // trait. + if (numResults == 1 && numVariadicResults == 0) { + auto cppName = op.getResults().begin()->constraint.getCPPClassName(); + opClass.addTrait("::mlir::OpTrait::OneTypedResult<" + cppName + ">::Impl"); + } + // Add successor size trait. unsigned numSuccessors = op.getNumSuccessors(); unsigned numVariadicSuccessors = op.getNumVariadicSuccessors();