diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -368,7 +368,7 @@ // Shape constraint related ops. //===----------------------------------------------------------------------===// -//TODO(tpopp): Move the code below and witnesses to a different file. +//TODO: Move the code below and witnesses to a different file. def Shape_AnyOp : Shape_Op<"any", [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "Return any combination of the input shapes."; @@ -499,7 +499,25 @@ let assemblyFormat = "$inputs attr-dict"; } +// TODO: Support witness attributes and then make this ConstantLike. +// Note: This operation might be replaced with a general op that takes a +// True/False Attribute. +def Shape_TrueWitnessOp : Shape_Op<"true_witness", [NoSideEffect]> { + let summary = "An operation that returns a successful witness."; + let description = [{ + %0 = shape.const_shape [1, 2, 3] + %1 = shape.const_shape [1, 2, 3] + %w0 = shape.cstr_eq(%0, %1) // Can be canonicalized to true_witness + %w1 = shape.true_witness + %w2 = shape.assuming_all(%w0, %w2) // Can be canonicalized to true_witness + }]; + let builders = [OpBuilder< + "OpBuilder &b, OperationState &result", + "build(b, result, ::mlir::shape::WitnessType::get(b.getContext()));" + >]; -// Canonicalization patterns. + let assemblyFormat = "attr-dict"; + let results = (outs Shape_WitnessType:$result); +} #endif // SHAPE_OPS diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir --- a/mlir/test/Dialect/Shape/ops.mlir +++ b/mlir/test/Dialect/Shape/ops.mlir @@ -73,7 +73,8 @@ %1 = shape.const_shape [1, 2, 3] %w0 = shape.cstr_broadcastable %0, %1 %w1 = shape.cstr_eq %0, %1 - %w3 = shape.assuming_all %w0, %w1 + %w2 = shape.true_witness + %w3 = shape.assuming_all %w0, %w1, %w2 shape.assuming %w3 -> !shape.shape { %2 = shape.any %0, %1 shape.assuming_yield %2 : !shape.shape