diff --git a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td --- a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td +++ b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td @@ -124,8 +124,61 @@ }]; } -def TerminatorOp : - Loop_Op<"terminator", [NativeOpTrait<"IsTerminator">]> { +def ParallelOp : Loop_Op<"parallel", + [SameVariadicOperandSize, SingleBlockImplicitTerminator<"TerminatorOp">]> { + let summary = "parallel for operation"; + let description = [{ + The "loop.parallel" operation represents a loop nest taking 3 variadic SSA + value as operands that represent the lower bound, upper bound and step, + respectively. The operation defines a variadic number of SSA values for its + induction variable. It has one region capturing the loop body. The induction + variable is represented as an argument of this region. These SSA values + always have type index, which is the size of the machine word. The step is a + value of type index, required to be positive. + The lower and upper bounds specify a half-open range: the range includes the + lower bound but does not include the upper bound. + + The body region must contain exactly one block that terminates with + "loop.terminator". Parsing ParallelOp will create such region and insert the + terminator when it is absent from the custom format. For example: + + loop.parallel (%iv) = (%lb) to (%ub) step (%step) { + ... // body + } + }]; + + let arguments = (ins Variadic:$lowerBound, Variadic:$upperBound, Variadic:$step); + let results = (outs Variadic:$results); + let regions = (region SizedRegion<1>:$body); +} + +def YieldOp : Loop_Op<"yield", [HasParent<"ParallelOp">]> { + let summary = "yield operation for parallel for"; + let description = [{ + "loop.yield" is an operation occuring inside "loop.parallel" operations. It + consists of one block with two arguments which have the same type as the + operand of "loop.yield". + }]; + + let arguments = (ins AnyType:$operand); + let regions = (region SizedRegion<1>:$operation); +} + +def YieldReturnOp : + Loop_Op<"yield.return", [HasParent<"YieldOp">, Terminator]> { + let summary = "terminator for yield operation"; + let description = [{ + "loop.yield.return" is a special terminator operation for the block inside + "loop.yield". It terminates the region. It should have the same type as the + operand of "loop.yield". Example for the custom format: + + loop.yield.return %res : f32 + }]; + + let arguments = (ins AnyType:$result); +} + +def TerminatorOp : Loop_Op<"terminator", [Terminator]> { let summary = "cf terminator operation"; let description = [{ "loop.terminator" is a special terminator operation for blocks inside diff --git a/mlir/lib/Dialect/LoopOps/LoopOps.cpp b/mlir/lib/Dialect/LoopOps/LoopOps.cpp --- a/mlir/lib/Dialect/LoopOps/LoopOps.cpp +++ b/mlir/lib/Dialect/LoopOps/LoopOps.cpp @@ -186,13 +186,13 @@ return failure(); // Parse the 'then' region. - if (parser.parseRegion(*thenRegion, {}, {})) + if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{})) return failure(); IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location); // If we find an 'else' keyword then parse the 'else' region. if (!parser.parseOptionalKeyword("else")) { - if (parser.parseRegion(*elseRegion, {}, {})) + if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{})) return failure(); IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location); } @@ -222,6 +222,203 @@ p.printOptionalAttrDict(op.getAttrs()); } +//===----------------------------------------------------------------------===// +// ParallelOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(ParallelOp op) { + // Check whether all constant step values are positive. + Operation::operand_range stepValues = op.step(); + for (const auto &stepValue : stepValues) + if (auto cst = + dyn_cast_or_null(stepValue->getDefiningOp())) + if (cst.getValue() <= 0) + return op.emitOpError("constant step operand must be positive"); + + // Check that there is at least one value in lowerBound, upperBound and step. + // It is sufficient to test only step, because it is ensured already that the + // number of elements in lowerBound, upperBound and step are the same. + if (stepValues.empty()) + return op.emitOpError( + "need at least one tuple element for lowerBound, upperBound and step"); + + // Check that the body defines the same number of block arguments as the + // number of tuple elements in step. + auto *body = &op.body().front(); + if (body->getNumArguments() != stepValues.size()) + return op.emitOpError( + "expected the same number of induction variables as bound and step " + "values"); + for (auto &arg : body->getArguments()) + if (!arg->getType().isIndex()) + return op.emitOpError( + "expected arguments for the induction variable to be of index type"); + + // Check that the number of results is the same as the number of YieldOps. + SmallVector yields(body->getOps()); + if (op.results().size() != yields.size()) + return op.emitOpError( + "expected number of results to be the same as number of yields"); + + // Check that the types of the results and yields are the same. + for (auto result_and_yield : llvm::zip(op.results(), yields)) { + auto resultType = std::get<0>(result_and_yield)->getType(); + auto yieldOp = std::get<1>(result_and_yield); + auto yieldType = yieldOp.operand()->getType(); + if (resultType != yieldType) + return yieldOp.emitOpError() + << "expected type of yield to be the same as result type: " + << resultType; + } + return success(); +} + +static ParseResult parseParallelOp(OpAsmParser &parser, + OperationState &result) { + auto &builder = parser.getBuilder(); + // Parse an opening `(` followed by induction variables followed by `)` + SmallVector ivs; + if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1, + OpAsmParser::Delimiter::Paren)) + return failure(); + + // Parse loop bounds. + SmallVector lower; + if (parser.parseEqual() || + parser.parseOperandList(lower, ivs.size(), + OpAsmParser::Delimiter::Paren) || + parser.resolveOperands(lower, builder.getIndexType(), result.operands)) + return failure(); + + SmallVector upper; + if (parser.parseKeyword("to") || + parser.parseOperandList(upper, ivs.size(), + OpAsmParser::Delimiter::Paren) || + parser.resolveOperands(upper, builder.getIndexType(), result.operands)) + return failure(); + + // Parse step value. + SmallVector steps; + if (parser.parseKeyword("step") || + parser.parseOperandList(steps, ivs.size(), + OpAsmParser::Delimiter::Paren) || + parser.resolveOperands(steps, builder.getIndexType(), result.operands)) + return failure(); + + // Now parse the body. + Region *body = result.addRegion(); + SmallVector types(ivs.size(), builder.getIndexType()); + if (parser.parseRegion(*body, ivs, types)) + return failure(); + + // Parse attributes and optional results (in case there is a yield). + if (parser.parseOptionalAttrDict(result.attributes) || + parser.parseOptionalColonTypeList(result.types)) + return failure(); + + // Add a terminator if none was parsed. + ForOp::ensureTerminator(*body, builder, result.location); + + return success(); +} + +static void print(OpAsmPrinter &p, ParallelOp op) { + p << op.getOperationName() << " ("; + p.printOperands(op.body().front().getArguments()); + p << ") = (" << op.lowerBound() << ") to (" << op.upperBound() << ") step (" + << op.step() << ")"; + p.printRegion(op.body(), /*printEntryBlockArgs=*/false); + p.printOptionalAttrDict(op.getAttrs()); + if (!op.results().empty()) { + p << " : "; + interleaveComma(op.getResultTypes(), p); + } +} + +//===----------------------------------------------------------------------===// +// YieldOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(YieldOp op) { + // The region of a YieldOp has two arguments of the same type as its operand. + auto type = op.operand()->getType(); + auto &block = op.operation().front(); + if (block.getNumArguments() != 2 || + llvm::any_of(block.getArguments(), [&](const BlockArgument &arg) { + return arg.getType() != type; + })) + return op.emitError() << "expected two arguments to yield block of type " + << type; + + // Check that the block is terminated by a YieldReturnOp. + if (!isa(block.getTerminator())) + return op.emitError("The block inside yield should be terminated with a " + "'loop.yield.return' op"); + + return success(); +} + +static ParseResult parseYieldOp(OpAsmParser &parser, OperationState &result) { + + // Parse an opening `(` followed by the yielded value followed by `)` + OpAsmParser::OperandType operand; + if (parser.parseLParen() || parser.parseOperand(operand) || + parser.parseRParen()) + return failure(); + + // Now parse the body. + Region *body = result.addRegion(); + if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{})) + return failure(); + + // And the type of the operand (and also what yield computes on). + Type resultType; + if (parser.parseColonType(resultType) || + parser.resolveOperand(operand, resultType, result.operands)) + return failure(); + + return success(); +} + +static void print(OpAsmPrinter &p, YieldOp op) { + p << op.getOperationName() << "(" << op.operand() << ") "; + p.printRegion(op.operation()); + p << " : " << op.operand()->getType(); +} + +//===----------------------------------------------------------------------===// +// YieldReturnOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(YieldReturnOp op) { + // The type of the return value should be the same type as the type of the + // operand of the enclosing YieldOp. + auto yieldOp = cast(op.getParentOp()); + Type yieldType = yieldOp.operand()->getType(); + if (yieldType != op.result()->getType()) + return op.emitError() << "YieldReturnOp needs to have type " << yieldType + << " (the type of the enclosing YieldOp)"; + return success(); +} + +static ParseResult parseYieldReturnOp(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::OperandType operand; + Type resultType; + if (parser.parseOperand(operand) || parser.parseColonType(resultType) || + parser.resolveOperand(operand, resultType, result.operands)) + return failure(); + + return success(); +} + +static void print(OpAsmPrinter &p, YieldReturnOp op) { + p << op.getOperationName() << " "; + p.printOperand(op.result()); + p << " : "; + p.printType(op.result()->getType()); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Loops/invalid.mlir b/mlir/test/Dialect/Loops/invalid.mlir --- a/mlir/test/Dialect/Loops/invalid.mlir +++ b/mlir/test/Dialect/Loops/invalid.mlir @@ -113,3 +113,213 @@ return } +// ----- + +func @parallel_arguments_different_tuple_size( + %arg0: index, %arg1: index, %arg2: index) { + // expected-error@+1 {{custom op 'loop.parallel' expected 1 operands}} + loop.parallel (%i0) = (%arg0) to (%arg1, %arg2) step () { + } + return +} + +// ----- + +func @parallel_arguments_wrong_lower_bound_type(%arg0: index, %arg1: index) { + %lb = constant 1.0 : f32 // expected-note {{prior use here}} + // expected-error@+1 {{use of value '%lb' expects different type than prior uses: 'index' vs 'f32'}} + loop.parallel (%i0) = (%lb) to (%arg0) step (%arg1) { + } + return +} + +// ----- + +func @parallel_arguments_wrong_upper_bound_type(%arg0: index, %arg1: index) { + %ub = constant 1.0 : f32 // expected-note {{prior use here}} + // expected-error@+1 {{use of value '%ub' expects different type than prior uses: 'index' vs 'f32'}} + loop.parallel (%i0) = (%arg0) to (%ub) step (%arg1) { + } + return +} + +// ----- + +func @parallel_arguments_wrong_step_type(%arg0: index, %arg1: index) { + %step = constant 1.0 : f32 // expected-note {{prior use here}} + // expected-error@+1 {{use of value '%step' expects different type than prior uses: 'index' vs 'f32'}} + loop.parallel (%i0) = (%arg0) to (%arg1) step (%step) { + } + return +} + +// ----- + +func @parallel_body_arguments_wrong_type( + %arg0: index, %arg1: index, %arg2: index) { + // expected-error@+1 {{expected arguments for the induction variable to be of index type}} + "loop.parallel"(%arg0, %arg1, %arg2) ({ + ^bb0(%i0: f32): + "loop.terminator"() : () -> () + }): (index, index, index) -> () + return +} + +// ----- + +func @parallel_body_wrong_number_of_arguments( + %arg0: index, %arg1: index, %arg2: index) { + // expected-error@+1 {{expected the same number of induction variables as bound and step values}} + "loop.parallel"(%arg0, %arg1, %arg2) ({ + ^bb0(%i0: index, %i1: index): + "loop.terminator"() : () -> () + }): (index, index, index) -> () + return +} + +// ----- + +func @parallel_no_tuple_elements() { + // expected-error@+1 {{need at least one tuple element for lowerBound, upperBound and step}} + loop.parallel () = () to () step () { + } + return +} + +// ----- + +func @parallel_step_not_positive( + %arg0: index, %arg1: index, %arg2: index, %arg3: index) { + // expected-error@+3 {{constant step operand must be positive}} + %c0 = constant 1 : index + %c1 = constant 0 : index + loop.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) step (%c0, %c1) { + } + return +} + +// ----- + +func @parallel_fewer_results_than_yields( + %arg0 : index, %arg1: index, %arg2: index) { + // expected-error@+1 {{expected number of results to be the same as number of yields}} + loop.parallel (%i0) = (%arg0) to (%arg1) step (%arg2) { + %c0 = constant 1.0 : f32 + loop.yield(%c0) { + ^bb0(%lhs: f32, %rhs: f32): + loop.yield.return %lhs : f32 + } : f32 + } + return +} + +// ----- + +func @parallel_more_results_than_yields( + %arg0 : index, %arg1 : index, %arg2 : index) { + // expected-error@+1 {{expected number of results to be the same as number of yields}} + %res = loop.parallel (%i0) = (%arg0) to (%arg1) step (%arg2) { + } : f32 + + return +} + +// ----- + +func @parallel_different_types_of_results_and_yields( + %arg0 : index, %arg1: index, %arg2: index) { + %res = loop.parallel (%i0) = (%arg0) to (%arg1) step (%arg2) { + // expected-error@+1 {{expected type of yield to be the same as result type: 'f32'}} + loop.yield(%arg0) { + ^bb0(%lhs: index, %rhs: index): + loop.yield.return %lhs : index + } : index + } : f32 + return +} + +// ----- + +func @top_level_yield(%arg0 : f32) { + // expected-error@+1 {{expects parent op 'loop.parallel'}} + loop.yield(%arg0) { + ^bb0(%lhs : f32, %rhs : f32): + loop.yield.return %lhs : f32 + } : f32 + return +} + +// ----- + +func @yield_no_block(%arg0 : index, %arg1 : f32) { + %res = loop.parallel (%i0) = (%arg0) to (%arg0) step (%arg0) { + // expected-error@+1 {{'loop.yield' op region #0 ('operation') failed to verify constraint: region with 1 blocks}} + loop.yield(%arg1) { + } : f32 + } : f32 + return +} + +// ----- + +func @yield_too_many_args(%arg0 : index, %arg1 : f32) { + %res = loop.parallel (%i0) = (%arg0) to (%arg0) step (%arg0) { + // expected-error@+1 {{expected two arguments to yield block of type 'f32'}} + loop.yield(%arg1) { + ^bb0(%lhs : f32, %rhs : f32, %other : f32): + loop.yield.return %lhs : f32 + } : f32 + } : f32 + return +} + +// ----- + +func @yield_wrong_args(%arg0 : index, %arg1 : f32) { + %res = loop.parallel (%i0) = (%arg0) to (%arg0) step (%arg0) { + // expected-error@+1 {{expected two arguments to yield block of type 'f32'}} + loop.yield(%arg1) { + ^bb0(%lhs : f32, %rhs : i32): + loop.yield.return %lhs : f32 + } : f32 + } : f32 + return +} + + +// ----- + +func @yield_wrong_terminator(%arg0 : index, %arg1 : f32) { + %res = loop.parallel (%i0) = (%arg0) to (%arg0) step (%arg0) { + // expected-error@+1 {{The block inside yield should be terminated with a 'loop.yield.return' op}} + loop.yield(%arg1) { + ^bb0(%lhs : f32, %rhs : f32): + "loop.terminator"(): () -> () + } : f32 + } : f32 + return +} + +// ----- + +func @yieldReturn_wrong_type(%arg0 : index, %arg1: f32) { + %res = loop.parallel (%i0) = (%arg0) to (%arg0) step (%arg0) { + loop.yield(%arg1) { + ^bb0(%lhs : f32, %rhs : f32): + %c0 = constant 1 : index + // expected-error@+1 {{YieldReturnOp needs to have type 'f32' (the type of the enclosing YieldOp)}} + loop.yield.return %c0 : index + } : f32 + } : f32 + return +} + +// ----- + +func @yieldReturn_not_inside_yield(%arg0 : f32) { + "foo.region"() ({ + // expected-error@+1 {{op expects parent op 'loop.yield'}} + loop.yield.return %arg0 : f32 + }): () -> () + return +} diff --git a/mlir/test/Dialect/Loops/ops.mlir b/mlir/test/Dialect/Loops/ops.mlir --- a/mlir/test/Dialect/Loops/ops.mlir +++ b/mlir/test/Dialect/Loops/ops.mlir @@ -49,3 +49,47 @@ // CHECK-NEXT: %{{.*}} = addf %{{.*}}, %{{.*}} : f32 // CHECK-NEXT: } else { // CHECK-NEXT: %{{.*}} = addf %{{.*}}, %{{.*}} : f32 + +func @std_parallel_loop(%arg0 : index, %arg1 : index, %arg2 : index, + %arg3 : index, %arg4 : index) { + %step = constant 1 : index + loop.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) + step (%arg4, %step) { + %min_cmp = cmpi "slt", %i0, %i1 : index + %min = select %min_cmp, %i0, %i1 : index + %max_cmp = cmpi "sge", %i0, %i1 : index + %max = select %max_cmp, %i0, %i1 : index + %red = loop.parallel (%i2) = (%min) to (%max) step (%i1) { + %cst = constant 1.0 : f32 + loop.yield(%cst) { + ^bb0(%lhs : f32, %rhs: f32): + %res = addf %lhs, %rhs : f32 + loop.yield.return %res : f32 + } : f32 + } : f32 + } + return +} +// CHECK-LABEL: func @std_parallel_loop( +// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: +// CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]: +// CHECK-SAME: %[[ARG2:[A-Za-z0-9]+]]: +// CHECK-SAME: %[[ARG3:[A-Za-z0-9]+]]: +// CHECK-SAME: %[[ARG4:[A-Za-z0-9]+]]: +// CHECK: %[[STEP:.*]] = constant 1 : index +// CHECK-NEXT: loop.parallel (%[[I0:.*]], %[[I1:.*]]) = (%[[ARG0]], %[[ARG1]]) to +// CHECK: (%[[ARG2]], %[[ARG3]]) step (%[[ARG4]], %[[STEP]]) { +// CHECK-NEXT: %[[MIN_CMP:.*]] = cmpi "slt", %[[I0]], %[[I1]] : index +// CHECK-NEXT: %[[MIN:.*]] = select %[[MIN_CMP]], %[[I0]], %[[I1]] : index +// CHECK-NEXT: %[[MAX_CMP:.*]] = cmpi "sge", %[[I0]], %[[I1]] : index +// CHECK-NEXT: %[[MAX:.*]] = select %[[MAX_CMP]], %[[I0]], %[[I1]] : index +// CHECK-NEXT: loop.parallel (%{{.*}}) = (%[[MIN]]) to (%[[MAX]]) step (%[[I1]]) { +// CHECK-NEXT: %[[CST:.*]] = constant 1.000000e+00 : f32 +// CHECK-NEXT: loop.yield(%[[CST]]) { +// CHECK-NEXT: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32): +// CHECK-NEXT: %[[RES:.*]] = addf %[[LHS]], %[[RHS]] : f32 +// CHECK-NEXT: loop.yield.return %[[RES]] : f32 +// CHECK-NEXT: } : f32 +// CHECK-NEXT: "loop.terminator"() : () -> () +// CHECK-NEXT: } : f32 +// CHECK-NEXT: "loop.terminator"() : () -> ()