diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp --- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp @@ -114,15 +114,28 @@ return false; // All operations need to have a stage. - if (forOp - .walk([this](Operation *op) { - if (op != forOp.getOperation() && !isa(op) && - stages.find(op) == stages.end()) - return WalkResult::interrupt(); - return WalkResult::advance(); - }) - .wasInterrupted()) - return false; + for (Operation &op : forOp.getBody()->without_terminator()) { + if (stages.find(&op) == stages.end()) { + op.emitOpError("not assigned a pipeline stage"); + return false; + } + } + + // Currently, we do not support assigning stages to ops in nested regions. The + // block of all operations assigned a stage should be the single `scf.for` + // body block. + for (const auto &[op, stageNum] : stages) { + (void)stageNum; + if (op == forOp.getBody()->getTerminator()) { + op->emitError("terminator should not be assigned a stage"); + return false; + } + if (op->getBlock() != forOp.getBody()) { + op->emitOpError("the owning Block of all operations assigned a stage " + "should be the loop body block"); + return false; + } + } // Only support loop carried dependency with a distance of 1. This means the // source of all the scf.yield operands needs to be defined by operations in @@ -243,6 +256,10 @@ auto newForOp = rewriter.create(forOp.getLoc(), forOp.getLowerBound(), newUb, forOp.getStep(), newLoopArg); + // When there are no iter args, the loop body terminator will be created. + // Since we always create it below, remove the terminator if it was created. + if (!newForOp.getBody()->empty()) + rewriter.eraseOp(newForOp.getBody()->getTerminator()); return newForOp; } diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir --- a/mlir/test/Dialect/SCF/loop-pipelining.mlir +++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir @@ -376,3 +376,80 @@ } { __test_pipelining_loop__ } return %r : f32 } + +// ----- + +// CHECK: @pipline_op_with_region(%[[ARG0:.+]]: memref, %[[ARG1:.+]]: memref, %[[ARG2:.+]]: memref) { +// CHECK: %[[C0:.+]] = arith.constant 0 : +// CHECK: %[[C3:.+]] = arith.constant 3 : +// CHECK: %[[C1:.+]] = arith.constant 1 : +// CHECK: %[[APRO:.+]] = memref.alloc() : +// CHECK: %[[BPRO:.+]] = memref.alloc() : +// CHECK: %[[ASV0:.+]] = memref.subview %[[ARG0]][%[[C0]]] [8] [1] : +// CHECK: %[[BSV0:.+]] = memref.subview %[[ARG1]][%[[C0]]] [8] [1] : + +// Prologue: +// CHECK: %[[PAV0:.+]] = memref.subview %[[APRO]][%[[C0]], 0] [1, 8] [1, 1] : +// CHECK: %[[PBV0:.+]] = memref.subview %[[BPRO]][%[[C0]], 0] [1, 8] [1, 1] : +// CHECK: memref.copy %[[ASV0]], %[[PAV0]] : +// CHECK: memref.copy %[[BSV0]], %[[PBV0]] : + +// Kernel: +// CHECK: %[[R:.+]]:2 = scf.for %[[IV:.+]] = %[[C0]] to %[[C3]] step %[[C1]] +// CHECK-SAME: iter_args(%[[IA:.+]] = %[[PAV0]], %[[IB:.+]] = %[[PBV0:.+]]) +// CHECK: %[[CV:.+]] = memref.subview %[[ARG2]] +// CHECK: linalg.generic +// CHECK-SAME: ins(%[[IA]], %[[IB]] : {{.*}}) outs(%[[CV]] : +// CHECK: %[[NEXT:.+]] = arith.addi %[[IV]], %[[C1]] +// CHECK: %[[ASV:.+]] = memref.subview %[[ARG0]][%[[NEXT]]] [8] [1] : +// CHECK: %[[NEXT:.+]] = arith.addi %[[IV]], %[[C1]] : +// CHECK: %[[BSV:.+]] = memref.subview %[[ARG1]][%[[NEXT]]] [8] [1] : +// CHECK: %[[NEXT:.+]] = arith.addi %[[IV]], %[[C1]] : +// CHECK: %[[BUFIDX:.+]] = affine.apply +// CHECK: %[[APROSV:.+]] = memref.subview %[[APRO]][%[[BUFIDX]], 0] [1, 8] [1, 1] : +// CHECK: %[[BPROSV:.+]] = memref.subview %[[BPRO]][%[[BUFIDX]], 0] [1, 8] [1, 1] : +// CHECK: memref.copy %[[ASV]], %[[APROSV]] : +// CHECK: memref.copy %[[BSV]], %[[BPROSV]] : +// CHECK: scf.yield %[[APROSV]], %[[BPROSV]] : +// CHECK: } +// CHECK: %[[CV:.+]] = memref.subview %[[ARG2]][%[[C3]]] [8] [1] : +// CHECK: linalg.generic +// CHECK-SAME: ins(%[[R]]#0, %[[R]]#1 : {{.*}}) outs(%[[CV]] : + + +#map = affine_map<(d0)[s0]->(d0 + s0)> +#map1 = affine_map<(d0)->(d0)> +func.func @pipline_op_with_region(%A: memref, %B: memref, %result: memref) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %cf = arith.constant 1.0 : f32 + %a_buf = memref.alloc() : memref<2x8xf32> + %b_buf = memref.alloc() : memref<2x8xf32> + scf.for %i0 = %c0 to %c4 step %c1 { + %A_view = memref.subview %A[%i0][8][1] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 2 } : memref to memref<8xf32, #map> + %B_view = memref.subview %B[%i0][8][1] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 3 } : memref to memref<8xf32, #map> + %buf_idx = affine.apply affine_map<(d0)->(d0 mod 2)> (%i0)[] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 4 } + %a_buf_view = memref.subview %a_buf[%buf_idx,0][1,8][1,1] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 5 } : memref<2x8xf32> to memref<8xf32, #map> + %b_buf_view = memref.subview %b_buf[%buf_idx,0][1,8][1,1] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 6 } : memref<2x8xf32> to memref<8xf32, #map> + memref.copy %A_view , %a_buf_view {__test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 7} : memref<8xf32, #map> to memref<8xf32, #map> + memref.copy %B_view , %b_buf_view {__test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 8} : memref<8xf32, #map> to memref<8xf32, #map> + %C_view = memref.subview %result[%i0][8][1] { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 0 } : memref to memref<8xf32, #map> + linalg.generic { + indexing_maps = [ + #map1, + #map1, + #map1 + ], + iterator_types = ["parallel"], + __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 1 + } ins(%a_buf_view, %b_buf_view : memref<8xf32, #map>, memref<8xf32, #map>) + outs(%C_view: memref<8xf32, #map>) { + ^bb0(%a: f32, %b: f32, %c: f32): + %add = arith.addf %a, %b : f32 + %accum = arith.addf %add, %c : f32 + linalg.yield %accum : f32 + } + } { __test_pipelining_loop__ } + return +} \ No newline at end of file diff --git a/mlir/test/lib/Dialect/SCF/CMakeLists.txt b/mlir/test/lib/Dialect/SCF/CMakeLists.txt --- a/mlir/test/lib/Dialect/SCF/CMakeLists.txt +++ b/mlir/test/lib/Dialect/SCF/CMakeLists.txt @@ -6,7 +6,9 @@ EXCLUDE_FROM_LIBMLIR - LINK_LIBS PUBLIC + LINK_LIBS PUBLIC + MLIRLinalgDialect + MLIRMemRefDialect MLIRPass MLIRSCFDialect MLIRSCFTransforms diff --git a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp --- a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp +++ b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp @@ -12,6 +12,8 @@ #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/Transforms.h" #include "mlir/Dialect/SCF/Utils/Utils.h" @@ -155,17 +157,30 @@ rewriter.create(loc, op->getResultTypes(), pred, true); // True branch. op->moveBefore(&ifOp.getThenRegion().front(), - ifOp.getThenRegion().front().end()); + ifOp.getThenRegion().front().begin()); rewriter.setInsertionPointAfter(op); - rewriter.create(loc, op->getResults()); + if (op->getNumResults() > 0) + rewriter.create(loc, op->getResults()); // False branch. rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); - SmallVector zeros; - for (Type type : op->getResultTypes()) { - zeros.push_back( - rewriter.create(loc, rewriter.getZeroAttr(type))); + SmallVector elseYieldOperands; + elseYieldOperands.reserve(ifOp.getNumResults()); + if (auto viewOp = dyn_cast(op)) { + // For sub-views, just clone the op. + // NOTE: This is okay in the test because we use dynamic memref sizes, so + // the verifier will not complain. Otherwise, we may create a logically + // out-of-bounds view and a different technique should be used. + Operation *opClone = rewriter.clone(*op); + elseYieldOperands.append(opClone->result_begin(), opClone->result_end()); + } else { + // Default to assuming constant numeric values. + for (Type type : op->getResultTypes()) { + elseYieldOperands.push_back(rewriter.create( + loc, rewriter.getZeroAttr(type))); + } } - rewriter.create(loc, zeros); + if (op->getNumResults() > 0) + rewriter.create(loc, elseYieldOperands); return ifOp.getOperation(); } @@ -189,7 +204,8 @@ } void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry.insert(); } void runOnOperation() override {