diff --git a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td --- a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td +++ b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td @@ -55,9 +55,11 @@ the terminator, so will the parsing even in cases when it is absent from the custom format. For example: + ```mlir loop.for %iv = %lb to %ub step %step { ... // body } + ``` }]; let arguments = (ins Index:$lowerBound, Index:$upperBound, Index:$step); let regions = (region SizedRegion<1>:$region); @@ -88,18 +90,22 @@ conditionally executing two regions of code. The operand to an if operation is a boolean value. The operation produces no results. For example: + ```mlir loop.if %b { ... } else { ... } + ``` The 'else' block is optional, and may be omitted. For example: + ```mlir loop.if %b { ... } + ``` }]; let arguments = (ins I1:$condition); let regions = (region SizedRegion<1>:$thenRegion, AnyRegion:$elseRegion); @@ -124,8 +130,114 @@ }]; } -def TerminatorOp : - Loop_Op<"terminator", [NativeOpTrait<"IsTerminator">]> { +def ParallelOp : Loop_Op<"parallel", + [SameVariadicOperandSize, SingleBlockImplicitTerminator<"TerminatorOp">]> { + let summary = "parallel for operation"; + let description = [{ + The "loop.parallel" operation represents a loop nest taking 3 groups of SSA + values as operands that represent the lower bounds, upper bounds and steps, + respectively. The operation defines a variadic number of SSA values for its + induction variables. It has one region capturing the loop body. The + induction variables are represented as an argument of this region. These SSA + values always have type index, which is the size of the machine word. The + steps are values of type index, required to be positive. + The lower and upper bounds specify a half-open range: the range includes the + lower bound but does not include the upper bound. + + Semantically we require that the iteration space can be iterated in any + order, and the loop body can be executed in parallel. If there are data + races, the behavior is undefined. + + The parallel loop operation supports reduction of values produced by + individual iterations into a single result. This is modeled using the + loop.reduce operation (see loop.reduce for details). Each result of a + loop.parallel operation is associated with a reduce operation that is an + immediate child. Reduces are matched to result values in order of their + appearance in the body. Consequently, we require that the body region has + the same number of results as it has reduce operations. + + The body region must contain exactly one block that terminates with + "loop.terminator". Parsing ParallelOp will create such region and insert the + terminator when it is absent from the custom format. For example: + + ```mlir + loop.parallel (%iv) = (%lb) to (%ub) step (%step) { + %zero = constant 0.0 : f32 + loop.reduce(%zero) { + ^bb0(%lhs : f32, %rhs: f32): + %res = addf %lhs, %rhs : f32 + loop.reduce.return %res : f32 + } : f32 + } + ``` + }]; + + let arguments = (ins Variadic:$lowerBound, + Variadic:$upperBound, + Variadic:$step); + let results = (outs Variadic:$results); + let regions = (region SizedRegion<1>:$body); +} + +def ReduceOp : Loop_Op<"reduce", [HasParent<"ParallelOp">]> { + let summary = "reduce operation for parallel for"; + let description = [{ + "loop.reduce" is an operation occuring inside "loop.parallel" operations. It + consists of one block with two arguments which have the same type as the + operand of "loop.reduce". + + "loop.reduce" is used to model the value for reduction computations of a + "loop.parallel" operation. It has to appear as an immediate child of a + "loop.parallel" and is associated with a result value of its parent + operation. + + Association is in the order of appearance in the body where the first result + of a parallel loop operation corresponds to the first "loop.reduce" in the + operation's body region. The reduce operation takes a single operand, which + is the value to be used in the reduction. + + The reduce operation contains a region whose entry block expects two + arguments of the same type as the operand. As the iteration order of the + parallel loop and hence reduction order is unspecified, the result of + reduction may be non-deterministic unless the operation is associative and + commutative. + + The result of the reduce operation's body must have the same type as the + operands and associated result value of the parallel loop operation. + Example: + + ```mlir + %zero = constant 0.0 : f32 + loop.reduce(%zero) { + ^bb0(%lhs : f32, %rhs: f32): + %res = addf %lhs, %rhs : f32 + loop.reduce.return %res : f32 + } : f32 + ``` + + }]; + + let arguments = (ins AnyType:$operand); + let regions = (region SizedRegion<1>:$reductionOperator); +} + +def ReduceReturnOp : + Loop_Op<"reduce.return", [HasParent<"ReduceOp">, Terminator]> { + let summary = "terminator for reduce operation"; + let description = [{ + "loop.reduce.return" is a special terminator operation for the block inside + "loop.reduce". It terminates the region. It should have the same type as the + operand of "loop.reduce". Example for the custom format: + + ```mlir + loop.reduce.return %res : f32 + ``` + }]; + + let arguments = (ins AnyType:$result); +} + +def TerminatorOp : Loop_Op<"terminator", [Terminator]> { let summary = "cf terminator operation"; let description = [{ "loop.terminator" is a special terminator operation for blocks inside @@ -133,7 +245,9 @@ syntax. However, `std` control operations omit the terminator in their custom syntax for brevity. + ```mlir loop.terminator + ``` }]; // No custom parsing/printing form. diff --git a/mlir/lib/Dialect/LoopOps/LoopOps.cpp b/mlir/lib/Dialect/LoopOps/LoopOps.cpp --- a/mlir/lib/Dialect/LoopOps/LoopOps.cpp +++ b/mlir/lib/Dialect/LoopOps/LoopOps.cpp @@ -185,13 +185,13 @@ return failure(); // Parse the 'then' region. - if (parser.parseRegion(*thenRegion, {}, {})) + if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{})) return failure(); IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location); // If we find an 'else' keyword then parse the 'else' region. if (!parser.parseOptionalKeyword("else")) { - if (parser.parseRegion(*elseRegion, {}, {})) + if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{})) return failure(); IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location); } @@ -221,6 +221,199 @@ p.printOptionalAttrDict(op.getAttrs()); } +//===----------------------------------------------------------------------===// +// ParallelOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(ParallelOp op) { + // Check that there is at least one value in lowerBound, upperBound and step. + // It is sufficient to test only step, because it is ensured already that the + // number of elements in lowerBound, upperBound and step are the same. + Operation::operand_range stepValues = op.step(); + if (stepValues.empty()) + return op.emitOpError( + "needs at least one tuple element for lowerBound, upperBound and step"); + + // Check whether all constant step values are positive. + for (Value stepValue : stepValues) + if (auto cst = dyn_cast_or_null(stepValue.getDefiningOp())) + if (cst.getValue() <= 0) + return op.emitOpError("constant step operand must be positive"); + + // Check that the body defines the same number of block arguments as the + // number of tuple elements in step. + Block *body = &op.body().front(); + if (body->getNumArguments() != stepValues.size()) + return op.emitOpError( + "expects the same number of induction variables as bound and step " + "values"); + for (auto arg : body->getArguments()) + if (!arg.getType().isIndex()) + return op.emitOpError( + "expects arguments for the induction variable to be of index type"); + + // Check that the number of results is the same as the number of ReduceOps. + SmallVector reductions(body->getOps()); + if (op.results().size() != reductions.size()) + return op.emitOpError( + "expects number of results to be the same as number of reductions"); + + // Check that the types of the results and reductions are the same. + for (auto resultAndReduce : llvm::zip(op.results(), reductions)) { + auto resultType = std::get<0>(resultAndReduce).getType(); + auto reduceOp = std::get<1>(resultAndReduce); + auto reduceType = reduceOp.operand().getType(); + if (resultType != reduceType) + return reduceOp.emitOpError() + << "expects type of reduce to be the same as result type: " + << resultType; + } + return success(); +} + +static ParseResult parseParallelOp(OpAsmParser &parser, + OperationState &result) { + auto &builder = parser.getBuilder(); + // Parse an opening `(` followed by induction variables followed by `)` + SmallVector ivs; + if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1, + OpAsmParser::Delimiter::Paren)) + return failure(); + + // Parse loop bounds. + SmallVector lower; + if (parser.parseEqual() || + parser.parseOperandList(lower, ivs.size(), + OpAsmParser::Delimiter::Paren) || + parser.resolveOperands(lower, builder.getIndexType(), result.operands)) + return failure(); + + SmallVector upper; + if (parser.parseKeyword("to") || + parser.parseOperandList(upper, ivs.size(), + OpAsmParser::Delimiter::Paren) || + parser.resolveOperands(upper, builder.getIndexType(), result.operands)) + return failure(); + + // Parse step value. + SmallVector steps; + if (parser.parseKeyword("step") || + parser.parseOperandList(steps, ivs.size(), + OpAsmParser::Delimiter::Paren) || + parser.resolveOperands(steps, builder.getIndexType(), result.operands)) + return failure(); + + // Now parse the body. + Region *body = result.addRegion(); + SmallVector types(ivs.size(), builder.getIndexType()); + if (parser.parseRegion(*body, ivs, types)) + return failure(); + + // Parse attributes and optional results (in case there is a reduce). + if (parser.parseOptionalAttrDict(result.attributes) || + parser.parseOptionalColonTypeList(result.types)) + return failure(); + + // Add a terminator if none was parsed. + ForOp::ensureTerminator(*body, builder, result.location); + + return success(); +} + +static void print(OpAsmPrinter &p, ParallelOp op) { + p << op.getOperationName() << " ("; + p.printOperands(op.body().front().getArguments()); + p << ") = (" << op.lowerBound() << ") to (" << op.upperBound() << ") step (" + << op.step() << ")"; + p.printRegion(op.body(), /*printEntryBlockArgs=*/false); + p.printOptionalAttrDict(op.getAttrs()); + if (!op.results().empty()) + p << " : " << op.getResultTypes(); +} + +//===----------------------------------------------------------------------===// +// ReduceOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(ReduceOp op) { + // The region of a ReduceOp has two arguments of the same type as its operand. + auto type = op.operand().getType(); + Block &block = op.reductionOperator().front(); + if (block.empty()) + return op.emitOpError("the block inside reduce should not be empty"); + if (block.getNumArguments() != 2 || + llvm::any_of(block.getArguments(), [&](const BlockArgument &arg) { + return arg.getType() != type; + })) + return op.emitOpError() << "expects two arguments to reduce block of type " + << type; + + // Check that the block is terminated by a ReduceReturnOp. + if (!isa(block.getTerminator())) + return op.emitOpError("the block inside reduce should be terminated with a " + "'loop.reduce.return' op"); + + return success(); +} + +static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) { + // Parse an opening `(` followed by the reduced value followed by `)` + OpAsmParser::OperandType operand; + if (parser.parseLParen() || parser.parseOperand(operand) || + parser.parseRParen()) + return failure(); + + // Now parse the body. + Region *body = result.addRegion(); + if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{})) + return failure(); + + // And the type of the operand (and also what reduce computes on). + Type resultType; + if (parser.parseColonType(resultType) || + parser.resolveOperand(operand, resultType, result.operands)) + return failure(); + + return success(); +} + +static void print(OpAsmPrinter &p, ReduceOp op) { + p << op.getOperationName() << "(" << op.operand() << ") "; + p.printRegion(op.reductionOperator()); + p << " : " << op.operand().getType(); +} + +//===----------------------------------------------------------------------===// +// ReduceReturnOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(ReduceReturnOp op) { + // The type of the return value should be the same type as the type of the + // operand of the enclosing ReduceOp. + auto reduceOp = cast(op.getParentOp()); + Type reduceType = reduceOp.operand().getType(); + if (reduceType != op.result().getType()) + return op.emitOpError() << "needs to have type " << reduceType + << " (the type of the enclosing ReduceOp)"; + return success(); +} + +static ParseResult parseReduceReturnOp(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::OperandType operand; + Type resultType; + if (parser.parseOperand(operand) || parser.parseColonType(resultType) || + parser.resolveOperand(operand, resultType, result.operands)) + return failure(); + + return success(); +} + +static void print(OpAsmPrinter &p, ReduceReturnOp op) { + p << op.getOperationName() << " " << op.result() << " : " + << op.result().getType(); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Loops/invalid.mlir b/mlir/test/Dialect/Loops/invalid.mlir --- a/mlir/test/Dialect/Loops/invalid.mlir +++ b/mlir/test/Dialect/Loops/invalid.mlir @@ -113,3 +113,184 @@ return } +// ----- + +func @parallel_arguments_different_tuple_size( + %arg0: index, %arg1: index, %arg2: index) { + // expected-error@+1 {{custom op 'loop.parallel' expected 1 operands}} + loop.parallel (%i0) = (%arg0) to (%arg1, %arg2) step () { + } + return +} + +// ----- + +func @parallel_body_arguments_wrong_type( + %arg0: index, %arg1: index, %arg2: index) { + // expected-error@+1 {{'loop.parallel' op expects arguments for the induction variable to be of index type}} + "loop.parallel"(%arg0, %arg1, %arg2) ({ + ^bb0(%i0: f32): + "loop.terminator"() : () -> () + }): (index, index, index) -> () + return +} + +// ----- + +func @parallel_body_wrong_number_of_arguments( + %arg0: index, %arg1: index, %arg2: index) { + // expected-error@+1 {{'loop.parallel' op expects the same number of induction variables as bound and step values}} + "loop.parallel"(%arg0, %arg1, %arg2) ({ + ^bb0(%i0: index, %i1: index): + "loop.terminator"() : () -> () + }): (index, index, index) -> () + return +} + +// ----- + +func @parallel_no_tuple_elements() { + // expected-error@+1 {{'loop.parallel' op needs at least one tuple element for lowerBound, upperBound and step}} + loop.parallel () = () to () step () { + } + return +} + +// ----- + +func @parallel_step_not_positive( + %arg0: index, %arg1: index, %arg2: index, %arg3: index) { + // expected-error@+3 {{constant step operand must be positive}} + %c0 = constant 1 : index + %c1 = constant 0 : index + loop.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) step (%c0, %c1) { + } + return +} + +// ----- + +func @parallel_fewer_results_than_reduces( + %arg0 : index, %arg1: index, %arg2: index) { + // expected-error@+1 {{expects number of results to be the same as number of reductions}} + loop.parallel (%i0) = (%arg0) to (%arg1) step (%arg2) { + %c0 = constant 1.0 : f32 + loop.reduce(%c0) { + ^bb0(%lhs: f32, %rhs: f32): + loop.reduce.return %lhs : f32 + } : f32 + } + return +} + +// ----- + +func @parallel_more_results_than_reduces( + %arg0 : index, %arg1 : index, %arg2 : index) { + // expected-error@+1 {{expects number of results to be the same as number of reductions}} + %res = loop.parallel (%i0) = (%arg0) to (%arg1) step (%arg2) { + } : f32 + + return +} + +// ----- + +func @parallel_different_types_of_results_and_reduces( + %arg0 : index, %arg1: index, %arg2: index) { + %res = loop.parallel (%i0) = (%arg0) to (%arg1) step (%arg2) { + // expected-error@+1 {{expects type of reduce to be the same as result type: 'f32'}} + loop.reduce(%arg0) { + ^bb0(%lhs: index, %rhs: index): + loop.reduce.return %lhs : index + } : index + } : f32 + return +} + +// ----- + +func @top_level_reduce(%arg0 : f32) { + // expected-error@+1 {{expects parent op 'loop.parallel'}} + loop.reduce(%arg0) { + ^bb0(%lhs : f32, %rhs : f32): + loop.reduce.return %lhs : f32 + } : f32 + return +} + +// ----- + +func @reduce_empty_block(%arg0 : index, %arg1 : f32) { + %res = loop.parallel (%i0) = (%arg0) to (%arg0) step (%arg0) { + // expected-error@+1 {{the block inside reduce should not be empty}} + loop.reduce(%arg1) { + ^bb0(%lhs : f32, %rhs : f32): + } : f32 + } : f32 + return +} + +// ----- + +func @reduce_too_many_args(%arg0 : index, %arg1 : f32) { + %res = loop.parallel (%i0) = (%arg0) to (%arg0) step (%arg0) { + // expected-error@+1 {{expects two arguments to reduce block of type 'f32'}} + loop.reduce(%arg1) { + ^bb0(%lhs : f32, %rhs : f32, %other : f32): + loop.reduce.return %lhs : f32 + } : f32 + } : f32 + return +} + +// ----- + +func @reduce_wrong_args(%arg0 : index, %arg1 : f32) { + %res = loop.parallel (%i0) = (%arg0) to (%arg0) step (%arg0) { + // expected-error@+1 {{expects two arguments to reduce block of type 'f32'}} + loop.reduce(%arg1) { + ^bb0(%lhs : f32, %rhs : i32): + loop.reduce.return %lhs : f32 + } : f32 + } : f32 + return +} + + +// ----- + +func @reduce_wrong_terminator(%arg0 : index, %arg1 : f32) { + %res = loop.parallel (%i0) = (%arg0) to (%arg0) step (%arg0) { + // expected-error@+1 {{the block inside reduce should be terminated with a 'loop.reduce.return' op}} + loop.reduce(%arg1) { + ^bb0(%lhs : f32, %rhs : f32): + "loop.terminator"(): () -> () + } : f32 + } : f32 + return +} + +// ----- + +func @reduceReturn_wrong_type(%arg0 : index, %arg1: f32) { + %res = loop.parallel (%i0) = (%arg0) to (%arg0) step (%arg0) { + loop.reduce(%arg1) { + ^bb0(%lhs : f32, %rhs : f32): + %c0 = constant 1 : index + // expected-error@+1 {{needs to have type 'f32' (the type of the enclosing ReduceOp)}} + loop.reduce.return %c0 : index + } : f32 + } : f32 + return +} + +// ----- + +func @reduceReturn_not_inside_reduce(%arg0 : f32) { + "foo.region"() ({ + // expected-error@+1 {{expects parent op 'loop.reduce'}} + loop.reduce.return %arg0 : f32 + }): () -> () + return +} diff --git a/mlir/test/Dialect/Loops/ops.mlir b/mlir/test/Dialect/Loops/ops.mlir --- a/mlir/test/Dialect/Loops/ops.mlir +++ b/mlir/test/Dialect/Loops/ops.mlir @@ -49,3 +49,47 @@ // CHECK-NEXT: %{{.*}} = addf %{{.*}}, %{{.*}} : f32 // CHECK-NEXT: } else { // CHECK-NEXT: %{{.*}} = addf %{{.*}}, %{{.*}} : f32 + +func @std_parallel_loop(%arg0 : index, %arg1 : index, %arg2 : index, + %arg3 : index, %arg4 : index) { + %step = constant 1 : index + loop.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) + step (%arg4, %step) { + %min_cmp = cmpi "slt", %i0, %i1 : index + %min = select %min_cmp, %i0, %i1 : index + %max_cmp = cmpi "sge", %i0, %i1 : index + %max = select %max_cmp, %i0, %i1 : index + %red = loop.parallel (%i2) = (%min) to (%max) step (%i1) { + %zero = constant 0.0 : f32 + loop.reduce(%zero) { + ^bb0(%lhs : f32, %rhs: f32): + %res = addf %lhs, %rhs : f32 + loop.reduce.return %res : f32 + } : f32 + } : f32 + } + return +} +// CHECK-LABEL: func @std_parallel_loop( +// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: +// CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]: +// CHECK-SAME: %[[ARG2:[A-Za-z0-9]+]]: +// CHECK-SAME: %[[ARG3:[A-Za-z0-9]+]]: +// CHECK-SAME: %[[ARG4:[A-Za-z0-9]+]]: +// CHECK: %[[STEP:.*]] = constant 1 : index +// CHECK-NEXT: loop.parallel (%[[I0:.*]], %[[I1:.*]]) = (%[[ARG0]], %[[ARG1]]) to +// CHECK: (%[[ARG2]], %[[ARG3]]) step (%[[ARG4]], %[[STEP]]) { +// CHECK-NEXT: %[[MIN_CMP:.*]] = cmpi "slt", %[[I0]], %[[I1]] : index +// CHECK-NEXT: %[[MIN:.*]] = select %[[MIN_CMP]], %[[I0]], %[[I1]] : index +// CHECK-NEXT: %[[MAX_CMP:.*]] = cmpi "sge", %[[I0]], %[[I1]] : index +// CHECK-NEXT: %[[MAX:.*]] = select %[[MAX_CMP]], %[[I0]], %[[I1]] : index +// CHECK-NEXT: loop.parallel (%{{.*}}) = (%[[MIN]]) to (%[[MAX]]) step (%[[I1]]) { +// CHECK-NEXT: %[[ZERO:.*]] = constant 0.000000e+00 : f32 +// CHECK-NEXT: loop.reduce(%[[ZERO]]) { +// CHECK-NEXT: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32): +// CHECK-NEXT: %[[RES:.*]] = addf %[[LHS]], %[[RHS]] : f32 +// CHECK-NEXT: loop.reduce.return %[[RES]] : f32 +// CHECK-NEXT: } : f32 +// CHECK-NEXT: "loop.terminator"() : () -> () +// CHECK-NEXT: } : f32 +// CHECK-NEXT: "loop.terminator"() : () -> ()