diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp --- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp @@ -74,7 +74,7 @@ PatternRewriter &rewriter); /// Emits the epilogue, this creates `maxStage - 1` part which will contain /// operations from stages [i; maxStage], where i is the part index. - void emitEpilogue(PatternRewriter &rewriter); + llvm::SmallVector emitEpilogue(PatternRewriter &rewriter); }; bool LoopPipelinerInternal::initializeLoopInfo( @@ -114,14 +114,25 @@ .wasInterrupted()) return false; - // TODO: Add support for loop with operands. - if (forOp.getNumIterOperands() > 0) + // Only support loop carried dependency with a distance of 1. This means the + // source of all the scf.yield operands needs to be defined by operations in + // the loop. + if (llvm::any_of(forOp.getBody()->getTerminator()->getOperands(), + [this](Value operand) { + Operation *def = operand.getDefiningOp(); + return !def || stages.find(def) == stages.end(); + })) return false; - return true; } void LoopPipelinerInternal::emitPrologue(PatternRewriter &rewriter) { + // Initialize the iteration argument to the loop initiale values. + for (BlockArgument &arg : forOp.getRegionIterArgs()) { + OpOperand &operand = forOp.getOpOperandForRegionIterArg(arg); + setValueMapping(arg, operand.get(), 0); + } + auto yield = cast(forOp.getBody()->getTerminator()); for (int64_t i = 0; i < maxStage; i++) { // special handling for induction variable as the increment is implicit. Value iv = rewriter.create(forOp.getLoc(), lb + i); @@ -138,6 +149,14 @@ for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) { setValueMapping(op->getResult(destId), newOp->getResult(destId), i - stages[op]); + // If the value is a loop carried dependency update the loop argument + // mapping. + for (OpOperand &operand : yield->getOpOperands()) { + if (operand.get() != op->getResult(destId)) + continue; + setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()], + newOp->getResult(destId), i - stages[op] + 1); + } } } } @@ -173,7 +192,19 @@ // stages. The initial values come from the prologue created above. // Keep track of the kernel argument associated to each version of the // values passed to the kernel. - auto newLoopArg = llvm::to_vector<8>(forOp.getIterOperands()); + llvm::SmallVector newLoopArg; + // For existing loop argument initialize them with the right version from the + // prologue. + for (auto retVal : + llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) { + Operation *def = retVal.value().getDefiningOp(); + assert(def && "Only support loop carried dependencies of distance 1"); + unsigned defStage = stages[def]; + Value valueVersion = valueMapping[forOp.getRegionIterArgs()[retVal.index()]] + [maxStage - defStage]; + assert(valueVersion); + newLoopArg.push_back(valueVersion); + } for (auto escape : crossStageValues) { LiverangeInfo &info = escape.second; Value value = escape.first; @@ -210,6 +241,9 @@ rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin()); BlockAndValueMapping mapping; mapping.map(forOp.getInductionVar(), newForOp.getInductionVar()); + for (auto arg : llvm::enumerate(forOp.getRegionIterArgs())) { + mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); + } for (Operation *op : opOrder) { int64_t useStage = stages[op]; auto *newOp = rewriter.clone(*op, mapping); @@ -226,6 +260,23 @@ rewriter.setInsertionPointAfter(newOp); continue; } + auto arg = operand.get().dyn_cast(); + if (arg && arg.getOwner() == forOp.getBody()) { + // If the value is a loop carried value coming from stage N + 1 remap, + // it will become a direct use. + Value ret = forOp.getBody()->getTerminator()->getOperand( + arg.getArgNumber() - 1); + Operation *dep = ret.getDefiningOp(); + if (!dep) + continue; + auto stageDep = stages.find(dep); + if (stageDep == stages.end() || stageDep->second == useStage) + continue; + assert(stageDep->second == useStage + 1); + newOp->setOperand(operand.getOperandNumber(), + mapping.lookupOrDefault(ret)); + continue; + } // For operands defined in a previous stage we need to remap it to use // the correct region argument. We look for the right version of the // Value based on the stage where it is used. @@ -249,6 +300,9 @@ // We create a mapping between original values and the associated loop // returned values that will be needed by the epilogue. llvm::SmallVector yieldOperands; + for (Value retVal : forOp.getBody()->getTerminator()->getOperands()) { + yieldOperands.push_back(mapping.lookupOrDefault(retVal)); + } for (auto &it : crossStageValues) { int64_t version = maxStage - it.second.lastUseStage + 1; unsigned numVersionReturned = it.second.lastUseStage - it.second.defStage; @@ -266,10 +320,22 @@ version++); yieldOperands.push_back(mapping.lookupOrDefault(it.first)); } + // Map the yield operand to the forOp returned value. + for (auto retVal : + llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) { + Operation *def = retVal.value().getDefiningOp(); + assert(def && "Only support loop carried dependencies of distance 1"); + unsigned defStage = stages[def]; + setValueMapping(forOp.getRegionIterArgs()[retVal.index()], + newForOp->getResult(retVal.index()), + maxStage - defStage + 1); + } rewriter.create(forOp.getLoc(), yieldOperands); } -void LoopPipelinerInternal::emitEpilogue(PatternRewriter &rewriter) { +llvm::SmallVector +LoopPipelinerInternal::emitEpilogue(PatternRewriter &rewriter) { + llvm::SmallVector returnValues(forOp->getNumResults()); // Emit different versions of the induction variable. They will be // removed by dead code if not used. for (int64_t i = 0; i < maxStage; i++) { @@ -295,9 +361,27 @@ for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) { setValueMapping(op->getResult(destId), newOp->getResult(destId), maxStage - stages[op] + i); + // If the value is a loop carried dependency update the loop argument + // mapping and keep track of the last version to replace the original + // forOp uses. + for (OpOperand &operand : + forOp.getBody()->getTerminator()->getOpOperands()) { + if (operand.get() != op->getResult(destId)) + continue; + unsigned version = maxStage - stages[op] + i + 1; + // If the version is greater than maxStage it means it maps to the + // original forOp returned value. + if (version > maxStage) { + returnValues[operand.getOperandNumber()] = newOp->getResult(destId); + continue; + } + setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()], + newOp->getResult(destId), version); + } } } } + return returnValues; } void LoopPipelinerInternal::setValueMapping(Value key, Value el, int64_t idx) { @@ -361,12 +445,11 @@ // 4. Emit the epilogue after the new forOp. rewriter.setInsertionPointAfter(newForOp); - pipeliner.emitEpilogue(rewriter); + llvm::SmallVector returnValues = pipeliner.emitEpilogue(rewriter); // 5. Erase the original loop and replace the uses with the epilogue output. if (forOp->getNumResults() > 0) - rewriter.replaceOp( - forOp, newForOp.getResults().take_front(forOp->getNumResults())); + rewriter.replaceOp(forOp, returnValues); else rewriter.eraseOp(forOp); diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir --- a/mlir/test/Dialect/SCF/loop-pipelining.mlir +++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir @@ -171,3 +171,118 @@ } { __test_pipelining_loop__ } return } + +// ----- + +// CHECK-LABEL: loop_carried( +// CHECK-SAME: %[[A:.*]]: memref, %[[R:.*]]: memref) { +// CHECK-DAG: %[[C0:.*]] = constant 0 : index +// CHECK-DAG: %[[C1:.*]] = constant 1 : index +// CHECK-DAG: %[[C3:.*]] = constant 3 : index +// CHECK-DAG: %[[CSTF:.*]] = constant 1.000000e+00 : f32 +// Prologue: +// CHECK: %[[L0:.*]] = memref.load %[[A]][%[[C0]]] : memref +// Kernel: +// CHECK-NEXT: %[[LR:.*]]:2 = scf.for %[[IV:.*]] = %[[C0]] to %[[C3]] +// CHECK-SAME: step %[[C1]] iter_args(%[[C:.*]] = %[[CSTF]], +// CHECK-SAME: %[[LARG:.*]] = %[[L0]]) -> (f32, f32) { +// CHECK-NEXT: %[[ADD0:.*]] = addf %[[LARG]], %[[C]] : f32 +// CHECK-NEXT: %[[IV1:.*]] = addi %[[IV]], %[[C1]] : index +// CHECK-NEXT: %[[L1:.*]] = memref.load %[[A]][%[[IV1]]] : memref +// CHECK-NEXT: scf.yield %[[ADD0]], %[[L1]] : f32, f32 +// CHECK-NEXT: } +// Epilogue: +// CHECK-NEXT: %[[ADD1:.*]] = addf %[[LR]]#1, %[[LR]]#0 : f32 +// CHECK-NEXT: memref.store %[[ADD1]], %[[R]][%[[C0]]] : memref +func @loop_carried(%A: memref, %result: memref) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c4 = constant 4 : index + %cf = constant 1.0 : f32 + %r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%arg0 = %cf) -> (f32) { + %A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 1 } : memref + %A1_elem = addf %A_elem, %arg0 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 0 } : f32 + scf.yield %A1_elem : f32 + } { __test_pipelining_loop__ } + memref.store %r, %result[%c0] : memref + return +} + +// ----- + +// CHECK-LABEL: backedge_different_stage +// CHECK-SAME: (%[[A:.*]]: memref) -> f32 { +// CHECK-DAG: %[[C0:.*]] = constant 0 : index +// CHECK-DAG: %[[C1:.*]] = constant 1 : index +// CHECK-DAG: %[[C2:.*]] = constant 2 : index +// CHECK-DAG: %[[CSTF:.*]] = constant 1.000000e+00 : f32 +// Prologue: +// CHECK: %[[L0:.*]] = memref.load %[[A]][%[[C0]]] : memref +// CHECK-NEXT: %[[ADD0:.*]] = addf %[[L0]], %[[CSTF]] : f32 +// CHECK-NEXT: %[[L1:.*]] = memref.load %[[A]][%[[C1]]] : memref +// Kernel: +// CHECK-NEXT: %[[R:.*]]:3 = scf.for %[[IV:.*]] = %[[C0]] to %[[C2]] +// CHECK-SAME: step %[[C1]] iter_args(%[[C:.*]] = %[[CSTF]], +// CHECK-SAME: %[[ADDARG:.*]] = %[[ADD0]], %[[LARG:.*]] = %[[L1]]) -> (f32, f32, f32) { +// CHECK-NEXT: %[[MUL0:.*]] = mulf %[[CSTF]], %[[ADDARG]] : f32 +// CHECK-NEXT: %[[ADD1:.*]] = addf %[[LARG]], %[[MUL0]] : f32 +// CHECK-NEXT: %[[IV2:.*]] = addi %[[IV]], %[[C2]] : index +// CHECK-NEXT: %[[L2:.*]] = memref.load %[[A]][%[[IV2]]] : memref +// CHECK-NEXT: scf.yield %[[MUL0]], %[[ADD1]], %[[L2]] : f32, f32, f32 +// CHECK-NEXT: } +// Epilogue: +// CHECK-NEXT: %[[MUL1:.*]] = mulf %[[CSTF]], %[[R]]#1 : f32 +// CHECK-NEXT: %[[ADD2:.*]] = addf %[[R]]#2, %[[MUL1]] : f32 +// CHECK-NEXT: %[[MUL2:.*]] = mulf %[[CSTF]], %[[ADD2]] : f32 +// CHECK-NEXT: return %[[MUL2]] : f32 +func @backedge_different_stage(%A: memref) -> f32 { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c4 = constant 4 : index + %cf = constant 1.0 : f32 + %r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%arg0 = %cf) -> (f32) { + %A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 2 } : memref + %A1_elem = addf %A_elem, %arg0 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 1 } : f32 + %A2_elem = mulf %cf, %A1_elem { __test_pipelining_stage__ = 2, __test_pipelining_op_order__ = 0 } : f32 + scf.yield %A2_elem : f32 + } { __test_pipelining_loop__ } + return %r : f32 +} + +// ----- + +// CHECK-LABEL: backedge_same_stage +// CHECK-SAME: (%[[A:.*]]: memref) -> f32 { +// CHECK-DAG: %[[C0:.*]] = constant 0 : index +// CHECK-DAG: %[[C1:.*]] = constant 1 : index +// CHECK-DAG: %[[C3:.*]] = constant 3 : index +// CHECK-DAG: %[[CSTF:.*]] = constant 1.000000e+00 : f32 +// Prologue: +// CHECK: %[[L0:.*]] = memref.load %[[A]][%[[C0]]] : memref +// Kernel: +// CHECK-NEXT: %[[R:.*]]:2 = scf.for %[[IV:.*]] = %[[C0]] to %[[C3]] +// CHECK-SAME: step %[[C1]] iter_args(%[[C:.*]] = %[[CSTF]], +// CHECK-SAME: %[[LARG:.*]] = %[[L0]]) -> (f32, f32) { +// CHECK-NEXT: %[[ADD0:.*]] = addf %[[LARG]], %[[C]] : f32 +// CHECK-NEXT: %[[MUL0:.*]] = mulf %[[CSTF]], %[[ADD0]] : f32 +// CHECK-NEXT: %[[IV1:.*]] = addi %[[IV]], %[[C1]] : index +// CHECK-NEXT: %[[L2:.*]] = memref.load %[[A]][%[[IV1]]] : memref +// CHECK-NEXT: scf.yield %[[MUL0]], %[[L2]] : f32, f32 +// CHECK-NEXT: } +// Epilogue: +// CHECK-NEXT: %[[ADD1:.*]] = addf %[[R]]#1, %[[R]]#0 : f32 +// CHECK-NEXT: %[[MUL1:.*]] = mulf %[[CSTF]], %[[ADD1]] : f32 +// CHECK-NEXT: return %[[MUL1]] : f32 +func @backedge_same_stage(%A: memref) -> f32 { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c4 = constant 4 : index + %cf = constant 1.0 : f32 + %r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%arg0 = %cf) -> (f32) { + %A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 2 } : memref + %A1_elem = addf %A_elem, %arg0 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 0 } : f32 + %A2_elem = mulf %cf, %A1_elem { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 1 } : f32 + scf.yield %A2_elem : f32 + } { __test_pipelining_loop__ } + return %r : f32 +}