diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -448,7 +448,7 @@ // Shape constraint related ops. //===----------------------------------------------------------------------===// -//TODO(tpopp): Move the code below and witnesses to a different file. +//TODO: Move the code below and witnesses to a different file. def Shape_AnyOp : Shape_Op<"any", [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "Return any combination of the input shapes."; @@ -485,11 +485,11 @@ Example: ```mlir - %w0 = shape.cstr_broadcastable [2,2], [3,1,2] // Success + %w0 = shape.cstr_broadcastable [2,2], [3,1,2] // Passing %w1 = shape.cstr_broadcastable [2,2], [3,2] // Failure - %w2 = shape.cstr_eq [1,2], [1,2], [1,2] // Success + %w2 = shape.cstr_eq [1,2], [1,2], [1,2] // Passing %wf = shape.assuming_all %w0, %w1 // Failure - %wt = shape.assuming_all %w0, %w2 // Success + %wt = shape.assuming_all %w0, %w2 // Passing ``` }]; @@ -549,7 +549,7 @@ Example: ```mlir - %w0 = shape.cstr_broadcastable [2,2], [3,1,2] // Success + %w0 = shape.cstr_broadcastable [2,2], [3,1,2] // Passing %w1 = shape.cstr_broadcastable [2,2], [3,2] // Failure ``` }]; @@ -569,7 +569,7 @@ Example: ```mlir - %w0 = shape.cstr_eq [1,2], [1,2], [1,2] // Success + %w0 = shape.cstr_eq [1,2], [1,2], [1,2] // Passing %w1 = shape.cstr_eq [2,2], [1,2] // Failure ``` }]; @@ -579,6 +579,28 @@ let assemblyFormat = "$inputs attr-dict"; } -// Canonicalization patterns. +def Shape_ConstWitnessOp : Shape_Op<"const_witness", [ConstantLike, NoSideEffect]> { + let summary = "An operation that returns a statically known witness value"; + let description = [{ + This operation represents a statically known witness result. This can be + often used to canonicalize/fold constraint and assuming code that will always + pass. + + ```mlir + %0 = shape.const_shape [1,2,3] + %1 = shape.const_shape [1, 2, 3] + %w0 = shape.cstr_eq(%0, %1) // Can be folded to "const_witness true" + %w1 = shape.const_witness true + %w2 = shape.assuming_all(%w0, %w2) // Can be folded to "const_witness true" + ``` + }]; + let arguments = (ins BoolAttr:$passing); + let results = (outs Shape_WitnessType:$result); + + let assemblyFormat = "$passing attr-dict"; + + let hasFolder = 1; +} + #endif // SHAPE_OPS diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -42,6 +42,9 @@ if (auto sizeType = type.dyn_cast()) { return builder.create(loc, type, value.cast()); } + if (auto witnessType = type.dyn_cast()) { + return builder.create(loc, type, value.cast()); + } return nullptr; } @@ -266,6 +269,12 @@ OpFoldResult ConstSizeOp::fold(ArrayRef) { return valueAttr(); } +//===----------------------------------------------------------------------===// +// ConstWitnessOp +//===----------------------------------------------------------------------===// + +OpFoldResult ConstWitnessOp::fold(ArrayRef) { return passingAttr(); } + //===----------------------------------------------------------------------===// // IndexToSizeOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir --- a/mlir/test/Dialect/Shape/ops.mlir +++ b/mlir/test/Dialect/Shape/ops.mlir @@ -73,8 +73,10 @@ %1 = shape.const_shape [1, 2, 3] %w0 = shape.cstr_broadcastable %0, %1 %w1 = shape.cstr_eq %0, %1 - %w3 = shape.assuming_all %w0, %w1 - shape.assuming %w3 -> !shape.shape { + %w2 = shape.const_witness true + %w3 = shape.const_witness false + %w4 = shape.assuming_all %w0, %w1, %w2, %w3 + shape.assuming %w4 -> !shape.shape { %2 = shape.any %0, %1 shape.assuming_yield %2 : !shape.shape }