diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td @@ -36,6 +36,25 @@ let parser = [{ return ::parse$cppClass(parser, result); }]; } +def ConditionOp : SCF_Op<"condition", + [HasParent<"WhileOp">, NoSideEffect, Terminator]> { + let summary = "loop continuation condition"; + let description = [{ + This operation accepts the continuation (i.e., inverse of exit) condition + of the `scf.while` construct. If its first argument is true, the "after" + region of `scf.while` is executed, with the remaining arguments forwarded + to the entry block of the region. Otherwise, the loop terminates. + }]; + + let arguments = (ins I1:$condition, Variadic:$args); + + let assemblyFormat = + [{ `(` $condition `)` attr-dict ($args^ `:` type($args))? }]; + + // Override the default verifier, everything is checked by traits. + let verifier = ?; +} + def ForOp : SCF_Op<"for", [DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, @@ -413,8 +432,135 @@ let assemblyFormat = "$result attr-dict `:` type($result)"; } +def WhileOp : SCF_Op<"while", + [DeclareOpInterfaceMethods, + RecursiveSideEffects]> { + let summary = "a generic 'while' loop"; + let description = [{ + This operation represents a generic "while"/"do-while" loop that keeps + iterating as long as a condition is satisfied. There is no restriction on + the complexity of the condition. It consists of two regions (with single + block each): "before" region and "after" region. The names of regions + indicates whether they execute before or after the condition check. + Therefore, if the main loop payload is located in the "before" region, the + operation is a "do-while" loop. Otherwise, it is a "while" loop. + + The "before" region terminates with a special operation, `scf.condition`, + that accepts as its first operand an `i1` value indicating whether to + proceed to the "after" region (value is `true`) or not. The two regions + communicate by means of region arguments. Initially, the "before" region + accepts as arguments the operands of the `scf.while` operation and uses them + to evaluate the condition. It forwards the trailing, non-condition operands + of the `scf.condition` terminator either to the "after" region if the + control flow is transferred there or to results of the `scf.while` operation + otherwise. The "after" region takes as arguments the values produced by the + "before" region and uses `scf.yield` to supply new arguments for the "after" + region, into which it transfers the control flow unconditionally. + + A simple "while" loop can be represented as follows. + + ```mlir + %res = scf.while (%arg1 = %init1) : (f32) -> f32 { + /* "Before" region. + * In a "while" loop, this region computes the condition. */ + %condition = call @evaluate_condition(%arg1) : (f32) -> i1 + + /* Forward the argument (as result or "after" region argument). */ + scf.condition(%condition) %arg1 : f32 + + } do { + ^bb0(%arg2: f32): + /* "After region. + * In a "while" loop, this region is the loop body. */ + %next = call @payload(%arg2) : (f32) -> f32 + + /* Forward the new value to the "before" region. + * The operand types must match the types of the `scf.while` operands. */ + scf.yield %next : f32 + } + ``` + + A simple "do-while" loop can be represented by reducing the "after" block + to a simple forwarder. + + ```mlir + %res = scf.while (%arg1 = %init1) : (f32) -> f32 { + /* "Before" region. + * In a "do-while" loop, this region contains the loop body. */ + %next = call @payload(%arg1) : (f32) -> f32 + + /* And also evalutes the condition. */ + %condition = call @evaluate_condition(%arg1) : (f32) -> i1 + + /* Loop through the "after" region. */ + scf.condition(%condition) %next : f32 + + } do { + ^bb0(%arg2: f32): + /* "After" region. + * Forwards the values back to "before" region unmodified. */ + scf.yield %arg2 : f32 + } + ``` + + Note that the types of region arguments need not to match with each other. + The op expects the operand types to match with argument types of the + "before" region"; the result types to match with the trailing operand types + of the terminator of the "before" region, and with the argument types of the + "after" region. The following scheme can be used to share the results of + some operations executed in the "before" region with the "after" region, + avoiding the need to recompute them. + + ```mlir + %res = scf.while (%arg1 = %init1) : (f32) -> i64 { + /* One can perform some computations, e.g., necessary to evaluate the + * condition, in the "before" region and forward their results to the + * "after" region. */ + %shared = call @shared_compute(%arg1) : (f32) -> i64 + + /* Evalute the condition. */ + %condition = call @evaluate_condition(%arg1, %shared) : (f32, i64) -> i1 + + /* Forward the result of the shared computation to the "after" region. + * The types must match the arguments of the "after" region as well as + * those of the `scf.while` results. */ + scf.condition(%condition) %shared : i64 + + } do { + ^bb0(%arg2: i64) { + /* Use the partial result to compute the rest of the payload in the + * "after" region. */ + %res = call @payload(%arg2) : (i64) -> f32 + + /* Forward the new value to the "before" region. + * The operand types must match the types of the `scf.while` operands. */ + scf.yield %res : f32 + } + ``` + + The custom syntax for this operation is as follows. + + ``` + op ::= `scf.while` assignments `:` function-type region `do` region + `attributes` attribute-dict + initializer ::= /* empty */ | `(` assignment-list `)` + assignment-list ::= assignment | assignment `,` assignment-list + assignment ::= ssa-value `=` ssa-value + ``` + }]; + + let arguments = (ins Variadic:$inits); + let results = (outs Variadic:$results); + let regions = (region SizedRegion<1>:$before, SizedRegion<1>:$after); + + let extraClassDeclaration = [{ + OperandRange getSuccessorEntryOperands(unsigned index); + }]; +} + def YieldOp : SCF_Op<"yield", [NoSideEffect, ReturnLike, Terminator, - ParentOneOf<["IfOp, ForOp", "ParallelOp"]>]> { + ParentOneOf<["IfOp, ForOp", "ParallelOp", + "WhileOp"]>]> { let summary = "loop yield and termination operation"; let description = [{ "scf.yield" yields an SSA value from the SCF dialect op region and @@ -436,4 +582,5 @@ // needed. let verifier = ?; } + #endif // MLIR_DIALECT_SCF_SCFOPS diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -751,11 +751,18 @@ parseOptionalColonTypeList(SmallVectorImpl &result) = 0; /// Parse a list of assignments of the form - /// (%x1 = %y1 : type1, %x2 = %y2 : type2, ...). - /// The list must contain at least one entry - virtual ParseResult - parseAssignmentList(SmallVectorImpl &lhs, - SmallVectorImpl &rhs) = 0; + /// (%x1 = %y1, %x2 = %y2, ...) + ParseResult parseAssignmentList(SmallVectorImpl &lhs, + SmallVectorImpl &rhs) { + OptionalParseResult result = parseOptionalAssignmentList(lhs, rhs); + if (!result.hasValue()) + return emitError(getCurrentLocation(), "expected '('"); + return result.getValue(); + } + + virtual OptionalParseResult + parseOptionalAssignmentList(SmallVectorImpl &lhs, + SmallVectorImpl &rhs) = 0; /// Parse a keyword followed by a type. ParseResult parseKeywordType(const char *keyword, Type &result) { diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -140,26 +140,37 @@ return RegionBranchOpInterface::verifyTypes(op); } +/// Prints the initialization list in the form of +/// (%inner = %outer, %inner2 = %outer2, <...>) +/// where 'inner' values are assumed to be region arguments and 'outer' values +/// are regular SSA values. +static void printInitializationList(OpAsmPrinter &p, + Block::BlockArgListType blocksArgs, + ValueRange initializers, + StringRef prefix = "") { + assert(blocksArgs.size() == initializers.size() && + "expected same length of arguments and initializers"); + if (initializers.empty()) + return; + + p << prefix << '('; + llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) { + p << std::get<0>(it) << " = " << std::get<1>(it); + }); + p << ")"; +} + static void print(OpAsmPrinter &p, ForOp op) { - bool printBlockTerminators = false; p << op.getOperationName() << " " << op.getInductionVar() << " = " << op.lowerBound() << " to " << op.upperBound() << " step " << op.step(); - if (op.hasIterOperands()) { - p << " iter_args("; - auto regionArgs = op.getRegionIterArgs(); - auto operands = op.getIterOperands(); - - llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) { - p << std::get<0>(it) << " = " << std::get<1>(it); - }); - p << ")"; - p << " -> (" << op.getResultTypes() << ")"; - printBlockTerminators = true; - } + printInitializationList(p, op.getRegionIterArgs(), op.getIterOperands(), + " iter_args"); + if (!op.getIterOperands().empty()) + p << " -> (" << op.getIterOperands().getTypes() << ')'; p.printRegion(op.region(), /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/printBlockTerminators); + /*printBlockTerminators=*/op.hasIterOperands()); p.printOptionalAttrDict(op.getAttrs()); } @@ -933,6 +944,158 @@ return success(); } +//===----------------------------------------------------------------------===// +// WhileOp +//===----------------------------------------------------------------------===// + +OperandRange WhileOp::getSuccessorEntryOperands(unsigned index) { + assert(index == 0 && + "WhileOp is expected to branch only to the first region"); + + return inits(); +} + +void WhileOp::getSuccessorRegions(Optional index, + ArrayRef operands, + SmallVectorImpl ®ions) { + (void)operands; + + if (!index.hasValue()) { + regions.emplace_back(&before(), before().getArguments()); + return; + } + + assert(*index < 2 && "there are only two regions in a WhileOp"); + if (*index == 0) { + regions.emplace_back(&after(), after().getArguments()); + regions.emplace_back(getResults()); + return; + } + + regions.emplace_back(&before(), before().getArguments()); +} + +/// Parses a `while` op. +/// +/// op ::= `scf.while` assignments `:` function-type region `do` region +/// `attributes` attribute-dict +/// initializer ::= /* empty */ | `(` assignment-list `)` +/// assignment-list ::= assignment | assignment `,` assignment-list +/// assignment ::= ssa-value `=` ssa-value +static ParseResult parseWhileOp(OpAsmParser &parser, OperationState &result) { + SmallVector regionArgs, operands; + Region *before = result.addRegion(); + Region *after = result.addRegion(); + + OptionalParseResult listResult = + parser.parseOptionalAssignmentList(regionArgs, operands); + if (listResult.hasValue() && failed(listResult.getValue())) + return failure(); + + FunctionType functionType; + llvm::SMLoc typeLoc = parser.getCurrentLocation(); + if (failed(parser.parseColonType(functionType))) + return failure(); + + result.addTypes(functionType.getResults()); + + if (functionType.getNumInputs() != operands.size()) { + return parser.emitError(typeLoc) + << "expected as many input types as operands " + << "(expected " << operands.size() << " got " + << functionType.getNumInputs() << ")"; + } + + // Resolve input operands. + if (failed(parser.resolveOperands(operands, functionType.getInputs(), + parser.getCurrentLocation(), + result.operands))) + return failure(); + + return failure( + parser.parseRegion(*before, regionArgs, functionType.getInputs()) || + parser.parseKeyword("do") || parser.parseRegion(*after) || + parser.parseOptionalAttrDictWithKeyword(result.attributes)); +} + +/// Prints a `while` op. +static void print(OpAsmPrinter &p, scf::WhileOp op) { + p << op.getOperationName(); + printInitializationList(p, op.before().front().getArguments(), op.inits(), + " "); + p << " : "; + p.printFunctionalType(op.inits().getTypes(), op.results().getTypes()); + p.printRegion(op.before(), /*printEntryBlockArgs=*/false); + p << " do"; + p.printRegion(op.after()); + p.printOptionalAttrDictWithKeyword(op.getAttrs()); +} + +/// Verifies that two ranges of types match, i.e. have the same number of +/// entries and that types are pairwise equals. Reports errors on the given +/// operation in case of mismatch. +template +static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left, + TypeRange right, StringRef message) { + if (left.size() != right.size()) + return op.emitOpError("expects the same number of ") << message; + + for (unsigned i = 0, e = left.size(); i < e; ++i) { + if (left[i] != right[i]) { + InFlightDiagnostic diag = op.emitOpError("expects the same types for ") + << message; + diag.attachNote() << "for argument " << i << ", found " << left[i] + << " and " << right[i]; + return diag; + } + } + + return success(); +} + +/// Verifies that the first block of the given `region` is terminated by a +/// YieldOp. Reports errors on the given operation if it is not the case. +template +static TerminatorTy verifyAndGetTerminator(scf::WhileOp op, Region ®ion, + StringRef errorMessage) { + Operation *terminatorOperation = region.front().getTerminator(); + if (auto yield = dyn_cast_or_null(terminatorOperation)) + return yield; + + auto diag = op.emitOpError(errorMessage); + if (terminatorOperation) + diag.attachNote(terminatorOperation->getLoc()) << "terminator here"; + return nullptr; +} + +static LogicalResult verify(scf::WhileOp op) { + if (failed(RegionBranchOpInterface::verifyTypes(op))) + return failure(); + + auto beforeTerminator = verifyAndGetTerminator( + op, op.before(), + "expects the 'before' region to terminate with 'scf.condition'"); + if (!beforeTerminator) + return failure(); + + TypeRange trailingTerminatorOperands = beforeTerminator.args().getTypes(); + if (failed(verifyTypeRangesMatch(op, trailingTerminatorOperands, + op.after().getArgumentTypes(), + "trailing operands of the 'before' block " + "terminator and 'after' region arguments"))) + return failure(); + + if (failed(verifyTypeRangesMatch( + op, trailingTerminatorOperands, op.getResultTypes(), + "trailing operands of the 'before' block terminator and op results"))) + return failure(); + + auto afterTerminator = verifyAndGetTerminator( + op, op.after(), + "expects the 'after' region to terminate with 'scf.yield'"); + return success(afterTerminator != nullptr); +} + //===----------------------------------------------------------------------===// // YieldOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp --- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp +++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp @@ -76,10 +76,13 @@ /// Verify that types match along all region control flow edges originating from /// `sourceNo` (region # if source is a region, llvm::None if source is parent /// op). `getInputsTypesForRegion` is a function that returns the types of the -/// inputs that flow from `sourceIndex' to the given region. -static LogicalResult verifyTypesAlongAllEdges( - Operation *op, Optional sourceNo, - function_ref)> getInputsTypesForRegion) { +/// inputs that flow from `sourceIndex' to the given region, or llvm::None if +/// the exact type match verification is not necessary (e.g., if the Op verifies +/// the match itself). +static LogicalResult +verifyTypesAlongAllEdges(Operation *op, Optional sourceNo, + function_ref(Optional)> + getInputsTypesForRegion) { auto regionInterface = cast(op); SmallVector successors; @@ -113,17 +116,20 @@ return diag; }; - TypeRange sourceTypes = getInputsTypesForRegion(succRegionNo); + Optional sourceTypes = getInputsTypesForRegion(succRegionNo); + if (!sourceTypes.hasValue()) + continue; + TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes(); - if (sourceTypes.size() != succInputsTypes.size()) { + if (sourceTypes->size() != succInputsTypes.size()) { InFlightDiagnostic diag = op->emitOpError(" region control flow edge "); - return printEdgeName(diag) << ": source has " << sourceTypes.size() + return printEdgeName(diag) << ": source has " << sourceTypes->size() << " operands, but target successor needs " << succInputsTypes.size(); } for (auto typesIdx : - llvm::enumerate(llvm::zip(sourceTypes, succInputsTypes))) { + llvm::enumerate(llvm::zip(*sourceTypes, succInputsTypes))) { Type sourceType = std::get<0>(typesIdx.value()); Type inputType = std::get<1>(typesIdx.value()); if (sourceType != inputType) { @@ -191,10 +197,15 @@ << " operands mismatch between return-like terminators"; } - auto inputTypesFromRegion = [&](Optional regionNo) -> TypeRange { + auto inputTypesFromRegion = + [&](Optional regionNo) -> Optional { + // If there is no return-like terminator, the op itself should verify + // type consistency. + if (!regionReturn) + return llvm::None; + // All successors get the same set of operands. - return regionReturn ? TypeRange(regionReturn->getOperands().getTypes()) - : TypeRange(); + return TypeRange(regionReturn->getOperands().getTypes()); }; if (failed(verifyTypesAlongAllEdges(op, regionNo, inputTypesFromRegion))) diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -1465,10 +1465,13 @@ } /// Parse a list of assignments of the form - /// (%x1 = %y1 : type1, %x2 = %y2 : type2, ...). - /// The list must contain at least one entry - ParseResult parseAssignmentList(SmallVectorImpl &lhs, - SmallVectorImpl &rhs) override { + /// (%x1 = %y1, %x2 = %y2, ...). + OptionalParseResult + parseOptionalAssignmentList(SmallVectorImpl &lhs, + SmallVectorImpl &rhs) override { + if (failed(parseOptionalLParen())) + return llvm::None; + auto parseElt = [&]() -> ParseResult { OperandType regionArg, operand; if (parseRegionArgument(regionArg) || parseEqual() || @@ -1478,8 +1481,6 @@ rhs.push_back(operand); return success(); }; - if (parseLParen()) - return failure(); return parser.parseCommaSeparatedListUntil(Token::r_paren, parseElt); } diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir --- a/mlir/test/Dialect/SCF/invalid.mlir +++ b/mlir/test/Dialect/SCF/invalid.mlir @@ -425,10 +425,88 @@ } // ----- + func @yield_invalid_parent_op() { "my.op"() ({ - // expected-error@+1 {{'scf.yield' op expects parent op to be one of 'scf.if, scf.for, scf.parallel'}} + // expected-error@+1 {{'scf.yield' op expects parent op to be one of 'scf.if, scf.for, scf.parallel, scf.while'}} scf.yield }) : () -> () return } + +// ----- + +func @while_parser_type_mismatch() { + %true = constant true + // expected-error@+1 {{expected as many input types as operands (expected 0 got 1)}} + scf.while : (i32) -> () { + scf.condition(%true) + } do { + scf.yield + } +} + +// ----- + +func @while_bad_terminator() { + // expected-error@+1 {{expects the 'before' region to terminate with 'scf.condition'}} + scf.while : () -> () { + // expected-note@+1 {{terminator here}} + "some.other_terminator"() : () -> () + } do { + scf.yield + } +} + +// ----- + +func @while_cross_region_type_mismatch() { + %true = constant true + // expected-error@+1 {{expects the same number of trailing operands of the 'before' block terminator and 'after' region arguments}} + scf.while : () -> () { + scf.condition(%true) + } do { + ^bb0(%arg0: i32): + scf.yield + } +} + +// ----- + +func @while_cross_region_type_mismatch() { + %true = constant true + // expected-error@+2 {{expects the same types for trailing operands of the 'before' block terminator and 'after' region arguments}} + // expected-note@+1 {{for argument 0, found 'i1' and 'i32}} + scf.while : () -> () { + scf.condition(%true) %true : i1 + } do { + ^bb0(%arg0: i32): + scf.yield + } +} + +// ----- + +func @while_result_type_mismatch() { + %true = constant true + // expected-error@+1 {{expects the same number of trailing operands of the 'before' block terminator and op results}} + scf.while : () -> () { + scf.condition(%true) %true : i1 + } do { + ^bb0(%arg0: i1): + scf.yield + } +} + +// ----- + +func @while_bad_terminator() { + %true = constant true + // expected-error@+1 {{expects the 'after' region to terminate with 'scf.yield'}} + scf.while : () -> () { + scf.condition(%true) + } do { + // expected-note@+1 {{terminator here}} + "some.other_terminator"() : () -> () + } +} diff --git a/mlir/test/Dialect/SCF/ops.mlir b/mlir/test/Dialect/SCF/ops.mlir --- a/mlir/test/Dialect/SCF/ops.mlir +++ b/mlir/test/Dialect/SCF/ops.mlir @@ -240,3 +240,42 @@ // CHECK-NEXT: scf.yield %[[IFRES]] : f32 // CHECK-NEXT: } // CHECK-NEXT: return %[[RESULT]] + +// CHECK-LABEL: @while +func @while() { + %0 = "test.get_some_value"() : () -> i32 + %1 = "test.get_some_value"() : () -> f32 + + // CHECK: = scf.while (%{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}) : (i32, f32) -> (i64, f64) { + %2:2 = scf.while (%arg0 = %0, %arg1 = %1) : (i32, f32) -> (i64, f64) { + %3:2 = "test.some_operation"(%arg0, %arg1) : (i32, f32) -> (i64, f64) + %4 = "test.some_condition"(%arg0, %arg1) : (i32, f32) -> i1 + // CHECK: scf.condition(%{{.*}}) %{{.*}}, %{{.*}} : i64, f64 + scf.condition(%4) %3#0, %3#1 : i64, f64 + // CHECK: } do { + } do { + // CHECK: ^{{.*}}(%{{.*}}: i64, %{{.*}}: f64): + ^bb0(%arg2: i64, %arg3: f64): + %5:2 = "test.some_operation"(%arg2, %arg3): (i64, f64) -> (i32, f32) + // CHECK: scf.yield %{{.*}}, %{{.*}} : i32, f32 + scf.yield %5#0, %5#1 : i32, f32 + // CHECK: attributes {foo = "bar"} + } attributes {foo="bar"} + return +} + +// CHECK-LABEL: @infinite_while +func @infinite_while() { + %true = constant true + + // CHECK: scf.while : () -> () { + scf.while : () -> () { + // CHECK: scf.condition(%{{.*}}) + scf.condition(%true) + // CHECK: } do { + } do { + // CHECK: scf.yield + scf.yield + } + return +}