diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -738,5 +738,27 @@ let hasFolder = 1; } +def Shape_CstrRequireOp : Shape_Op<"cstr_require", []> { + let summary = "Represents a runtime assertion that an i1 is `true`"; + let description = [{ + Represents a runtime assretion that an i1 is true. It returns a + !shape.witness to order this assertion. + + For simplicity, prefer using other cstr_* ops if they are available for a + given constraint. + + Example: + ```mlir + %bool = ... + %w0 = shape.cstr_require %bool // Passing if `%bool` is true. + ``` + }]; + let arguments = (ins I1:$pred); + let results = (outs Shape_WitnessType:$result); + + let assemblyFormat = "$pred attr-dict"; + + let hasFolder = 1; +} #endif // SHAPE_OPS diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -490,6 +490,14 @@ OpFoldResult ConstWitnessOp::fold(ArrayRef) { return passingAttr(); } +//===----------------------------------------------------------------------===// +// CstrRequireOp +//===----------------------------------------------------------------------===// + +OpFoldResult CstrRequireOp::fold(ArrayRef operands) { + return operands[0]; +} + //===----------------------------------------------------------------------===// // ShapeEqOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -386,7 +386,31 @@ } // ----- +// cstr_require with constant can be folded +// CHECK-LABEL: func @cstr_require_fold +func @cstr_require_fold() { + // CHECK-NEXT: shape.const_witness true + // CHECK-NEXT: consume.witness + // CHECK-NEXT: return + %true = constant true + %0 = shape.cstr_require %true + "consume.witness"(%0) : (!shape.witness) -> () + return +} + +// ----- +// cstr_require without constant cannot be folded +// CHECK-LABEL: func @cstr_require_no_fold +func @cstr_require_no_fold(%arg0: i1) { + // CHECK-NEXT: shape.cstr_require + // CHECK-NEXT: consume.witness + // CHECK-NEXT: return + %0 = shape.cstr_require %arg0 + "consume.witness"(%0) : (!shape.witness) -> () + return +} +// ----- // assuming_all with known passing witnesses can be folded // CHECK-LABEL: func @f func @f() { diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir --- a/mlir/test/Dialect/Shape/ops.mlir +++ b/mlir/test/Dialect/Shape/ops.mlir @@ -100,12 +100,14 @@ func @test_constraints() { %0 = shape.const_shape [] : !shape.shape %1 = shape.const_shape [1, 2, 3] : !shape.shape + %true = constant true %w0 = shape.cstr_broadcastable %0, %1 : !shape.shape, !shape.shape %w1 = shape.cstr_eq %0, %1 %w2 = shape.const_witness true %w3 = shape.const_witness false - %w4 = shape.assuming_all %w0, %w1, %w2, %w3 - shape.assuming %w4 -> !shape.shape { + %w4 = shape.cstr_require %true + %w_all = shape.assuming_all %w0, %w1, %w2, %w3, %w4 + shape.assuming %w_all -> !shape.shape { %2 = "shape.any"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape shape.assuming_yield %2 : !shape.shape }