diff --git a/mlir/include/mlir/Dialect/Shape/IR/Shape.h b/mlir/include/mlir/Dialect/Shape/IR/Shape.h --- a/mlir/include/mlir/Dialect/Shape/IR/Shape.h +++ b/mlir/include/mlir/Dialect/Shape/IR/Shape.h @@ -17,6 +17,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -14,6 +14,7 @@ #define SHAPE_OPS include "mlir/Dialect/Shape/IR/ShapeBase.td" +include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -406,12 +407,16 @@ nothing else. They should not exist after a program is fully lowered and ready to execute. }]; - let arguments = (ins Shape_WitnessType); - let regions = (region SizedRegion<1>:$thenRegion); + let arguments = (ins Shape_WitnessType:$condition); + let regions = (region SizedRegion<1>:$doRegion); let results = (outs Variadic:$results); + + let printer = [{ return ::print(p, *this); }]; + let parser = [{ return ::parse$cppClass(parser, result); }]; } -def Shape_AssumingYieldOp : Shape_Op<"assuming_yield", [Terminator]> { +def Shape_AssumingYieldOp : Shape_Op<"assuming_yield", + [NoSideEffect, ReturnLike, Terminator]> { let summary = "Yield operation"; let description = [{ This yield operation represents a return operation within the assert_and_exec @@ -421,6 +426,10 @@ }]; let arguments = (ins Variadic:$operands); + let builders = [ + OpBuilder<"OpBuilder &builder, OperationState &result", + [{ /* nothing to do */ }]> + ]; let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; } diff --git a/mlir/lib/Dialect/Shape/CMakeLists.txt b/mlir/lib/Dialect/Shape/CMakeLists.txt --- a/mlir/lib/Dialect/Shape/CMakeLists.txt +++ b/mlir/lib/Dialect/Shape/CMakeLists.txt @@ -8,6 +8,7 @@ MLIRShapeOpsIncGen LINK_LIBS PUBLIC + MLIRControlFlowInterfaces MLIRDialect MLIRInferTypeOpInterface MLIRIR diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -107,6 +107,50 @@ return success(); } +//===----------------------------------------------------------------------===// +// AssumingOp +//===----------------------------------------------------------------------===// + +static ParseResult parseAssumingOp(OpAsmParser &parser, + OperationState &result) { + result.regions.reserve(1); + Region *doRegion = result.addRegion(); + + auto &builder = parser.getBuilder(); + OpAsmParser::OperandType cond; + if (parser.parseOperand(cond) || + parser.resolveOperand(cond, builder.getType(), + result.operands)) + return failure(); + + // Parse optional results type list. + if (parser.parseOptionalArrowTypeList(result.types)) + return failure(); + + // Parse the region and add a terminator if elided. + if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{})) + return failure(); + AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location); + + // Parse the optional attribute list. + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + return success(); +} + +static void print(OpAsmPrinter &p, AssumingOp op) { + bool yieldsResults = !op.results().empty(); + + p << AssumingOp::getOperationName() << " " << op.condition(); + if (yieldsResults) { + p << " -> (" << op.getResultTypes() << ")"; + } + p.printRegion(op.doRegion(), + /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/yieldsResults); + p.printOptionalAttrDict(op.getAttrs()); +} + //===----------------------------------------------------------------------===// // BroadcastOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir --- a/mlir/test/Dialect/Shape/ops.mlir +++ b/mlir/test/Dialect/Shape/ops.mlir @@ -74,9 +74,9 @@ %w0 = shape.cstr_broadcastable(%0, %1) %w1 = shape.cstr_eq(%0, %1) %w3 = shape.assuming_all(%w0, %w1) - "shape.assuming"(%w3) ( { + shape.assuming %w3 -> !shape.shape { %2 = shape.any(%0, %1) shape.assuming_yield %2 : !shape.shape - }) : (!shape.witness) -> !shape.shape + } return }