diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -337,18 +337,21 @@ LogicalResult vectorizeLinalgOp(OpBuilder &builder, Operation *op, SmallVectorImpl &newResults); -/// Emits a loop nest of `LoopTy` with the proper body for `op`. +/// Emits a loop nest of `LoopTy` with the proper body for `linalgOp`. template -Optional linalgLowerOpToLoops(OpBuilder &builder, Operation *op); +Optional linalgLowerOpToLoops(PatternRewriter &rewriter, + LinalgOp linalgOp); -/// Emits a loop nest of `scf.for` with the proper body for `op`. -LogicalResult linalgOpToLoops(OpBuilder &builder, Operation *op); +/// Emits a loop nest of `scf.for` with the proper body for `linalgOp`. +LogicalResult linalgOpToLoops(PatternRewriter &rewriter, LinalgOp linalgOp); -/// Emits a loop nest of `scf.parallel` with the proper body for `op`. -LogicalResult linalgOpToParallelLoops(OpBuilder &builder, Operation *op); +/// Emits a loop nest of `scf.parallel` with the proper body for `linalgOp`. +LogicalResult linalgOpToParallelLoops(PatternRewriter &rewriter, + LinalgOp linalgOp); -/// Emits a loop nest of `affine.for` with the proper body for `op`. -LogicalResult linalgOpToAffineLoops(OpBuilder &builder, Operation *op); +/// Emits a loop nest of `affine.for` with the proper body for `linalgOp`. +LogicalResult linalgOpToAffineLoops(PatternRewriter &rewriter, + LinalgOp linalgOp); //===----------------------------------------------------------------------===// // Preconditions that ensure the corresponding transformation succeeds and can diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -457,18 +457,17 @@ } template -static Optional linalgOpToLoopsImpl(Operation *op, +static Optional linalgOpToLoopsImpl(LinalgOp linalgOp, OpBuilder &builder) { using IndexedValueTy = typename GenerateLoopNest::IndexedValueTy; - ScopedContext scope(builder, op->getLoc()); + ScopedContext scope(builder, linalgOp.getLoc()); // The flattened loopToOperandRangesMaps is expected to be an invertible // permutation map (which is asserted in the inverse calculation). - auto linalgOp = cast(op); assert(linalgOp.hasBufferSemantics() && "expected linalg op with buffer semantics"); - auto loopRanges = linalgOp.createLoopRanges(builder, op->getLoc()); + auto loopRanges = linalgOp.createLoopRanges(builder, linalgOp.getLoc()); auto iteratorTypes = llvm::to_vector<4>(linalgOp.iterator_types().getValue()); SmallVector allIvs; @@ -477,7 +476,7 @@ [&](ValueRange ivs, ValueRange iterArgs) -> scf::ValueVector { assert(iterArgs.empty() && "unexpected iterArgs"); allIvs.append(ivs.begin(), ivs.end()); - llvm::TypeSwitch(op) + llvm::TypeSwitch(linalgOp) .Case([&](auto op) { emitScalarImplementation(allIvs, op); @@ -546,10 +545,8 @@ auto linalgOp = dyn_cast(op); if (!isa(op)) return failure(); - Optional loopOps = linalgOpToLoopsImpl(op, rewriter); - if (!loopOps.hasValue()) + if (!linalgLowerOpToLoops(rewriter, linalgOp)) return failure(); - replaceIndexOpsByInductionVariables(linalgOp, rewriter, loopOps.getValue()); rewriter.eraseOp(op); return success(); } @@ -695,40 +692,48 @@ return std::make_unique(); } -/// Emits a loop nest with the proper body for `op`. +/// Emits a loop nest with the proper body for `linalgOp`. template -Optional mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder, - Operation *op) { - return linalgOpToLoopsImpl(op, builder); +Optional +mlir::linalg::linalgLowerOpToLoops(PatternRewriter &rewriter, + LinalgOp linalgOp) { + Optional loopOps = + linalgOpToLoopsImpl(linalgOp.getOperation(), rewriter); + if (loopOps.hasValue()) + replaceIndexOpsByInductionVariables(linalgOp, rewriter, loopOps.getValue()); + return loopOps; } template Optional -mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder, - Operation *op); +mlir::linalg::linalgLowerOpToLoops(PatternRewriter &rewriter, + LinalgOp linalgOp); template Optional -mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder, - Operation *op); +mlir::linalg::linalgLowerOpToLoops(PatternRewriter &rewriter, + LinalgOp linalgOp); template Optional -mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder, - Operation *op); +mlir::linalg::linalgLowerOpToLoops(PatternRewriter &rewriter, + LinalgOp linalgOp); -/// Emits a loop nest of `affine.for` with the proper body for `op`. -LogicalResult mlir::linalg::linalgOpToAffineLoops(OpBuilder &builder, - Operation *op) { - Optional loops = linalgLowerOpToLoops(builder, op); +/// Emits a loop nest of `affine.for` with the proper body for `linalgOp`. +LogicalResult mlir::linalg::linalgOpToAffineLoops(PatternRewriter &rewriter, + LinalgOp linalgOp) { + Optional loops = + linalgLowerOpToLoops(rewriter, linalgOp); return loops ? success() : failure(); } -/// Emits a loop nest of `scf.for` with the proper body for `op`. -LogicalResult mlir::linalg::linalgOpToLoops(OpBuilder &builder, Operation *op) { - Optional loops = linalgLowerOpToLoops(builder, op); +/// Emits a loop nest of `scf.for` with the proper body for `linalgOp`. +LogicalResult mlir::linalg::linalgOpToLoops(PatternRewriter &rewriter, + LinalgOp linalgOp) { + Optional loops = + linalgLowerOpToLoops(rewriter, linalgOp); return loops ? success() : failure(); } -/// Emits a loop nest of `scf.parallel` with the proper body for `op`. -LogicalResult mlir::linalg::linalgOpToParallelLoops(OpBuilder &builder, - Operation *op) { +/// Emits a loop nest of `scf.parallel` with the proper body for `linalgOp`. +LogicalResult mlir::linalg::linalgOpToParallelLoops(PatternRewriter &rewriter, + LinalgOp linalgOp) { Optional loops = - linalgLowerOpToLoops(builder, op); + linalgLowerOpToLoops(rewriter, linalgOp); return loops ? success() : failure(); }