diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -1239,10 +1239,35 @@ newSteps.push_back(step); } } - // Exit if all or none of the loop dimensions perform a single iteration. - if (newLowerBounds.size() == 0 || - newLowerBounds.size() == op.lowerBound().size()) + // Exit if none of the loop dimensions perform a single iteration. + if (newLowerBounds.size() == op.lowerBound().size()) return failure(); + + if (newLowerBounds.empty()) { + // All of the loop dimensions perform a single iteration. Inline + // loop body and nested ReduceOp's + SmallVector results; + results.reserve(op.initVals().size()); + for (auto &bodyOp : op.getLoopBody().front().without_terminator()) { + if (auto reduce = dyn_cast(bodyOp)) { + Block &reduceBlock = reduce.reductionOperator().front(); + auto initValIndex = results.size(); + mapping.map(reduceBlock.getArgument(0), op.initVals()[initValIndex]); + mapping.map(reduceBlock.getArgument(1), + mapping.lookupOrDefault(reduce.operand())); + for (auto &reduceBodyOp : reduceBlock.without_terminator()) + rewriter.clone(reduceBodyOp, mapping); + + auto result = mapping.lookupOrDefault( + cast(reduceBlock.getTerminator()).result()); + results.push_back(result); + } else { + rewriter.clone(bodyOp, mapping); + } + } + rewriter.replaceOp(op, results); + return success(); + } // Replace the parallel loop by lower-dimensional parallel loop. auto newOp = rewriter.create(op.getLoc(), newLowerBounds, newUpperBounds, diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -3,7 +3,7 @@ // ----- -func @single_iteration(%A: memref) { +func @single_iteration_some(%A: memref) { %c0 = constant 0 : index %c1 = constant 1 : index %c2 = constant 2 : index @@ -19,7 +19,7 @@ return } -// CHECK-LABEL: func @single_iteration( +// CHECK-LABEL: func @single_iteration_some( // CHECK-SAME: [[ARG0:%.*]]: memref) { // CHECK-DAG: [[C42:%.*]] = constant 42 : i32 // CHECK-DAG: [[C7:%.*]] = constant 7 : index @@ -35,6 +35,70 @@ // ----- +func @single_iteration_all(%A: memref) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c3 = constant 3 : index + %c6 = constant 6 : index + %c7 = constant 7 : index + %c10 = constant 10 : index + scf.parallel (%i0, %i1, %i2) = (%c0, %c3, %c7) to (%c1, %c6, %c10) step (%c1, %c3, %c3) { + %c42 = constant 42 : i32 + memref.store %c42, %A[%i0, %i1, %i2] : memref + scf.yield + } + return +} + +// CHECK-LABEL: func @single_iteration_all( +// CHECK-SAME: [[ARG0:%.*]]: memref) { +// CHECK-DAG: [[C42:%.*]] = constant 42 : i32 +// CHECK-DAG: [[C7:%.*]] = constant 7 : index +// CHECK-DAG: [[C3:%.*]] = constant 3 : index +// CHECK-DAG: [[C0:%.*]] = constant 0 : index +// CHECK-NOT: scf.parallel +// CHECK: memref.store [[C42]], [[ARG0]]{{\[}}[[C0]], [[C3]], [[C7]]] : memref +// CHECK-NOT: scf.yield +// CHECK: return + +// ----- + +func @single_iteration_reduce(%A: index, %B: index) -> (index, index) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + %c3 = constant 3 : index + %c6 = constant 6 : index + %0:2 = scf.parallel (%i0, %i1) = (%c1, %c3) to (%c2, %c6) step (%c1, %c3) init(%A, %B) -> (index, index) { + scf.reduce(%i0) : index { + ^bb0(%lhs: index, %rhs: index): + %1 = addi %lhs, %rhs : index + scf.reduce.return %1 : index + } + scf.reduce(%i1) : index { + ^bb0(%lhs: index, %rhs: index): + %2 = muli %lhs, %rhs : index + scf.reduce.return %2 : index + } + scf.yield + } + return %0#0, %0#1 : index, index +} + +// CHECK-LABEL: func @single_iteration_reduce( +// CHECK-SAME: [[ARG0:%.*]]: index, [[ARG1:%.*]]: index) +// CHECK-DAG: [[C3:%.*]] = constant 3 : index +// CHECK-DAG: [[C1:%.*]] = constant 1 : index +// CHECK-NOT: scf.parallel +// CHECK-NOT: scf.reduce +// CHECK-NOT: scf.reduce.return +// CHECK-NOT: scf.yield +// CHECK: [[V0:%.*]] = addi [[ARG0]], [[C1]] +// CHECK: [[V1:%.*]] = muli [[ARG1]], [[C3]] +// CHECK: return [[V0]], [[V1]] + +// ----- + func private @side_effect() func @one_unused(%cond: i1) -> (index) { %c0 = constant 0 : index @@ -488,7 +552,7 @@ %ub : index, %lb : index, %step : index) -> (i32, i32) { // CHECK-NEXT: %[[C32:.*]] = constant 32 : i32 %cst = constant 32 : i32 - // CHECK-NEXT: %[[FOR_RES:.*]] = scf.for {{.*}} iter_args({{.*}} = %[[A0]]) -> (i32) { + // CHECK-NEXT: %[[FOR_RES:.*]] = scf.for {{.*}} iter_args({{.*}} = %[[A0]]) -> (i32) { %0:2 = scf.for %arg1 = %lb to %ub step %step iter_args(%arg2 = %arg0, %arg3 = %cst) -> (i32, i32) { %1 = addi %arg2, %cst : i32 @@ -512,7 +576,7 @@ %1 = addi %arg2, %cst : i32 scf.yield %1, %1 : i32, i32 } - + // CHECK: return %[[FOR_RES]] : i32 return %0#0 : i32 }