diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -923,6 +923,70 @@ let summary = "base-e exponential of the specified value"; } +//===----------------------------------------------------------------------===// +// ExecuteRegionOp +//===----------------------------------------------------------------------===// + +def ExecuteRegionOp : Std_Op<"execute_region"> { + let summary = "operation that executes its region exactly once"; + let description = [{ + The execute_region operation executes the region held exactly once. The op + cannot have any operands, nor does its region have any arguments. All SSA + values that dominate the op can be accessed inside. The op's region can have + multiple blocks and the blocks can have terminators the same way as FuncOp. + The values returned from this op's region define the op's results. The op + primarily provides control flow encapsulation and isolation from a parent + op's control flow restrictions if any; for example, it allows representation + of inlined calls in the inside of structured control flow ops with + restrictions like affine.for/if, loop.for/if ops, and thus the optimization + of IR in such a mixed form. + + Ex: + + ```mlir + loop.for %i = 0 to 128 { + %y = execute_region -> i32 { + %x = load %A[%i] : memref<128xi32> + return %x : i32 + } + } + + affine.for %i = 0 to 100 { + "foo"() : () -> () + %v = execute_region -> i64 { + cond_br %cond, ^bb1, ^bb2 + + ^bb1: + %c1 = constant 1 : i64 + br ^bb3(%c1 : i64) + + ^bb2: + %c2 = constant 2 : i64 + br ^bb3(%c2 : i64) + + ^bb3(%x : i64): + return %x : i64 + } + "bar"(%v) : (i64) -> () + } + ``` + }]; + + let results = (outs Variadic); + + let regions = (region AnyRegion:$region); + + // TODO: If the parent is a func like op (which would be the case if all other + // ops are from the std dialect), the inliner logic could be readily used to + // inline. + let hasCanonicalizer = 0; + + // TODO: can fold if it returns a constant. + // TODO: Single block execute_region ops can be readily inlined irrespective + // of which op is a parent. Add a fold for this. + let hasFolder = 0; +} + //===----------------------------------------------------------------------===// // ExtractElementOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1381,6 +1381,61 @@ return {}; } +//===----------------------------------------------------------------------===// +// ExecuteRegionOp +//===----------------------------------------------------------------------===// + +/// +/// (ssa-id `=`)? `execute_region` `->` function-result-type `{` +/// block+ +/// `}` +// +// Ex: +// std.execute_region -> i32 { +// %idx = load %rI[%i] : memref<128xi32> +// return %idx : i32 +// } +// +static ParseResult parseExecuteRegionOp(OpAsmParser &parser, + OperationState &result) { + if (parser.parseOptionalArrowTypeList(result.types)) + return failure(); + + // Introduce the body region and parse it. + Region *body = result.addRegion(); + if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}) || + parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + // Parse the optional attribute list. + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + return success(); +} + +static void print(OpAsmPrinter &p, ExecuteRegionOp op) { + p << ExecuteRegionOp::getOperationName(); + if (op.getNumResults() > 0) + p << " -> " << op.getResultTypes(); + + p.printRegion(op.region(), + /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/true); + + p.printOptionalAttrDict(op.getAttrs()); +} + +static LogicalResult verify(ExecuteRegionOp op) { + if (op.region().empty()) + return op.emitOpError("region needs to have at least one block"); + + if (op.region().front().getNumArguments() > 0) + return op.emitOpError("region cannot have any arguments"); + + return success(); +} + //===----------------------------------------------------------------------===// // LoadOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -591,6 +591,27 @@ // CHECK: %4 = call_indirect %f_0(%arg0) : (i32) -> i32 %3 = "std.call_indirect"(%f_0, %arg0) : ((i32) -> i32, i32) -> i32 + // CHECK: execute_region -> i64 { + // CHECK-NEXT: constant + // CHECK-NEXT: return + // CHECK-NEXT: } + %4 = execute_region -> i64 { + %c1 = constant 1 : i64 + return %c1 : i64 + } + + // CHECK: execute_region { + // CHECK-NEXT: br ^bb1 + // CHECK-NEXT: ^bb1: // pred: ^bb0 + // CHECK-NEXT: return + // CHECK-NEXT: } + "std.execute_region"() ({ + ^bb0: + br ^bb1 + ^bb1: + return + }) : () -> () + return } diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -1231,3 +1231,14 @@ // expected-error @+1 {{expected unsigned integer elements, but parsed negative value}} "foo"() {bar = dense<[5, -5]> : vector<2xui32>} : () -> () } + +// ----- + +func @execute_region() { + // expected-error @+1 {{region cannot have any arguments}} + "std.execute_region"() ({ + ^bb0(%i : i32): + return + }) : () -> () + return +} diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -491,6 +491,33 @@ return } +// CHECK-LABEL: func @propagate_into_execute_region +func @propagate_into_execute_region() { + %cond = constant 0 : i1 + affine.for %i = 0 to 100 { + "foo"() : () -> () + %v = execute_region -> i64 { + cond_br %cond, ^bb1, ^bb2 + + ^bb1: + %c1 = constant 1 : i64 + br ^bb3(%c1 : i64) + + ^bb2: + %c2 = constant 2 : i64 + br ^bb3(%c2 : i64) + + ^bb3(%x : i64): + return %x : i64 + } + "bar"(%v) : (i64) -> () + // CHECK: std.execute_region -> i64 { + // CHECK-NEXT: return %c2_i64 : i64 + // CHECK-NEXT: } + } + return +} + // CHECK-LABEL: func @const_fold_propagate func @const_fold_propagate() -> memref { %VT_i = constant 512 : index