diff --git a/mlir/include/mlir/Dialect/SCF/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms.h --- a/mlir/include/mlir/Dialect/SCF/Transforms.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms.h @@ -125,7 +125,21 @@ /// order picked for the pipelined loop. using GetScheduleFnType = std::function> &)>; - GetScheduleFnType getScheduleFn; + GetScheduleFnType getScheduleFn = nullptr; + enum class PipelinerPart { + Prologue, + Kernel, + Epilogue, + }; + /// Lambda called by the pipeliner to allow the user to annotate the IR while + /// it is generated. + /// The callback passes the operation created along with the part of the + /// pipeline and the iteration index. The iteration index is always 0 for the + /// kernel. For the prologue and epilogue, it corresponds to the iteration + /// peeled out of the loop in the range [0, maxStage[. + using AnnotationlFnType = + std::function; + AnnotationlFnType annotateFn = nullptr; // TODO: add option to decide if the prologue/epilogue should be peeled. }; diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp --- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp @@ -41,6 +41,7 @@ int64_t ub; int64_t lb; int64_t step; + PipeliningOption::AnnotationlFnType annotateFn = nullptr; // When peeling the kernel we generate several version of each value for // different stage of the prologue. This map tracks the mapping between @@ -126,6 +127,7 @@ return !def || stages.find(def) == stages.end(); })) return false; + annotateFn = options.annotateFn; return true; } @@ -150,6 +152,8 @@ if (it != valueMapping.end()) newOp->setOperand(opIdx, it->second[i - stages[op]]); } + if (annotateFn) + annotateFn(newOp, PipeliningOption::PipelinerPart::Prologue, i); for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) { setValueMapping(op->getResult(destId), newOp->getResult(destId), i - stages[op]); @@ -297,6 +301,8 @@ newOp->setOperand(operand.getOperandNumber(), newForOp.getRegionIterArgs()[remap->second]); } + if (annotateFn) + annotateFn(newOp, PipeliningOption::PipelinerPart::Kernel, 0); } // Collect the Values that need to be returned by the forOp. For each @@ -363,6 +369,8 @@ newOp->setOperand(opIdx, v); } } + if (annotateFn) + annotateFn(newOp, PipeliningOption::PipelinerPart::Epilogue, i - 1); for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) { setValueMapping(op->getResult(destId), newOp->getResult(destId), maxStage - stages[op] + i); diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir --- a/mlir/test/Dialect/SCF/loop-pipelining.mlir +++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt %s -test-scf-pipelining -split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-scf-pipelining=annotate -split-input-file | FileCheck %s --check-prefix ANNOTATE // CHECK-LABEL: simple_pipeline( // CHECK-SAME: %[[A:.*]]: memref, %[[R:.*]]: memref) { @@ -97,6 +98,22 @@ // CHECK-NEXT: memref.store %[[LR]]#0, %[[R]][%[[C2]]] : memref // CHECK-NEXT: %[[ADD2:.*]] = arith.addf %[[LR]]#1, %{{.*}} : f32 // CHECK-NEXT: memref.store %[[ADD2]], %[[R]][%[[C3]]] : memref + +// Prologue: +// ANNOTATE: memref.load {{.*}} {__test_pipelining_iteration = 0 : i32, __test_pipelining_part = "prologue"} +// ANNOTATE: memref.load {{.*}} {__test_pipelining_iteration = 1 : i32, __test_pipelining_part = "prologue"} +// Kernel: +// ANNOTATE: scf.for +// ANNOTATE: memref.store {{.*}} {__test_pipelining_iteration = 0 : i32, __test_pipelining_part = "kernel"} +// ANNOTATE: arith.addf {{.*}} {__test_pipelining_iteration = 0 : i32, __test_pipelining_part = "kernel"} +// ANNOTATE: memref.load {{.*}} {__test_pipelining_iteration = 0 : i32, __test_pipelining_part = "kernel"} +// ANNOTATE: scf.yield +// ANNOTATE: } +// Epilogue: +// ANNOTATE: memref.store {{.*}} {__test_pipelining_iteration = 0 : i32, __test_pipelining_part = "epilogue"} +// ANNOTATE: arith.addf {{.*}} {__test_pipelining_iteration = 0 : i32, __test_pipelining_part = "epilogue"} +// ANNOTATE: memref.store {{.*}} {__test_pipelining_iteration = 1 : i32, __test_pipelining_part = "epilogue"} + func @three_stage(%A: memref, %result: memref) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index diff --git a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp --- a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp +++ b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp @@ -90,12 +90,23 @@ static const StringLiteral kTestPipeliningOpOrderMarker = "__test_pipelining_op_order__"; +static const StringLiteral kTestPipeliningAnnotationPart = + "__test_pipelining_part"; +static const StringLiteral kTestPipeliningAnnotationIteration = + "__test_pipelining_iteration"; + class TestSCFPipeliningPass : public PassWrapper> { public: + TestSCFPipeliningPass() = default; + TestSCFPipeliningPass(const TestSCFPipeliningPass &) {} StringRef getArgument() const final { return "test-scf-pipelining"; } StringRef getDescription() const final { return "test scf.forOp pipelining"; } - explicit TestSCFPipeliningPass() = default; + + Option annotatePipeline{ + *this, "annotate", + llvm::cl::desc("Annote operations during loop pipelining transformation"), + llvm::cl::init(false)}; static void getSchedule(scf::ForOp forOp, @@ -115,6 +126,25 @@ }); } + static void annotate(Operation *op, + mlir::scf::PipeliningOption::PipelinerPart part, + unsigned iteration) { + OpBuilder b(op); + switch (part) { + case mlir::scf::PipeliningOption::PipelinerPart::Prologue: + op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("prologue")); + break; + case mlir::scf::PipeliningOption::PipelinerPart::Kernel: + op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("kernel")); + break; + case mlir::scf::PipeliningOption::PipelinerPart::Epilogue: + op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("epilogue")); + break; + } + op->setAttr(kTestPipeliningAnnotationIteration, + b.getI32IntegerAttr(iteration)); + } + void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } @@ -123,7 +153,8 @@ RewritePatternSet patterns(&getContext()); mlir::scf::PipeliningOption options; options.getScheduleFn = getSchedule; - + if (annotatePipeline) + options.annotateFn = annotate; scf::populateSCFLoopPipeliningPatterns(patterns, options); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); getOperation().walk([](Operation *op) {