diff --git a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td --- a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td +++ b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td @@ -272,13 +272,13 @@ For example: ```mlir - loop.parallel (%iv) = (%lb) to (%ub) step (%step) { + loop.parallel (%iv) = (%lb) to (%ub) step (%step) -> f32 { %zero = constant 0.0 : f32 - loop.reduce(%zero) { + loop.reduce(%zero) : f32 { ^bb0(%lhs : f32, %rhs: f32): %res = addf %lhs, %rhs : f32 loop.reduce.return %res : f32 - } : f32 + } } ``` }]; diff --git a/mlir/lib/Dialect/LoopOps/LoopOps.cpp b/mlir/lib/Dialect/LoopOps/LoopOps.cpp --- a/mlir/lib/Dialect/LoopOps/LoopOps.cpp +++ b/mlir/lib/Dialect/LoopOps/LoopOps.cpp @@ -423,6 +423,10 @@ return failure(); } + // Parse optional results in case there is a reduce. + if (parser.parseOptionalArrowTypeList(result.types)) + return failure(); + // Now parse the body. Region *body = result.addRegion(); SmallVector types(ivs.size(), builder.getIndexType()); @@ -437,9 +441,8 @@ static_cast(steps.size()), static_cast(initVals.size())})); - // Parse attributes and optional results (in case there is a reduce). - if (parser.parseOptionalAttrDict(result.attributes) || - parser.parseOptionalColonTypeList(result.types)) + // Parse attributes. + if (parser.parseOptionalAttrDict(result.attributes)) return failure(); if (!initVals.empty()) @@ -457,11 +460,10 @@ << ")"; if (!op.initVals().empty()) p << " init (" << op.initVals() << ")"; + p.printOptionalArrowTypeList(op.getResultTypes()); p.printRegion(op.region(), /*printEntryBlockArgs=*/false); p.printOptionalAttrDict( op.getAttrs(), /*elidedAttrs=*/ParallelOp::getOperandSegmentSizeAttr()); - if (!op.results().empty()) - p << " : " << op.getResultTypes(); } ParallelOp mlir::loop::getParallelForInductionVarOwner(Value val) { @@ -515,24 +517,24 @@ parser.parseRParen()) return failure(); - // Now parse the body. - Region *body = result.addRegion(); - if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{})) - return failure(); - - // And the type of the operand (and also what reduce computes on). Type resultType; + // And the type of the operand (and also what reduce computes on). if (parser.parseColonType(resultType) || parser.resolveOperand(operand, resultType, result.operands)) return failure(); + // Now parse the body. + Region *body = result.addRegion(); + if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{})) + return failure(); + return success(); } static void print(OpAsmPrinter &p, ReduceOp op) { p << op.getOperationName() << "(" << op.operand() << ") "; - p.printRegion(op.reductionOperator()); p << " : " << op.operand().getType(); + p.printRegion(op.reductionOperator()); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Loops/invalid.mlir b/mlir/test/Dialect/Loops/invalid.mlir --- a/mlir/test/Dialect/Loops/invalid.mlir +++ b/mlir/test/Dialect/Loops/invalid.mlir @@ -175,10 +175,10 @@ // expected-error@+1 {{expects number of results: 0 to be the same as number of reductions: 1}} loop.parallel (%i0) = (%arg0) to (%arg1) step (%arg2) { %c0 = constant 1.0 : f32 - loop.reduce(%c0) { + loop.reduce(%c0) : f32 { ^bb0(%lhs: f32, %rhs: f32): loop.reduce.return %lhs : f32 - } : f32 + } } return } @@ -189,8 +189,8 @@ %arg0 : index, %arg1 : index, %arg2 : index) { // expected-error@+2 {{expects number of results: 1 to be the same as number of reductions: 0}} %zero = constant 1.0 : f32 - %res = loop.parallel (%i0) = (%arg0) to (%arg1) step (%arg2) init (%zero) { - } : f32 + %res = loop.parallel (%i0) = (%arg0) to (%arg1) step (%arg2) init (%zero) -> f32 { + } return } @@ -200,13 +200,12 @@ func @parallel_more_results_than_initial_values( %arg0 : index, %arg1: index, %arg2: index) { // expected-error@+1 {{expects number of results: 1 to be the same as number of initial values: 0}} - %res = loop.parallel (%i0) = (%arg0) to (%arg1) step (%arg2) { - loop.reduce(%arg0) { + %res = loop.parallel (%i0) = (%arg0) to (%arg1) step (%arg2) -> f32 { + loop.reduce(%arg0) : index { ^bb0(%lhs: index, %rhs: index): loop.reduce.return %lhs : index - } : index - } : f32 - return + } + } } // ----- @@ -214,13 +213,13 @@ func @parallel_different_types_of_results_and_reduces( %arg0 : index, %arg1: index, %arg2: index) { %zero = constant 0.0 : f32 - %res = loop.parallel (%i0) = (%arg0) to (%arg1) step (%arg2) init (%zero) { + %res = loop.parallel (%i0) = (%arg0) to (%arg1) step (%arg2) init (%zero) -> f32 { // expected-error@+1 {{expects type of reduce: 'index' to be the same as result type: 'f32'}} - loop.reduce(%arg0) { + loop.reduce(%arg0) : index { ^bb0(%lhs: index, %rhs: index): loop.reduce.return %lhs : index - } : index - } : f32 + } + } return } @@ -228,10 +227,10 @@ func @top_level_reduce(%arg0 : f32) { // expected-error@+1 {{expects parent op 'loop.parallel'}} - loop.reduce(%arg0) { + loop.reduce(%arg0) : f32 { ^bb0(%lhs : f32, %rhs : f32): loop.reduce.return %lhs : f32 - } : f32 + } return } @@ -239,12 +238,12 @@ func @reduce_empty_block(%arg0 : index, %arg1 : f32) { %zero = constant 0.0 : f32 - %res = loop.parallel (%i0) = (%arg0) to (%arg0) step (%arg0) init (%zero) { + %res = loop.parallel (%i0) = (%arg0) to (%arg0) step (%arg0) init (%zero) -> f32 { // expected-error@+1 {{the block inside reduce should not be empty}} - loop.reduce(%arg1) { + loop.reduce(%arg1) : f32 { ^bb0(%lhs : f32, %rhs : f32): - } : f32 - } : f32 + } + } return } @@ -252,13 +251,13 @@ func @reduce_too_many_args(%arg0 : index, %arg1 : f32) { %zero = constant 0.0 : f32 - %res = loop.parallel (%i0) = (%arg0) to (%arg0) step (%arg0) init (%zero) { + %res = loop.parallel (%i0) = (%arg0) to (%arg0) step (%arg0) init (%zero) -> f32 { // expected-error@+1 {{expects two arguments to reduce block of type 'f32'}} - loop.reduce(%arg1) { + loop.reduce(%arg1) : f32 { ^bb0(%lhs : f32, %rhs : f32, %other : f32): loop.reduce.return %lhs : f32 - } : f32 - } : f32 + } + } return } @@ -266,13 +265,13 @@ func @reduce_wrong_args(%arg0 : index, %arg1 : f32) { %zero = constant 0.0 : f32 - %res = loop.parallel (%i0) = (%arg0) to (%arg0) step (%arg0) init (%zero) { + %res = loop.parallel (%i0) = (%arg0) to (%arg0) step (%arg0) init (%zero) -> f32 { // expected-error@+1 {{expects two arguments to reduce block of type 'f32'}} - loop.reduce(%arg1) { + loop.reduce(%arg1) : f32 { ^bb0(%lhs : f32, %rhs : i32): loop.reduce.return %lhs : f32 - } : f32 - } : f32 + } + } return } @@ -281,13 +280,13 @@ func @reduce_wrong_terminator(%arg0 : index, %arg1 : f32) { %zero = constant 0.0 : f32 - %res = loop.parallel (%i0) = (%arg0) to (%arg0) step (%arg0) init (%zero) { + %res = loop.parallel (%i0) = (%arg0) to (%arg0) step (%arg0) init (%zero) -> f32 { // expected-error@+1 {{the block inside reduce should be terminated with a 'loop.reduce.return' op}} - loop.reduce(%arg1) { + loop.reduce(%arg1) : f32 { ^bb0(%lhs : f32, %rhs : f32): loop.yield - } : f32 - } : f32 + } + } return } @@ -295,14 +294,14 @@ func @reduceReturn_wrong_type(%arg0 : index, %arg1: f32) { %zero = constant 0.0 : f32 - %res = loop.parallel (%i0) = (%arg0) to (%arg0) step (%arg0) init (%zero) { - loop.reduce(%arg1) { + %res = loop.parallel (%i0) = (%arg0) to (%arg0) step (%arg0) init (%zero) -> f32 { + loop.reduce(%arg1) : f32 { ^bb0(%lhs : f32, %rhs : f32): %c0 = constant 1 : index // expected-error@+1 {{needs to have type 'f32' (the type of the enclosing ReduceOp)}} loop.reduce.return %c0 : index - } : f32 - } : f32 + } + } return } diff --git a/mlir/test/Dialect/Loops/ops.mlir b/mlir/test/Dialect/Loops/ops.mlir --- a/mlir/test/Dialect/Loops/ops.mlir +++ b/mlir/test/Dialect/Loops/ops.mlir @@ -60,14 +60,15 @@ %max_cmp = cmpi "sge", %i0, %i1 : index %max = select %max_cmp, %i0, %i1 : index %zero = constant 0.0 : f32 - %red = loop.parallel (%i2) = (%min) to (%max) step (%i1) init (%zero) { + %red = loop.parallel (%i2) = (%min) to (%max) step (%i1) + init (%zero) -> f32 { %one = constant 1.0 : f32 - loop.reduce(%one) { + loop.reduce(%one) : f32 { ^bb0(%lhs : f32, %rhs: f32): %res = addf %lhs, %rhs : f32 loop.reduce.return %res : f32 - } : f32 - } : f32 + } + } } return } @@ -86,15 +87,15 @@ // CHECK-NEXT: %[[MAX:.*]] = select %[[MAX_CMP]], %[[I0]], %[[I1]] : index // CHECK-NEXT: %[[ZERO:.*]] = constant 0.000000e+00 : f32 // CHECK-NEXT: loop.parallel (%{{.*}}) = (%[[MIN]]) to (%[[MAX]]) -// CHECK-SAME: step (%[[I1]]) init (%[[ZERO]]) { +// CHECK-SAME: step (%[[I1]]) init (%[[ZERO]]) -> f32 { // CHECK-NEXT: %[[ONE:.*]] = constant 1.000000e+00 : f32 -// CHECK-NEXT: loop.reduce(%[[ONE]]) { +// CHECK-NEXT: loop.reduce(%[[ONE]]) : f32 { // CHECK-NEXT: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32): // CHECK-NEXT: %[[RES:.*]] = addf %[[LHS]], %[[RHS]] : f32 // CHECK-NEXT: loop.reduce.return %[[RES]] : f32 -// CHECK-NEXT: } : f32 +// CHECK-NEXT: } // CHECK-NEXT: loop.yield -// CHECK-NEXT: } : f32 +// CHECK-NEXT: } // CHECK-NEXT: loop.yield func @parallel_explicit_yield(