diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -28,7 +28,9 @@ class Shape_Op traits = []> : Op; -def Shape_AddOp : Shape_Op<"add", [Commutative, NoSideEffect]> { +def Shape_AddOp + : Shape_Op<"add", [Commutative, NoSideEffect, + DeclareOpInterfaceMethods]> { let summary = "Addition of sizes and indices"; let description = [{ Adds two sizes or indices. If either operand is an error it will be @@ -47,6 +49,11 @@ }]; let verifier = [{ return verifySizeOrIndexOp(*this); }]; + + let extraClassDeclaration = [{ + /// InferTypeOpInterface: + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); + }]; } def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative, NoSideEffect]> { @@ -77,6 +84,8 @@ OptionalAttr:$error); let results = (outs Shape_ShapeOrExtentTensorType:$result); + let builders = [OpBuilder<(ins "Value":$shape)>]; + let assemblyFormat = [{ $shapes attr-dict `:` type($shapes) `->` type($result) }]; @@ -139,7 +148,9 @@ let hasFolder = 1; } -def Shape_DivOp : Shape_Op<"div", [NoSideEffect]> { +def Shape_DivOp + : Shape_Op<"div", [NoSideEffect, + DeclareOpInterfaceMethods]> { let summary = "Division of sizes and indices"; let description = [{ Divides two sizes or indices. If either operand is an error it will be @@ -167,10 +178,15 @@ let verifier = [{ return ::verifySizeOrIndexOp(*this); }]; let hasFolder = 1; + + let extraClassDeclaration = [{ + /// InferTypeOpInterface: + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); + }]; } -def Shape_ShapeEqOp : Shape_Op<"shape_eq", [NoSideEffect, Commutative, - InferTypeOpInterface]> { +def Shape_ShapeEqOp + : Shape_Op<"shape_eq", [NoSideEffect, Commutative, InferTypeOpInterface]> { let summary = "Returns whether the input shapes or extent tensors are equal"; let description = [{ Takes one or more shape or extent tensor operands and determines whether @@ -284,7 +300,9 @@ let assemblyFormat = "$shapes attr-dict `:` type($shapes)"; } -def Shape_RankOp : Shape_Op<"rank", [NoSideEffect]> { +def Shape_RankOp + : Shape_Op<"rank", [NoSideEffect, + DeclareOpInterfaceMethods]> { let summary = "Gets the rank of a shape"; let description = [{ Returns the rank of the shape or extent tensor, i.e. the number of extents. @@ -298,6 +316,11 @@ let hasFolder = 1; let hasCanonicalizer = 1; let verifier = [{ return ::verifySizeOrIndexOp(*this); }]; + + let extraClassDeclaration = [{ + /// InferTypeOpInterface: + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); + }]; } def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [NoSideEffect]> { @@ -318,7 +341,10 @@ let hasFolder = 1; } -def Shape_GetExtentOp : Shape_Op<"get_extent", [NoSideEffect]> { +def Shape_GetExtentOp + : Shape_Op<"get_extent", + [NoSideEffect, + DeclareOpInterfaceMethods]> { let summary = "Gets the specified extent from a shape or extent tensor"; let description = [{ Gets the extent indexed by `dim` from the `shape` operand. If the shape is @@ -338,6 +364,8 @@ let extraClassDeclaration = [{ /// Get the `dim` value as integer if it is constant. Optional getConstantDim(); + /// InferTypeOpInterface: + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); }]; let hasFolder = 1; @@ -363,7 +391,9 @@ let hasCanonicalizer = 1; } -def Shape_JoinOp : Shape_Op<"join", [Commutative]> { +def Shape_JoinOp + : Shape_Op<"join", [Commutative, + DeclareOpInterfaceMethods]> { let summary = "Returns the least general shape.shape of its operands"; let description = [{ An operation that computes the least general shape of input operands. @@ -399,9 +429,16 @@ $arg0 `,` $arg1 (`,` `error` `=` $error^)? attr-dict `:` type($arg0) `,` type($arg1) `->` type($result) }]; + + let extraClassDeclaration = [{ + // InferTypeOpInterface: + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); + }]; } -def Shape_MaxOp : Shape_Op<"max", [Commutative, NoSideEffect]> { +def Shape_MaxOp + : Shape_Op<"max", [Commutative, NoSideEffect, + DeclareOpInterfaceMethods]> { let summary = "Elementwise maximum"; let description = [{ Computes the elementwise maximum of two sizes or shapes with equal ranks. @@ -418,9 +455,16 @@ }]; let hasFolder = 1; + + let extraClassDeclaration = [{ + // InferTypeOpInterface: + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); + }]; } -def Shape_MinOp : Shape_Op<"min", [Commutative, NoSideEffect]> { +def Shape_MinOp + : Shape_Op<"min", [Commutative, NoSideEffect, + DeclareOpInterfaceMethods]> { let summary = "Elementwise minimum"; let description = [{ Computes the elementwise minimum of two sizes or shapes with equal ranks. @@ -437,9 +481,16 @@ }]; let hasFolder = 1; + + let extraClassDeclaration = [{ + // InferTypeOpInterface: + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); + }]; } -def Shape_MulOp : Shape_Op<"mul", [Commutative, NoSideEffect]> { +def Shape_MulOp + : Shape_Op<"mul", [Commutative, NoSideEffect, + DeclareOpInterfaceMethods]> { let summary = "Multiplication of sizes and indices"; let description = [{ Multiplies two sizes or indices. If either operand is an error it will be @@ -459,9 +510,17 @@ let verifier = [{ return ::verifySizeOrIndexOp(*this); }]; let hasFolder = 1; + + let extraClassDeclaration = [{ + /// InferTypeOpInterface: + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); + }]; } -def Shape_NumElementsOp : Shape_Op<"num_elements", [NoSideEffect]> { +def Shape_NumElementsOp + : Shape_Op<"num_elements", + [NoSideEffect, + DeclareOpInterfaceMethods]> { let summary = "Returns the number of elements for a given shape"; let description = [{ Returns the number of elements for a given shape which is the product of its @@ -474,12 +533,14 @@ let arguments = (ins Shape_ShapeOrExtentTensorType:$shape); let results = (outs Shape_SizeOrIndexType:$result); - let builders = [OpBuilder<(ins "Value":$shape)>]; - let assemblyFormat = "$shape attr-dict `:` type($shape) `->` type($result)"; let hasFolder = 1; let verifier = [{ return ::verifySizeOrIndexOp(*this); }]; + let extraClassDeclaration = [{ + /// InferTypeOpInterface: + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); + }]; } def Shape_ReduceOp : Shape_Op<"reduce", @@ -529,7 +590,10 @@ let parser = [{ return ::parse$cppClass(parser, result); }]; } -def Shape_ShapeOfOp : Shape_Op<"shape_of", [NoSideEffect]> { +def Shape_ShapeOfOp + : Shape_Op<"shape_of", + [NoSideEffect, + DeclareOpInterfaceMethods]> { let summary = "Returns shape of a value or shaped type operand"; let description = [{ @@ -542,11 +606,14 @@ let assemblyFormat = "$arg attr-dict `:` type($arg) `->` type($result)"; - let builders = [OpBuilder<(ins "Value":$arg)>]; - let verifier = [{ return ::verifyShapeOrExtentTensorOp(*this); }]; let hasCanonicalizer = 1; let hasFolder = 1; + + let extraClassDeclaration = [{ + // InferTypeOpInterface: + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); + }]; } def Shape_SizeToIndexOp : Shape_Op<"size_to_index", [NoSideEffect]> { @@ -861,7 +928,8 @@ let verifier = [{ return ::verify(*this); }]; } -def Shape_CstrEqOp : Shape_Op<"cstr_eq", [Commutative, InferTypeOpInterface]> { +def Shape_CstrEqOp + : Shape_Op<"cstr_eq", [Commutative, InferTypeOpInterface]> { let summary = "Determines if all input shapes are equal"; let description = [{ Given 1 or more input shapes, determine if all shapes are the exact same. diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -15,6 +15,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/TypeSwitch.h" @@ -401,6 +402,29 @@ result.addTypes(assumingTypes); } +//===----------------------------------------------------------------------===// +// AddOp +//===----------------------------------------------------------------------===// + +LogicalResult mlir::shape::AddOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + if (operands[0].getType().isa() || + operands[1].getType().isa()) + inferredReturnTypes.assign({SizeType::get(context)}); + else + inferredReturnTypes.assign({IndexType::get(context)}); + return success(); +} + +bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { + if (l.size() != 1 || r.size() != 1) + return false; + // SizeType is compatible with IndexType. + return true; +} + //===----------------------------------------------------------------------===// // AssumingAllOp //===----------------------------------------------------------------------===// @@ -920,6 +944,25 @@ return IntegerAttr::get(indexTy, quotient); } +LogicalResult mlir::shape::DivOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + if (operands[0].getType().isa() || + operands[1].getType().isa()) + inferredReturnTypes.assign({SizeType::get(context)}); + else + inferredReturnTypes.assign({IndexType::get(context)}); + return success(); +} + +bool mlir::shape::DivOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { + if (l.size() != 1 || r.size() != 1) + return false; + // SizeType is compatible with IndexType. + return true; +} + //===----------------------------------------------------------------------===// // ShapeEqOp //===----------------------------------------------------------------------===// @@ -1061,6 +1104,22 @@ } } +LogicalResult mlir::shape::GetExtentOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + inferredReturnTypes.assign({IndexType::get(context)}); + return success(); +} + +bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(TypeRange l, + TypeRange r) { + if (l.size() != 1 || r.size() != 1) + return false; + // SizeType is compatible with IndexType. + return true; +} + //===----------------------------------------------------------------------===// // IsBroadcastableOp //===----------------------------------------------------------------------===// @@ -1079,6 +1138,45 @@ return nullptr; } +//===----------------------------------------------------------------------===// +// JoinOp +//===----------------------------------------------------------------------===// + +LogicalResult mlir::shape::JoinOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + // The type of operands must be consistent. + if (operands[0].getType() != operands[1].getType()) + return failure(); + + if (operands[0].getType().isa()) + inferredReturnTypes.assign({ShapeType::get(context)}); + else + inferredReturnTypes.assign({SizeType::get(context)}); + return success(); +} + +bool mlir::shape::JoinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { + if (l.size() != 1 || r.size() != 1) + return false; + if (l == r) + return true; + + Type lhs = l.front(); + Type rhs = r.front(); + + if (lhs != rhs) + return false; + + if (lhs.isa() || lhs.isa()) + return true; + + if (succeeded(verifyCompatibleShapes({lhs, rhs}))) + return true; + return false; +} + //===----------------------------------------------------------------------===// // RankOp //===----------------------------------------------------------------------===// @@ -1138,6 +1236,24 @@ patterns.add(context); } +LogicalResult mlir::shape::RankOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + if (operands[0].getType().isa()) + inferredReturnTypes.assign({SizeType::get(context)}); + else + inferredReturnTypes.assign({IndexType::get(context)}); + return success(); +} + +bool mlir::shape::RankOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { + if (l.size() != 1 || r.size() != 1) + return false; + // SizeType is compatible with IndexType. + return true; +} + //===----------------------------------------------------------------------===// // NumElementsOp //===----------------------------------------------------------------------===// @@ -1156,14 +1272,23 @@ return builder.getIndexAttr(product.getLimitedValue()); } -void NumElementsOp::build(OpBuilder &builder, OperationState &result, - Value shape) { - if (shape.getType().isa()) { - auto type = builder.getIndexType(); - return build(builder, result, type, shape); - } - auto type = SizeType::get(builder.getContext()); - return build(builder, result, type, shape); +LogicalResult mlir::shape::NumElementsOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + if (operands[0].getType().isa()) + inferredReturnTypes.assign({SizeType::get(context)}); + else + inferredReturnTypes.assign({IndexType::get(context)}); + return success(); +} + +bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(TypeRange l, + TypeRange r) { + if (l.size() != 1 || r.size() != 1) + return false; + // SizeType is compatible with IndexType. + return true; } //===----------------------------------------------------------------------===// @@ -1177,6 +1302,24 @@ return nullptr; } +LogicalResult mlir::shape::MaxOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + if (operands[0].getType() == operands[1].getType()) + inferredReturnTypes.assign({operands[0].getType()}); + else + inferredReturnTypes.assign({SizeType::get(context)}); + return success(); +} + +bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { + if (l.size() != 1 || r.size() != 1) + return false; + // SizeType is compatible with IndexType. + return true; +} + //===----------------------------------------------------------------------===// // MinOp //===----------------------------------------------------------------------===// @@ -1188,6 +1331,24 @@ return nullptr; } +LogicalResult mlir::shape::MinOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + if (operands[0].getType() == operands[1].getType()) + inferredReturnTypes.assign({operands[0].getType()}); + else + inferredReturnTypes.assign({SizeType::get(context)}); + return success(); +} + +bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { + if (l.size() != 1 || r.size() != 1) + return false; + // SizeType is compatible with IndexType. + return true; +} + //===----------------------------------------------------------------------===// // MulOp //===----------------------------------------------------------------------===// @@ -1204,6 +1365,24 @@ return IntegerAttr::get(indexTy, folded); } +LogicalResult mlir::shape::MulOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + if (operands[0].getType().isa() || + operands[1].getType().isa()) + inferredReturnTypes.assign({SizeType::get(context)}); + else + inferredReturnTypes.assign({IndexType::get(context)}); + return success(); +} + +bool mlir::shape::MulOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { + if (l.size() != 1 || r.size() != 1) + return false; + // SizeType is compatible with IndexType. + return true; +} //===----------------------------------------------------------------------===// // ShapeOfOp //===----------------------------------------------------------------------===// @@ -1216,18 +1395,6 @@ return builder.getIndexTensorAttr(type.getShape()); } -void ShapeOfOp::build(OpBuilder &builder, OperationState &result, Value arg) { - if (auto shapedTy = arg.getType().dyn_cast()) { - int64_t rank = - shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize; - Type indexTy = builder.getIndexType(); - Type extentTensorTy = RankedTensorType::get({rank}, indexTy); - return ShapeOfOp::build(builder, result, extentTensorTy, arg); - } - Type shapeTy = builder.getType(); - return ShapeOfOp::build(builder, result, shapeTy, arg); -} - namespace { struct ShapeOfWithTensor : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -1282,6 +1449,41 @@ patterns.add(context); } +LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + if (operands[0].getType().isa()) + inferredReturnTypes.assign({ShapeType::get(context)}); + else { + auto shapedTy = operands[0].getType().cast(); + int64_t rank = + shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize; + Type indexTy = IndexType::get(context); + Type extentTensorTy = RankedTensorType::get({rank}, indexTy); + inferredReturnTypes.assign({extentTensorTy}); + } + return success(); +} + +bool mlir::shape::ShapeOfOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { + if (l.size() != 1 || r.size() != 1) + return false; + if (l == r) + return true; + + Type lhs = l.front(); + Type rhs = r.front(); + + if (lhs.isa() || rhs.isa()) + // Shape type is compatible with all other valid return types. + return true; + + if (succeeded(verifyCompatibleShapes({lhs, rhs}))) + return true; + return false; +} + //===----------------------------------------------------------------------===// // SizeToIndexOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Shape/invalid.mlir b/mlir/test/Dialect/Shape/invalid.mlir --- a/mlir/test/Dialect/Shape/invalid.mlir +++ b/mlir/test/Dialect/Shape/invalid.mlir @@ -97,6 +97,14 @@ // ----- +func @shape_of_incompatible_return_types(%value_arg : tensor<1x2xindex>) { + // expected-error@+1 {{'shape.shape_of' op inferred type(s) 'tensor<2xindex>' are incompatible with return type(s) of operation 'tensor<3xf32>'}} + %0 = shape.shape_of %value_arg : tensor<1x2xindex> -> tensor<3xf32> + return +} + +// ----- + func @rank(%arg : !shape.shape) { // expected-error@+1 {{if at least one of the operands can hold error values then the result must be of type `size` to propagate them}} %0 = shape.rank %arg : !shape.shape -> index