diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -783,7 +783,7 @@ let verifier = [{ return ::verify(*this); }]; } -def Shape_CstrEqOp : Shape_Op<"cstr_eq", [Commutative]> { +def Shape_CstrEqOp : Shape_Op<"cstr_eq", [Commutative, InferTypeOpInterface]> { let summary = "Determines if all input shapes are equal"; let description = [{ Given 1 or more input shapes, determine if all shapes are the exact same. @@ -796,10 +796,21 @@ %w1 = shape.cstr_eq [2,2], [1,2] // Failure ``` }]; - let arguments = (ins Variadic:$inputs); + let arguments = (ins Variadic:$shapes); let results = (outs Shape_WitnessType:$result); - let assemblyFormat = "$inputs attr-dict"; + let assemblyFormat = "$shapes attr-dict `:` type($shapes)"; + + let extraClassDeclaration = [{ + // TODO: This should really be automatic. Figure out how to not need this defined. + static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, + ::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands, + ::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions, + ::llvm::SmallVectorImpl<::mlir::Type>&inferredReturnTypes) { + inferredReturnTypes.push_back(::mlir::shape::WitnessType::get(context)); + return success(); + }; + }]; let hasCanonicalizer = 1; let hasFolder = 1; diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -360,7 +360,7 @@ // CHECK-NEXT: shape.const_witness true // CHECK-NEXT: consume.witness // CHECK-NEXT: return - %0 = shape.cstr_eq %arg0, %arg0, %arg0 + %0 = shape.cstr_eq %arg0, %arg0, %arg0 : !shape.shape, !shape.shape, !shape.shape "consume.witness"(%0) : (!shape.witness) -> () return } @@ -375,7 +375,7 @@ %cs0 = shape.const_shape [0, 1] : !shape.shape %cs1 = shape.const_shape [0, 1] : !shape.shape %cs2 = shape.const_shape [0, 1] : !shape.shape - %0 = shape.cstr_eq %cs0, %cs1, %cs2 + %0 = shape.cstr_eq %cs0, %cs1, %cs2 : !shape.shape, !shape.shape, !shape.shape "consume.witness"(%0) : (!shape.witness) -> () return } @@ -391,7 +391,7 @@ // CHECK-NEXT: return %cs0 = shape.const_shape [0, 1] : !shape.shape %cs1 = shape.const_shape [3, 1] : !shape.shape - %0 = shape.cstr_eq %cs0, %cs1 + %0 = shape.cstr_eq %cs0, %cs1 : !shape.shape, !shape.shape "consume.witness"(%0) : (!shape.witness) -> () return } @@ -403,7 +403,7 @@ // CHECK-NEXT: shape.cstr_eq // CHECK-NEXT: consume.witness // CHECK-NEXT: return - %0 = shape.cstr_eq %arg0, %arg1 + %0 = shape.cstr_eq %arg0, %arg1 : !shape.shape, !shape.shape "consume.witness"(%0) : (!shape.witness) -> () return } diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir --- a/mlir/test/Dialect/Shape/ops.mlir +++ b/mlir/test/Dialect/Shape/ops.mlir @@ -102,7 +102,7 @@ %1 = shape.const_shape [1, 2, 3] : !shape.shape %true = constant true %w0 = shape.cstr_broadcastable %0, %1 : !shape.shape, !shape.shape - %w1 = shape.cstr_eq %0, %1 + %w1 = shape.cstr_eq %0, %1 : !shape.shape, !shape.shape %w2 = shape.const_witness true %w3 = shape.const_witness false %w4 = shape.cstr_require %true, "msg" @@ -114,6 +114,12 @@ return } +func @eq_on_extent_tensors(%lhs : tensor, + %rhs : tensor) { + %w0 = shape.cstr_eq %lhs, %rhs : tensor, tensor + return +} + func @broadcastable_on_extent_tensors(%lhs : tensor, %rhs : tensor) { %w0 = shape.cstr_broadcastable %lhs, %rhs : tensor, tensor diff --git a/mlir/test/Dialect/Shape/remove-shape-constraints.mlir b/mlir/test/Dialect/Shape/remove-shape-constraints.mlir --- a/mlir/test/Dialect/Shape/remove-shape-constraints.mlir +++ b/mlir/test/Dialect/Shape/remove-shape-constraints.mlir @@ -29,7 +29,7 @@ // REPLACE: shape.assuming %[[WITNESS]] // CANON-NEXT: test.source // CANON-NEXT: return - %0 = shape.cstr_eq %arg0, %arg1 + %0 = shape.cstr_eq %arg0, %arg1 : !shape.shape, !shape.shape %1 = shape.assuming %0 -> index { %2 = "test.source"() : () -> (index) shape.assuming_yield %2 : index @@ -46,7 +46,7 @@ // CANON-NEXT: test.source // CANON-NEXT: return %0 = shape.cstr_broadcastable %arg0, %arg1 : !shape.shape, !shape.shape - %1 = shape.cstr_eq %arg0, %arg1 + %1 = shape.cstr_eq %arg0, %arg1 : !shape.shape, !shape.shape %2 = shape.assuming_all %0, %1 %3 = shape.assuming %0 -> index { %4 = "test.source"() : () -> (index)