diff --git a/mlir/include/mlir/Dialect/Shape/IR/Shape.h b/mlir/include/mlir/Dialect/Shape/IR/Shape.h --- a/mlir/include/mlir/Dialect/Shape/IR/Shape.h +++ b/mlir/include/mlir/Dialect/Shape/IR/Shape.h @@ -30,7 +30,8 @@ Shape, Size, ValueShape, - LAST_SHAPE_TYPE = ValueShape + Witness, + LAST_SHAPE_TYPE = Witness }; } // namespace ShapeTypes @@ -105,6 +106,22 @@ } }; +/// The Witness represents a runtime constraint, to be used as shape related +/// preconditions on code execution. +class WitnessType : public Type::TypeBase { +public: + using Base::Base; + + static WitnessType get(MLIRContext *context) { + return Base::get(context, ShapeTypes::Kind::Witness); + } + + /// Support method to enable LLVM-style type casting. + static bool kindof(unsigned kind) { + return kind == ShapeTypes::Kind::Witness; + } +}; + #define GET_OP_CLASSES #include "mlir/Dialect/Shape/IR/ShapeOps.h.inc" diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -17,6 +17,16 @@ include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffects.td" +def Shape_WitnessType : DialectType()">, "witness">, + BuildableType<"$_builder.getType<::mlir::shape::WitnessType>()"> { + let typeDescription = [{ + A witness is a specialized boolean value to represent preconditions to code + being executed with the alternative state being a platform specific form of + assertion. + }]; +} + //===----------------------------------------------------------------------===// // Shape op definitions //===----------------------------------------------------------------------===// @@ -311,4 +321,101 @@ let hasFolder = 1; } +// Shape constraint related ops. +def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable", []> { + let summary = "Determines if 2 shapes can be successfully broadcasted."; + let description = [{ + Given 2 input shapes, return a witness specifying if they are broadcastable. + This broadcastable follows the same logic as what shape.broadcast documents. + + Example: + ```mlir + %w0 = shape.cstr_broadcastable([2,2], [3,1,2]) // Success + %w1 = shape.cstr_broadcastable([2,2], [3,2]) // Failure + ``` + }]; + + let arguments = (ins Shape_ShapeType:$lhs, Shape_ShapeType:$rhs); + let results = (outs Shape_WitnessType:$result); +} + +def Shape_CstrEqOp : Shape_Op<"cstr_eq", []> { + let summary = "Determines if all input shapes are equal."; + let description = [{ + Given 1 or more input shapes, determine if all shapes are the exact same. + This returns a witness for runtime checking of this property. + + Example: + ```mlir + %w0 = shape.cstr_eq([1,2], [1,2], [1,2]) // Success + %w1 = shape.cstr_eq([2,2], [1,2]) // Failure + ``` + }]; + let arguments = (ins Variadic:$inputs); + let results = (outs Shape_WitnessType:$result); +} + +def Shape_AnyOp : Shape_Op<"any", + [DeclareOpInterfaceMethods]> { + let summary = "Return any combination of the input shapes."; + let description = [{ + This operation takes multiple inputs and returns some combination of their + inputs. The returned set of dimensions is not guaranteed but is likely to + reserve a constant shape as soon as a combination of the dimensions allows + that. + + Example: + ```mlir + %s0 = shape.any([2,2], [1,3]) // [1,2] + %s1 = shape.any([?,?], [1,2]) // [1,2] + ``` + }]; + + let arguments = (ins Variadic:$inputs); + let results = (outs Shape_ShapeType:$result); +} + +def Shape_CstrAllOp : Shape_Op<"cstr_all", []> { + let summary = "Return a logical AND of all witnesses."; + let description = [{ + Used to simplify constraints as any single failing precondition is enough + to prevent execution. + + Example: + ```mlir + %w0 = shape.cstr_broadcastable([2,2], [3,1,2]) // Success + %w1 = shape.cstr_broadcastable([2,2], [3,2]) // Failure + %w2 = shape.cstr_eq([1,2], [1,2], [1,2]) // Success + %wf = shape.cstr_all(%w0, %w1) // Failure + %wt = shape.cstr_all(%w0, %w2) // Success + ``` + }]; + + let arguments = (ins Variadic:$inputs); + let results = (outs Shape_WitnessType:$result); +} + +def Shape_AssertAndExecYieldOp : Shape_Op<"assert_and_exec_yield", [Terminator]> { + let summary = "Yield operation"; + let description = [{ + This yield operation represents a return operation within the assert_and_exec + region. The operation takes variable number of operands and produces no + results. The operand number and types must match the return signature of + the region that contains the operation. + }]; + + let arguments = (ins Variadic:$operands); +} + +def Shape_AssertAndExec : Shape_Op<"assert_and_exec", + [SingleBlockImplicitTerminator<"AssertAndExecYieldOp">, RecursiveSideEffects]> { + let summary = "assert on witness before executing the region."; + let description = [{ + Executes the region if and only if all witness inputs are true. + }]; + let arguments = (ins Variadic); + let regions = (region SizedRegion<1>:$thenRegion); + let results = (outs Variadic:$results); +} + #endif // SHAPE_OPS diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -24,7 +24,8 @@ #define GET_OP_LIST #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" >(); - addTypes(); + addTypes(); // Allow unknown operations during prototyping and testing. As the dialect is // still evolving it makes it simple to start with an unregistered ops and // try different variants before actually defining the op. @@ -60,6 +61,8 @@ return SizeType::get(getContext()); if (keyword == "value_shape") return ValueShapeType::get(getContext()); + if (keyword == "witness") + return WitnessType::get(getContext()); parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword; return Type(); @@ -83,6 +86,9 @@ case ShapeTypes::ValueShape: os << "value_shape"; return; + case ShapeTypes::Witness: + os << "witness"; + return; default: llvm_unreachable("unexpected 'shape' type kind"); } @@ -282,6 +288,18 @@ return DenseIntElementsAttr::get(type, shape); } +//===----------------------------------------------------------------------===// +// AnyOp +//===----------------------------------------------------------------------===// +LogicalResult +AnyOp::inferReturnTypes(MLIRContext *context, Optional location, + ValueRange operands, DictionaryAttr attributes, + RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + inferredReturnTypes.push_back(ShapeType::get(context)); + return success(); +} + namespace mlir { namespace shape { diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir --- a/mlir/test/Dialect/Shape/ops.mlir +++ b/mlir/test/Dialect/Shape/ops.mlir @@ -62,3 +62,16 @@ %1 = shape.const_shape [1, 2, 3] return } + +func @test_constraints() { + %0 = shape.const_shape [] + %1 = shape.const_shape [1, 2, 3] + %w0 = "shape.cstr_broadcastable"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.witness + %w1 = "shape.cstr_eq"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.witness + %2 = "shape.any"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape + %w3 = "shape.all"(%w0, %w1) : (!shape.witness, !shape.witness) -> !shape.witness + "shape.assert_and_exec"(%w3) ( { + "shape.assert_and_exec_yield"(%2) : (!shape.shape) -> () + }) : (!shape.witness) -> !shape.shape + return +}