diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -28,8 +28,8 @@ let options = [ Option<"foldOneTripLoopsOnly", "fold-one-trip-loops-only", "bool", /*default=*/"false", - "Only folds the one-trip loops from Linalg ops on tensors " - "(for testing purposes only)"> + "Only folds the one-trip loops from Linalg ops on tensors " + "(for testing purposes only)"> ]; let dependentDialects = ["linalg::LinalgDialect"]; } @@ -52,12 +52,24 @@ let summary = "Lower the operations from the linalg dialect into affine " "loops"; let constructor = "mlir::createConvertLinalgToAffineLoopsPass()"; + let options = [ + ListOption<"interchangeVector", "interchange-vector", "unsigned", + "Permute the loops in the nest following the given " + "interchange vector", + "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated"> + ]; let dependentDialects = ["linalg::LinalgDialect", "AffineDialect"]; } def LinalgLowerToLoops : FunctionPass<"convert-linalg-to-loops"> { let summary = "Lower the operations from the linalg dialect into loops"; let constructor = "mlir::createConvertLinalgToLoopsPass()"; + let options = [ + ListOption<"interchangeVector", "interchange-vector", "unsigned", + "Permute the loops in the nest following the given " + "interchange vector", + "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated"> + ]; let dependentDialects = ["linalg::LinalgDialect", "scf::SCFDialect", "AffineDialect"]; } @@ -72,6 +84,12 @@ let summary = "Lower the operations from the linalg dialect into parallel " "loops"; let constructor = "mlir::createConvertLinalgToParallelLoopsPass()"; + let options = [ + ListOption<"interchangeVector", "interchange-vector", "unsigned", + "Permute the loops in the nest following the given " + "interchange vector", + "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated"> + ]; let dependentDialects = ["AffineDialect", "linalg::LinalgDialect", "scf::SCFDialect"]; } diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -267,16 +267,28 @@ /// Emits a loop nest of `LoopTy` with the proper body for `op`. template -Optional linalgLowerOpToLoops(OpBuilder &builder, Operation *op); - -/// Emits a loop nest of `scf.for` with the proper body for `op`. -LogicalResult linalgOpToLoops(OpBuilder &builder, Operation *op); - -/// Emits a loop nest of `scf.parallel` with the proper body for `op`. -LogicalResult linalgOpToParallelLoops(OpBuilder &builder, Operation *op); +Optional +linalgLowerOpToLoops(OpBuilder &builder, Operation *op, + ArrayRef interchangeVector = {}); + +/// Emits a loop nest of `scf.for` with the proper body for `op`. The generated +/// loop nest will follow the `interchangeVector`-permutated iterator order. If +/// `interchangeVector` is empty, then no permutation happens. +LogicalResult linalgOpToLoops(OpBuilder &builder, Operation *op, + ArrayRef interchangeVector = {}); + +/// Emits a loop nest of `scf.parallel` with the proper body for `op`. The +/// generated loop nest will follow the `interchangeVector`-permutated +// iterator order. If `interchangeVector` is empty, then no permutation happens. +LogicalResult +linalgOpToParallelLoops(OpBuilder &builder, Operation *op, + ArrayRef interchangeVector = {}); -/// Emits a loop nest of `affine.for` with the proper body for `op`. -LogicalResult linalgOpToAffineLoops(OpBuilder &builder, Operation *op); +/// Emits a loop nest of `affine.for` with the proper body for `op`. The +/// generated loop nest will follow the `interchangeVector`-permutated +// iterator order. If `interchangeVector` is empty, then no permutation happens. +LogicalResult linalgOpToAffineLoops(OpBuilder &builder, Operation *op, + ArrayRef interchangeVector = {}); //===----------------------------------------------------------------------===// // Preconditions that ensure the corresponding transformation succeeds and can @@ -587,13 +599,17 @@ AffineLoops = 2, ParallelLoops = 3 }; + template struct LinalgLoweringPattern : public RewritePattern { LinalgLoweringPattern(MLIRContext *context, LinalgLoweringType loweringType, LinalgMarker marker = LinalgMarker(), + ArrayRef interchangeVector = {}, PatternBenefit benefit = 1) : RewritePattern(OpTy::getOperationName(), {}, benefit, context), - marker(marker), loweringType(loweringType) {} + marker(marker), loweringType(loweringType), + interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} + // TODO: Move implementation to .cpp once named ops are auto-generated. LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { @@ -603,18 +619,24 @@ if (failed(marker.checkAndNotify(rewriter, linalgOp))) return failure(); - if (loweringType == LinalgLoweringType::LibraryCall) { + switch (loweringType) { + case LinalgLoweringType::LibraryCall: // TODO: Move lowering to library calls here. return failure(); - } else if (loweringType == LinalgLoweringType::Loops) { - if (failed(linalgOpToLoops(rewriter, op))) + case LinalgLoweringType::Loops: + if (failed(linalgOpToLoops(rewriter, op, interchangeVector))) return failure(); - } else if (loweringType == LinalgLoweringType::AffineLoops) { - if (failed(linalgOpToAffineLoops(rewriter, op))) + break; + case LinalgLoweringType::AffineLoops: + if (failed(linalgOpToAffineLoops(rewriter, op, interchangeVector))) return failure(); - } else if (failed(linalgOpToParallelLoops(rewriter, op))) { - return failure(); + break; + case LinalgLoweringType::ParallelLoops: + if (failed(linalgOpToParallelLoops(rewriter, op, interchangeVector))) + return failure(); + break; } + rewriter.eraseOp(op); return success(); } @@ -625,6 +647,8 @@ /// Controls whether the pattern lowers to library calls, scf.for, affine.for /// or scf.parallel. LinalgLoweringType loweringType; + /// Permutated loop order in the generated loop nest. + SmallVector interchangeVector; }; /// Linalg generalization patterns diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -23,7 +23,6 @@ #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" - #include "llvm/ADT/TypeSwitch.h" using namespace mlir; @@ -505,10 +504,10 @@ } template -static Optional linalgOpToLoopsImpl(Operation *op, - OpBuilder &builder) { +static Optional +linalgOpToLoopsImpl(Operation *op, OpBuilder &builder, + ArrayRef interchangeVector) { using IndexedValueTy = typename GenerateLoopNest::IndexedValueTy; - ScopedContext scope(builder, op->getLoc()); // The flattened loopToOperandRangesMaps is expected to be an invertible @@ -516,10 +515,20 @@ auto linalgOp = cast(op); assert(linalgOp.hasBufferSemantics() && "expected linalg op with buffer semantics"); + auto loopRanges = linalgOp.createLoopRanges(builder, op->getLoc()); + auto iteratorTypes = llvm::to_vector<4>(linalgOp.iterator_types().getValue()); + + if (!interchangeVector.empty()) { + assert(interchangeVector.size() == loopRanges.size()); + assert(interchangeVector.size() == iteratorTypes.size()); + applyPermutationToVector(loopRanges, interchangeVector); + applyPermutationToVector(iteratorTypes, interchangeVector); + } + SmallVector allIvs; GenerateLoopNest::doit( - loopRanges, /*iterInitArgs*/ {}, linalgOp.iterator_types().getValue(), + loopRanges, /*iterInitArgs=*/{}, iteratorTypes, [&](ValueRange ivs, ValueRange iterArgs) -> scf::ValueVector { assert(iterArgs.empty() && "unexpected iterArgs"); allIvs.append(ivs.begin(), ivs.end()); @@ -552,26 +561,33 @@ template class LinalgRewritePattern : public RewritePattern { public: - LinalgRewritePattern() : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {} + LinalgRewritePattern(ArrayRef interchangeVector) + : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()), + interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { if (!isa(op)) return failure(); - if (!linalgOpToLoopsImpl(op, rewriter)) + if (!linalgOpToLoopsImpl(op, rewriter, interchangeVector)) return failure(); rewriter.eraseOp(op); return success(); } + +private: + SmallVector interchangeVector; }; struct FoldAffineOp; } // namespace template -static void lowerLinalgToLoopsImpl(FuncOp funcOp, MLIRContext *context) { +static void lowerLinalgToLoopsImpl(FuncOp funcOp, + ArrayRef interchangeVector) { + MLIRContext *context = funcOp.getContext(); OwningRewritePatternList patterns; - patterns.insert>(); + patterns.insert>(interchangeVector); DimOp::getCanonicalizationPatterns(patterns, context); AffineApplyOp::getCanonicalizationPatterns(patterns, context); patterns.insert(context); @@ -620,20 +636,20 @@ struct LowerToAffineLoops : public LinalgLowerToAffineLoopsBase { void runOnFunction() override { - lowerLinalgToLoopsImpl(getFunction(), &getContext()); + lowerLinalgToLoopsImpl(getFunction(), interchangeVector); } }; struct LowerToLoops : public LinalgLowerToLoopsBase { void runOnFunction() override { - lowerLinalgToLoopsImpl(getFunction(), &getContext()); + lowerLinalgToLoopsImpl(getFunction(), interchangeVector); } }; struct LowerToParallelLoops : public LinalgLowerToParallelLoopsBase { void runOnFunction() override { - lowerLinalgToLoopsImpl(getFunction(), &getContext()); + lowerLinalgToLoopsImpl(getFunction(), interchangeVector); } }; } // namespace @@ -654,38 +670,43 @@ /// Emits a loop nest with the proper body for `op`. template -Optional mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder, - Operation *op) { - return linalgOpToLoopsImpl(op, builder); +Optional +mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder, Operation *op, + ArrayRef interchangeVector) { + return linalgOpToLoopsImpl(op, builder, interchangeVector); } +template Optional mlir::linalg::linalgLowerOpToLoops( + OpBuilder &builder, Operation *op, ArrayRef interchangeVector); +template Optional mlir::linalg::linalgLowerOpToLoops( + OpBuilder &builder, Operation *op, ArrayRef interchangeVector); template Optional -mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder, - Operation *op); -template Optional -mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder, - Operation *op); -template Optional -mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder, - Operation *op); +mlir::linalg::linalgLowerOpToLoops( + OpBuilder &builder, Operation *op, ArrayRef interchangeVector); /// Emits a loop nest of `affine.for` with the proper body for `op`. -LogicalResult mlir::linalg::linalgOpToAffineLoops(OpBuilder &builder, - Operation *op) { - Optional loops = linalgLowerOpToLoops(builder, op); +LogicalResult +mlir::linalg::linalgOpToAffineLoops(OpBuilder &builder, Operation *op, + ArrayRef interchangeVector) { + Optional loops = + linalgLowerOpToLoops(builder, op, interchangeVector); return loops ? success() : failure(); } /// Emits a loop nest of `scf.for` with the proper body for `op`. -LogicalResult mlir::linalg::linalgOpToLoops(OpBuilder &builder, Operation *op) { - Optional loops = linalgLowerOpToLoops(builder, op); +LogicalResult +mlir::linalg::linalgOpToLoops(OpBuilder &builder, Operation *op, + ArrayRef interchangeVector) { + Optional loops = + linalgLowerOpToLoops(builder, op, interchangeVector); return loops ? success() : failure(); } /// Emits a loop nest of `scf.parallel` with the proper body for `op`. -LogicalResult mlir::linalg::linalgOpToParallelLoops(OpBuilder &builder, - Operation *op) { +LogicalResult +mlir::linalg::linalgOpToParallelLoops(OpBuilder &builder, Operation *op, + ArrayRef interchangeVector) { Optional loops = - linalgLowerOpToLoops(builder, op); + linalgLowerOpToLoops(builder, op, interchangeVector); return loops ? success() : failure(); } diff --git a/mlir/test/Dialect/Linalg/loop-order.mlir b/mlir/test/Dialect/Linalg/loop-order.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/loop-order.mlir @@ -0,0 +1,24 @@ +// RUN: mlir-opt %s -convert-linalg-to-loops="interchange-vector=4,0,3,1,2" | FileCheck --check-prefix=LOOP %s +// RUN: mlir-opt %s -convert-linalg-to-parallel-loops="interchange-vector=4,0,3,1,2" | FileCheck --check-prefix=PARALLEL %s +// RUN: mlir-opt %s -convert-linalg-to-affine-loops="interchange-vector=4,0,3,1,2" | FileCheck --check-prefix=AFFINE %s + +func @copy(%input: memref<1x2x3x4x5xf32>, %output: memref<1x2x3x4x5xf32>) { + linalg.copy(%input, %output): memref<1x2x3x4x5xf32>, memref<1x2x3x4x5xf32> + return +} + +// LOOP: scf.for %{{.*}} = %c0 to %c5 step %c1 +// LOOP: scf.for %{{.*}} = %c0 to %c1 step %c1 +// LOOP: scf.for %{{.*}} = %c0 to %c4 step %c1 +// LOOP: scf.for %{{.*}} = %c0 to %c2 step %c1 +// LOOP: scf.for %{{.*}} = %c0 to %c3 step %c1 + +// PARALLEL: scf.parallel +// PARALLEL-SAME: to (%c5, %c1, %c4, %c2, %c3) + +// AFFINE: affine.for %{{.*}} = 0 to 5 +// AFFINE: affine.for %{{.*}} = 0 to 1 +// AFFINE: affine.for %{{.*}} = 0 to 4 +// AFFINE: affine.for %{{.*}} = 0 to 2 +// AFFINE: affine.for %{{.*}} = 0 to 3 +