diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -28,7 +28,9 @@ class Shape_Op traits = []> : Op; -def Shape_AddOp : Shape_Op<"add", [Commutative, NoSideEffect]> { +def Shape_AddOp : Shape_Op<"add", + [Commutative, NoSideEffect, + DeclareOpInterfaceMethods]> { let summary = "Addition of sizes and indices"; let description = [{ Adds two sizes or indices. If either operand is an error it will be @@ -47,6 +49,11 @@ }]; let verifier = [{ return verifySizeOrIndexOp(*this); }]; + + let extraClassDeclaration = [{ + /// InferTypeOpInterface: + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); + }]; } def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative, NoSideEffect]> { @@ -77,6 +84,8 @@ OptionalAttr:$error); let results = (outs Shape_ShapeOrExtentTensorType:$result); + let builders = [OpBuilder<(ins "Value":$shape)>]; + let assemblyFormat = [{ $shapes attr-dict `:` type($shapes) `->` type($result) }]; @@ -145,7 +154,8 @@ let hasFolder = 1; } -def Shape_DivOp : Shape_Op<"div", [NoSideEffect]> { +def Shape_DivOp : Shape_Op<"div", [NoSideEffect, + DeclareOpInterfaceMethods]> { let summary = "Division of sizes and indices"; let description = [{ Divides two sizes or indices. If either operand is an error it will be @@ -173,10 +183,15 @@ let verifier = [{ return ::verifySizeOrIndexOp(*this); }]; let hasFolder = 1; + + let extraClassDeclaration = [{ + /// InferTypeOpInterface: + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); + }]; } -def Shape_ShapeEqOp : Shape_Op<"shape_eq", [NoSideEffect, Commutative, - InferTypeOpInterface]> { +def Shape_ShapeEqOp : Shape_Op<"shape_eq", + [NoSideEffect, Commutative, InferTypeOpInterface]> { let summary = "Returns whether the input shapes or extent tensors are equal"; let description = [{ Takes one or more shape or extent tensor operands and determines whether @@ -290,7 +305,8 @@ let assemblyFormat = "$shapes attr-dict `:` type($shapes)"; } -def Shape_RankOp : Shape_Op<"rank", [NoSideEffect]> { +def Shape_RankOp : Shape_Op<"rank", + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "Gets the rank of a shape"; let description = [{ Returns the rank of the shape or extent tensor, i.e. the number of extents. @@ -304,6 +320,11 @@ let hasFolder = 1; let hasCanonicalizer = 1; let verifier = [{ return ::verifySizeOrIndexOp(*this); }]; + + let extraClassDeclaration = [{ + /// InferTypeOpInterface: + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); + }]; } def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [NoSideEffect]> { @@ -324,7 +345,8 @@ let hasFolder = 1; } -def Shape_GetExtentOp : Shape_Op<"get_extent", [NoSideEffect]> { +def Shape_GetExtentOp : Shape_Op<"get_extent", + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "Gets the specified extent from a shape or extent tensor"; let description = [{ Gets the extent indexed by `dim` from the `shape` operand. If the shape is @@ -344,6 +366,8 @@ let extraClassDeclaration = [{ /// Get the `dim` value as integer if it is constant. Optional getConstantDim(); + /// InferTypeOpInterface: + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); }]; let hasFolder = 1; @@ -369,7 +393,8 @@ let hasCanonicalizer = 1; } -def Shape_JoinOp : Shape_Op<"join", [Commutative]> { +def Shape_JoinOp : Shape_Op<"join", + [Commutative, DeclareOpInterfaceMethods]> { let summary = "Returns the least general shape.shape of its operands"; let description = [{ An operation that computes the least general shape of input operands. @@ -405,9 +430,16 @@ $arg0 `,` $arg1 (`,` `error` `=` $error^)? attr-dict `:` type($arg0) `,` type($arg1) `->` type($result) }]; + + let extraClassDeclaration = [{ + // InferTypeOpInterface: + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); + }]; } -def Shape_MaxOp : Shape_Op<"max", [Commutative, NoSideEffect]> { +def Shape_MaxOp : Shape_Op<"max", + [Commutative, NoSideEffect, + DeclareOpInterfaceMethods]> { let summary = "Elementwise maximum"; let description = [{ Computes the elementwise maximum of two sizes or shapes with equal ranks. @@ -424,9 +456,16 @@ }]; let hasFolder = 1; + + let extraClassDeclaration = [{ + // InferTypeOpInterface: + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); + }]; } -def Shape_MinOp : Shape_Op<"min", [Commutative, NoSideEffect]> { +def Shape_MinOp : Shape_Op<"min", + [Commutative, NoSideEffect, + DeclareOpInterfaceMethods]> { let summary = "Elementwise minimum"; let description = [{ Computes the elementwise minimum of two sizes or shapes with equal ranks. @@ -443,9 +482,16 @@ }]; let hasFolder = 1; + + let extraClassDeclaration = [{ + // InferTypeOpInterface: + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); + }]; } -def Shape_MulOp : Shape_Op<"mul", [Commutative, NoSideEffect]> { +def Shape_MulOp : Shape_Op<"mul", + [Commutative, NoSideEffect, + DeclareOpInterfaceMethods]> { let summary = "Multiplication of sizes and indices"; let description = [{ Multiplies two sizes or indices. If either operand is an error it will be @@ -465,9 +511,15 @@ let verifier = [{ return ::verifySizeOrIndexOp(*this); }]; let hasFolder = 1; + + let extraClassDeclaration = [{ + /// InferTypeOpInterface: + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); + }]; } -def Shape_NumElementsOp : Shape_Op<"num_elements", [NoSideEffect]> { +def Shape_NumElementsOp : Shape_Op<"num_elements", + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "Returns the number of elements for a given shape"; let description = [{ Returns the number of elements for a given shape which is the product of its @@ -480,12 +532,14 @@ let arguments = (ins Shape_ShapeOrExtentTensorType:$shape); let results = (outs Shape_SizeOrIndexType:$result); - let builders = [OpBuilder<(ins "Value":$shape)>]; - let assemblyFormat = "$shape attr-dict `:` type($shape) `->` type($result)"; let hasFolder = 1; let verifier = [{ return ::verifySizeOrIndexOp(*this); }]; + let extraClassDeclaration = [{ + /// InferTypeOpInterface: + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); + }]; } def Shape_ReduceOp : Shape_Op<"reduce", @@ -535,7 +589,8 @@ let parser = [{ return ::parse$cppClass(parser, result); }]; } -def Shape_ShapeOfOp : Shape_Op<"shape_of", [NoSideEffect]> { +def Shape_ShapeOfOp : Shape_Op<"shape_of", + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "Returns shape of a value or shaped type operand"; let description = [{ @@ -548,11 +603,14 @@ let assemblyFormat = "$arg attr-dict `:` type($arg) `->` type($result)"; - let builders = [OpBuilder<(ins "Value":$arg)>]; - let verifier = [{ return ::verifyShapeOrExtentTensorOp(*this); }]; let hasCanonicalizer = 1; let hasFolder = 1; + + let extraClassDeclaration = [{ + // InferTypeOpInterface: + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); + }]; } def Shape_SizeToIndexOp : Shape_Op<"size_to_index", [NoSideEffect]> { diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td --- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td +++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td @@ -34,7 +34,9 @@ The method takes an optional location which, if set, will be used to report errors on. The operands and attributes correspond to those with which an Operation would be created (e.g., as used in Operation::create) - and the regions of the op. + and the regions of the op. Be aware that this method is supposed to be + called with valid arguments, e.g., operands are verified, or it may result + in an undefined behavior. }], /*retTy=*/"::mlir::LogicalResult", /*methodName=*/"inferReturnTypes", diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -404,6 +404,30 @@ result.addTypes(assumingTypes); } +//===----------------------------------------------------------------------===// +// AddOp +//===----------------------------------------------------------------------===// + +LogicalResult mlir::shape::AddOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + if (operands[0].getType().isa() || + operands[1].getType().isa()) + inferredReturnTypes.assign({SizeType::get(context)}); + else + inferredReturnTypes.assign({IndexType::get(context)}); + return success(); +} + +bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { + if (l.size() != 1 || r.size() != 1) + return false; + // SizeType is compatible with IndexType. + return l.front().isa() && + r.front().isa(); +} + //===----------------------------------------------------------------------===// // AssumingAllOp //===----------------------------------------------------------------------===// @@ -955,6 +979,26 @@ return IntegerAttr::get(indexTy, quotient); } +LogicalResult mlir::shape::DivOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + if (operands[0].getType().isa() || + operands[1].getType().isa()) + inferredReturnTypes.assign({SizeType::get(context)}); + else + inferredReturnTypes.assign({IndexType::get(context)}); + return success(); +} + +bool mlir::shape::DivOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { + if (l.size() != 1 || r.size() != 1) + return false; + // SizeType is compatible with IndexType. + return l.front().isa() && + r.front().isa(); +} + //===----------------------------------------------------------------------===// // ShapeEqOp //===----------------------------------------------------------------------===// @@ -1096,6 +1140,23 @@ } } +LogicalResult mlir::shape::GetExtentOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + inferredReturnTypes.assign({IndexType::get(context)}); + return success(); +} + +bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(TypeRange l, + TypeRange r) { + if (l.size() != 1 || r.size() != 1) + return false; + // SizeType is compatible with IndexType. + return l.front().isa() && + r.front().isa(); +} + //===----------------------------------------------------------------------===// // IsBroadcastableOp //===----------------------------------------------------------------------===// @@ -1114,6 +1175,41 @@ return nullptr; } +//===----------------------------------------------------------------------===// +// JoinOp +//===----------------------------------------------------------------------===// + +LogicalResult mlir::shape::JoinOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + if (operands[0].getType().isa()) + inferredReturnTypes.assign({ShapeType::get(context)}); + else + inferredReturnTypes.assign({SizeType::get(context)}); + return success(); +} + +bool mlir::shape::JoinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { + if (l.size() != 1 || r.size() != 1) + return false; + if (l == r) + return true; + + Type lhs = l.front(); + Type rhs = r.front(); + + if (lhs != rhs) + return false; + + if (lhs.isa() || lhs.isa()) + return true; + + if (succeeded(verifyCompatibleShapes({lhs, rhs}))) + return true; + return false; +} + //===----------------------------------------------------------------------===// // RankOp //===----------------------------------------------------------------------===// @@ -1173,6 +1269,25 @@ patterns.add(context); } +LogicalResult mlir::shape::RankOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + if (operands[0].getType().isa()) + inferredReturnTypes.assign({SizeType::get(context)}); + else + inferredReturnTypes.assign({IndexType::get(context)}); + return success(); +} + +bool mlir::shape::RankOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { + if (l.size() != 1 || r.size() != 1) + return false; + // SizeType is compatible with IndexType. + return l.front().isa() && + r.front().isa(); +} + //===----------------------------------------------------------------------===// // NumElementsOp //===----------------------------------------------------------------------===// @@ -1191,14 +1306,24 @@ return builder.getIndexAttr(product.getLimitedValue()); } -void NumElementsOp::build(OpBuilder &builder, OperationState &result, - Value shape) { - if (shape.getType().isa()) { - auto type = builder.getIndexType(); - return build(builder, result, type, shape); - } - auto type = SizeType::get(builder.getContext()); - return build(builder, result, type, shape); +LogicalResult mlir::shape::NumElementsOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + if (operands[0].getType().isa()) + inferredReturnTypes.assign({SizeType::get(context)}); + else + inferredReturnTypes.assign({IndexType::get(context)}); + return success(); +} + +bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(TypeRange l, + TypeRange r) { + if (l.size() != 1 || r.size() != 1) + return false; + // SizeType is compatible with IndexType. + return l.front().isa() && + r.front().isa(); } //===----------------------------------------------------------------------===// @@ -1212,6 +1337,27 @@ return nullptr; } +LogicalResult mlir::shape::MaxOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + if (operands[0].getType() == operands[1].getType()) + inferredReturnTypes.assign({operands[0].getType()}); + else + inferredReturnTypes.assign({SizeType::get(context)}); + return success(); +} + +bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { + if (l.size() != 1 || r.size() != 1) + return false; + if (l.front().isa() && r.front().isa()) + return true; + if (l.front().isa() && r.front().isa()) + return true; + return false; +} + //===----------------------------------------------------------------------===// // MinOp //===----------------------------------------------------------------------===// @@ -1223,6 +1369,27 @@ return nullptr; } +LogicalResult mlir::shape::MinOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + if (operands[0].getType() == operands[1].getType()) + inferredReturnTypes.assign({operands[0].getType()}); + else + inferredReturnTypes.assign({SizeType::get(context)}); + return success(); +} + +bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { + if (l.size() != 1 || r.size() != 1) + return false; + if (l.front().isa() && r.front().isa()) + return true; + if (l.front().isa() && r.front().isa()) + return true; + return false; +} + //===----------------------------------------------------------------------===// // MulOp //===----------------------------------------------------------------------===// @@ -1239,6 +1406,25 @@ return IntegerAttr::get(indexTy, folded); } +LogicalResult mlir::shape::MulOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + if (operands[0].getType().isa() || + operands[1].getType().isa()) + inferredReturnTypes.assign({SizeType::get(context)}); + else + inferredReturnTypes.assign({IndexType::get(context)}); + return success(); +} + +bool mlir::shape::MulOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { + if (l.size() != 1 || r.size() != 1) + return false; + // SizeType is compatible with IndexType. + return l.front().isa() && + r.front().isa(); +} //===----------------------------------------------------------------------===// // ShapeOfOp //===----------------------------------------------------------------------===// @@ -1251,18 +1437,6 @@ return builder.getIndexTensorAttr(type.getShape()); } -void ShapeOfOp::build(OpBuilder &builder, OperationState &result, Value arg) { - if (auto shapedTy = arg.getType().dyn_cast()) { - int64_t rank = - shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize; - Type indexTy = builder.getIndexType(); - Type extentTensorTy = RankedTensorType::get({rank}, indexTy); - return ShapeOfOp::build(builder, result, extentTensorTy, arg); - } - Type shapeTy = builder.getType(); - return ShapeOfOp::build(builder, result, shapeTy, arg); -} - namespace { struct ShapeOfWithTensor : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -1317,6 +1491,44 @@ patterns.add(context); } +LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + if (operands[0].getType().isa()) + inferredReturnTypes.assign({ShapeType::get(context)}); + else { + auto shapedTy = operands[0].getType().cast(); + int64_t rank = + shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize; + Type indexTy = IndexType::get(context); + Type extentTensorTy = RankedTensorType::get({rank}, indexTy); + inferredReturnTypes.assign({extentTensorTy}); + } + return success(); +} + +bool mlir::shape::ShapeOfOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { + if (l.size() != 1 || r.size() != 1) + return false; + if (l == r) + return true; + + Type lhs = l.front(); + Type rhs = r.front(); + + if (!lhs.isa() || !rhs.isa()) + return false; + + if (lhs.isa() || rhs.isa()) + // Shape type is compatible with all other valid return types. + return true; + + if (succeeded(verifyCompatibleShapes({lhs, rhs}))) + return true; + return false; +} + //===----------------------------------------------------------------------===// // SizeToIndexOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Shape/invalid.mlir b/mlir/test/Dialect/Shape/invalid.mlir --- a/mlir/test/Dialect/Shape/invalid.mlir +++ b/mlir/test/Dialect/Shape/invalid.mlir @@ -97,6 +97,14 @@ // ----- +func @shape_of_incompatible_return_types(%value_arg : tensor<1x2xindex>) { + // expected-error@+1 {{'shape.shape_of' op inferred type(s) 'tensor<2xindex>' are incompatible with return type(s) of operation 'tensor<3xf32>'}} + %0 = shape.shape_of %value_arg : tensor<1x2xindex> -> tensor<3xf32> + return +} + +// ----- + func @rank(%arg : !shape.shape) { // expected-error@+1 {{if at least one of the operands can hold error values then the result must be of type `size` to propagate them}} %0 = shape.rank %arg : !shape.shape -> index