diff --git a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td --- a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td +++ b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td @@ -272,13 +272,13 @@ For example: ```mlir - loop.parallel (%iv) = (%lb) to (%ub) step (%step) { + loop.parallel (%iv) = (%lb) to (%ub) step (%step) -> f32 { %zero = constant 0.0 : f32 - loop.reduce(%zero) { + loop.reduce(%zero) : f32 { ^bb0(%lhs : f32, %rhs: f32): %res = addf %lhs, %rhs : f32 loop.reduce.return %res : f32 - } : f32 + } } ``` }]; diff --git a/mlir/lib/Dialect/LoopOps/LoopOps.cpp b/mlir/lib/Dialect/LoopOps/LoopOps.cpp --- a/mlir/lib/Dialect/LoopOps/LoopOps.cpp +++ b/mlir/lib/Dialect/LoopOps/LoopOps.cpp @@ -407,7 +407,7 @@ parser.resolveOperands(upper, builder.getIndexType(), result.operands)) return failure(); - // Parse step value. + // Parse step values. SmallVector steps; if (parser.parseKeyword("step") || parser.parseOperandList(steps, ivs.size(), @@ -415,7 +415,7 @@ parser.resolveOperands(steps, builder.getIndexType(), result.operands)) return failure(); - // Parse step value. + // Parse init values. SmallVector initVals; if (succeeded(parser.parseOptionalKeyword("init"))) { if (parser.parseOperandList(initVals, /*requiredOperandCount=*/-1, @@ -423,6 +423,10 @@ return failure(); } + // Parse optional results in case there is a reduce. + if (parser.parseOptionalArrowTypeList(result.types)) + return failure(); + // Now parse the body. Region *body = result.addRegion(); SmallVector types(ivs.size(), builder.getIndexType()); @@ -437,9 +441,8 @@ static_cast(steps.size()), static_cast(initVals.size())})); - // Parse attributes and optional results (in case there is a reduce). - if (parser.parseOptionalAttrDict(result.attributes) || - parser.parseOptionalColonTypeList(result.types)) + // Parse attributes. + if (parser.parseOptionalAttrDict(result.attributes)) return failure(); if (!initVals.empty()) @@ -457,11 +460,10 @@ << ")"; if (!op.initVals().empty()) p << " init (" << op.initVals() << ")"; + p.printOptionalArrowTypeList(op.getResultTypes()); p.printRegion(op.region(), /*printEntryBlockArgs=*/false); p.printOptionalAttrDict( op.getAttrs(), /*elidedAttrs=*/ParallelOp::getOperandSegmentSizeAttr()); - if (!op.results().empty()) - p << " : " << op.getResultTypes(); } ParallelOp mlir::loop::getParallelForInductionVarOwner(Value val) { @@ -515,24 +517,24 @@ parser.parseRParen()) return failure(); - // Now parse the body. - Region *body = result.addRegion(); - if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{})) - return failure(); - - // And the type of the operand (and also what reduce computes on). Type resultType; + // Parse the type of the operand (and also what reduce computes on). if (parser.parseColonType(resultType) || parser.resolveOperand(operand, resultType, result.operands)) return failure(); + // Now parse the body. + Region *body = result.addRegion(); + if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{})) + return failure(); + return success(); } static void print(OpAsmPrinter &p, ReduceOp op) { p << op.getOperationName() << "(" << op.operand() << ") "; - p.printRegion(op.reductionOperator()); p << " : " << op.operand().getType(); + p.printRegion(op.reductionOperator()); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/convert-to-cfg.mlir b/mlir/test/Conversion/convert-to-cfg.mlir --- a/mlir/test/Conversion/convert-to-cfg.mlir +++ b/mlir/test/Conversion/convert-to-cfg.mlir @@ -268,14 +268,14 @@ // The continuation block has access to the (last value of) reduction. // CHECK: ^[[CONTINUE]]: // CHECK: return %[[ITER_ARG]] - %0 = loop.parallel (%i) = (%arg0) to (%arg1) step (%arg2) init(%arg3) { + %0 = loop.parallel (%i) = (%arg0) to (%arg1) step (%arg2) init(%arg3) -> f32 { %cst = constant 42.0 : f32 - loop.reduce(%cst) { + loop.reduce(%cst) : f32 { ^bb0(%lhs: f32, %rhs: f32): %1 = mulf %lhs, %rhs : f32 loop.reduce.return %1 : f32 - } : f32 - } : f32 + } + } return %0 : f32 } @@ -304,20 +304,20 @@ %step = constant 1 : index %init = constant 42 : i64 %0:2 = loop.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) - step (%arg4, %step) init(%arg5, %init) { + step (%arg4, %step) init(%arg5, %init) -> (f32, i64) { %cf = constant 42.0 : f32 - loop.reduce(%cf) { + loop.reduce(%cf) : f32 { ^bb0(%lhs: f32, %rhs: f32): %1 = addf %lhs, %rhs : f32 loop.reduce.return %1 : f32 - } : f32 + } %2 = call @generate() : () -> i64 - loop.reduce(%2) { + loop.reduce(%2) : i64 { ^bb0(%lhs: i64, %rhs: i64): %3 = or %lhs, %rhs : i64 loop.reduce.return %3 : i64 - } : i64 - } : f32, i64 + } + } return %0#0, %0#1 : f32, i64 } diff --git a/mlir/test/Dialect/Loops/invalid.mlir b/mlir/test/Dialect/Loops/invalid.mlir --- a/mlir/test/Dialect/Loops/invalid.mlir +++ b/mlir/test/Dialect/Loops/invalid.mlir @@ -175,10 +175,10 @@ // expected-error@+1 {{expects number of results: 0 to be the same as number of reductions: 1}} loop.parallel (%i0) = (%arg0) to (%arg1) step (%arg2) { %c0 = constant 1.0 : f32 - loop.reduce(%c0) { + loop.reduce(%c0) : f32 { ^bb0(%lhs: f32, %rhs: f32): loop.reduce.return %lhs : f32 - } : f32 + } } return } @@ -189,8 +189,8 @@ %arg0 : index, %arg1 : index, %arg2 : index) { // expected-error@+2 {{expects number of results: 1 to be the same as number of reductions: 0}} %zero = constant 1.0 : f32 - %res = loop.parallel (%i0) = (%arg0) to (%arg1) step (%arg2) init (%zero) { - } : f32 + %res = loop.parallel (%i0) = (%arg0) to (%arg1) step (%arg2) init (%zero) -> f32 { + } return } @@ -200,13 +200,12 @@ func @parallel_more_results_than_initial_values( %arg0 : index, %arg1: index, %arg2: index) { // expected-error@+1 {{expects number of results: 1 to be the same as number of initial values: 0}} - %res = loop.parallel (%i0) = (%arg0) to (%arg1) step (%arg2) { - loop.reduce(%arg0) { + %res = loop.parallel (%i0) = (%arg0) to (%arg1) step (%arg2) -> f32 { + loop.reduce(%arg0) : index { ^bb0(%lhs: index, %rhs: index): loop.reduce.return %lhs : index - } : index - } : f32 - return + } + } } // ----- @@ -214,13 +213,14 @@ func @parallel_different_types_of_results_and_reduces( %arg0 : index, %arg1: index, %arg2: index) { %zero = constant 0.0 : f32 - %res = loop.parallel (%i0) = (%arg0) to (%arg1) step (%arg2) init (%zero) { + %res = loop.parallel (%i0) = (%arg0) to (%arg1) + step (%arg2) init (%zero) -> f32 { // expected-error@+1 {{expects type of reduce: 'index' to be the same as result type: 'f32'}} - loop.reduce(%arg0) { + loop.reduce(%arg0) : index { ^bb0(%lhs: index, %rhs: index): loop.reduce.return %lhs : index - } : index - } : f32 + } + } return } @@ -228,10 +228,10 @@ func @top_level_reduce(%arg0 : f32) { // expected-error@+1 {{expects parent op 'loop.parallel'}} - loop.reduce(%arg0) { + loop.reduce(%arg0) : f32 { ^bb0(%lhs : f32, %rhs : f32): loop.reduce.return %lhs : f32 - } : f32 + } return } @@ -239,12 +239,13 @@ func @reduce_empty_block(%arg0 : index, %arg1 : f32) { %zero = constant 0.0 : f32 - %res = loop.parallel (%i0) = (%arg0) to (%arg0) step (%arg0) init (%zero) { + %res = loop.parallel (%i0) = (%arg0) to (%arg0) + step (%arg0) init (%zero) -> f32 { // expected-error@+1 {{the block inside reduce should not be empty}} - loop.reduce(%arg1) { + loop.reduce(%arg1) : f32 { ^bb0(%lhs : f32, %rhs : f32): - } : f32 - } : f32 + } + } return } @@ -252,13 +253,14 @@ func @reduce_too_many_args(%arg0 : index, %arg1 : f32) { %zero = constant 0.0 : f32 - %res = loop.parallel (%i0) = (%arg0) to (%arg0) step (%arg0) init (%zero) { + %res = loop.parallel (%i0) = (%arg0) to (%arg0) + step (%arg0) init (%zero) -> f32 { // expected-error@+1 {{expects two arguments to reduce block of type 'f32'}} - loop.reduce(%arg1) { + loop.reduce(%arg1) : f32 { ^bb0(%lhs : f32, %rhs : f32, %other : f32): loop.reduce.return %lhs : f32 - } : f32 - } : f32 + } + } return } @@ -266,13 +268,14 @@ func @reduce_wrong_args(%arg0 : index, %arg1 : f32) { %zero = constant 0.0 : f32 - %res = loop.parallel (%i0) = (%arg0) to (%arg0) step (%arg0) init (%zero) { + %res = loop.parallel (%i0) = (%arg0) to (%arg0) + step (%arg0) init (%zero) -> f32 { // expected-error@+1 {{expects two arguments to reduce block of type 'f32'}} - loop.reduce(%arg1) { + loop.reduce(%arg1) : f32 { ^bb0(%lhs : f32, %rhs : i32): loop.reduce.return %lhs : f32 - } : f32 - } : f32 + } + } return } @@ -281,13 +284,14 @@ func @reduce_wrong_terminator(%arg0 : index, %arg1 : f32) { %zero = constant 0.0 : f32 - %res = loop.parallel (%i0) = (%arg0) to (%arg0) step (%arg0) init (%zero) { + %res = loop.parallel (%i0) = (%arg0) to (%arg0) + step (%arg0) init (%zero) -> f32 { // expected-error@+1 {{the block inside reduce should be terminated with a 'loop.reduce.return' op}} - loop.reduce(%arg1) { + loop.reduce(%arg1) : f32 { ^bb0(%lhs : f32, %rhs : f32): loop.yield - } : f32 - } : f32 + } + } return } @@ -295,14 +299,15 @@ func @reduceReturn_wrong_type(%arg0 : index, %arg1: f32) { %zero = constant 0.0 : f32 - %res = loop.parallel (%i0) = (%arg0) to (%arg0) step (%arg0) init (%zero) { - loop.reduce(%arg1) { + %res = loop.parallel (%i0) = (%arg0) to (%arg0) + step (%arg0) init (%zero) -> f32 { + loop.reduce(%arg1) : f32 { ^bb0(%lhs : f32, %rhs : f32): %c0 = constant 1 : index // expected-error@+1 {{needs to have type 'f32' (the type of the enclosing ReduceOp)}} loop.reduce.return %c0 : index - } : f32 - } : f32 + } + } return } @@ -349,7 +354,8 @@ %s0 = constant 0.0 : f32 %t0 = constant 1 : i32 // expected-error@+1 {{mismatch in number of loop-carried values and defined values}} - %result1:3 = loop.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%si = %s0, %ti = %t0) -> (f32, i32, f32) { + %result1:3 = loop.for %i0 = %arg0 to %arg1 step %arg2 + iter_args(%si = %s0, %ti = %t0) -> (f32, i32, f32) { %sn = addf %si, %si : f32 %tn = addi %ti, %ti : i32 loop.yield %sn, %tn, %sn : f32, i32, f32 @@ -364,7 +370,8 @@ %t0 = constant 1 : i32 %u0 = constant 1.0 : f32 // expected-error@+1 {{mismatch in number of loop-carried values and defined values}} - %result1:2 = loop.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%si = %s0, %ti = %t0, %ui = %u0) -> (f32, i32) { + %result1:2 = loop.for %i0 = %arg0 to %arg1 step %arg2 + iter_args(%si = %s0, %ti = %t0, %ui = %u0) -> (f32, i32) { %sn = addf %si, %si : f32 %tn = addi %ti, %ti : i32 %un = subf %ui, %ui : f32 @@ -379,8 +386,9 @@ // expected-note@+1 {{prior use here}} %s0 = constant 0.0 : f32 %t0 = constant 1.0 : f32 - // expected-error@+1 {{expects different type than prior uses: 'i32' vs 'f32'}} - %result1:2 = loop.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%si = %s0, %ti = %t0) -> (i32, i32) { + // expected-error@+2 {{expects different type than prior uses: 'i32' vs 'f32'}} + %result1:2 = loop.for %i0 = %arg0 to %arg1 step %arg2 + iter_args(%si = %s0, %ti = %t0) -> (i32, i32) { %sn = addf %si, %si : i32 %tn = addf %ti, %ti : i32 loop.yield %sn, %tn : i32, i32 diff --git a/mlir/test/Dialect/Loops/ops.mlir b/mlir/test/Dialect/Loops/ops.mlir --- a/mlir/test/Dialect/Loops/ops.mlir +++ b/mlir/test/Dialect/Loops/ops.mlir @@ -60,14 +60,22 @@ %max_cmp = cmpi "sge", %i0, %i1 : index %max = select %max_cmp, %i0, %i1 : index %zero = constant 0.0 : f32 - %red = loop.parallel (%i2) = (%min) to (%max) step (%i1) init (%zero) { + %int_zero = constant 0 : i32 + %red:2 = loop.parallel (%i2) = (%min) to (%max) step (%i1) + init (%zero, %int_zero) -> (f32, i32) { %one = constant 1.0 : f32 - loop.reduce(%one) { + loop.reduce(%one) : f32 { ^bb0(%lhs : f32, %rhs: f32): %res = addf %lhs, %rhs : f32 loop.reduce.return %res : f32 - } : f32 - } : f32 + } + %int_one = constant 1 : i32 + loop.reduce(%int_one) : i32 { + ^bb0(%lhs : i32, %rhs: i32): + %res = muli %lhs, %rhs : i32 + loop.reduce.return %res : i32 + } + } } return } @@ -85,16 +93,24 @@ // CHECK-NEXT: %[[MAX_CMP:.*]] = cmpi "sge", %[[I0]], %[[I1]] : index // CHECK-NEXT: %[[MAX:.*]] = select %[[MAX_CMP]], %[[I0]], %[[I1]] : index // CHECK-NEXT: %[[ZERO:.*]] = constant 0.000000e+00 : f32 +// CHECK-NEXT: %[[INT_ZERO:.*]] = constant 0 : i32 // CHECK-NEXT: loop.parallel (%{{.*}}) = (%[[MIN]]) to (%[[MAX]]) -// CHECK-SAME: step (%[[I1]]) init (%[[ZERO]]) { +// CHECK-SAME: step (%[[I1]]) +// CHECK-SAME: init (%[[ZERO]], %[[INT_ZERO]]) -> (f32, i32) { // CHECK-NEXT: %[[ONE:.*]] = constant 1.000000e+00 : f32 -// CHECK-NEXT: loop.reduce(%[[ONE]]) { +// CHECK-NEXT: loop.reduce(%[[ONE]]) : f32 { // CHECK-NEXT: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32): // CHECK-NEXT: %[[RES:.*]] = addf %[[LHS]], %[[RHS]] : f32 // CHECK-NEXT: loop.reduce.return %[[RES]] : f32 -// CHECK-NEXT: } : f32 +// CHECK-NEXT: } +// CHECK-NEXT: %[[INT_ONE:.*]] = constant 1 : i32 +// CHECK-NEXT: loop.reduce(%[[INT_ONE]]) : i32 { +// CHECK-NEXT: ^bb0(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32): +// CHECK-NEXT: %[[RES:.*]] = muli %[[LHS]], %[[RHS]] : i32 +// CHECK-NEXT: loop.reduce.return %[[RES]] : i32 +// CHECK-NEXT: } // CHECK-NEXT: loop.yield -// CHECK-NEXT: } : f32 +// CHECK-NEXT: } // CHECK-NEXT: loop.yield func @parallel_explicit_yield(