diff --git a/mlir/include/mlir/Dialect/Shape/IR/Shape.h b/mlir/include/mlir/Dialect/Shape/IR/Shape.h --- a/mlir/include/mlir/Dialect/Shape/IR/Shape.h +++ b/mlir/include/mlir/Dialect/Shape/IR/Shape.h @@ -17,6 +17,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -14,6 +14,7 @@ #define SHAPE_OPS include "mlir/Dialect/Shape/IR/ShapeBase.td" +include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -426,12 +427,16 @@ nothing else. They should not exist after a program is fully lowered and ready to execute. }]; - let arguments = (ins Shape_WitnessType); - let regions = (region SizedRegion<1>:$thenRegion); + let arguments = (ins Shape_WitnessType:$witness); + let regions = (region SizedRegion<1>:$doRegion); let results = (outs Variadic:$results); + + let printer = [{ return ::print(p, *this); }]; + let parser = [{ return ::parse$cppClass(parser, result); }]; } -def Shape_AssumingYieldOp : Shape_Op<"assuming_yield", [Terminator]> { +def Shape_AssumingYieldOp : Shape_Op<"assuming_yield", + [NoSideEffect, ReturnLike, Terminator]> { let summary = "Yield operation"; let description = [{ This yield operation represents a return operation within the assert_and_exec @@ -441,6 +446,11 @@ }]; let arguments = (ins Variadic:$operands); + + let builders = [ + OpBuilder<"OpBuilder &builder, OperationState &result", + [{ /* nothing to do */ }]> + ]; } def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable", []> { diff --git a/mlir/lib/Dialect/Shape/CMakeLists.txt b/mlir/lib/Dialect/Shape/CMakeLists.txt --- a/mlir/lib/Dialect/Shape/CMakeLists.txt +++ b/mlir/lib/Dialect/Shape/CMakeLists.txt @@ -8,6 +8,7 @@ MLIRShapeOpsIncGen LINK_LIBS PUBLIC + MLIRControlFlowInterfaces MLIRDialect MLIRInferTypeOpInterface MLIRIR diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -108,6 +108,50 @@ } //===----------------------------------------------------------------------===// +// AssumingOp +//===----------------------------------------------------------------------===// + +static ParseResult parseAssumingOp(OpAsmParser &parser, + OperationState &result) { + result.regions.reserve(1); + Region *doRegion = result.addRegion(); + + auto &builder = parser.getBuilder(); + OpAsmParser::OperandType cond; + if (parser.parseOperand(cond) || + parser.resolveOperand(cond, builder.getType(), + result.operands)) + return failure(); + + // Parse optional results type list. + if (parser.parseOptionalArrowTypeList(result.types)) + return failure(); + + // Parse the region and add a terminator if elided. + if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{})) + return failure(); + AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location); + + // Parse the optional attribute list. + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + return success(); +} + +static void print(OpAsmPrinter &p, AssumingOp op) { + bool yieldsResults = !op.results().empty(); + + p << AssumingOp::getOperationName() << " " << op.witness(); + if (yieldsResults) { + p << " -> (" << op.getResultTypes() << ")"; + } + p.printRegion(op.doRegion(), + /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/yieldsResults); + p.printOptionalAttrDict(op.getAttrs()); +} + +//===----------------------------------------------------------------------===// // BroadcastOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir --- a/mlir/test/Dialect/Shape/ops.mlir +++ b/mlir/test/Dialect/Shape/ops.mlir @@ -74,9 +74,9 @@ %w0 = "shape.cstr_broadcastable"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.witness %w1 = "shape.cstr_eq"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.witness %w3 = "shape.assuming_all"(%w0, %w1) : (!shape.witness, !shape.witness) -> !shape.witness - "shape.assuming"(%w3) ( { + shape.assuming %w3 -> !shape.shape { %2 = "shape.any"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape "shape.assuming_yield"(%2) : (!shape.shape) -> () - }) : (!shape.witness) -> !shape.shape + } return }