diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -121,50 +121,56 @@ [AutomaticAllocationScope, DeclareOpInterfaceMethods, + AllTypesMatch<["lowerBound", "upperBound", "step"]>, ConditionallySpeculatable, DeclareOpInterfaceMethods, SingleBlockImplicitTerminator<"scf::YieldOp">, RecursiveMemoryEffects]> { let summary = "for operation"; let description = [{ - The "scf.for" operation represents a loop taking 3 SSA value as operands + The `scf.for` operation represents a loop taking 3 SSA value as operands that represent the lower bound, upper bound and step respectively. The operation defines an SSA value for its induction variable. It has one region capturing the loop body. The induction variable is represented as an - argument of this region. This SSA value always has type index, which is the - size of the machine word. The step is a value of type index, required to be - positive. - The lower and upper bounds specify a half-open range: the range includes - the lower bound but does not include the upper bound. + argument of this region. This SSA value is a signless integer or index. + The step is a value of same type but required to be positive. The lower and + upper bounds specify a half-open range: the range includes the lower bound + but does not include the upper bound. The body region must contain exactly one block that terminates with - "scf.yield". Calling ForOp::build will create such a region and insert + `scf.yield`. Calling ForOp::build will create such a region and insert the terminator implicitly if none is defined, so will the parsing even in cases when it is absent from the custom format. For example: ```mlir + // Index case. scf.for %iv = %lb to %ub step %step { ... // body } + ... + // Integer case. + scf.for %iv_32 = %lb_32 to %ub_32 step %step_32 : i32 { + ... // body + } ``` `scf.for` can also operate on loop-carried variables and returns the final values after loop termination. The initial values of the variables are - passed as additional SSA operands to the "scf.for" following the 3 loop + passed as additional SSA operands to the `scf.for` following the 3 loop control SSA values mentioned above (lower bound, upper bound and step). The operation region has an argument for the induction variable, followed by one argument for each loop-carried variable, representing the value of the variable at the current iteration. - The region must terminate with a "scf.yield" that passes the current + The region must terminate with a `scf.yield` that passes the current values of all loop-carried variables to the next iteration, or to the - "scf.for" result, if at the last iteration. The static type of a + `scf.for` result, if at the last iteration. The static type of a loop-carried variable may not change with iterations; its runtime type is allowed to change. Note, that when the loop-carried variables are present, calling ForOp::build will not insert the terminator implicitly. The caller - must insert "scf.yield" in that case. + must insert `scf.yield` in that case. - "scf.for" results hold the final values after the last iteration. + `scf.for` results hold the final values after the last iteration. For example, to sum-reduce a memref: ```mlir @@ -185,11 +191,11 @@ } ``` - If the "scf.for" defines any values, a yield must be explicitly present. - The number and types of the "scf.for" results must match the initial - values in the "iter_args" binding and the yield operands. + If the `scf.for` defines any values, a yield must be explicitly present. + The number and types of the `scf.for` results must match the initial + values in the `iter_args` binding and the yield operands. - Another example with a nested "scf.if" (see "scf.if" for details) to + Another example with a nested `scf.if` (see `scf.if` for details) to perform conditional reduction: ```mlir @@ -213,9 +219,9 @@ } ``` }]; - let arguments = (ins Index:$lowerBound, - Index:$upperBound, - Index:$step, + let arguments = (ins AnySignlessIntegerOrIndex:$lowerBound, + AnySignlessIntegerOrIndex:$upperBound, + AnySignlessIntegerOrIndex:$step, Variadic:$initArgs); let results = (outs Variadic:$results); let regions = (region SizedRegion<1>:$region); diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -297,10 +297,11 @@ result.addOperands(iterArgs); for (Value v : iterArgs) result.addTypes(v.getType()); + Type t = lb.getType(); Region *bodyRegion = result.addRegion(); bodyRegion->push_back(new Block); Block &bodyBlock = bodyRegion->front(); - bodyBlock.addArgument(builder.getIndexType(), result.location); + bodyBlock.addArgument(t, result.location); for (Value v : iterArgs) bodyBlock.addArgument(v.getType(), v.getLoc()); @@ -337,11 +338,9 @@ LogicalResult ForOp::verifyRegions() { // Check that the body defines as single block argument for the induction // variable. - auto *body = getBody(); - if (!body->getArgument(0).getType().isIndex()) + if (getInductionVar().getType() != getLowerBound().getType()) return emitOpError( - "expected body first argument to be an index argument for " - "the induction variable"); + "expected induction variable to be same type as bounds and step"); auto opNumResults = getNumResults(); if (opNumResults == 0) @@ -363,7 +362,7 @@ return emitOpError() << "types mismatch between " << i << "th iter region arg and defined value"; - i++; + ++i; } return success(); } @@ -413,6 +412,8 @@ if (!getIterOperands().empty()) p << " -> (" << getIterOperands().getTypes() << ')'; p << ' '; + if (Type t = getInductionVar().getType(); !t.isIndex()) + p << " : " << t << ' '; p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/hasIterOperands()); @@ -421,21 +422,27 @@ ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) { auto &builder = parser.getBuilder(); - Type indexType = builder.getIndexType(); + Type type; OpAsmParser::Argument inductionVariable; - inductionVariable.type = indexType; OpAsmParser::UnresolvedOperand lb, ub, step; // Parse the induction variable followed by '='. - if (parser.parseArgument(inductionVariable) || parser.parseEqual() || + if (parser.parseOperand(inductionVariable.ssaName) || parser.parseEqual() || // Parse loop bounds. - parser.parseOperand(lb) || - parser.resolveOperand(lb, indexType, result.operands) || - parser.parseKeyword("to") || parser.parseOperand(ub) || - parser.resolveOperand(ub, indexType, result.operands) || - parser.parseKeyword("step") || parser.parseOperand(step) || - parser.resolveOperand(step, indexType, result.operands)) + parser.parseOperand(lb) || parser.parseKeyword("to") || + parser.parseOperand(ub) || parser.parseKeyword("step") || + parser.parseOperand(step)) + return failure(); + // Parse optional type, else assume Index. + if (parser.parseOptionalColon()) + type = builder.getIndexType(); + else if (parser.parseType(type)) + return failure(); + inductionVariable.type = type; + if (parser.resolveOperand(lb, type, result.operands) || + parser.resolveOperand(ub, type, result.operands) || + parser.resolveOperand(step, type, result.operands)) return failure(); // Parse the optional initial iteration arguments. diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir --- a/mlir/test/Dialect/SCF/invalid.mlir +++ b/mlir/test/Dialect/SCF/invalid.mlir @@ -1,7 +1,7 @@ // RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -verify-diagnostics func.func @loop_for_lb(%arg0: f32, %arg1: index) { - // expected-error@+1 {{operand #0 must be index}} + // expected-error@+1 {{operand #0 must be signless integer or index}} "scf.for"(%arg0, %arg1, %arg1) ({}) : (f32, index, index) -> () return } @@ -9,7 +9,7 @@ // ----- func.func @loop_for_ub(%arg0: f32, %arg1: index) { - // expected-error@+1 {{operand #1 must be index}} + // expected-error@+1 {{operand #1 must be signless integer or index}} "scf.for"(%arg1, %arg0, %arg1) ({}) : (index, f32, index) -> () return } @@ -17,13 +17,21 @@ // ----- func.func @loop_for_step(%arg0: f32, %arg1: index) { - // expected-error@+1 {{operand #2 must be index}} + // expected-error@+1 {{operand #2 must be signless integer or index}} "scf.for"(%arg1, %arg1, %arg0) ({}) : (index, index, f32) -> () return } // ----- +func.func @loop_for_mismatch(%arg0: i32, %arg1: index) { + // expected-error@+1 {{all of {lowerBound, upperBound, step} have same type}} + "scf.for"(%arg1, %arg0, %arg1) ({}) : (index, i32, index) -> () + return +} + +// ----- + func.func @loop_for_step_positive(%arg0: index) { // expected-error@+2 {{constant step operand must be positive}} %c0 = arith.constant 0 : index @@ -63,7 +71,7 @@ // ----- func.func @loop_for_single_index_argument(%arg0: index) { - // expected-error@+1 {{op expected body first argument to be an index argument for the induction variable}} + // expected-error@+1 {{expected induction variable to be same type as bounds}} "scf.for"(%arg0, %arg0, %arg0) ( { ^bb0(%i0 : f32): diff --git a/mlir/test/Dialect/SCF/ops.mlir b/mlir/test/Dialect/SCF/ops.mlir --- a/mlir/test/Dialect/SCF/ops.mlir +++ b/mlir/test/Dialect/SCF/ops.mlir @@ -26,6 +26,17 @@ // CHECK-NEXT: %{{.*}} = arith.select %{{.*}}, %{{.*}}, %{{.*}} : index // CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { +func.func @std_for_i32(%arg0 : i32, %arg1 : i32, %arg2 : i32) { + scf.for %i0 = %arg0 to %arg1 step %arg2 : i32 { + scf.for %i1 = %arg0 to %arg1 step %arg2 : i32 { + } + } + return +} +// CHECK-LABEL: func @std_for_i32( +// CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} : i32 { +// CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} : i32 { + func.func @std_if(%arg0: i1, %arg1: f32) { scf.if %arg0 { %0 = arith.addf %arg1, %arg1 : f32