diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td @@ -110,6 +110,11 @@ def Shape_SizeOrIndexType : AnyTypeOf<[Shape_SizeType, Index], "size or index">; +// Any type representing a shape or size/dim. +def Shape_AnyShapeOrSizeType : AnyTypeOf< + [Shape_SizeOrIndexType, Shape_ShapeOrExtentTensorType], + "any shape or size">; + def Shape_WitnessType : Shape_Type<"Witness", "witness"> { let description = [{ A witness is a structural device in the compiler to maintain ordering of diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -406,11 +406,11 @@ def Shape_MeetOp : Shape_Op<"meet", [Commutative, DeclareOpInterfaceMethods]> { - let summary = "Returns the least general shape.shape of its operands"; + let summary = "Returns the least general shape or size of its operands"; let description = [{ - An operation that computes the least general shape of input operands. + An operation that computes the least general shape or dim of input operands. This effectively asserts that corresponding static dimensions are equal. - The behavior is to match each element of the `shape.shape` and propagate the + The behavior is to match each element of the shape/size and propagate the most restrictive information, returning an invalid shape if there are contradictory requirements. E.g., using pseudo code @@ -433,9 +433,11 @@ ``` }]; - let arguments = (ins Shape_ShapeOrSizeType:$arg0, Shape_ShapeOrSizeType:$arg1, - OptionalAttr:$error); - let results = (outs Shape_ShapeOrSizeType:$result); + let arguments = (ins + Shape_AnyShapeOrSizeType:$arg0, + Shape_AnyShapeOrSizeType:$arg1, + OptionalAttr:$error); + let results = (outs Shape_AnyShapeOrSizeType:$result); let assemblyFormat = [{ $arg0 `,` $arg1 (`,` `error` `=` $error^)? attr-dict `:` diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -1309,7 +1309,53 @@ MLIRContext *context, Optional location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { - inferredReturnTypes.assign({operands[0].getType()}); + if (operands.empty()) + return failure(); + + auto isShapeType = [](Type arg) { + if (arg.isa()) + return true; + return isExtentTensorType(arg); + }; + + ValueRange::type_range types = operands.getTypes(); + Type acc = types.front(); + for (auto t : drop_begin(types)) { + Type l = acc, r = t; + if (!l.isa()) + std::swap(l, r); + + // Handle sizes, propagate error type if present. + if (l.isa()) { + if (r.isa()) + acc = l; + else + return emitOptionalError(location, "requires all sizes or shapes"); + } else if (l.isa()) { + if (r.isa()) + acc = r; + else + return emitOptionalError(location, "requires all sizes or shapes"); + } else if (l.isa()) { + // Handle shapes, propagate error type if present. + if (isShapeType(r)) + acc = l; + else + return emitOptionalError(location, "requires all sizes or shapes"); + } else if (isExtentTensorType(l)) { + auto rank1 = l.cast().getShape()[0]; + auto rank2 = r.cast().getShape()[0]; + if (ShapedType::isDynamic(rank1)) + acc = l; + else if (ShapedType::isDynamic(rank2)) + acc = r; + else if (rank1 != rank2) + return emitOptionalError(location, "unequal shape cardinality"); + else + acc = l; + } + } + inferredReturnTypes.assign({acc}); return success(); } @@ -1322,11 +1368,13 @@ Type lhs = l.front(); Type rhs = r.front(); - if (lhs != rhs) - return false; + if (!lhs.isa()) + std::swap(lhs, rhs); - if (lhs.isa() || lhs.isa()) - return true; + if (lhs.isa()) + return rhs.isa(); + if (lhs.isa()) + return rhs.isa(); if (succeeded(verifyCompatibleShapes({lhs, rhs}))) return true; diff --git a/mlir/test/Dialect/Shape/invalid.mlir b/mlir/test/Dialect/Shape/invalid.mlir --- a/mlir/test/Dialect/Shape/invalid.mlir +++ b/mlir/test/Dialect/Shape/invalid.mlir @@ -272,3 +272,20 @@ %0 = shape.const_shape [4, 5, 6] : tensor<2xindex> return } + +// ----- + +func.func @invalid_meet(%arg0 : !shape.shape, %arg1 : index) -> index { + // expected-error@+1 {{requires all sizes or shapes}} + %result = shape.meet %arg0, %arg1 : !shape.shape, index -> index + return %result : index +} + +// ----- + +func.func @invalid_meet(%arg0 : tensor<2xindex>, %arg1 : tensor<3xindex>) -> tensor { + // expected-error@+1 {{unequal shape cardinality}} + %result = shape.meet %arg0, %arg1 : tensor<2xindex>, tensor<3xindex> -> tensor + return %result : tensor +} + diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir --- a/mlir/test/Dialect/Shape/ops.mlir +++ b/mlir/test/Dialect/Shape/ops.mlir @@ -325,3 +325,9 @@ !shape.size, !shape.size -> !shape.size return %2 : !shape.size } + +func.func @meet_index(%arg0 : index, %arg1 : index) -> index { + %result = shape.meet %arg0, %arg1 : index, index -> index + return %result : index +} +