diff --git a/mlir/include/mlir/Dialect/Shape/IR/Shape.h b/mlir/include/mlir/Dialect/Shape/IR/Shape.h --- a/mlir/include/mlir/Dialect/Shape/IR/Shape.h +++ b/mlir/include/mlir/Dialect/Shape/IR/Shape.h @@ -30,7 +30,8 @@ Shape, Size, ValueShape, - LAST_SHAPE_TYPE = ValueShape + Witness, + LAST_SHAPE_TYPE = Witness }; } // namespace ShapeTypes @@ -105,6 +106,22 @@ } }; +/// The Witness represents a runtime constraint, to be used as shape related +/// preconditions on code execution. +class WitnessType : public Type::TypeBase { +public: + using Base::Base; + + static WitnessType get(MLIRContext *context) { + return Base::get(context, ShapeTypes::Kind::Witness); + } + + /// Support method to enable LLVM-style type casting. + static bool kindof(unsigned kind) { + return kind == ShapeTypes::Kind::Witness; + } +}; + #define GET_OP_CLASSES #include "mlir/Dialect/Shape/IR/ShapeOps.h.inc" diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -17,6 +17,32 @@ include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" +def Shape_WitnessType : DialectType()">, "witness">, + BuildableType<"$_builder.getType<::mlir::shape::WitnessType>()"> { + let typeDescription = [{ + A witness is a structural device in the compiler to maintain ordering of + code relying on information obtained from passing assertions. Witnesses do + not represent any physical data. + + "cstr_" operations will return witnesses and be lowered into assertion logic + when not resolvable at compile time. + + "assuming_" operations will take witnesses as input and represent only + information to the compiler, so they do not exist in executing code. Code + that is dependent on "assuming_" operations can assume all cstr operations + transitively before are honored as true. + + These abstractions are intended to allow the compiler more freedom with + assertions by merely showing the assertion through dataflow at this time + rather than a side effecting operation that acts as a barrier. This can be + viewed similarly to a compiler representation of promises from asynchronous, + possibly crashing assertions. Reliant code will not be reordered to before + the code and non-reliant code can be reordered freely, and there are no + guarantees on the final ordering of the assertions or their related code. + }]; +} + //===----------------------------------------------------------------------===// // Shape op definitions //===----------------------------------------------------------------------===// @@ -313,4 +339,123 @@ let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// Shape constraint related ops. +//===----------------------------------------------------------------------===// + +//TODO(tpopp): Move the code below and witnesses to a different file. +def Shape_AnyOp : Shape_Op<"any", + [DeclareOpInterfaceMethods]> { + let summary = "Return any combination of the input shapes."; + let description = [{ + This operation takes multiple input shapes and returns some combination of + their dimensions. This can be best seen with examples below. + + The result is undefined, but still side-effect free, in cases where the + inputs have differing ranks or differ in extents of shared dimensions. + + Example: + ```mlir + %s0 = shape.any([2,?], [?,3]) // [2,3] + %s1 = shape.any([?,?], [1,2]) // [1,2] + ``` + }]; + + let arguments = (ins Variadic:$inputs); + let results = (outs Shape_ShapeType:$result); +} + +def Shape_AssumingAllOp : Shape_Op<"assuming_all", []> { + let summary = "Return a logical AND of all witnesses."; + let description = [{ + Used to simplify constraints as any single failing precondition is enough + to prevent execution. + + "assuming" operations represent an execution order restriction to the + compiler, information for dependent code to rely on (by assuming), and + nothing else. They should not exist after a program is fully lowered and + ready to execute. + + Example: + ```mlir + %w0 = shape.cstr_broadcastable([2,2], [3,1,2]) // Success + %w1 = shape.cstr_broadcastable([2,2], [3,2]) // Failure + %w2 = shape.cstr_eq([1,2], [1,2], [1,2]) // Success + %wf = shape.assume_all(%w0, %w1) // Failure + %wt = shape.assume_all(%w0, %w2) // Success + ``` + }]; + + let arguments = (ins Variadic:$inputs); + let results = (outs Shape_WitnessType:$result); +} + +def Shape_Assuming : Shape_Op<"assuming", + [SingleBlockImplicitTerminator<"AssumingYield">, + RecursiveSideEffects]> { + let summary = "Execute the region."; + let description = [{ + Executes the region assuming all witnesses are true. + + "assuming" operations represent an execution order restriction to the + compiler, information for dependent code to rely on (by assuming), and + nothing else. They should not exist after a program is fully lowered and + ready to execute. + }]; + let arguments = (ins Shape_WitnessType); + let regions = (region SizedRegion<1>:$thenRegion); + let results = (outs Variadic:$results); +} + +def Shape_AssumingYield : Shape_Op<"assuming_yield", [Terminator]> { + let summary = "Yield operation"; + let description = [{ + This yield operation represents a return operation within the assert_and_exec + region. The operation takes variable number of operands and produces no + results. The operand number and types must match the return signature of + the region that contains the operation. + }]; + + let arguments = (ins Variadic:$operands); +} + +def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable", []> { + let summary = "Determines if 2 shapes can be successfully broadcasted."; + let description = [{ + Given 2 input shapes, return a witness specifying if they are broadcastable. + This broadcastable follows the same logic as what shape.broadcast documents. + + "cstr" operations represent runtime assertions. + + Example: + ```mlir + %w0 = shape.cstr_broadcastable([2,2], [3,1,2]) // Success + %w1 = shape.cstr_broadcastable([2,2], [3,2]) // Failure + ``` + }]; + + let arguments = (ins Shape_ShapeType:$lhs, Shape_ShapeType:$rhs); + let results = (outs Shape_WitnessType:$result); +} + +def Shape_CstrEqOp : Shape_Op<"cstr_eq", []> { + let summary = "Determines if all input shapes are equal."; + let description = [{ + Given 1 or more input shapes, determine if all shapes are the exact same. + + "cstr" operations represent runtime assertions. + + Example: + ```mlir + %w0 = shape.cstr_eq([1,2], [1,2], [1,2]) // Success + %w1 = shape.cstr_eq([2,2], [1,2]) // Failure + ``` + }]; + let arguments = (ins Variadic:$inputs); + let results = (outs Shape_WitnessType:$result); +} + + +// Canonicalization patterns. + #endif // SHAPE_OPS diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -24,7 +24,8 @@ #define GET_OP_LIST #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" >(); - addTypes(); + addTypes(); // Allow unknown operations during prototyping and testing. As the dialect is // still evolving it makes it simple to start with an unregistered ops and // try different variants before actually defining the op. @@ -60,6 +61,8 @@ return SizeType::get(getContext()); if (keyword == "value_shape") return ValueShapeType::get(getContext()); + if (keyword == "witness") + return WitnessType::get(getContext()); parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword; return Type(); @@ -83,11 +86,27 @@ case ShapeTypes::ValueShape: os << "value_shape"; return; + case ShapeTypes::Witness: + os << "witness"; + return; default: llvm_unreachable("unexpected 'shape' type kind"); } } +//===----------------------------------------------------------------------===// +// AnyOp +//===----------------------------------------------------------------------===// + +LogicalResult +AnyOp::inferReturnTypes(MLIRContext *context, Optional location, + ValueRange operands, DictionaryAttr attributes, + RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + inferredReturnTypes.push_back(ShapeType::get(context)); + return success(); +} + //===----------------------------------------------------------------------===// // BroadcastOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir --- a/mlir/test/Dialect/Shape/ops.mlir +++ b/mlir/test/Dialect/Shape/ops.mlir @@ -67,3 +67,16 @@ %0 = shape.shape_of %arg0 : tensor return %0 : !shape.shape } + +func @test_constraints() { + %0 = shape.const_shape [] + %1 = shape.const_shape [1, 2, 3] + %w0 = "shape.cstr_broadcastable"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.witness + %w1 = "shape.cstr_eq"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.witness + %w3 = "shape.assuming_all"(%w0, %w1) : (!shape.witness, !shape.witness) -> !shape.witness + "shape.assuming"(%w3) ( { + %2 = "shape.any"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape + "shape.assuming_yield"(%2) : (!shape.shape) -> () + }) : (!shape.witness) -> !shape.shape + return +}