diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td @@ -55,6 +55,70 @@ let verifier = ?; } +//===----------------------------------------------------------------------===// +// ExecuteRegionOp +//===----------------------------------------------------------------------===// + +def ExecuteRegionOp : SCF_Op<"execute_region"> { + let summary = "operation that executes its region exactly once"; + let description = [{ + The `execute_region` operation executes the region held exactly once. The op + cannot have any operands, nor does its region have any arguments. All SSA + values that dominate the op can be accessed inside. The op's region can have + multiple blocks and the blocks can have terminators the same way as FuncOp. + The values returned from this op's region define the op's results. The op + primarily provides control flow encapsulation and isolation from a parent + op's control flow restrictions if any; for example, it allows representation + of inlined calls in the inside of structured control flow ops with + restrictions like affine.for/if, scf.for/if ops, and thus the optimization + of IR in such a mixed form. + + Example: + + ```mlir + scf.for %i = 0 to 128 step %c1 { + %y = scf.execute_region -> i32 { + %x = load %A[%i] : memref<128xi32> + scf.yield %x : i32 + } + } + + affine.for %i = 0 to 100 { + "foo"() : () -> () + %v = scf.execute_region -> i64 { + cond_br %cond, ^bb1, ^bb2 + + ^bb1: + %c1 = constant 1 : i64 + br ^bb3(%c1 : i64) + + ^bb2: + %c2 = constant 2 : i64 + br ^bb3(%c2 : i64) + + ^bb3(%x : i64): + scf.yield %x : i64 + } + "bar"(%v) : (i64) -> () + } + ``` + }]; + + let results = (outs Variadic); + + let regions = (region AnyRegion:$region); + + // TODO: If the parent is a func like op (which would be the case if all other + // ops are from the std dialect), the inliner logic could be readily used to + // inline. + let hasCanonicalizer = 0; + + // TODO: can fold if it returns a constant. + // TODO: Single block execute_region ops can be readily inlined irrespective + // of which op is a parent. Add a fold for this. + let hasFolder = 0; +} + def ForOp : SCF_Op<"for", [DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, @@ -598,8 +662,8 @@ } def YieldOp : SCF_Op<"yield", [NoSideEffect, ReturnLike, Terminator, - ParentOneOf<["IfOp, ForOp", "ParallelOp", - "WhileOp"]>]> { + ParentOneOf<["ExecuteRegionOp, ForOp", + "IfOp, ParallelOp, WhileOp"]>]> { let summary = "loop yield and termination operation"; let description = [{ "scf.yield" yields an SSA value from the SCF dialect op region and diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -69,6 +69,59 @@ builder.create(loc); } +//===----------------------------------------------------------------------===// +// ExecuteRegionOp +//===----------------------------------------------------------------------===// + +/// +/// (ssa-id `=`)? `execute_region` `->` function-result-type `{` +/// block+ +/// `}` +/// +/// Example: +/// std.execute_region -> i32 { +/// %idx = load %rI[%i] : memref<128xi32> +/// return %idx : i32 +/// } +/// +static ParseResult parseExecuteRegionOp(OpAsmParser &parser, + OperationState &result) { + if (parser.parseOptionalArrowTypeList(result.types)) + return failure(); + + // Introduce the body region and parse it. + Region *body = result.addRegion(); + if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}) || + parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + // Parse the optional attribute list. + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + return success(); +} + +static void print(OpAsmPrinter &p, ExecuteRegionOp op) { + p << ExecuteRegionOp::getOperationName(); + if (op.getNumResults() > 0) + p << " -> " << op.getResultTypes(); + + p.printRegion(op.region(), + /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/true); + + p.printOptionalAttrDict(op->getAttrs()); +} + +static LogicalResult verify(ExecuteRegionOp op) { + if (op.region().empty()) + return op.emitOpError("region needs to have at least one block"); + if (op.region().front().getNumArguments() > 0) + return op.emitOpError("region cannot have any arguments"); + return success(); +} + //===----------------------------------------------------------------------===// // ForOp //===----------------------------------------------------------------------===// @@ -205,9 +258,9 @@ parser.parseArrowTypeList(result.types)) return failure(); // Resolve input operands. - for (auto operand_type : llvm::zip(operands, result.types)) - if (parser.resolveOperand(std::get<0>(operand_type), - std::get<1>(operand_type), result.operands)) + for (auto operandType : llvm::zip(operands, result.types)) + if (parser.resolveOperand(std::get<0>(operandType), + std::get<1>(operandType), result.operands)) return failure(); } // Induction variable. @@ -240,7 +293,7 @@ } LogicalResult ForOp::moveOutOfLoop(ArrayRef ops) { - for (auto op : ops) + for (auto *op : ops) op->moveBefore(*this); return success(); } @@ -1618,7 +1671,7 @@ } LogicalResult ParallelOp::moveOutOfLoop(ArrayRef ops) { - for (auto op : ops) + for (auto *op : ops) op->moveBefore(*this); return success(); } diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -897,3 +897,33 @@ // CHECK-NEXT: "test.firstCodeTrue"() : () -> () // CHECK-NEXT: "test.secondCodeTrue"() : () -> () // CHECK-NEXT: } + +// ----- + +// CHECK-LABEL: func @propagate_into_execute_region +func @propagate_into_execute_region() { + %cond = constant 0 : i1 + affine.for %i = 0 to 100 { + "test.foo"() : () -> () + %v = scf.execute_region -> i64 { + cond_br %cond, ^bb1, ^bb2 + + ^bb1: + %c1 = constant 1 : i64 + br ^bb3(%c1 : i64) + + ^bb2: + %c2 = constant 2 : i64 + br ^bb3(%c2 : i64) + + ^bb3(%x : i64): + scf.yield %x : i64 + } + "test.bar"(%v) : (i64) -> () + // CHECK: %[[C2:.*]] = constant 2 : i64 + // CHECK: scf.execute_region -> i64 { + // CHECK-NEXT: scf.yield %[[C2]] : i64 + // CHECK-NEXT: } + } + return +} diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir --- a/mlir/test/Dialect/SCF/invalid.mlir +++ b/mlir/test/Dialect/SCF/invalid.mlir @@ -428,7 +428,7 @@ func @yield_invalid_parent_op() { "my.op"() ({ - // expected-error@+1 {{'scf.yield' op expects parent op to be one of 'scf.if, scf.for, scf.parallel, scf.while'}} + // expected-error@+1 {{'scf.yield' op expects parent op to be one of 'scf.execute_region, scf.for, scf.if, scf.parallel, scf.while'}} scf.yield }) : () -> () return @@ -510,3 +510,14 @@ "some.other_terminator"() : () -> () } } + +// ----- + +func @execute_region() { + // expected-error @+1 {{region cannot have any arguments}} + "scf.execute_region"() ({ + ^bb0(%i : i32): + scf.yield + }) : () -> () + return +} diff --git a/mlir/test/Dialect/SCF/ops.mlir b/mlir/test/Dialect/SCF/ops.mlir --- a/mlir/test/Dialect/SCF/ops.mlir +++ b/mlir/test/Dialect/SCF/ops.mlir @@ -279,3 +279,28 @@ } return } + +// CHECK-LABEL: func @execute_region +func @execute_region() -> i64 { + // CHECK: scf.execute_region -> i64 { + // CHECK-NEXT: constant + // CHECK-NEXT: scf.yield + // CHECK-NEXT: } + %res = scf.execute_region -> i64 { + %c1 = constant 1 : i64 + scf.yield %c1 : i64 + } + + // CHECK: scf.execute_region { + // CHECK-NEXT: br ^bb1 + // CHECK-NEXT: ^bb1: + // CHECK-NEXT: scf.yield + // CHECK-NEXT: } + "scf.execute_region"() ({ + ^bb0: + br ^bb1 + ^bb1: + scf.yield + }) : () -> () + return %res : i64 +} diff --git a/mlir/utils/vim/syntax/mlir.vim b/mlir/utils/vim/syntax/mlir.vim --- a/mlir/utils/vim/syntax/mlir.vim +++ b/mlir/utils/vim/syntax/mlir.vim @@ -53,8 +53,10 @@ syn match mlirOps /\/ syn match mlirOps /\/ syn match mlirOps /\/ -syn match mlirOps /\/ -syn match mlirOps /\/ +syn match mlirOps /\/ +syn match mlirOps /\/ +syn match mlirOps /\/ +syn match mlirOps /\/ " TODO: dialect name prefixed ops (llvm or std).