diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -93,7 +93,8 @@ let verifier = [{ return ::verify(*this); }]; } -def Shape_ConstShapeOp : Shape_Op<"const_shape", [ConstantLike, NoSideEffect]> { +def Shape_ConstShapeOp : Shape_Op<"const_shape", + [ConstantLike, NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "Creates a constant shape or extent tensor"; let description = [{ Creates a constant shape or extent tensor. The individual extents are given @@ -103,7 +104,7 @@ ```mlir %0 = shape.const_shape [] : !shape.shape %1 = shape.const_shape [1, 2, 3] : !shape.shape - %2 = shape.const_shape [4, 5, 6] : tensor + %2 = shape.const_shape [4, 5, 6] : tensor<3xindex> ``` }]; let arguments = (ins IndexElementsAttr:$shape); @@ -114,6 +115,11 @@ let parser = [{ return ::parse$cppClass(parser, result); }]; let hasFolder = 1; let hasCanonicalizer = 1; + + let extraClassDeclaration = [{ + // InferTypeOpInterface: + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); + }]; } def Shape_ConstSizeOp : Shape_Op<"const_size", [ diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -740,6 +740,50 @@ patterns.add(context); } +LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + Builder b(context); + auto shape = attributes.getAs("shape"); + if (!shape) + return emitOptionalError(location, "missing shape attribute"); + inferredReturnTypes.assign({RankedTensorType::get( + {static_cast(shape.size())}, b.getIndexType())}); + return success(); +} + +bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l, + TypeRange r) { + if (l.size() != 1 || r.size() != 1) + return false; + if (l == r) + return true; + + Type lhs = l.front(); + Type rhs = r.front(); + + if (lhs.isa() || rhs.isa()) + // Shape type is compatible with all other valid return types. + return true; + + auto lhsTensorType = lhs.cast(); + auto rhsTensorType = rhs.cast(); + if (lhsTensorType.getElementType() != rhsTensorType.getElementType()) + return false; + if (!lhsTensorType.hasRank() || !rhsTensorType.hasRank()) + // If either is unranked, then it is compatible. + return true; + if (lhsTensorType.getRank() != 1 || rhsTensorType.getRank() != 1) + return false; + + if (ShapedType::isDynamic(lhsTensorType.getDimSize(0)) || + ShapedType::isDynamic(rhsTensorType.getDimSize(0))) + return true; + + return lhsTensorType.getDimSize(0) == rhsTensorType.getDimSize(0); +} + //===----------------------------------------------------------------------===// // CstrBroadcastableOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Shape/invalid.mlir b/mlir/test/Dialect/Shape/invalid.mlir --- a/mlir/test/Dialect/Shape/invalid.mlir +++ b/mlir/test/Dialect/Shape/invalid.mlir @@ -254,3 +254,13 @@ %0 = shape.cstr_broadcastable %arg : !shape.shape return %0 : !shape.witness } + +// ----- + +// Test that type inference flags the wrong return type. + +func @const_shape() { + // expected-error@+1 {{'tensor<3xindex>' are incompatible with return type(s) of operation 'tensor<2xindex>'}} + %0 = shape.const_shape [4, 5, 6] : tensor<2xindex> + return +} diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir --- a/mlir/test/Dialect/Shape/ops.mlir +++ b/mlir/test/Dialect/Shape/ops.mlir @@ -36,6 +36,7 @@ func @const_shape() { %0 = shape.const_shape [1, 2, 3] : !shape.shape %1 = shape.const_shape [4, 5, 6] : tensor + %2 = shape.const_shape [4, 5, 6] : tensor<3xindex> return }