diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -509,11 +509,14 @@ //===----------------------------------------------------------------------===// // TODO: Move the code below and witnesses to a different file. -def Shape_AnyOp : Shape_Op<"any", [Commutative, NoSideEffect]> { +def Shape_AnyOp : Shape_Op<"any", [Commutative, + NoSideEffect, + SameOperandsAndResultType]> { let summary = "Return any combination of the input shapes"; let description = [{ - This operation takes multiple input shapes and returns some combination of - their dimensions. This can be best seen with examples below. + This operation takes multiple input shapes or extent tensors and returns + some combination of their dimensions. This can be best seen with examples + below. The result is undefined, but still side-effect free, in cases where the inputs have differing ranks or differ in extents of shared dimensions. @@ -525,11 +528,10 @@ ``` }]; - let arguments = (ins Variadic:$inputs); - let results = (outs Shape_ShapeType:$result); - - let assemblyFormat = "$inputs attr-dict"; + let arguments = (ins Variadic:$inputs); + let results = (outs Shape_ShapeOrExtentTensorType:$result); + let assemblyFormat = "$inputs `:` type($result) attr-dict"; let hasFolder = 1; } diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -165,11 +165,12 @@ // Lower `any` to its first operand. // CHECK-LABEL: @any_of_three // CHECK-SAME: (%[[A:.*]]: tensor, %[[B:.*]]: tensor, %[[C:.*]]: tensor) -> tensor -func @any_of_three(%a : !shape.shape, %b : !shape.shape, %c : !shape.shape) - -> !shape.shape { +func @any_of_three(%a : tensor, + %b : tensor, + %c : tensor) -> tensor { // CHECK: return %[[A]] : tensor - %result = shape.any %a, %b, %c - return %result : !shape.shape + %result = shape.any %a, %b, %c : tensor + return %result : tensor } // ----- @@ -177,9 +178,9 @@ // Lower `any` to its first operand. // CHECK-LABEL: @any_of_one // CHECK-SAME: (%[[A:.*]]: tensor) -> tensor -func @any_of_one(%a : !shape.shape) -> !shape.shape { +func @any_of_one(%a : tensor) -> tensor { // CHECK: return %[[A]] : tensor - %result = shape.any %a - return %result : !shape.shape + %result = shape.any %a : tensor + return %result : tensor } diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -364,14 +364,25 @@ // any can be replaced with a constant input if it has one. // CHECK-LABEL: func @f -func @f(%arg0 : !shape.shape) -> !shape.shape { +func @f(%arg : !shape.shape) -> !shape.shape { // CHECK-NEXT: %[[CS:.*]] = shape.const_shape // CHECK-NEXT: return %[[CS]] %0 = shape.const_shape [2, 3, 4] : !shape.shape - %1 = shape.any %0, %arg0 + %1 = shape.any %0, %arg : !shape.shape return %1 : !shape.shape } +// ----- + +// any can be replaced with a constant input if it has one. +// CHECK-LABEL: func @f +func @f(%arg : tensor) -> tensor { + // CHECK-NEXT: %[[CS:.*]] = shape.const_shape [2, 3, 4] : tensor + // CHECK-NEXT: return %[[CS]] : tensor + %0 = shape.const_shape [2, 3, 4] : tensor + %1 = shape.any %0, %arg : tensor + return %1 : tensor +} // ----- @@ -380,7 +391,7 @@ func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> !shape.shape { // CHECK-NEXT: %[[CS:.*]] = shape.any // CHECK-NEXT: return %[[CS]] - %1 = shape.any %arg0, %arg1 + %1 = shape.any %arg0, %arg1 : !shape.shape return %1 : !shape.shape } diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir --- a/mlir/test/Dialect/Shape/ops.mlir +++ b/mlir/test/Dialect/Shape/ops.mlir @@ -1,4 +1,3 @@ -// RUN: mlir-opt -split-input-file %s | mlir-opt | FileCheck %s // Verify the printed output can be parsed. // RUN: mlir-opt %s | mlir-opt | FileCheck %s // Verify the generic form can be parsed. @@ -99,7 +98,7 @@ %w3 = shape.const_witness false %w4 = shape.assuming_all %w0, %w1, %w2, %w3 shape.assuming %w4 -> !shape.shape { - %2 = shape.any %0, %1 + %2 = shape.any %0, %1 : !shape.shape shape.assuming_yield %2 : !shape.shape } return @@ -173,3 +172,14 @@ %result = shape.get_extent %arg, %c0 : tensor return %result : !shape.size } + +func @any() { + %0 = shape.const_shape [1, 2, 3] : !shape.shape + %1 = shape.const_shape [4, 5, 6] : !shape.shape + %2 = shape.any %0, %1 : !shape.shape + %3 = shape.const_shape [1, 2, 3] : tensor + %4 = shape.const_shape [4, 5, 6] : tensor + %5 = shape.any %3, %4 : tensor + return +} +