diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp --- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp @@ -294,81 +294,6 @@ return newForOp; } -/// Replace any use of `target` with `replacement` in `op`'s operands or within -/// `op`'s nested regions. -static void replaceInOp(Operation *op, Value target, Value replacement) { - for (auto &use : llvm::make_early_inc_range(target.getUses())) { - if (op->isAncestor(use.getOwner())) - use.set(replacement); - } -} - -/// Given a cloned op in the new kernel body, updates induction variable uses. -/// We replace it with a version incremented based on the stage where it is -/// used. -static void updateInductionVariableUses(RewriterBase &rewriter, Location loc, - Operation *newOp, Value newForIv, - unsigned maxStage, unsigned useStage, - unsigned step) { - rewriter.setInsertionPoint(newOp); - Value offset = rewriter.create( - loc, (maxStage - useStage) * step); - Value iv = rewriter.create(loc, newForIv, offset); - replaceInOp(newOp, newForIv, iv); - rewriter.setInsertionPointAfter(newOp); -} - -/// If the value is a loop carried value coming from stage N + 1 remap, it will -/// become a direct use. -static void updateIterArgUses(RewriterBase &rewriter, IRMapping &bvm, - Operation *newOp, ForOp oldForOp, ForOp newForOp, - unsigned useStage, - const DenseMap &stages) { - - for (unsigned i = 0; i < oldForOp.getNumRegionIterArgs(); i++) { - Value yieldedVal = oldForOp.getBody()->getTerminator()->getOperand(i); - Operation *dep = yieldedVal.getDefiningOp(); - if (!dep) - continue; - auto stageDep = stages.find(dep); - if (stageDep == stages.end() || stageDep->second == useStage) - continue; - if (stageDep->second != useStage + 1) - continue; - Value replacement = bvm.lookup(yieldedVal); - replaceInOp(newOp, newForOp.getRegionIterArg(i), replacement); - } -} - -/// For operands defined in a previous stage we need to remap it to use the -/// correct region argument. We look for the right version of the Value based -/// on the stage where it is used. -static void updateCrossStageUses( - RewriterBase &rewriter, Operation *newOp, IRMapping &bvm, ForOp newForOp, - unsigned useStage, const DenseMap &stages, - const llvm::DenseMap, unsigned> &loopArgMap) { - // Because we automatically cloned the sub-regions, there's no simple way - // to walk the nested regions in pairs of (oldOps, newOps), so we just - // traverse the set of remapped loop arguments, filter which ones are - // relevant, and replace any uses. - for (auto [remapPair, newIterIdx] : loopArgMap) { - auto [crossArgValue, stageIdx] = remapPair; - Operation *def = crossArgValue.getDefiningOp(); - assert(def); - unsigned stageDef = stages.lookup(def); - if (useStage <= stageDef || useStage - stageDef != stageIdx) - continue; - - // Use "lookupOrDefault" for the target value because some operations - // are remapped, while in other cases the original will be present. - Value target = bvm.lookupOrDefault(crossArgValue); - Value replacement = newForOp.getRegionIterArg(newIterIdx); - - // Replace uses in the new op's operands and any nested uses. - replaceInOp(newOp, target, replacement); - } -} - void LoopPipelinerInternal::createKernel( scf::ForOp newForOp, const llvm::MapVector @@ -400,16 +325,59 @@ for (Operation *op : opOrder) { int64_t useStage = stages[op]; auto *newOp = rewriter.clone(*op, mapping); - - // Within the kernel body, update uses of the induction variable, uses of - // the original iter args, and uses of cross stage values. - updateInductionVariableUses(rewriter, forOp.getLoc(), newOp, - newForOp.getInductionVar(), maxStage, - stages[op], step); - updateIterArgUses(rewriter, mapping, newOp, forOp, newForOp, useStage, - stages); - updateCrossStageUses(rewriter, newOp, mapping, newForOp, useStage, stages, - loopArgMap); + SmallVector operands; + // Collect all the operands for the cloned op and its nested ops. + op->walk([&operands](Operation *nestedOp) { + for (OpOperand &operand : nestedOp->getOpOperands()) { + operands.push_back(&operand); + } + }); + for (OpOperand *operand : operands) { + Operation *nestedNewOp = mapping.lookup(operand->getOwner()); + // Special case for the induction variable uses. We replace it with a + // version incremented based on the stage where it is used. + if (operand->get() == forOp.getInductionVar()) { + rewriter.setInsertionPoint(newOp); + Value offset = rewriter.create( + forOp.getLoc(), (maxStage - stages[op]) * step); + Value iv = rewriter.create( + forOp.getLoc(), newForOp.getInductionVar(), offset); + nestedNewOp->setOperand(operand->getOperandNumber(), iv); + rewriter.setInsertionPointAfter(newOp); + continue; + } + auto arg = operand->get().dyn_cast(); + if (arg && arg.getOwner() == forOp.getBody()) { + // If the value is a loop carried value coming from stage N + 1 remap, + // it will become a direct use. + Value ret = forOp.getBody()->getTerminator()->getOperand( + arg.getArgNumber() - 1); + Operation *dep = ret.getDefiningOp(); + if (!dep) + continue; + auto stageDep = stages.find(dep); + if (stageDep == stages.end() || stageDep->second == useStage) + continue; + assert(stageDep->second == useStage + 1); + nestedNewOp->setOperand(operand->getOperandNumber(), + mapping.lookupOrDefault(ret)); + continue; + } + // For operands defined in a previous stage we need to remap it to use + // the correct region argument. We look for the right version of the + // Value based on the stage where it is used. + Operation *def = operand->get().getDefiningOp(); + if (!def) + continue; + auto stageDef = stages.find(def); + if (stageDef == stages.end() || stageDef->second == useStage) + continue; + auto remap = loopArgMap.find( + std::make_pair(operand->get(), useStage - stageDef->second)); + assert(remap != loopArgMap.end()); + nestedNewOp->setOperand(operand->getOperandNumber(), + newForOp.getRegionIterArgs()[remap->second]); + } if (predicates[useStage]) { newOp = predicateFn(newOp, predicates[useStage], rewriter); diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir --- a/mlir/test/Dialect/SCF/loop-pipelining.mlir +++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir @@ -627,3 +627,47 @@ } { __test_pipelining_loop__ } return } + +// ----- + +// CHECK-LABEL: @backedge_mix_order +// CHECK-SAME: (%[[A:.*]]: memref) -> f32 { +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[CSTF:.*]] = arith.constant 2.000000e+00 : f32 +// Prologue: +// CHECK: %[[L0:.*]] = memref.load %[[A]][%[[C0]]] : memref +// CHECK-NEXT: %[[L1:.*]] = memref.load %[[A]][%[[C1]]] : memref +// Kernel: +// CHECK-NEXT: %[[R:.*]]:3 = scf.for %[[IV:.*]] = %[[C0]] to %[[C3]] +// CHECK-SAME: step %[[C1]] iter_args(%[[C:.*]] = %[[CSTF]], +// CHECK-SAME: %[[ARG1:.*]] = %[[L0]], %[[ARG2:.*]] = %[[L1]]) -> (f32, f32, f32) { +// CHECK-NEXT: %[[IV2:.*]] = arith.addi %[[IV]], %[[C1]] : index +// CHECK-NEXT: %[[L2:.*]] = memref.load %[[A]][%[[IV2]]] : memref +// CHECK-NEXT: %[[MUL0:.*]] = arith.mulf %[[C]], %[[ARG1]] : f32 +// CHECK-NEXT: %[[IV3:.*]] = arith.addi %[[IV]], %[[C1]] : index +// CHECK-NEXT: %[[IV4:.*]] = arith.addi %[[IV3]], %[[C1]] : index +// CHECK-NEXT: %[[L3:.*]] = memref.load %[[A]][%[[IV4]]] : memref +// CHECK-NEXT: %[[MUL1:.*]] = arith.mulf %[[ARG2]], %[[MUL0]] : f32 +// CHECK-NEXT: scf.yield %[[MUL1]], %[[L2]], %[[L3]] : f32, f32, f32 +// CHECK-NEXT: } +// Epilogue: +// CHECK-NEXT: %[[MUL1:.*]] = arith.mulf %[[R]]#0, %[[R]]#1 : f32 +// CHECK-NEXT: %[[MUL2:.*]] = arith.mulf %[[R]]#2, %[[MUL1]] : f32 +// CHECK-NEXT: return %[[MUL2]] : f32 +func.func @backedge_mix_order(%A: memref) -> f32 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %cf = arith.constant 2.0 : f32 + %r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%arg0 = %cf) -> (f32) { + %A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 0 } : memref + %A2_elem = arith.mulf %arg0, %A_elem { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 1 } : f32 + %i1 = arith.addi %i0, %c1 { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 2 } : index + %A1_elem = memref.load %A[%i1] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 3 } : memref + %A3_elem = arith.mulf %A1_elem, %A2_elem { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 4 } : f32 + scf.yield %A3_elem : f32 + } { __test_pipelining_loop__ } + return %r : f32 +} \ No newline at end of file