diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -609,8 +609,9 @@ def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable", [Commutative]> { let summary = "Determines if 2 shapes can be successfully broadcasted"; let description = [{ - Given 2 input shapes, return a witness specifying if they are broadcastable. - This broadcastable follows the same logic as what shape.broadcast documents. + Given two input shapes or extent tensors, return a witness specifying if + they are broadcastable. This broadcastable follows the same logic as what + shape.broadcast documents. "cstr" operations represent runtime assertions. @@ -621,10 +622,11 @@ ``` }]; - let arguments = (ins Shape_ShapeType:$lhs, Shape_ShapeType:$rhs); + let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs, + Shape_ShapeOrExtentTensorType:$rhs); let results = (outs Shape_WitnessType:$result); - let assemblyFormat = "$lhs `,` $rhs attr-dict"; + let assemblyFormat = "$lhs `,` $rhs `:` type($lhs) `,` type($rhs) attr-dict"; let hasCanonicalizer = 1; let hasFolder = 1; diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -431,7 +431,7 @@ // CHECK-NEXT: return %cs0 = shape.const_shape [3, 1] %cs1 = shape.const_shape [1, 5] - %0 = shape.cstr_broadcastable %cs0, %cs1 + %0 = shape.cstr_broadcastable %cs0, %cs1 : !shape.shape, !shape.shape "consume.witness"(%0) : (!shape.witness) -> () return } @@ -447,7 +447,7 @@ // CHECK-NEXT: return %cs0 = shape.const_shape [1, 3] %cs1 = shape.const_shape [1, 5] - %0 = shape.cstr_broadcastable %cs0, %cs1 + %0 = shape.cstr_broadcastable %cs0, %cs1 : !shape.shape, !shape.shape "consume.witness"(%0) : (!shape.witness) -> () return } @@ -461,7 +461,7 @@ // CHECK-NEXT: consume.witness // CHECK-NEXT: return %cs0 = shape.const_shape [1,3] - %0 = shape.cstr_broadcastable %arg0, %cs0 + %0 = shape.cstr_broadcastable %arg0, %cs0 : !shape.shape, !shape.shape "consume.witness"(%0) : (!shape.witness) -> () return } @@ -473,7 +473,20 @@ // CHECK-NEXT: shape.const_witness true // CHECK-NEXT: consume.witness // CHECK-NEXT: return - %0 = shape.cstr_broadcastable %arg0, %arg0 + %0 = shape.cstr_broadcastable %arg0, %arg0 : !shape.shape, !shape.shape + "consume.witness"(%0) : (!shape.witness) -> () + return +} + +// ----- + +// Broadcastable canonicalization also works on extent tensors. +// CHECK-LABEL: func @broadcastable_on_extent_tensors +func @broadcastable_on_extent_tensors(%arg : tensor) { + // CHECK-NEXT: shape.const_witness true + // CHECK-NEXT: consume.witness + // CHECK-NEXT: return + %0 = shape.cstr_broadcastable %arg, %arg : tensor, tensor "consume.witness"(%0) : (!shape.witness) -> () return } @@ -560,7 +573,7 @@ // CHECK-NEXT: return %0 = shape.const_shape [] %1 = shape.shape_of %arg0 : tensor - %2 = shape.cstr_broadcastable %0, %1 + %2 = shape.cstr_broadcastable %0, %1 : !shape.shape, !shape.shape "consume.witness"(%2) : (!shape.witness) -> () return } @@ -577,7 +590,7 @@ // CHECK-NEXT: return %0 = shape.shape_of %arg0 : tensor %1 = shape.shape_of %arg1 : tensor - %2 = shape.cstr_broadcastable %0, %1 + %2 = shape.cstr_broadcastable %0, %1 : !shape.shape, !shape.shape "consume.witness"(%2) : (!shape.witness) -> () return } @@ -592,7 +605,7 @@ // CHECK-NEXT: return %0 = shape.shape_of %arg1 : tensor %1 = shape.shape_of %arg0 : tensor<*xf32> - %2 = shape.cstr_broadcastable %0, %1 + %2 = shape.cstr_broadcastable %0, %1 : !shape.shape, !shape.shape "consume.witness"(%2) : (!shape.witness) -> () return } diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir --- a/mlir/test/Dialect/Shape/ops.mlir +++ b/mlir/test/Dialect/Shape/ops.mlir @@ -86,7 +86,7 @@ func @test_constraints() { %0 = shape.const_shape [] %1 = shape.const_shape [1, 2, 3] - %w0 = shape.cstr_broadcastable %0, %1 + %w0 = shape.cstr_broadcastable %0, %1 : !shape.shape, !shape.shape %w1 = shape.cstr_eq %0, %1 %w2 = shape.const_witness true %w3 = shape.const_witness false @@ -98,6 +98,12 @@ return } +func @broadcastable_on_extent_tensors(%lhs : tensor, + %rhs : tensor) { + %w0 = shape.cstr_broadcastable %lhs, %rhs : tensor, tensor + return +} + func @test_mul(%lhs: !shape.size, %rhs: !shape.size) -> !shape.size { %product = shape.mul %lhs, %rhs return %product: !shape.size diff --git a/mlir/test/Dialect/Shape/remove-shape-constraints.mlir b/mlir/test/Dialect/Shape/remove-shape-constraints.mlir --- a/mlir/test/Dialect/Shape/remove-shape-constraints.mlir +++ b/mlir/test/Dialect/Shape/remove-shape-constraints.mlir @@ -11,7 +11,7 @@ // REPLACE: shape.assuming %[[WITNESS]] // CANON-NEXT: test.source // CANON-NEXT: return - %0 = shape.cstr_broadcastable %arg0, %arg1 + %0 = shape.cstr_broadcastable %arg0, %arg1 : !shape.shape, !shape.shape %1 = shape.assuming %0 -> index { %2 = "test.source"() : () -> (index) shape.assuming_yield %2 : index @@ -45,7 +45,7 @@ func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> index { // CANON-NEXT: test.source // CANON-NEXT: return - %0 = shape.cstr_broadcastable %arg0, %arg1 + %0 = shape.cstr_broadcastable %arg0, %arg1 : !shape.shape, !shape.shape %1 = shape.cstr_eq %arg0, %arg1 %2 = shape.assuming_all %0, %1 %3 = shape.assuming %0 -> index {