diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -233,6 +233,12 @@ Block::BlockArgListType getRegionIterArgs() { return getBody()->getArguments().drop_front(getNumInductionVars()); } + /// Return the `index`-th region iteration argument. + BlockArgument getRegionIterArg(unsigned index) { + assert(index < getNumRegionIterArgs() && + "expected an index less than the number of region iter args"); + return getBody()->getArguments().drop_front(getNumInductionVars())[index]; + } Operation::operand_range getIterOperands() { return getOperands().drop_front(getNumControlOperands()); } diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp --- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp @@ -18,6 +18,7 @@ #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Support/MathExtras.h" +#include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/MapVector.h" using namespace mlir; @@ -114,15 +115,28 @@ return false; // All operations need to have a stage. - if (forOp - .walk([this](Operation *op) { - if (op != forOp.getOperation() && !isa(op) && - stages.find(op) == stages.end()) - return WalkResult::interrupt(); - return WalkResult::advance(); - }) - .wasInterrupted()) - return false; + for (Operation &op : forOp.getBody()->without_terminator()) { + if (stages.find(&op) == stages.end()) { + op.emitOpError("not assigned a pipeline stage"); + return false; + } + } + + // Currently, we do not support assigning stages to ops in nested regions. The + // block of all operations assigned a stage should be the single `scf.for` + // body block. + for (const auto &[op, stageNum] : stages) { + (void)stageNum; + if (op == forOp.getBody()->getTerminator()) { + op->emitError("terminator should not be assigned a stage"); + return false; + } + if (op->getBlock() != forOp.getBody()) { + op->emitOpError("the owning Block of all operations assigned a stage " + "should be the loop body block"); + return false; + } + } // Only support loop carried dependency with a distance of 1. This means the // source of all the scf.yield operands needs to be defined by operations in @@ -137,6 +151,27 @@ return true; } +/// Clone `op` and call `callback` on the cloned op's oeprands as well as any +/// operands of nested ops that: +/// 1) aren't defined within the new op or +/// 2) are block arguments. +static Operation * +cloneAndUpdateOperands(RewriterBase &rewriter, Operation *op, + function_ref callback) { + Operation *clone = rewriter.clone(*op); + for (OpOperand &operand : clone->getOpOperands()) + callback(&operand); + clone->walk([&](Operation *nested) { + for (OpOperand &operand : nested->getOpOperands()) { + Operation *def = operand.get().getDefiningOp(); + if ((def && !clone->isAncestor(def)) || + operand.get().isa()) + callback(&operand); + } + }); + return clone; +} + void LoopPipelinerInternal::emitPrologue(PatternRewriter &rewriter) { // Initialize the iteration argument to the loop initiale values. for (BlockArgument &arg : forOp.getRegionIterArgs()) { @@ -152,12 +187,14 @@ for (Operation *op : opOrder) { if (stages[op] > i) continue; - Operation *newOp = rewriter.clone(*op); - for (unsigned opIdx = 0; opIdx < op->getNumOperands(); opIdx++) { - auto it = valueMapping.find(op->getOperand(opIdx)); - if (it != valueMapping.end()) - newOp->setOperand(opIdx, it->second[i - stages[op]]); - } + Operation *newOp = + cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) { + auto it = valueMapping.find(newOperand->get()); + if (it != valueMapping.end()) { + Value replacement = it->second[i - stages[op]]; + newOperand->set(replacement); + } + }); if (annotateFn) annotateFn(newOp, PipeliningOption::PipelinerPart::Prologue, i); for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) { @@ -181,18 +218,25 @@ llvm::MapVector crossStageValues; for (Operation *op : opOrder) { unsigned stage = stages[op]; - for (OpOperand &operand : op->getOpOperands()) { + + auto analyzeOperand = [&](OpOperand &operand) { Operation *def = operand.get().getDefiningOp(); if (!def) - continue; + return; auto defStage = stages.find(def); if (defStage == stages.end() || defStage->second == stage) - continue; + return; assert(stage > defStage->second); LiverangeInfo &info = crossStageValues[operand.get()]; info.defStage = defStage->second; info.lastUseStage = std::max(info.lastUseStage, stage); - } + }; + + for (OpOperand &operand : op->getOpOperands()) + analyzeOperand(operand); + visitUsedValuesDefinedAbove(op->getRegions(), [&](OpOperand *operand) { + analyzeOperand(*operand); + }); } return crossStageValues; } @@ -243,9 +287,89 @@ auto newForOp = rewriter.create(forOp.getLoc(), forOp.getLowerBound(), newUb, forOp.getStep(), newLoopArg); + // When there are no iter args, the loop body terminator will be created. + // Since we always create it below, remove the terminator if it was created. + if (!newForOp.getBody()->empty()) + rewriter.eraseOp(newForOp.getBody()->getTerminator()); return newForOp; } +/// Replace any use of `target` with `replacement` in `op`'s operands or within +/// `op`'s nested regions. +static void replaceInOp(Operation *op, Value target, Value replacement) { + for (auto &use : llvm::make_early_inc_range(target.getUses())) { + if (op->isAncestor(use.getOwner())) + use.set(replacement); + } +} + +/// Given a cloned op in the new kernel body, updates induction variable uses. +/// We replace it with a version incremented based on the stage where it is +/// used. +static void updateInductionVariableUses(RewriterBase &rewriter, Location loc, + Operation *newOp, Value newForIv, + unsigned maxStage, unsigned useStage, + unsigned step) { + rewriter.setInsertionPoint(newOp); + Value offset = rewriter.create( + loc, (maxStage - useStage) * step); + Value iv = rewriter.create(loc, newForIv, offset); + replaceInOp(newOp, newForIv, iv); + rewriter.setInsertionPointAfter(newOp); +} + +/// If the value is a loop carried value coming from stage N + 1 remap, it will +/// become a direct use. +static void updateIterArgUses(RewriterBase &rewriter, BlockAndValueMapping &bvm, + Operation *newOp, ForOp oldForOp, ForOp newForOp, + unsigned useStage, + const DenseMap &stages) { + + for (unsigned i = 0; i < oldForOp.getNumRegionIterArgs(); i++) { + Value yieldedVal = oldForOp.getBody()->getTerminator()->getOperand(i); + Operation *dep = yieldedVal.getDefiningOp(); + if (!dep) + continue; + auto stageDep = stages.find(dep); + if (stageDep == stages.end() || stageDep->second == useStage) + continue; + if (stageDep->second != useStage + 1) + continue; + Value replacement = bvm.lookup(yieldedVal); + replaceInOp(newOp, newForOp.getRegionIterArg(i), replacement); + } +} + +/// For operands defined in a previous stage we need to remap it to use the +/// correct region argument. We look for the right version of the Value based +/// on the stage where it is used. +static void updateCrossStageUses( + RewriterBase &rewriter, Operation *newOp, BlockAndValueMapping &bvm, + ForOp newForOp, unsigned useStage, + const DenseMap &stages, + const llvm::DenseMap, unsigned> &loopArgMap) { + // Because we automatically cloned the sub-regions, there's no simple way + // to walk the nested regions in pairs of (oldOps, newOps), so we just + // traverse the set of remapped loop arguments, filter which ones are + // relevant, and replace any uses. + for (auto [remapPair, newIterIdx] : loopArgMap) { + auto [crossArgValue, stageIdx] = remapPair; + Operation *def = crossArgValue.getDefiningOp(); + assert(def); + unsigned stageDef = stages.lookup(def); + if (useStage <= stageDef || useStage - stageDef != stageIdx) + continue; + + // Use "lookupOrDefault" for the target value because some operations + // are remapped, while in other cases the original will be present. + Value target = bvm.lookupOrDefault(crossArgValue); + Value replacement = newForOp.getRegionIterArg(newIterIdx); + + // Replace uses in the new op's operands and any nested uses. + replaceInOp(newOp, target, replacement); + } +} + void LoopPipelinerInternal::createKernel( scf::ForOp newForOp, const llvm::MapVector @@ -277,51 +401,17 @@ for (Operation *op : opOrder) { int64_t useStage = stages[op]; auto *newOp = rewriter.clone(*op, mapping); - for (OpOperand &operand : op->getOpOperands()) { - // Special case for the induction variable uses. We replace it with a - // version incremented based on the stage where it is used. - if (operand.get() == forOp.getInductionVar()) { - rewriter.setInsertionPoint(newOp); - Value offset = rewriter.create( - forOp.getLoc(), (maxStage - stages[op]) * step); - Value iv = rewriter.create( - forOp.getLoc(), newForOp.getInductionVar(), offset); - newOp->setOperand(operand.getOperandNumber(), iv); - rewriter.setInsertionPointAfter(newOp); - continue; - } - auto arg = operand.get().dyn_cast(); - if (arg && arg.getOwner() == forOp.getBody()) { - // If the value is a loop carried value coming from stage N + 1 remap, - // it will become a direct use. - Value ret = forOp.getBody()->getTerminator()->getOperand( - arg.getArgNumber() - 1); - Operation *dep = ret.getDefiningOp(); - if (!dep) - continue; - auto stageDep = stages.find(dep); - if (stageDep == stages.end() || stageDep->second == useStage) - continue; - assert(stageDep->second == useStage + 1); - newOp->setOperand(operand.getOperandNumber(), - mapping.lookupOrDefault(ret)); - continue; - } - // For operands defined in a previous stage we need to remap it to use - // the correct region argument. We look for the right version of the - // Value based on the stage where it is used. - Operation *def = operand.get().getDefiningOp(); - if (!def) - continue; - auto stageDef = stages.find(def); - if (stageDef == stages.end() || stageDef->second == useStage) - continue; - auto remap = loopArgMap.find( - std::make_pair(operand.get(), useStage - stageDef->second)); - assert(remap != loopArgMap.end()); - newOp->setOperand(operand.getOperandNumber(), - newForOp.getRegionIterArgs()[remap->second]); - } + + // Within the kernel body, update uses of the induction variable, uses of + // the original iter args, and uses of cross stage values. + updateInductionVariableUses(rewriter, forOp.getLoc(), newOp, + newForOp.getInductionVar(), maxStage, + stages[op], step); + updateIterArgUses(rewriter, mapping, newOp, forOp, newForOp, useStage, + stages); + updateCrossStageUses(rewriter, newOp, mapping, newForOp, useStage, stages, + loopArgMap); + if (predicates[useStage]) { newOp = predicateFn(newOp, predicates[useStage], rewriter); // Remap the results to the new predicated one. @@ -382,21 +472,20 @@ forOp.getLoc(), lb + step * ((((ub - 1) - lb) / step) - i)); setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i); } - // Emit `maxStage - 1` epilogue part that includes operations fro stages + // Emit `maxStage - 1` epilogue part that includes operations from stages // [i; maxStage]. for (int64_t i = 1; i <= maxStage; i++) { for (Operation *op : opOrder) { if (stages[op] < i) continue; - Operation *newOp = rewriter.clone(*op); - for (unsigned opIdx = 0; opIdx < op->getNumOperands(); opIdx++) { - auto it = valueMapping.find(op->getOperand(opIdx)); - if (it != valueMapping.end()) { - Value v = it->second[maxStage - stages[op] + i]; - assert(v); - newOp->setOperand(opIdx, v); - } - } + Operation *newOp = + cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) { + auto it = valueMapping.find(newOperand->get()); + if (it != valueMapping.end()) { + Value replacement = it->second[maxStage - stages[op] + i]; + newOperand->set(replacement); + } + }); if (annotateFn) annotateFn(newOp, PipeliningOption::PipelinerPart::Epilogue, i - 1); for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) { diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir --- a/mlir/test/Dialect/SCF/loop-pipelining.mlir +++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir @@ -34,6 +34,54 @@ return } + +// ----- + +// CHECK-LABEL: simple_pipeline_region( +// CHECK-SAME: %[[A:.*]]: memref, %[[R:.*]]: memref) { +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index +// Prologue: +// CHECK: %[[L0:.*]] = scf.execute_region +// CHECK-NEXT: memref.load %[[A]][%[[C0]]] : memref +// Kernel: +// CHECK: %[[L1:.*]] = scf.for %[[IV:.*]] = %[[C0]] to %[[C3]] +// CHECK-SAME: step %[[C1]] iter_args(%[[LARG:.*]] = %[[L0]]) -> (f32) { +// CHECK-NEXT: %[[ADD0:.*]] = scf.execute_region +// CHECK-NEXT: arith.addf %[[LARG]], %{{.*}} : f32 +// CHECK: memref.store %[[ADD0]], %[[R]][%[[IV]]] : memref +// CHECK-NEXT: %[[IV1:.*]] = arith.addi %[[IV]], %[[C1]] : index +// CHECK-NEXT: %[[LR:.*]] = scf.execute_region +// CHECK-NEXT: memref.load %[[A]][%[[IV1]]] : memref +// CHECK: scf.yield %[[LR]] : f32 +// CHECK-NEXT: } +// Epilogue: +// CHECK-NEXT: %[[ADD1:.*]] = scf.execute_region +// CHECK-NEXT: arith.addf %[[L1]], %{{.*}} : f32 +// CHECK: memref.store %[[ADD1]], %[[R]][%[[C3]]] : memref +func.func @simple_pipeline_region(%A: memref, %result: memref) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %cf = arith.constant 1.0 : f32 + scf.for %i0 = %c0 to %c4 step %c1 { + + %A_elem = scf.execute_region -> f32 { + %A_elem1 = memref.load %A[%i0] : memref + scf.yield %A_elem1 : f32 + } { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 2 } + + %A1_elem = scf.execute_region -> f32 { + %A1_elem1 = arith.addf %A_elem, %cf : f32 + scf.yield %A1_elem1 : f32 + } { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 0 } + + memref.store %A1_elem, %result[%i0] { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 1 } : memref + } { __test_pipelining_loop__ } + return +} + // ----- // CHECK-LABEL: simple_pipeline_step( @@ -269,6 +317,65 @@ // ----- +// CHECK-LABEL: region_multiple_uses( +// CHECK-SAME: %[[A:.*]]: memref, %[[R:.*]]: memref) { +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[C7:.*]] = arith.constant 7 : index +// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index +// CHECK-DAG: %[[C9:.*]] = arith.constant 9 : index +// Prologue: +// CHECK: %[[L0:.*]] = memref.load %[[A]][%[[C0]]] : memref +// CHECK-NEXT: %[[ADD0:.*]] = arith.addf %[[L0]], %{{.*}} : f32 +// CHECK-NEXT: %[[L1:.*]] = memref.load %[[A]][%[[C1]]] : memref +// CHECK-NEXT: %[[ADD1:.*]] = arith.addf %[[L1]], %{{.*}} : f32 +// CHECK-NEXT: %[[MUL0:.*]] = scf.execute_region +// arith.mulf %[[ADD0]], %[[L0]] : f32 +// CHECK: %[[L2:.*]] = memref.load %[[A]][%[[C2]]] : memref +// Kernel: +// CHECK-NEXT: %[[LR:.*]]:4 = scf.for %[[IV:.*]] = %[[C0]] to %[[C7]] +// CHECK-SAME: step %[[C1]] iter_args(%[[LA1:.*]] = %[[L1]], +// CHECK-SAME: %[[LA2:.*]] = %[[L2]], %[[ADDARG1:.*]] = %[[ADD1]], +// CHECK-SAME: %[[MULARG0:.*]] = %[[MUL0]]) -> (f32, f32, f32, f32) { +// CHECK-NEXT: %[[ADD2:.*]] = arith.addf %[[LA2]], %{{.*}} : f32 +// CHECK-NEXT: %[[MUL1:.*]] = scf.execute_region +// arith.mulf %[[ADDARG1]], %[[LA1]] : f32 +// CHECK: memref.store %[[MULARG0]], %[[R]][%[[IV]]] : memref +// CHECK-NEXT: %[[IV3:.*]] = arith.addi %[[IV]], %[[C3]] : index +// CHECK-NEXT: %[[L3:.*]] = memref.load %[[A]][%[[IV3]]] : memref +// CHECK-NEXT: scf.yield %[[LA2]], %[[L3]], %[[ADD2]], %[[MUL1]] : f32, f32, f32, f32 +// CHECK-NEXT: } +// Epilogue: +// CHECK-NEXT: %[[ADD3:.*]] = arith.addf %[[LR]]#1, %{{.*}} : f32 +// CHECK-NEXT: %[[MUL2:.*]] = scf.execute_region +// arith.mulf %[[LR]]#2, %[[LR]]#0 : f32 +// CHECK: memref.store %[[LR]]#3, %[[R]][%[[C7]]] : memref +// CHECK-NEXT: %[[MUL3:.*]] = scf.execute_region +/// %[[ADD3]], %[[LR]]#1 : f32 +// CHECK: memref.store %[[MUL2]], %[[R]][%[[C8]]] : memref +// CHECK-NEXT: memref.store %[[MUL3]], %[[R]][%[[C9]]] : memref + +func.func @region_multiple_uses(%A: memref, %result: memref) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + %cf = arith.constant 1.0 : f32 + scf.for %i0 = %c0 to %c10 step %c1 { + %A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 3 } : memref + %A1_elem = arith.addf %A_elem, %cf { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 0 } : f32 + %A2_elem = scf.execute_region -> f32 { + %A2_elem1 = arith.mulf %A1_elem, %A_elem : f32 + scf.yield %A2_elem1 : f32 + } { __test_pipelining_stage__ = 2, __test_pipelining_op_order__ = 1 } + memref.store %A2_elem, %result[%i0] { __test_pipelining_stage__ = 3, __test_pipelining_op_order__ = 2 } : memref + } { __test_pipelining_loop__ } + return +} + +// ----- + // CHECK-LABEL: loop_carried( // CHECK-SAME: %[[A:.*]]: memref, %[[R:.*]]: memref) { // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index @@ -341,6 +448,58 @@ return %r : f32 } +// ----- + +// CHECK-LABEL: region_backedge_different_stage +// CHECK-SAME: (%[[A:.*]]: memref) -> f32 { +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[CSTF:.*]] = arith.constant 1.000000e+00 : f32 +// Prologue: +// CHECK: %[[L0:.*]] = scf.execute_region +// CHECK-NEXT: memref.load %[[A]][%[[C0]]] : memref +// CHECK: %[[ADD0:.*]] = scf.execute_region +// CHECK-NEXT: arith.addf %[[L0]], %[[CSTF]] : f32 +// CHECK: %[[L1:.*]] = scf.execute_region +// CHECK-NEXT: memref.load %[[A]][%[[C1]]] : memref +// Kernel: +// CHECK: %[[R:.*]]:3 = scf.for %[[IV:.*]] = %[[C0]] to %[[C2]] +// CHECK-SAME: step %[[C1]] iter_args(%[[C:.*]] = %[[CSTF]], +// CHECK-SAME: %[[ADDARG:.*]] = %[[ADD0]], %[[LARG:.*]] = %[[L1]]) -> (f32, f32, f32) { +// CHECK: %[[ADD1:.*]] = scf.execute_region +// CHECK-NEXT: arith.addf %[[LARG]], %[[ADDARG]] : f32 +// CHECK: %[[IV2:.*]] = arith.addi %[[IV]], %[[C2]] : index +// CHECK: %[[L2:.*]] = scf.execute_region +// CHECK-NEXT: memref.load %[[A]][%[[IV2]]] : memref +// CHECK: scf.yield %[[ADDARG]], %[[ADD1]], %[[L2]] : f32, f32, f32 +// CHECK-NEXT: } +// Epilogue: +// CHECK: %[[ADD2:.*]] = scf.execute_region +// CHECK-NEXT: arith.addf %[[R]]#2, %[[R]]#1 : f32 +// CHECK: return %[[ADD2]] : f32 + +func.func @region_backedge_different_stage(%A: memref) -> f32 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %cf = arith.constant 1.0 : f32 + %r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%arg0 = %cf) -> (f32) { + %A_elem = scf.execute_region -> f32 { + %A_elem1 = memref.load %A[%i0] : memref + scf.yield %A_elem1 : f32 + } { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 2 } + %A1_elem = scf.execute_region -> f32 { + %inner = arith.addf %A_elem, %arg0 : f32 + scf.yield %inner : f32 + } { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 1 } + %A2_elem = arith.mulf %cf, %A1_elem { __test_pipelining_stage__ = 2, __test_pipelining_op_order__ = 0 } : f32 + scf.yield %A2_elem : f32 + } { __test_pipelining_loop__ } + return %r : f32 +} + + // ----- // CHECK-LABEL: backedge_same_stage @@ -376,3 +535,88 @@ } { __test_pipelining_loop__ } return %r : f32 } + +// ----- + +// CHECK: @pipeline_op_with_region(%[[ARG0:.+]]: memref, %[[ARG1:.+]]: memref, %[[ARG2:.+]]: memref) { +// CHECK: %[[C0:.+]] = arith.constant 0 : +// CHECK: %[[C3:.+]] = arith.constant 3 : +// CHECK: %[[C1:.+]] = arith.constant 1 : +// CHECK: %[[APRO:.+]] = memref.alloc() : +// CHECK: %[[BPRO:.+]] = memref.alloc() : +// CHECK: %[[ASV0:.+]] = memref.subview %[[ARG0]][%[[C0]]] [8] [1] : +// CHECK: %[[BSV0:.+]] = memref.subview %[[ARG1]][%[[C0]]] [8] [1] : + +// Prologue: +// CHECK: %[[PAV0:.+]] = memref.subview %[[APRO]][%[[C0]], 0] [1, 8] [1, 1] : +// CHECK: %[[PBV0:.+]] = memref.subview %[[BPRO]][%[[C0]], 0] [1, 8] [1, 1] : +// CHECK: memref.copy %[[ASV0]], %[[PAV0]] : +// CHECK: memref.copy %[[BSV0]], %[[PBV0]] : + +// Kernel: +// CHECK: %[[R:.+]]:2 = scf.for %[[IV:.+]] = %[[C0]] to %[[C3]] step %[[C1]] +// CHECK-SAME: iter_args(%[[IA:.+]] = %[[PAV0]], %[[IB:.+]] = %[[PBV0:.+]]) +// CHECK: %[[CV:.+]] = memref.subview %[[ARG2]] +// CHECK: linalg.generic +// CHECK-SAME: ins(%[[IA]], %[[IB]], %{{.*}} : {{.*}}) outs(%[[CV]] : +// CHECK: %[[NEXT:.+]] = arith.addi %[[IV]], %[[C1]] +// CHECK: %[[ASV:.+]] = memref.subview %[[ARG0]][%[[NEXT]]] [8] [1] : +// CHECK: %[[NEXT:.+]] = arith.addi %[[IV]], %[[C1]] : +// CHECK: %[[BSV:.+]] = memref.subview %[[ARG1]][%[[NEXT]]] [8] [1] : +// CHECK: %[[NEXT:.+]] = arith.addi %[[IV]], %[[C1]] : +// CHECK: %[[BUFIDX:.+]] = affine.apply +// CHECK: %[[APROSV:.+]] = memref.subview %[[APRO]][%[[BUFIDX]], 0] [1, 8] [1, 1] : +// CHECK: %[[BPROSV:.+]] = memref.subview %[[BPRO]][%[[BUFIDX]], 0] [1, 8] [1, 1] : +// CHECK: memref.copy %[[ASV]], %[[APROSV]] : +// CHECK: memref.copy %[[BSV]], %[[BPROSV]] : +// CHECK: scf.yield %[[APROSV]], %[[BPROSV]] : +// CHECK: } +// CHECK: %[[CV:.+]] = memref.subview %[[ARG2]][%[[C3]]] [8] [1] : +// CHECK: linalg.generic +// CHECK-SAME: ins(%[[R]]#0, %[[R]]#1, %{{.*}} : {{.*}}) outs(%[[CV]] : + + +#map = affine_map<(d0)[s0]->(d0 + s0)> +#map1 = affine_map<(d0)->(d0)> +#map2 = affine_map<(d0)->()> +#linalg_attrs = { + indexing_maps = [ + #map1, + #map1, + #map2, + #map1 + ], + iterator_types = ["parallel"], + __test_pipelining_stage__ = 1, + __test_pipelining_op_order__ = 2 +} +func.func @pipeline_op_with_region(%A: memref, %B: memref, %result: memref) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %cf = arith.constant 1.0 : f32 + %a_buf = memref.alloc() : memref<2x8xf32> + %b_buf = memref.alloc() : memref<2x8xf32> + scf.for %i0 = %c0 to %c4 step %c1 { + %A_view = memref.subview %A[%i0][8][1] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 3 } : memref to memref<8xf32, #map> + %B_view = memref.subview %B[%i0][8][1] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 4 } : memref to memref<8xf32, #map> + %buf_idx = affine.apply affine_map<(d0)->(d0 mod 2)> (%i0)[] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 5 } + %a_buf_view = memref.subview %a_buf[%buf_idx,0][1,8][1,1] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 6 } : memref<2x8xf32> to memref<8xf32, #map> + %b_buf_view = memref.subview %b_buf[%buf_idx,0][1,8][1,1] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 7 } : memref<2x8xf32> to memref<8xf32, #map> + memref.copy %A_view , %a_buf_view {__test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 8} : memref<8xf32, #map> to memref<8xf32, #map> + memref.copy %B_view , %b_buf_view {__test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 9} : memref<8xf32, #map> to memref<8xf32, #map> + %C_view = memref.subview %result[%i0][8][1] { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 0 } : memref to memref<8xf32, #map> + %scalar = arith.addf %cf, %cf {__test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 1} : f32 + linalg.generic #linalg_attrs ins(%a_buf_view, %b_buf_view, %scalar : memref<8xf32, #map>, memref<8xf32, #map>, f32) + outs(%C_view: memref<8xf32, #map>) { + ^bb0(%a: f32, %b: f32, %s: f32, %c: f32): + %add = arith.addf %a, %b : f32 + %accum = arith.addf %add, %c : f32 + %accum1 = arith.addf %scalar, %accum : f32 + %accum2 = arith.addf %s, %accum1 : f32 + linalg.yield %accum2 : f32 + } + scf.yield + } { __test_pipelining_loop__ } + return +} diff --git a/mlir/test/lib/Dialect/SCF/CMakeLists.txt b/mlir/test/lib/Dialect/SCF/CMakeLists.txt --- a/mlir/test/lib/Dialect/SCF/CMakeLists.txt +++ b/mlir/test/lib/Dialect/SCF/CMakeLists.txt @@ -7,6 +7,7 @@ EXCLUDE_FROM_LIBMLIR LINK_LIBS PUBLIC + MLIRMemRefDialect MLIRPass MLIRSCFDialect MLIRSCFTransforms diff --git a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp --- a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp +++ b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/Transforms.h" #include "mlir/Dialect/SCF/Utils/Utils.h" @@ -131,6 +132,7 @@ std::vector> &schedule) { if (!forOp->hasAttr(kTestPipeliningLoopMarker)) return; + schedule.resize(forOp.getBody()->getOperations().size() - 1); forOp.walk([&schedule](Operation *op) { auto attrStage = @@ -153,17 +155,30 @@ rewriter.create(loc, op->getResultTypes(), pred, true); // True branch. op->moveBefore(&ifOp.getThenRegion().front(), - ifOp.getThenRegion().front().end()); + ifOp.getThenRegion().front().begin()); rewriter.setInsertionPointAfter(op); - rewriter.create(loc, op->getResults()); + if (op->getNumResults() > 0) + rewriter.create(loc, op->getResults()); // False branch. rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); - SmallVector zeros; - for (Type type : op->getResultTypes()) { - zeros.push_back( - rewriter.create(loc, rewriter.getZeroAttr(type))); + SmallVector elseYieldOperands; + elseYieldOperands.reserve(ifOp.getNumResults()); + if (auto viewOp = dyn_cast(op)) { + // For sub-views, just clone the op. + // NOTE: This is okay in the test because we use dynamic memref sizes, so + // the verifier will not complain. Otherwise, we may create a logically + // out-of-bounds view and a different technique should be used. + Operation *opClone = rewriter.clone(*op); + elseYieldOperands.append(opClone->result_begin(), opClone->result_end()); + } else { + // Default to assuming constant numeric values. + for (Type type : op->getResultTypes()) { + elseYieldOperands.push_back(rewriter.create( + loc, rewriter.getZeroAttr(type))); + } } - rewriter.create(loc, zeros); + if (op->getNumResults() > 0) + rewriter.create(loc, elseYieldOperands); return ifOp.getOperation(); } @@ -187,7 +202,7 @@ } void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry.insert(); } void runOnOperation() override {