diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -713,10 +713,10 @@ let verifier = [{ return ::verify(*this); }]; } -def Shape_AssumingOp : Shape_Op<"assuming", - [SingleBlockImplicitTerminator<"AssumingYieldOp">, - DeclareOpInterfaceMethods, - RecursiveSideEffects]> { +def Shape_AssumingOp : Shape_Op<"assuming", [ + SingleBlockImplicitTerminator<"AssumingYieldOp">, + DeclareOpInterfaceMethods, + RecursiveSideEffects]> { let summary = "Execute the region"; let description = [{ Executes the region assuming all witnesses are true. @@ -742,6 +742,11 @@ static void inlineRegionIntoParent(AssumingOp &op, PatternRewriter &rewriter); }]; + let builders = [ + OpBuilder<(ins "Value":$witness, + CArg<"function_ref(OpBuilder &, Location)>">)> + ]; + let hasCanonicalizer = 1; } @@ -757,7 +762,9 @@ let arguments = (ins Variadic:$operands); - let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>]; + let builders = [ + OpBuilder<(ins), [{ /* nothing to do */ }]>, + ]; let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; } diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -311,6 +311,27 @@ rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming); } +void AssumingOp::build( + OpBuilder &builder, OperationState &result, Value witness, + function_ref(OpBuilder &, Location)> bodyBuilder) { + + result.addOperands(witness); + Region *bodyRegion = result.addRegion(); + bodyRegion->push_back(new Block); + Block &bodyBlock = bodyRegion->front(); + + // Build body. + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(&bodyBlock); + SmallVector yieldValues = bodyBuilder(builder, result.location); + builder.create(result.location, yieldValues); + + SmallVector assumingTypes; + for (Value v : yieldValues) + assumingTypes.push_back(v.getType()); + result.addTypes(assumingTypes); +} + //===----------------------------------------------------------------------===// // AssumingAllOp //===----------------------------------------------------------------------===//