diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -343,21 +343,17 @@ LogicalResult vectorizeLinalgOp(OpBuilder &builder, Operation *op, SmallVectorImpl &newResults); -/// Emits a loop nest of `LoopTy` with the proper body for `linalgOp`. -template -Optional linalgLowerOpToLoops(PatternRewriter &rewriter, - LinalgOp linalgOp); - /// Emits a loop nest of `scf.for` with the proper body for `linalgOp`. -LogicalResult linalgOpToLoops(PatternRewriter &rewriter, LinalgOp linalgOp); +Optional linalgOpToLoops(PatternRewriter &rewriter, + LinalgOp linalgOp); /// Emits a loop nest of `scf.parallel` with the proper body for `linalgOp`. -LogicalResult linalgOpToParallelLoops(PatternRewriter &rewriter, - LinalgOp linalgOp); +Optional linalgOpToParallelLoops(PatternRewriter &rewriter, + LinalgOp linalgOp); /// Emits a loop nest of `affine.for` with the proper body for `linalgOp`. -LogicalResult linalgOpToAffineLoops(PatternRewriter &rewriter, - LinalgOp linalgOp); +Optional linalgOpToAffineLoops(PatternRewriter &rewriter, + LinalgOp linalgOp); //===----------------------------------------------------------------------===// // Preconditions that ensure the corresponding transformation succeeds and can @@ -825,15 +821,15 @@ // TODO: Move lowering to library calls here. return failure(); case LinalgLoweringType::Loops: - if (failed(linalgOpToLoops(rewriter, op))) + if (!linalgOpToLoops(rewriter, op)) return failure(); break; case LinalgLoweringType::AffineLoops: - if (failed(linalgOpToAffineLoops(rewriter, op))) + if (!linalgOpToAffineLoops(rewriter, op)) return failure(); break; case LinalgLoweringType::ParallelLoops: - if (failed(linalgOpToParallelLoops(rewriter, op))) + if (!linalgOpToParallelLoops(rewriter, op)) return failure(); break; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -378,18 +378,54 @@ getPoolingInput(op, indices.inputs); } +/// Replace the index operations in the body of the loop nest by the matching +/// induction variables. +static void replaceIndexOpsByInductionVariables(LinalgOp linalgOp, + PatternRewriter &rewriter, + ArrayRef loopOps) { + // Extract the induction variables of the loop nest from outer to inner. + SmallVector allIvs; + for (Operation *loopOp : loopOps) { + llvm::TypeSwitch(loopOp) + .Case([&](scf::ParallelOp parallelOp) { + allIvs.append(parallelOp.getInductionVars().begin(), + parallelOp.getInductionVars().end()); + }) + .Case([&](scf::ForOp forOp) { + allIvs.push_back(forOp.getInductionVar()); + }) + .Case([&](AffineForOp affineForOp) { + allIvs.push_back(affineForOp.getInductionVar()); + }) + .Default([&](Operation *op) { assert(false && "unexpected op"); }); + } + assert(linalgOp.getNumLoops() == allIvs.size() && + "expected the number of loops and induction variables to match"); + // Replace the index operations in the body of the innermost loop op. + if (!loopOps.empty()) { + LoopLikeOpInterface loopOp = loopOps.back(); + for (IndexOp indexOp : + llvm::make_early_inc_range(loopOp.getLoopBody().getOps())) + rewriter.replaceOp(indexOp, allIvs[indexOp.dim()]); + } +} + template -static Optional linalgOpToLoopsImpl(LinalgOp linalgOp, - OpBuilder &builder) { +static Optional linalgOpToLoopsImpl(PatternRewriter &rewriter, + LinalgOp linalgOp) { using IndexedValueTy = typename GenerateLoopNest::IndexedValueTy; - ScopedContext scope(builder, linalgOp.getLoc()); + ScopedContext scope(rewriter, linalgOp.getLoc()); + + // Canonicalize indexed_generic operations before lowering them to loops. + if (isa(linalgOp)) + return llvm::None; // The flattened loopToOperandRangesMaps is expected to be an invertible // permutation map (which is asserted in the inverse calculation). assert(linalgOp.hasBufferSemantics() && "expected linalg op with buffer semantics"); - auto loopRanges = linalgOp.createLoopRanges(builder, linalgOp.getLoc()); + auto loopRanges = linalgOp.createLoopRanges(rewriter, linalgOp.getLoc()); auto iteratorTypes = llvm::to_vector<4>(linalgOp.iterator_types().getValue()); SmallVector allIvs; @@ -420,41 +456,11 @@ loopSet.insert(ivVal.getOwner()->getParentOp()); } LinalgLoops loops(loopSet.begin(), loopSet.end()); + // Replace all index operations in the loop body. + replaceIndexOpsByInductionVariables(linalgOp, rewriter, loops); return loops; } -/// Replace the index operations in the body of the loop nest by the matching -/// induction variables. -static void replaceIndexOpsByInductionVariables(LinalgOp linalgOp, - PatternRewriter &rewriter, - ArrayRef loopOps) { - // Extract the induction variables of the loop nest from outer to inner. - SmallVector allIvs; - for (Operation *loopOp : loopOps) { - llvm::TypeSwitch(loopOp) - .Case([&](scf::ParallelOp parallelOp) { - allIvs.append(parallelOp.getInductionVars().begin(), - parallelOp.getInductionVars().end()); - }) - .Case([&](scf::ForOp forOp) { - allIvs.push_back(forOp.getInductionVar()); - }) - .Case([&](AffineForOp affineForOp) { - allIvs.push_back(affineForOp.getInductionVar()); - }) - .Default([&](Operation *op) { assert(false && "unexpected op"); }); - } - assert(linalgOp.getNumLoops() == allIvs.size() && - "expected the number of loops and induction variables to match"); - // Replace the index operations in the body of the innermost loop op. - if (!loopOps.empty()) { - LoopLikeOpInterface loopOp = loopOps.back(); - for (IndexOp indexOp : - llvm::make_early_inc_range(loopOp.getLoopBody().getOps())) - rewriter.replaceOp(indexOp, allIvs[indexOp.dim()]); - } -} - namespace { template class LinalgRewritePattern : public RewritePattern { @@ -467,7 +473,7 @@ auto linalgOp = dyn_cast(op); if (!isa(op)) return failure(); - if (!linalgLowerOpToLoops(rewriter, linalgOp)) + if (!linalgOpToLoopsImpl(rewriter, linalgOp)) return failure(); rewriter.eraseOp(op); return success(); @@ -614,52 +620,22 @@ return std::make_unique(); } -/// Emits a loop nest with the proper body for `linalgOp`. -template -Optional -mlir::linalg::linalgLowerOpToLoops(PatternRewriter &rewriter, - LinalgOp linalgOp) { - // Convert indexed_generic ops to generic ops before lowering them to loops. - if (isa(linalgOp)) - return llvm::None; - - Optional loopOps = - linalgOpToLoopsImpl(linalgOp.getOperation(), rewriter); - if (loopOps.hasValue()) - replaceIndexOpsByInductionVariables(linalgOp, rewriter, loopOps.getValue()); - return loopOps; -} - -template Optional -mlir::linalg::linalgLowerOpToLoops(PatternRewriter &rewriter, - LinalgOp linalgOp); -template Optional -mlir::linalg::linalgLowerOpToLoops(PatternRewriter &rewriter, - LinalgOp linalgOp); -template Optional -mlir::linalg::linalgLowerOpToLoops(PatternRewriter &rewriter, - LinalgOp linalgOp); - /// Emits a loop nest of `affine.for` with the proper body for `linalgOp`. -LogicalResult mlir::linalg::linalgOpToAffineLoops(PatternRewriter &rewriter, - LinalgOp linalgOp) { - Optional loops = - linalgLowerOpToLoops(rewriter, linalgOp); - return loops ? success() : failure(); +Optional +mlir::linalg::linalgOpToAffineLoops(PatternRewriter &rewriter, + LinalgOp linalgOp) { + return linalgOpToLoopsImpl(rewriter, linalgOp); } /// Emits a loop nest of `scf.for` with the proper body for `linalgOp`. -LogicalResult mlir::linalg::linalgOpToLoops(PatternRewriter &rewriter, - LinalgOp linalgOp) { - Optional loops = - linalgLowerOpToLoops(rewriter, linalgOp); - return loops ? success() : failure(); +Optional mlir::linalg::linalgOpToLoops(PatternRewriter &rewriter, + LinalgOp linalgOp) { + return linalgOpToLoopsImpl(rewriter, linalgOp); } /// Emits a loop nest of `scf.parallel` with the proper body for `linalgOp`. -LogicalResult mlir::linalg::linalgOpToParallelLoops(PatternRewriter &rewriter, - LinalgOp linalgOp) { - Optional loops = - linalgLowerOpToLoops(rewriter, linalgOp); - return loops ? success() : failure(); +Optional +mlir::linalg::linalgOpToParallelLoops(PatternRewriter &rewriter, + LinalgOp linalgOp) { + return linalgOpToLoopsImpl(rewriter, linalgOp); }