diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp --- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp @@ -18,7 +18,9 @@ #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Support/MathExtras.h" +#include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/MapVector.h" +#include "llvm/ADT/SmallPtrSet.h" using namespace mlir; using namespace mlir::scf; @@ -114,15 +116,28 @@ return false; // All operations need to have a stage. - if (forOp - .walk([this](Operation *op) { - if (op != forOp.getOperation() && !isa(op) && - stages.find(op) == stages.end()) - return WalkResult::interrupt(); - return WalkResult::advance(); - }) - .wasInterrupted()) - return false; + for (Operation &op : forOp.getBody()->without_terminator()) { + if (stages.find(&op) == stages.end()) { + op.emitOpError("not assigned a pipeline stage"); + return false; + } + } + + // Currently, we do not support assigning stages to ops in nested regions. The + // block of all operations assigned a stage should be the single `scf.for` + // body block. + for (const auto &[op, stageNum] : stages) { + (void)stageNum; + if (op == forOp.getBody()->getTerminator()) { + op->emitError("terminator should not be assigned a stage"); + return false; + } + if (op->getBlock() != forOp.getBody()) { + op->emitOpError("the owning Block of all operations assigned a stage " + "should be the loop body block"); + return false; + } + } // Only support loop carried dependency with a distance of 1. This means the // source of all the scf.yield operands needs to be defined by operations in @@ -158,6 +173,18 @@ if (it != valueMapping.end()) newOp->setOperand(opIdx, it->second[i - stages[op]]); } + + // Update uses of nested ops, all of which share the stage with the parent + // op. + visitUsedValuesDefinedAbove(op->getRegions(), [&](OpOperand *operand) { + auto it = valueMapping.find(operand->get()); + if (it != valueMapping.end()) { + for (Region &r : newOp->getRegions()) + replaceAllUsesInRegionWith(operand->get(), + it->second[i - stages[op]], r); + } + }); + if (annotateFn) annotateFn(newOp, PipeliningOption::PipelinerPart::Prologue, i); for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) { @@ -176,20 +203,56 @@ } } +/// Returns the operands for `op` as well as the operands of values defined +/// above for ops in nested regions. +static SmallVector +getOpOperandsAndUsesAbove(Operation *op, scf::ForOp forOp = nullptr) { + SmallVector result; + for (OpOperand &operand : op->getOpOperands()) { + result.push_back(&operand); + } + + // Collect proper ancestors of `limit` upfront to avoid traversing the region + // tree for every value. + for (Region ®ion : op->getRegions()) { + llvm::SmallPtrSet properAncestors; + for (auto *reg = region.getParentRegion(); reg != nullptr; + reg = reg->getParentRegion()) { + properAncestors.insert(reg); + } + + region.walk([&](Operation *nestedOp) { + for (OpOperand &operand : nestedOp->getOpOperands()) { + if (properAncestors.count(operand.get().getParentRegion())) { + result.push_back(&operand); + continue; + } + auto arg = operand.get().dyn_cast(); + if (forOp != nullptr && arg && arg.getOwner() == forOp.getBody()) { + result.push_back(&operand); + } + } + }); + } + + return result; +} + llvm::MapVector LoopPipelinerInternal::analyzeCrossStageValues() { llvm::MapVector crossStageValues; for (Operation *op : opOrder) { unsigned stage = stages[op]; - for (OpOperand &operand : op->getOpOperands()) { - Operation *def = operand.get().getDefiningOp(); + + for (OpOperand *operand : getOpOperandsAndUsesAbove(op)) { + Operation *def = operand->get().getDefiningOp(); if (!def) continue; auto defStage = stages.find(def); if (defStage == stages.end() || defStage->second == stage) continue; assert(stage > defStage->second); - LiverangeInfo &info = crossStageValues[operand.get()]; + LiverangeInfo &info = crossStageValues[operand->get()]; info.defStage = defStage->second; info.lastUseStage = std::max(info.lastUseStage, stage); } @@ -243,6 +306,10 @@ auto newForOp = rewriter.create(forOp.getLoc(), forOp.getLowerBound(), newUb, forOp.getStep(), newLoopArg); + // When there are no iter args, the loop body terminator will be created. + // Since we always create it below, remove the terminator if it was created. + if (!newForOp.getBody()->empty()) + rewriter.eraseOp(newForOp.getBody()->getTerminator()); return newForOp; } @@ -277,20 +344,39 @@ for (Operation *op : opOrder) { int64_t useStage = stages[op]; auto *newOp = rewriter.clone(*op, mapping); - for (OpOperand &operand : op->getOpOperands()) { + + auto updateNewOpOrNestedOps = [&](OpOperand *operand, Value replacement) { + if (operand->getOwner() == op) { + newOp->setOperand(operand->getOperandNumber(), replacement); + return; + } + for (Region ®ion : newOp->getRegions()) { + if (auto arg = operand->get().dyn_cast()) { + if (arg.getOwner() == forOp.getBody()) { + replaceAllUsesInRegionWith( + newForOp.getRegionIterArgs()[operand->getOperandNumber() - 1], + replacement, region); + continue; + } + } + replaceAllUsesInRegionWith(operand->get(), replacement, region); + } + }; + + for (OpOperand *operand : getOpOperandsAndUsesAbove(op, forOp)) { // Special case for the induction variable uses. We replace it with a // version incremented based on the stage where it is used. - if (operand.get() == forOp.getInductionVar()) { + if (operand->get() == forOp.getInductionVar()) { rewriter.setInsertionPoint(newOp); Value offset = rewriter.create( forOp.getLoc(), (maxStage - stages[op]) * step); Value iv = rewriter.create( forOp.getLoc(), newForOp.getInductionVar(), offset); - newOp->setOperand(operand.getOperandNumber(), iv); + updateNewOpOrNestedOps(operand, iv); rewriter.setInsertionPointAfter(newOp); continue; } - auto arg = operand.get().dyn_cast(); + auto arg = operand->get().dyn_cast(); if (arg && arg.getOwner() == forOp.getBody()) { // If the value is a loop carried value coming from stage N + 1 remap, // it will become a direct use. @@ -303,24 +389,23 @@ if (stageDep == stages.end() || stageDep->second == useStage) continue; assert(stageDep->second == useStage + 1); - newOp->setOperand(operand.getOperandNumber(), - mapping.lookupOrDefault(ret)); + updateNewOpOrNestedOps(operand, mapping.lookupOrDefault(ret)); continue; } // For operands defined in a previous stage we need to remap it to use // the correct region argument. We look for the right version of the // Value based on the stage where it is used. - Operation *def = operand.get().getDefiningOp(); + Operation *def = operand->get().getDefiningOp(); if (!def) continue; auto stageDef = stages.find(def); if (stageDef == stages.end() || stageDef->second == useStage) continue; auto remap = loopArgMap.find( - std::make_pair(operand.get(), useStage - stageDef->second)); + std::make_pair(operand->get(), useStage - stageDef->second)); assert(remap != loopArgMap.end()); - newOp->setOperand(operand.getOperandNumber(), - newForOp.getRegionIterArgs()[remap->second]); + updateNewOpOrNestedOps(operand, + newForOp.getRegionIterArgs()[remap->second]); } if (predicates[useStage]) { newOp = predicateFn(newOp, predicates[useStage], rewriter); @@ -397,6 +482,18 @@ newOp->setOperand(opIdx, v); } } + + // Update uses of nested ops, all of which share the stage with the + // parent op. + visitUsedValuesDefinedAbove(op->getRegions(), [&](OpOperand *operand) { + auto it = valueMapping.find(operand->get()); + if (it != valueMapping.end()) { + for (Region &r : newOp->getRegions()) + replaceAllUsesInRegionWith( + operand->get(), it->second[maxStage - stages[op] + i], r); + } + }); + if (annotateFn) annotateFn(newOp, PipeliningOption::PipelinerPart::Epilogue, i - 1); for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) { diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir --- a/mlir/test/Dialect/SCF/loop-pipelining.mlir +++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir @@ -341,6 +341,52 @@ return %r : f32 } +// ----- + +// CHECK-LABEL: region_backedge_different_stage +// CHECK-SAME: (%[[A:.*]]: memref) -> f32 { +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[CSTF:.*]] = arith.constant 1.000000e+00 : f32 +// Prologue: +// CHECK: %[[L0:.*]] = memref.load %[[A]][%[[C0]]] : memref +// CHECK: %[[ADD0:.*]] = scf.execute_region +// CHECK-NEXT: arith.addf %[[L0]], %[[CSTF]] : f32 +// CHECK: %[[L1:.*]] = memref.load %[[A]][%[[C1]]] : memref +// Kernel: +// CHECK-NEXT: %[[R:.*]]:3 = scf.for %[[IV:.*]] = %[[C0]] to %[[C2]] +// CHECK-SAME: step %[[C1]] iter_args(%[[C:.*]] = %[[CSTF]], +// CHECK-SAME: %[[ADDARG:.*]] = %[[ADD0]], %[[LARG:.*]] = %[[L1]]) -> (f32, f32, f32) { +// CHECK: %[[ADD1:.*]] = scf.execute_region +// CHECK-NEXT: arith.addf %[[LARG]], %[[ADDARG]] : f32 +// CHECK: %[[IV2:.*]] = arith.addi %[[IV]], %[[C2]] : index +// CHECK-NEXT: %[[L2:.*]] = memref.load %[[A]][%[[IV2]]] : memref +// CHECK-NEXT: scf.yield %[[ADDARG]], %[[ADD1]], %[[L2]] : f32, f32, f32 +// CHECK-NEXT: } +// Epilogue: +// CHECK: %[[ADD2:.*]] = scf.execute_region +// CHECK-NEXT: arith.addf %[[R]]#2, %[[R]]#1 : f32 +// CHECK: return %[[ADD2]] : f32 + +func.func @region_backedge_different_stage(%A: memref) -> f32 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %cf = arith.constant 1.0 : f32 + %r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%arg0 = %cf) -> (f32) { + %A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 2 } : memref + %A1_elem = scf.execute_region -> f32 { + %inner = arith.addf %A_elem, %arg0 : f32 + scf.yield %inner : f32 + } { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 1 } + %A2_elem = arith.mulf %cf, %A1_elem { __test_pipelining_stage__ = 2, __test_pipelining_op_order__ = 0 } : f32 + scf.yield %A2_elem : f32 + } { __test_pipelining_loop__ } + return %r : f32 +} + + // ----- // CHECK-LABEL: backedge_same_stage @@ -376,3 +422,81 @@ } { __test_pipelining_loop__ } return %r : f32 } + +// ----- + +// CHECK: @pipline_op_with_region(%[[ARG0:.+]]: memref, %[[ARG1:.+]]: memref, %[[ARG2:.+]]: memref) { +// CHECK: %[[C0:.+]] = arith.constant 0 : +// CHECK: %[[C3:.+]] = arith.constant 3 : +// CHECK: %[[C1:.+]] = arith.constant 1 : +// CHECK: %[[APRO:.+]] = memref.alloc() : +// CHECK: %[[BPRO:.+]] = memref.alloc() : +// CHECK: %[[ASV0:.+]] = memref.subview %[[ARG0]][%[[C0]]] [8] [1] : +// CHECK: %[[BSV0:.+]] = memref.subview %[[ARG1]][%[[C0]]] [8] [1] : + +// Prologue: +// CHECK: %[[PAV0:.+]] = memref.subview %[[APRO]][%[[C0]], 0] [1, 8] [1, 1] : +// CHECK: %[[PBV0:.+]] = memref.subview %[[BPRO]][%[[C0]], 0] [1, 8] [1, 1] : +// CHECK: memref.copy %[[ASV0]], %[[PAV0]] : +// CHECK: memref.copy %[[BSV0]], %[[PBV0]] : + +// Kernel: +// CHECK: %[[R:.+]]:2 = scf.for %[[IV:.+]] = %[[C0]] to %[[C3]] step %[[C1]] +// CHECK-SAME: iter_args(%[[IA:.+]] = %[[PAV0]], %[[IB:.+]] = %[[PBV0:.+]]) +// CHECK: %[[CV:.+]] = memref.subview %[[ARG2]] +// CHECK: linalg.generic +// CHECK-SAME: ins(%[[IA]], %[[IB]] : {{.*}}) outs(%[[CV]] : +// CHECK: %[[NEXT:.+]] = arith.addi %[[IV]], %[[C1]] +// CHECK: %[[ASV:.+]] = memref.subview %[[ARG0]][%[[NEXT]]] [8] [1] : +// CHECK: %[[NEXT:.+]] = arith.addi %[[IV]], %[[C1]] : +// CHECK: %[[BSV:.+]] = memref.subview %[[ARG1]][%[[NEXT]]] [8] [1] : +// CHECK: %[[NEXT:.+]] = arith.addi %[[IV]], %[[C1]] : +// CHECK: %[[BUFIDX:.+]] = affine.apply +// CHECK: %[[APROSV:.+]] = memref.subview %[[APRO]][%[[BUFIDX]], 0] [1, 8] [1, 1] : +// CHECK: %[[BPROSV:.+]] = memref.subview %[[BPRO]][%[[BUFIDX]], 0] [1, 8] [1, 1] : +// CHECK: memref.copy %[[ASV]], %[[APROSV]] : +// CHECK: memref.copy %[[BSV]], %[[BPROSV]] : +// CHECK: scf.yield %[[APROSV]], %[[BPROSV]] : +// CHECK: } +// CHECK: %[[CV:.+]] = memref.subview %[[ARG2]][%[[C3]]] [8] [1] : +// CHECK: linalg.generic +// CHECK-SAME: ins(%[[R]]#0, %[[R]]#1 : {{.*}}) outs(%[[CV]] : + + +#map = affine_map<(d0)[s0]->(d0 + s0)> +#map1 = affine_map<(d0)->(d0)> +func.func @pipline_op_with_region(%A: memref, %B: memref, %result: memref) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %cf = arith.constant 1.0 : f32 + %a_buf = memref.alloc() : memref<2x8xf32> + %b_buf = memref.alloc() : memref<2x8xf32> + scf.for %i0 = %c0 to %c4 step %c1 { + %A_view = memref.subview %A[%i0][8][1] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 2 } : memref to memref<8xf32, #map> + %B_view = memref.subview %B[%i0][8][1] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 3 } : memref to memref<8xf32, #map> + %buf_idx = affine.apply affine_map<(d0)->(d0 mod 2)> (%i0)[] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 4 } + %a_buf_view = memref.subview %a_buf[%buf_idx,0][1,8][1,1] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 5 } : memref<2x8xf32> to memref<8xf32, #map> + %b_buf_view = memref.subview %b_buf[%buf_idx,0][1,8][1,1] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 6 } : memref<2x8xf32> to memref<8xf32, #map> + memref.copy %A_view , %a_buf_view {__test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 7} : memref<8xf32, #map> to memref<8xf32, #map> + memref.copy %B_view , %b_buf_view {__test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 8} : memref<8xf32, #map> to memref<8xf32, #map> + %C_view = memref.subview %result[%i0][8][1] { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 0 } : memref to memref<8xf32, #map> + linalg.generic { + indexing_maps = [ + #map1, + #map1, + #map1 + ], + iterator_types = ["parallel"], + __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 1 + } ins(%a_buf_view, %b_buf_view : memref<8xf32, #map>, memref<8xf32, #map>) + outs(%C_view: memref<8xf32, #map>) { + ^bb0(%a: f32, %b: f32, %c: f32): + %add = arith.addf %a, %b : f32 + %accum = arith.addf %add, %c : f32 + linalg.yield %accum : f32 + } + } { __test_pipelining_loop__ } + return +} + diff --git a/mlir/test/lib/Dialect/SCF/CMakeLists.txt b/mlir/test/lib/Dialect/SCF/CMakeLists.txt --- a/mlir/test/lib/Dialect/SCF/CMakeLists.txt +++ b/mlir/test/lib/Dialect/SCF/CMakeLists.txt @@ -6,7 +6,9 @@ EXCLUDE_FROM_LIBMLIR - LINK_LIBS PUBLIC + LINK_LIBS PUBLIC + MLIRLinalgDialect + MLIRMemRefDialect MLIRPass MLIRSCFDialect MLIRSCFTransforms diff --git a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp --- a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp +++ b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp @@ -12,6 +12,8 @@ #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/Transforms.h" #include "mlir/Dialect/SCF/Utils/Utils.h" @@ -155,17 +157,30 @@ rewriter.create(loc, op->getResultTypes(), pred, true); // True branch. op->moveBefore(&ifOp.getThenRegion().front(), - ifOp.getThenRegion().front().end()); + ifOp.getThenRegion().front().begin()); rewriter.setInsertionPointAfter(op); - rewriter.create(loc, op->getResults()); + if (op->getNumResults() > 0) + rewriter.create(loc, op->getResults()); // False branch. rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); - SmallVector zeros; - for (Type type : op->getResultTypes()) { - zeros.push_back( - rewriter.create(loc, rewriter.getZeroAttr(type))); + SmallVector elseYieldOperands; + elseYieldOperands.reserve(ifOp.getNumResults()); + if (auto viewOp = dyn_cast(op)) { + // For sub-views, just clone the op. + // NOTE: This is okay in the test because we use dynamic memref sizes, so + // the verifier will not complain. Otherwise, we may create a logically + // out-of-bounds view and a different technique should be used. + Operation *opClone = rewriter.clone(*op); + elseYieldOperands.append(opClone->result_begin(), opClone->result_end()); + } else { + // Default to assuming constant numeric values. + for (Type type : op->getResultTypes()) { + elseYieldOperands.push_back(rewriter.create( + loc, rewriter.getZeroAttr(type))); + } } - rewriter.create(loc, zeros); + if (op->getNumResults() > 0) + rewriter.create(loc, elseYieldOperands); return ifOp.getOperation(); } @@ -189,7 +204,8 @@ } void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry.insert(); } void runOnOperation() override {