diff --git a/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp b/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp @@ -39,6 +39,7 @@ void mlir::populateConvertShapeConstraintsConversionPatterns( OwningRewritePatternList &patterns, MLIRContext *ctx) { patterns.insert(ctx); + patterns.insert(ctx); patterns.insert(ctx); } diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.td b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.td --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.td +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.td @@ -24,4 +24,11 @@ (Shape_IsBroadcastableOp $shapes), (BroadcastableStringAttr))>; +def EqStringAttr : NativeCodeCall<[{ + $_builder.getStringAttr("required equal shapes") +}]>; + +def CstrEqToRequire : Pat<(Shape_CstrEqOp $shapes), + (Shape_CstrRequireOp (Shape_ShapeEqOp $shapes), (EqStringAttr))>; + #endif // MLIR_CONVERSION_SHAPETOSTANDARD_TD diff --git a/mlir/test/Conversion/ShapeToStandard/convert-shape-constraints.mlir b/mlir/test/Conversion/ShapeToStandard/convert-shape-constraints.mlir --- a/mlir/test/Conversion/ShapeToStandard/convert-shape-constraints.mlir +++ b/mlir/test/Conversion/ShapeToStandard/convert-shape-constraints.mlir @@ -14,6 +14,19 @@ return %witness : !shape.witness } +// CHECK-LABEL: func @cstr_eq( +// CHECK-SAME: %[[LHS:.*]]: tensor, +// CHECK-SAME: %[[RHS:.*]]: tensor) -> !shape.witness { +// CHECK: %[[RET:.*]] = shape.const_witness true +// CHECK: %[[EQUAL_IS_VALID:.*]] = shape.shape_eq %[[LHS]], %[[RHS]] +// CHECK: assert %[[EQUAL_IS_VALID]], "required equal shapes" +// CHECK: return %[[RET]] : !shape.witness +// CHECK: } +func @cstr_eq(%arg0: tensor, %arg1: tensor) -> !shape.witness { + %witness = shape.cstr_eq %arg0, %arg1 : tensor, tensor + return %witness : !shape.witness +} + // CHECK-LABEL: func @cstr_require func @cstr_require(%arg0: i1) -> !shape.witness { // CHECK: %[[RET:.*]] = shape.const_witness true