diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -903,7 +903,7 @@ Pure, SameVariadicResultSize, ViewLikeOpInterface, - InferTypeOpInterfaceAdaptor]> { + InferTypeOpAdaptor]> { let summary = "Extracts a buffer base with offset and strides"; let description = [{ Extracts a base buffer, offset and strides. This op allows additional layers diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -713,9 +713,8 @@ def IfOp : SCF_Op<"if", [DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, - SingleBlockImplicitTerminator<"scf::YieldOp">, RecursiveMemoryEffects, - NoRegionArguments]> { + InferTypeOpAdaptor, SingleBlockImplicitTerminator<"scf::YieldOp">, + RecursiveMemoryEffects, NoRegionArguments]> { let summary = "if-then-else operation"; let description = [{ The `scf.if` operation represents an if-then-else construct for diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -32,8 +32,7 @@ Op; def Shape_AddOp : Shape_Op<"add", - [Commutative, Pure, - DeclareOpInterfaceMethods]> { + [Commutative, Pure, InferTypeOpAdaptorWithIsCompatible]> { let summary = "Addition of sizes and indices"; let description = [{ Adds two sizes or indices. If either operand is an error it will be @@ -51,12 +50,6 @@ $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) }]; - let extraClassDeclaration = [{ - // Returns when two result types are compatible for this op; method used by - // InferTypeOpInterface - static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); - }]; - let hasFolder = 1; let hasVerifier = 1; } @@ -109,7 +102,7 @@ } def Shape_ConstShapeOp : Shape_Op<"const_shape", - [ConstantLike, Pure, DeclareOpInterfaceMethods]> { + [ConstantLike, Pure, InferTypeOpAdaptorWithIsCompatible]> { let summary = "Creates a constant shape or extent tensor"; let description = [{ Creates a constant shape or extent tensor. The individual extents are given @@ -128,11 +121,6 @@ let hasCustomAssemblyFormat = 1; let hasFolder = 1; let hasCanonicalizer = 1; - - let extraClassDeclaration = [{ - // InferTypeOpInterface: - static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); - }]; } def Shape_ConstSizeOp : Shape_Op<"const_size", [ @@ -158,8 +146,7 @@ let hasFolder = 1; } -def Shape_DivOp : Shape_Op<"div", [Pure, - DeclareOpInterfaceMethods]> { +def Shape_DivOp : Shape_Op<"div", [Pure, InferTypeOpAdaptorWithIsCompatible]> { let summary = "Division of sizes and indices"; let description = [{ Divides two sizes or indices. If either operand is an error it will be @@ -187,12 +174,6 @@ let hasFolder = 1; let hasVerifier = 1; - - let extraClassDeclaration = [{ - // Returns when two result types are compatible for this op; method used by - // InferTypeOpInterface - static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); - }]; } def Shape_ShapeEqOp : Shape_Op<"shape_eq", [Pure, Commutative]> { @@ -287,7 +268,7 @@ } def Shape_RankOp : Shape_Op<"rank", - [Pure, DeclareOpInterfaceMethods]> { + [Pure, InferTypeOpAdaptorWithIsCompatible]> { let summary = "Gets the rank of a shape"; let description = [{ Returns the rank of the shape or extent tensor, i.e. the number of extents. @@ -301,12 +282,6 @@ let hasFolder = 1; let hasCanonicalizer = 1; let hasVerifier = 1; - - let extraClassDeclaration = [{ - // Returns when two result types are compatible for this op; method used by - // InferTypeOpInterface - static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); - }]; } def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [ @@ -330,7 +305,7 @@ } def Shape_DimOp : Shape_Op<"dim", - [Pure, DeclareOpInterfaceMethods]> { + [Pure, InferTypeOpAdaptorWithIsCompatible]> { let summary = "Gets the specified extent from the shape of a shaped input"; let description = [{ Gets the extent indexed by `dim` from the shape of the `value` operand. If @@ -354,17 +329,13 @@ let extraClassDeclaration = [{ /// Get the `index` value as integer if it is constant. std::optional getConstantIndex(); - - /// Returns when two result types are compatible for this op; method used - /// by InferTypeOpInterface - static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); }]; let hasFolder = 1; } def Shape_GetExtentOp : Shape_Op<"get_extent", - [Pure, DeclareOpInterfaceMethods]> { + [Pure, InferTypeOpAdaptorWithIsCompatible]> { let summary = "Gets the specified extent from a shape or extent tensor"; let description = [{ Gets the extent indexed by `dim` from the `shape` operand. If the shape is @@ -384,9 +355,6 @@ let extraClassDeclaration = [{ /// Get the `dim` value as integer if it is constant. std::optional getConstantDim(); - /// Returns when two result types are compatible for this op; method used - /// by InferTypeOpInterface - static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); }]; let hasFolder = 1; @@ -413,8 +381,7 @@ } def Shape_MaxOp : Shape_Op<"max", - [Commutative, Pure, - DeclareOpInterfaceMethods]> { + [Commutative, Pure, InferTypeOpAdaptorWithIsCompatible]> { let summary = "Elementwise maximum"; let description = [{ Computes the elementwise maximum of two sizes or shapes with equal ranks. @@ -431,16 +398,10 @@ }]; let hasFolder = 1; - - let extraClassDeclaration = [{ - // Returns when two result types are compatible for this op; method used by - // InferTypeOpInterface - static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); - }]; } def Shape_MeetOp : Shape_Op<"meet", - [Commutative, DeclareOpInterfaceMethods]> { + [Commutative, InferTypeOpAdaptorWithIsCompatible]> { let summary = "Returns the least general shape or size of its operands"; let description = [{ An operation that computes the least general shape or dim of input operands. @@ -478,17 +439,10 @@ $arg0 `,` $arg1 (`,` `error` `=` $error^)? attr-dict `:` type($arg0) `,` type($arg1) `->` type($result) }]; - - let extraClassDeclaration = [{ - // Returns when two result types are compatible for this op; method used by - // InferTypeOpInterface - static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); - }]; } def Shape_MinOp : Shape_Op<"min", - [Commutative, Pure, - DeclareOpInterfaceMethods]> { + [Commutative, Pure, InferTypeOpAdaptorWithIsCompatible]> { let summary = "Elementwise minimum"; let description = [{ Computes the elementwise minimum of two sizes or shapes with equal ranks. @@ -505,17 +459,10 @@ }]; let hasFolder = 1; - - let extraClassDeclaration = [{ - // Returns when two result types are compatible for this op; method used by - // InferTypeOpInterface - static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); - }]; } def Shape_MulOp : Shape_Op<"mul", - [Commutative, Pure, - DeclareOpInterfaceMethods]> { + [Commutative, Pure, InferTypeOpAdaptorWithIsCompatible]> { let summary = "Multiplication of sizes and indices"; let description = [{ Multiplies two sizes or indices. If either operand is an error it will be @@ -535,16 +482,10 @@ let hasFolder = 1; let hasVerifier = 1; - - let extraClassDeclaration = [{ - // Returns when two result types are compatible for this op; method used by - // InferTypeOpInterface - static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); - }]; } def Shape_NumElementsOp : Shape_Op<"num_elements", - [Pure, DeclareOpInterfaceMethods]> { + [Pure, InferTypeOpAdaptorWithIsCompatible]> { let summary = "Returns the number of elements for a given shape"; let description = [{ Returns the number of elements for a given shape which is the product of @@ -561,11 +502,6 @@ let hasFolder = 1; let hasVerifier = 1; - let extraClassDeclaration = [{ - // Returns when two result types are compatible for this op; method used by - // InferTypeOpInterface - static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); - }]; } def Shape_ReduceOp : Shape_Op<"reduce", @@ -616,7 +552,7 @@ } def Shape_ShapeOfOp : Shape_Op<"shape_of", - [Pure, DeclareOpInterfaceMethods]> { + [Pure, InferTypeOpAdaptorWithIsCompatible]> { let summary = "Returns shape of a value or shaped type operand"; let description = [{ @@ -632,12 +568,6 @@ let hasCanonicalizer = 1; let hasFolder = 1; let hasVerifier = 1; - - let extraClassDeclaration = [{ - // Returns when two result types are compatible for this op; method used by - // InferTypeOpInterface - static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); - }]; } def Shape_ValueOfOp : Shape_Op<"value_of", [Pure]> { diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -463,7 +463,7 @@ TCresVTEtIsSameAsOpBase<0, 0>>, PredOpTrait<"second operand v2 and result have same element type", TCresVTEtIsSameAsOpBase<0, 1>>, - DeclareOpInterfaceMethods]>, + InferTypeOpAdaptor]>, Arguments<(ins AnyVectorOfAnyRank:$v1, AnyVectorOfAnyRank:$v2, I64ArrayAttr:$mask)>, Results<(outs AnyVector:$vector)> { @@ -572,7 +572,7 @@ Vector_Op<"extract", [Pure, PredOpTrait<"operand and result have same element type", TCresVTEtIsSameAsOpBase<0, 0>>, - DeclareOpInterfaceMethods]>, + InferTypeOpAdaptorWithIsCompatible]>, Arguments<(ins AnyVectorOfAnyRank:$vector, I64ArrayAttr:$position)>, Results<(outs AnyType)> { let summary = "extract operation"; @@ -598,7 +598,6 @@ VectorType getSourceVectorType() { return ::llvm::cast(getVector().getType()); } - static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); }]; let assemblyFormat = "$vector `` $position attr-dict `:` type($vector)"; let hasCanonicalizer = 1; diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h --- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h +++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h @@ -259,8 +259,8 @@ namespace OpTrait { template -class InferTypeOpInterfaceAdaptor - : public TraitBase {}; +class InferTypeOpAdaptor : public TraitBase { +}; /// Tensor type inference trait that constructs a tensor from the inferred /// shape and elemental types. diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td --- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td +++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td @@ -186,35 +186,42 @@ // Convenient trait to define a wrapper to inferReturnTypes that passes in the // Op Adaptor directly -def InferTypeOpInterfaceAdaptor : TraitList< +class InferTypeOpAdaptorBase : TraitList< [ // Op implements infer type op interface. DeclareOpInterfaceMethods, NativeOpTrait< - /*name=*/"InferTypeOpInterfaceAdaptor", + /*name=*/"InferTypeOpAdaptor", /*traits=*/[], /*extraOpDeclaration=*/[{ - static LogicalResult - inferReturnTypesAdaptor(MLIRContext *context, - std::optional location, + static ::mlir::LogicalResult + inferReturnTypes(::mlir::MLIRContext *context, + std::optional<::mlir::Location> location, Adaptor adaptor, - SmallVectorImpl &inferredReturnTypes); - }], + ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes); + }] # additionalDecls, /*extraOpDefinition=*/[{ - LogicalResult - $cppClass::inferReturnTypes(MLIRContext *context, - std::optional location, - ValueRange operands, DictionaryAttr attributes, - OpaqueProperties properties, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { + ::mlir::LogicalResult + $cppClass::inferReturnTypes(::mlir::MLIRContext *context, + std::optional<::mlir::Location> location, + ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, + ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions, + ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { $cppClass::Adaptor adaptor(operands, attributes, properties, regions); - return $cppClass::inferReturnTypesAdaptor(context, + return $cppClass::inferReturnTypes(context, location, adaptor, inferredReturnTypes); } }] > ]>; +def InferTypeOpAdaptor : InferTypeOpAdaptorBase; +def InferTypeOpAdaptorWithIsCompatible : InferTypeOpAdaptorBase< + [{ + static bool isCompatibleReturnTypes(::mlir::TypeRange l, ::mlir::TypeRange r); + }] +>; + // Convenience class grouping together type and shaped type op interfaces for // ops that have tensor return types. class InferTensorTypeBase overridenMethods = []> : TraitList< @@ -231,13 +238,13 @@ /*traits=*/[], /*extraOpDeclaration=*/[{}], /*extraOpDefinition=*/[{ - LogicalResult - $cppClass::inferReturnTypes(MLIRContext *context, - std::optional location, - ValueRange operands, DictionaryAttr attributes, - OpaqueProperties properties, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { - SmallVector retComponents; + ::mlir::LogicalResult + $cppClass::inferReturnTypes(::mlir::MLIRContext *context, + std::optional<::mlir::Location> location, + ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, + ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions, + ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { + ::llvm::SmallVector<::mlir::ShapedTypeComponents, 2> retComponents; if (failed($cppClass::inferReturnTypeComponents(context, location, operands, attributes, properties, regions, retComponents))) diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1354,7 +1354,7 @@ /// The number and type of the results are inferred from the /// shape of the source. -LogicalResult ExtractStridedMetadataOp::inferReturnTypesAdaptor( +LogicalResult ExtractStridedMetadataOp::inferReturnTypes( MLIRContext *context, std::optional location, ExtractStridedMetadataOp::Adaptor adaptor, SmallVectorImpl &inferredReturnTypes) { diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -1841,12 +1841,11 @@ LogicalResult IfOp::inferReturnTypes(MLIRContext *ctx, std::optional loc, - ValueRange operands, DictionaryAttr attrs, - OpaqueProperties properties, RegionRange regions, + IfOp::Adaptor adaptor, SmallVectorImpl &inferredReturnTypes) { - if (regions.empty()) + if (adaptor.getRegions().empty()) return failure(); - Region *r = regions.front(); + Region *r = &adaptor.getThenRegion(); if (r->empty()) return failure(); Block &b = r->front(); diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -394,11 +394,10 @@ //===----------------------------------------------------------------------===// LogicalResult mlir::shape::AddOp::inferReturnTypes( - MLIRContext *context, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { - if (llvm::isa(operands[0].getType()) || - llvm::isa(operands[1].getType())) + MLIRContext *context, std::optional location, + AddOp::Adaptor adaptor, SmallVectorImpl &inferredReturnTypes) { + if (llvm::isa(adaptor.getLhs().getType()) || + llvm::isa(adaptor.getRhs().getType())) inferredReturnTypes.assign({SizeType::get(context)}); else inferredReturnTypes.assign({IndexType::get(context)}); @@ -916,18 +915,17 @@ } LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes( - MLIRContext *context, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { + MLIRContext *context, std::optional location, + ConstShapeOp::Adaptor adaptor, SmallVectorImpl &inferredReturnTypes) { Builder b(context); - Properties *prop = properties.as(); + const Properties *prop = &adaptor.getProperties(); DenseIntElementsAttr shape; // TODO: this is only exercised by the Python bindings codepath which does not // support properties if (prop) shape = prop->shape; else - shape = attributes.getAs("shape"); + shape = adaptor.getAttributes().getAs("shape"); if (!shape) return emitOptionalError(location, "missing shape attribute"); inferredReturnTypes.assign({RankedTensorType::get( @@ -1104,11 +1102,9 @@ } LogicalResult mlir::shape::DimOp::inferReturnTypes( - MLIRContext *context, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { - DimOpAdaptor dimOp(operands); - inferredReturnTypes.assign({dimOp.getIndex().getType()}); + MLIRContext *context, std::optional location, + DimOp::Adaptor adaptor, SmallVectorImpl &inferredReturnTypes) { + inferredReturnTypes.assign({adaptor.getIndex().getType()}); return success(); } @@ -1141,11 +1137,10 @@ } LogicalResult mlir::shape::DivOp::inferReturnTypes( - MLIRContext *context, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { - if (llvm::isa(operands[0].getType()) || - llvm::isa(operands[1].getType())) + MLIRContext *context, std::optional location, + DivOp::Adaptor adaptor, SmallVectorImpl &inferredReturnTypes) { + if (llvm::isa(adaptor.getLhs().getType()) || + llvm::isa(adaptor.getRhs().getType())) inferredReturnTypes.assign({SizeType::get(context)}); else inferredReturnTypes.assign({IndexType::get(context)}); @@ -1361,9 +1356,8 @@ } LogicalResult mlir::shape::GetExtentOp::inferReturnTypes( - MLIRContext *context, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { + MLIRContext *context, std::optional location, + GetExtentOp::Adaptor adaptor, SmallVectorImpl &inferredReturnTypes) { inferredReturnTypes.assign({IndexType::get(context)}); return success(); } @@ -1399,10 +1393,9 @@ //===----------------------------------------------------------------------===// LogicalResult mlir::shape::MeetOp::inferReturnTypes( - MLIRContext *context, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { - if (operands.empty()) + MLIRContext *context, std::optional location, + MeetOp::Adaptor adaptor, SmallVectorImpl &inferredReturnTypes) { + if (adaptor.getOperands().empty()) return failure(); auto isShapeType = [](Type arg) { @@ -1411,7 +1404,7 @@ return isExtentTensorType(arg); }; - ValueRange::type_range types = operands.getTypes(); + ValueRange::type_range types = adaptor.getOperands().getTypes(); Type acc = types.front(); for (auto t : drop_begin(types)) { Type l = acc, r = t; @@ -1535,10 +1528,9 @@ } LogicalResult mlir::shape::RankOp::inferReturnTypes( - MLIRContext *context, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { - if (llvm::isa(operands[0].getType())) + MLIRContext *context, std::optional location, + RankOp::Adaptor adaptor, SmallVectorImpl &inferredReturnTypes) { + if (llvm::isa(adaptor.getShape().getType())) inferredReturnTypes.assign({SizeType::get(context)}); else inferredReturnTypes.assign({IndexType::get(context)}); @@ -1571,10 +1563,10 @@ } LogicalResult mlir::shape::NumElementsOp::inferReturnTypes( - MLIRContext *context, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + MLIRContext *context, std::optional location, + NumElementsOp::Adaptor adaptor, SmallVectorImpl &inferredReturnTypes) { - if (llvm::isa(operands[0].getType())) + if (llvm::isa(adaptor.getShape().getType())) inferredReturnTypes.assign({SizeType::get(context)}); else inferredReturnTypes.assign({IndexType::get(context)}); @@ -1603,11 +1595,10 @@ } LogicalResult mlir::shape::MaxOp::inferReturnTypes( - MLIRContext *context, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { - if (operands[0].getType() == operands[1].getType()) - inferredReturnTypes.assign({operands[0].getType()}); + MLIRContext *context, std::optional location, + MaxOp::Adaptor adaptor, SmallVectorImpl &inferredReturnTypes) { + if (adaptor.getLhs().getType() == adaptor.getRhs().getType()) + inferredReturnTypes.assign({adaptor.getLhs().getType()}); else inferredReturnTypes.assign({SizeType::get(context)}); return success(); @@ -1635,11 +1626,10 @@ } LogicalResult mlir::shape::MinOp::inferReturnTypes( - MLIRContext *context, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { - if (operands[0].getType() == operands[1].getType()) - inferredReturnTypes.assign({operands[0].getType()}); + MLIRContext *context, std::optional location, + MinOp::Adaptor adaptor, SmallVectorImpl &inferredReturnTypes) { + if (adaptor.getLhs().getType() == adaptor.getRhs().getType()) + inferredReturnTypes.assign({adaptor.getLhs().getType()}); else inferredReturnTypes.assign({SizeType::get(context)}); return success(); @@ -1672,11 +1662,10 @@ } LogicalResult mlir::shape::MulOp::inferReturnTypes( - MLIRContext *context, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { - if (llvm::isa(operands[0].getType()) || - llvm::isa(operands[1].getType())) + MLIRContext *context, std::optional location, + MulOp::Adaptor adaptor, SmallVectorImpl &inferredReturnTypes) { + if (llvm::isa(adaptor.getLhs().getType()) || + llvm::isa(adaptor.getRhs().getType())) inferredReturnTypes.assign({SizeType::get(context)}); else inferredReturnTypes.assign({IndexType::get(context)}); @@ -1759,13 +1748,12 @@ } LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes( - MLIRContext *context, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { - if (llvm::isa(operands[0].getType())) + MLIRContext *context, std::optional location, + ShapeOfOp::Adaptor adaptor, SmallVectorImpl &inferredReturnTypes) { + if (llvm::isa(adaptor.getArg().getType())) inferredReturnTypes.assign({ShapeType::get(context)}); else { - auto shapedTy = llvm::cast(operands[0].getType()); + auto shapedTy = llvm::cast(adaptor.getArg().getType()); int64_t rank = shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamic; Type indexTy = IndexType::get(context); diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1146,15 +1146,15 @@ LogicalResult ExtractOp::inferReturnTypes(MLIRContext *, std::optional, - ValueRange operands, DictionaryAttr attributes, - OpaqueProperties properties, RegionRange, + ExtractOp::Adaptor adaptor, SmallVectorImpl &inferredReturnTypes) { - ExtractOp::Adaptor op(operands, attributes, properties); - auto vectorType = llvm::cast(op.getVector().getType()); - if (static_cast(op.getPosition().size()) == vectorType.getRank()) { + auto vectorType = llvm::cast(adaptor.getVector().getType()); + if (static_cast(adaptor.getPosition().size()) == + vectorType.getRank()) { inferredReturnTypes.push_back(vectorType.getElementType()); } else { - auto n = std::min(op.getPosition().size(), vectorType.getRank()); + auto n = + std::min(adaptor.getPosition().size(), vectorType.getRank()); inferredReturnTypes.push_back(VectorType::get( vectorType.getShape().drop_front(n), vectorType.getElementType())); } @@ -2114,17 +2114,15 @@ LogicalResult ShuffleOp::inferReturnTypes(MLIRContext *, std::optional, - ValueRange operands, DictionaryAttr attributes, - OpaqueProperties properties, RegionRange, + ShuffleOp::Adaptor adaptor, SmallVectorImpl &inferredReturnTypes) { - ShuffleOp::Adaptor op(operands, attributes, properties); - auto v1Type = llvm::cast(op.getV1().getType()); + auto v1Type = llvm::cast(adaptor.getV1().getType()); auto v1Rank = v1Type.getRank(); // Construct resulting type: leading dimension matches mask // length, all trailing dimensions match the operands. SmallVector shape; shape.reserve(v1Rank); - shape.push_back(std::max(1, op.getMask().size())); + shape.push_back(std::max(1, adaptor.getMask().size())); // In the 0-D case there is no trailing shape to append. if (v1Rank > 0) llvm::append_range(shape, v1Type.getShape().drop_front()); diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -1385,6 +1385,19 @@ return success(); } +LogicalResult OpWithInferTypeAdaptorInterfaceOp::inferReturnTypes( + MLIRContext *, std::optional location, + OpWithInferTypeAdaptorInterfaceOp::Adaptor adaptor, + SmallVectorImpl &inferredReturnTypes) { + if (adaptor.getX().getType() != adaptor.getY().getType()) { + return emitOptionalError(location, "operand type mismatch ", + adaptor.getX().getType(), " vs ", + adaptor.getY().getType()); + } + inferredReturnTypes.assign({adaptor.getX().getType()}); + return success(); +} + // TODO: We should be able to only define either inferReturnType or // refineReturnType, currently only refineReturnType can be omitted. LogicalResult OpWithRefineTypeInterfaceOp::inferReturnTypes( diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -761,6 +761,12 @@ let results = (outs AnyTensor); } +def OpWithInferTypeAdaptorInterfaceOp : TEST_Op<"op_with_infer_type_adaptor_if", [ + InferTypeOpAdaptor]> { + let arguments = (ins AnyTensor:$x, AnyTensor:$y); + let results = (outs AnyTensor); +} + def OpWithRefineTypeInterfaceOp : TEST_Op<"op_with_refine_type_if", [ DeclareOpInterfaceMethods]> { diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -485,6 +485,8 @@ // output would be in reverse order underneath `op` from which // the attributes and regions are used. invokeCreateWithInferredReturnType(op); + invokeCreateWithInferredReturnType( + op); invokeCreateWithInferredReturnType< OpWithShapedTypeInferTypeInterfaceOp>(op); };