diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td @@ -413,8 +413,93 @@ let assemblyFormat = "$result attr-dict `:` type($result)"; } +def WhileOp : SCF_Op<"while", [RecursiveSideEffects]> { + let summary = "a generic 'while' loop"; + let description = [{ + This operation represents generic "while"/"do-while" loop that keeps + iterating as long as a condition is satisfied. There is no restriction on + the complexity of the condition. It consists of two regions (with single + block each): "before" region and "after" region. The names of regions + indicates whether they execute before or after the condition check. + Therefore, if the main loop payload is located in the "before" region, the + operation is a "do-while" loop. Otherwise, it is a "while" loop. + + The terminator of the "before" region accepts as its first operand an `i1` + value indicating whether to proceed to the "after" region (value is `true`) + or not. Two regions communicate by means of region arguments. Initially, the + "before" region accepts as arguments the operands of the `scf.while` + operation and uses them to evaluate the condition. It uses the `scf.yield` + terminator with the condition value, followed by a list of values that are + either forwarded to the "after" region if the control flow is transferred + there or returned as results of the `scf.while` operation otherwise. The + "after" region takes as arguments the values yielded by the "before" region + and uses `scf.yield` to supply new arguments for the "after" region, into + which it transfers the control flow unconditionally. + + A simple "while" loop can be represented as follows. + + ```mlir + %res = scf.while (%arg1 = %init1) : (f32) -> f32 { + /* "Before" region. */ + %condition = call @evaluate_condition(%arg1) : (f32) -> i1 + + /* Forward the argument (as result or "after" region argument). */ + scf.yield %condition, %arg1 : i1, f32 + + } do { + ^bb0(%arg2: f32): + %next = call @payload(%arg2) : (f32) -> f32 + + /* Forward the new value to the "before" region. */ + scf.yield %next : f32 + } + ``` + + A simple "do-while" loop can be represented by reducing the "after" block + to a simple forwarder. + + ```mlir + %res = scf.while (%arg1 = %init1) : (f32) -> f32 { + /* First, compute the payload. */ + %next = call @payload(%arg1) : (f32) -> f32 + + /* Then, evaluate the condition. */ + %condition = call @evaluate_condition(%arg1) : (f32) -> i1 + + /* Loop through the "after" region. */ + scf.yield %condition, %next : i1, f32 + + } do { + ^bb0(%arg2: f32): + /* Just forward the values back to "before" region unmodified. */ + scf.yield %arg2 : f32 + } + ``` + + Note that the types of region arguments need not to match. The op expects + the operand types to match with argument types of the "before" region"; the + result types to match with the trailing operand types of the terminator of + the "before" region, and with the argument types of the "after" region. + + The custom syntax for this operation is as follows. + + ``` + op ::= `scf.while` assignments `:` function-type region `do` region + `attributes` attribute-dict + initializer ::= /* empty */ | `(` assignment-list `)` + assignment-list ::= assignment | assignment `,` assignment-list + assignment ::= ssa-value `=` ssa-value + ``` + }]; + + let arguments = (ins Variadic:$inits); + let results = (outs Variadic:$results); + let regions = (region SizedRegion<1>:$before, SizedRegion<1>:$after); +} + def YieldOp : SCF_Op<"yield", [NoSideEffect, ReturnLike, Terminator, - ParentOneOf<["IfOp, ForOp", "ParallelOp"]>]> { + ParentOneOf<["IfOp, ForOp", "ParallelOp", + "WhileOp"]>]> { let summary = "loop yield and termination operation"; let description = [{ "scf.yield" yields an SSA value from the SCF dialect op region and @@ -436,4 +521,5 @@ // needed. let verifier = ?; } + #endif // MLIR_DIALECT_SCF_SCFOPS diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -751,11 +751,18 @@ parseOptionalColonTypeList(SmallVectorImpl &result) = 0; /// Parse a list of assignments of the form - /// (%x1 = %y1 : type1, %x2 = %y2 : type2, ...). - /// The list must contain at least one entry - virtual ParseResult - parseAssignmentList(SmallVectorImpl &lhs, - SmallVectorImpl &rhs) = 0; + /// (%x1 = %y1, %x2 = %y2, ...) + ParseResult parseAssignmentList(SmallVectorImpl &lhs, + SmallVectorImpl &rhs) { + OptionalParseResult result = parseOptionalAssignmentList(lhs, rhs); + if (!result.hasValue()) + return emitError(getCurrentLocation(), "expected '('"); + return result.getValue(); + } + + virtual OptionalParseResult + parseOptionalAssignmentList(SmallVectorImpl &lhs, + SmallVectorImpl &rhs) = 0; /// Parse a keyword followed by a type. ParseResult parseKeywordType(const char *keyword, Type &result) { diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -140,26 +140,33 @@ return RegionBranchOpInterface::verifyTypes(op); } +static void printInitializationList(OpAsmPrinter &p, + Block::BlockArgListType blocksArgs, + ValueRange initializers, + StringRef prefix = "") { + assert(blocksArgs.size() == initializers.size() && + "expected same length of arguments and initializers"); + if (initializers.empty()) + return; + + p << prefix << '('; + llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) { + p << std::get<0>(it) << " = " << std::get<1>(it); + }); + p << ")"; +} + static void print(OpAsmPrinter &p, ForOp op) { - bool printBlockTerminators = false; p << op.getOperationName() << " " << op.getInductionVar() << " = " << op.lowerBound() << " to " << op.upperBound() << " step " << op.step(); - if (op.hasIterOperands()) { - p << " iter_args("; - auto regionArgs = op.getRegionIterArgs(); - auto operands = op.getIterOperands(); - - llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) { - p << std::get<0>(it) << " = " << std::get<1>(it); - }); - p << ")"; - p << " -> (" << op.getResultTypes() << ")"; - printBlockTerminators = true; - } + printInitializationList(p, op.getRegionIterArgs(), op.getIterOperands(), + " iter_args"); + if (!op.getIterOperands().empty()) + p << " -> (" << op.getIterOperands().getTypes() << ')'; p.printRegion(op.region(), /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/printBlockTerminators); + /*printBlockTerminators=*/op.hasIterOperands()); p.printOptionalAttrDict(op.getAttrs()); } @@ -933,6 +940,131 @@ return success(); } +//===----------------------------------------------------------------------===// +// WhileOp +//===----------------------------------------------------------------------===// + +static ParseResult parseWhileOp(OpAsmParser &parser, OperationState &result) { + SmallVector regionArgs, operands; + Region *before = result.addRegion(); + Region *after = result.addRegion(); + + OptionalParseResult listResult = + parser.parseOptionalAssignmentList(regionArgs, operands); + if (listResult.hasValue() && failed(listResult.getValue())) + return failure(); + + FunctionType functionType; + llvm::SMLoc typeLoc = parser.getCurrentLocation(); + if (failed(parser.parseColonType(functionType))) + return failure(); + + result.addTypes(functionType.getResults()); + + if (functionType.getNumInputs() != operands.size()) { + return parser.emitError(typeLoc) + << "expected as many input types as operands " + << "(expected " << operands.size() << " got " + << functionType.getNumInputs() << ")"; + } + + // Resolve input operands. + for (auto pair : llvm::zip(operands, functionType.getInputs())) + if (parser.resolveOperand(std::get<0>(pair), std::get<1>(pair), + result.operands)) + return failure(); + + return failure( + parser.parseRegion(*before, regionArgs, functionType.getInputs()) || + parser.parseKeyword("do") || parser.parseRegion(*after) || + parser.parseOptionalAttrDictWithKeyword(result.attributes)); +} + +static void print(OpAsmPrinter &p, scf::WhileOp op) { + p << op.getOperationName(); + printInitializationList(p, op.before().front().getArguments(), op.inits(), + " "); + p << " : "; + p.printFunctionalType(op.inits().getTypes(), op.results().getTypes()); + p.printRegion(op.before(), /*printEntryBlockArgs=*/false); + p << " do"; + p.printRegion(op.after()); + p.printOptionalAttrDictWithKeyword(op.getAttrs()); +} + +template +static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left, + TypeRange right, StringRef message) { + if (left.size() != right.size()) + return op.emitOpError("expects the same number of ") << message; + + for (unsigned i = 0, e = left.size(); i < e; ++i) + if (left[i] != right[i]) { + InFlightDiagnostic diag = op.emitOpError("expects the same types for ") + << message; + diag.attachNote() << "for argument " << i << ", found " << left[i] + << " and " << right[i]; + return diag; + } + + return success(); +} + +static scf::YieldOp verifyAndGetYieldTerminator(scf::WhileOp op, Region ®ion, + StringRef errorMessage) { + Operation *terminatorOperation = region.front().getTerminator(); + if (auto yield = dyn_cast_or_null(terminatorOperation)) + return yield; + + auto diag = op.emitOpError(errorMessage); + if (terminatorOperation) + diag.attachNote(terminatorOperation->getLoc()) << "terminator here"; + return nullptr; +} + +static LogicalResult verify(scf::WhileOp op) { + if (failed(verifyTypeRangesMatch( + op, op.inits().getTypes(), op.before().getArgumentTypes(), + "operands and arguments of the first region"))) + return failure(); + + scf::YieldOp beforeTerminator = verifyAndGetYieldTerminator( + op, op.before(), + "expects the first region to terminate with 'scf.yield'"); + if (!beforeTerminator) + return failure(); + + if (beforeTerminator.getNumOperands() == 0 || + !beforeTerminator.results()[0].getType().isInteger(1)) + return op.emitOpError("expects the leading operand of the first region " + "terminator to be of 'i1' type") + .attachNote(beforeTerminator.getLoc()) + << "terminator here"; + + TypeRange trailingTerminatorOperands = + TypeRange(beforeTerminator.results().getTypes()).drop_front(); + if (failed(verifyTypeRangesMatch( + op, trailingTerminatorOperands, op.after().getArgumentTypes(), + "trailing operands of the first block terminator and second block " + "arguments"))) + return failure(); + + if (failed(verifyTypeRangesMatch( + op, trailingTerminatorOperands, op.getResultTypes(), + "trailing operands of the first block terminator and op results"))) + return failure(); + + scf::YieldOp afterTerminator = verifyAndGetYieldTerminator( + op, op.after(), + "expects the second region to terminate with 'scf.yield'"); + if (!afterTerminator) + return failure(); + + return verifyTypeRangesMatch( + op, afterTerminator.results().getTypes(), op.before().getArgumentTypes(), + "operands of the second block terminator and first block arguments"); +} + //===----------------------------------------------------------------------===// // YieldOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -1465,10 +1465,13 @@ } /// Parse a list of assignments of the form - /// (%x1 = %y1 : type1, %x2 = %y2 : type2, ...). - /// The list must contain at least one entry - ParseResult parseAssignmentList(SmallVectorImpl &lhs, - SmallVectorImpl &rhs) override { + /// (%x1 = %y1, %x2 = %y2, ...). + OptionalParseResult + parseOptionalAssignmentList(SmallVectorImpl &lhs, + SmallVectorImpl &rhs) override { + if (failed(parseOptionalLParen())) + return llvm::None; + auto parseElt = [&]() -> ParseResult { OperandType regionArg, operand; if (parseRegionArgument(regionArg) || parseEqual() || @@ -1478,8 +1481,6 @@ rhs.push_back(operand); return success(); }; - if (parseLParen()) - return failure(); return parser.parseCommaSeparatedListUntil(Token::r_paren, parseElt); } diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir --- a/mlir/test/Dialect/SCF/invalid.mlir +++ b/mlir/test/Dialect/SCF/invalid.mlir @@ -425,10 +425,143 @@ } // ----- + func @yield_invalid_parent_op() { "my.op"() ({ - // expected-error@+1 {{'scf.yield' op expects parent op to be one of 'scf.if, scf.for, scf.parallel'}} + // expected-error@+1 {{'scf.yield' op expects parent op to be one of 'scf.if, scf.for, scf.parallel, scf.while'}} scf.yield }) : () -> () return } + +// ----- + +func @while_parser_type_mismatch() { + %true = constant true + // expected-error@+1 {{expected as many input types as operands (expected 0 got 1)}} + scf.while : (i32) -> () { + scf.yield %true : i1 + } do { + scf.yield + } +} + +// ----- + +func @while_arg_mismatch() { + %0 = "test.get_some_value"() : () -> i32 + %true = constant true + + // expected-error@+1 {{expects the same number of operands and arguments of the first region}} + "scf.while"(%0) ({ + ^bb0: + scf.yield %true : i1 + }, { + scf.yield + }) : (i32) -> () +} + +// ----- + +func @while_arg_mismatch() { + %0 = "test.get_some_value"() : () -> i32 + %true = constant true + + // expected-error@+2 {{expects the same types for operands and arguments of the first region}} + // expected-note@+1 {{for argument 0, found 'i32' and 'i64'}} + "scf.while"(%0) ({ + ^bb0(%arg0: i64): + scf.yield %true : i1 + }, { + scf.yield + }) : (i32) -> () +} + +// ----- + +func @while_bad_terminator() { + // expected-error@+1 {{expects the first region to terminate with 'scf.yield'}} + scf.while : () -> () { + // expected-note@+1 {{terminator here}} + "some.other_terminator"() : () -> () + } do { + scf.yield + } +} + +// ----- + +func @while_bad_terminator_arg() { + // expected-error@+1 {{expects the leading operand of the first region terminator to be of 'i1' type}} + scf.while : () -> () { + // expected-note@+1 {{terminator here}} + scf.yield + } do { + scf.yield + } +} + +// ----- + +func @while_cross_region_type_mismatch() { + %true = constant true + // expected-error@+1 {{expects the same number of trailing operands of the first block terminator and second block arguments}} + scf.while : () -> () { + scf.yield %true : i1 + } do { + ^bb0(%arg0: i32): + scf.yield + } +} + +// ----- + +func @while_cross_region_type_mismatch() { + %true = constant true + // expected-error@+2 {{expects the same types for trailing operands of the first block terminator and second block arguments}} + // expected-note@+1 {{for argument 0, found 'i1' and 'i32}} + scf.while : () -> () { + scf.yield %true, %true : i1, i1 + } do { + ^bb0(%arg0: i32): + scf.yield + } +} + +// ----- + +func @while_result_type_mismatch() { + %true = constant true + // expected-error@+1 {{expects the same number of trailing operands of the first block terminator and op results}} + scf.while : () -> () { + scf.yield %true, %true : i1, i1 + } do { + ^bb0(%arg0: i1): + scf.yield + } +} + +// ----- + +func @while_bad_terminator() { + %true = constant true + // expected-error@+1 {{expects the second region to terminate with 'scf.yield'}} + scf.while : () -> () { + scf.yield %true : i1 + } do { + // expected-note@+1 {{terminator here}} + "some.other_terminator"() : () -> () + } +} + +// ----- + +func @while_cross_region_type_mismatch() { + %true = constant true + // expected-error@+1 {{expects the same number of operands of the second block terminator and first block arguments}} + scf.while : () -> () { + scf.yield %true : i1 + } do { + scf.yield %true : i1 + } +} diff --git a/mlir/test/Dialect/SCF/ops.mlir b/mlir/test/Dialect/SCF/ops.mlir --- a/mlir/test/Dialect/SCF/ops.mlir +++ b/mlir/test/Dialect/SCF/ops.mlir @@ -240,3 +240,42 @@ // CHECK-NEXT: scf.yield %[[IFRES]] : f32 // CHECK-NEXT: } // CHECK-NEXT: return %[[RESULT]] + +// CHECK-LABEL: @while +func @while() { + %0 = "test.get_some_value"() : () -> i32 + %1 = "test.get_some_value"() : () -> f32 + + // CHECK: = scf.while (%{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}) : (i32, f32) -> (i64, f64) { + %2:2 = scf.while (%arg0 = %0, %arg1 = %1) : (i32, f32) -> (i64, f64) { + %3:2 = "test.some_operation"(%arg0, %arg1) : (i32, f32) -> (i64, f64) + %4 = "test.some_condition"(%arg0, %arg1) : (i32, f32) -> i1 + // CHECK: scf.yield %{{.*}}, %{{.*}}, %{{.*}} : i1, i64, f64 + scf.yield %4, %3#0, %3#1 : i1, i64, f64 + // CHECK: } do { + } do { + // CHECK: ^{{.*}}(%{{.*}}: i64, %{{.*}}: f64): + ^bb0(%arg2: i64, %arg3: f64): + %5:2 = "test.some_operation"(%arg2, %arg3): (i64, f64) -> (i32, f32) + // CHECK: scf.yield %{{.*}}, %{{.*}} : i32, f32 + scf.yield %5#0, %5#1 : i32, f32 + // CHECK: attributes {foo = "bar"} + } attributes {foo="bar"} + return +} + +// CHECK-LABEL: @infinite_while +func @infinite_while() { + %true = constant true + + // CHECK: scf.while : () -> () { + scf.while : () -> () { + // CHECK: scf.yield %{{.*}} : i1 + scf.yield %true : i1 + // CHECK: } do { + } do { + // CHECK: scf.yield + scf.yield + } + return +}