diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -9,14 +9,21 @@ #ifndef MLIR_DIALECT_LINALG_UTILS_H_ #define MLIR_DIALECT_LINALG_UTILS_H_ +#include "mlir/Dialect/Affine/EDSC/Intrinsics.h" +#include "mlir/Dialect/Linalg/EDSC/Builders.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "llvm/ADT/SetVector.h" +using mlir::edsc::intrinsics::AffineIndexedValue; +using mlir::edsc::intrinsics::StdIndexedValue; + namespace mlir { class AffineExpr; +class AffineForOp; class AffineMap; class OperationFolder; class PatternRewriter; @@ -141,6 +148,21 @@ inVec = auxVec; } +/// Utility class used to generate nested loops with ranges described by +/// `loopRanges` and loop type described by the `iteratorTypes`. `allIvs` is +/// populated with induction variables for all generated loops on return, with +/// `fun` used to generate the body of the innermost loop. +template +struct GenerateLoopNest { + using IndexedValueTy = + typename std::conditional::value, + AffineIndexedValue, StdIndexedValue>::type; + static void doit(MutableArrayRef allIvs, + ArrayRef loopRanges, + ArrayRef iteratorTypes, + std::function fun); +}; + } // namespace linalg } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -487,80 +487,9 @@ } }; -namespace { -/// Helper struct to generate the loop nest for the op. This factored out here -/// to be able to partially specialize this for different LoopTy. -template -class GenerateLoopNest { -public: - using IndexedValueTy = - typename std::conditional::value, - AffineIndexedValue, StdIndexedValue>::type; - static void doit(ConcreteOpTy linalgOp, ArrayRef loopRanges, - MutableArrayRef allIvs) { - GenericLoopNestRangeBuilder(allIvs, loopRanges)([&] { - SmallVector allIvValues(allIvs.begin(), allIvs.end()); - LinalgScopedEmitter::emitScalarImplementation(allIvValues, - linalgOp); - }); - } -}; - -/// Generates loop nest using scf.parallel. scf.parallel is only used for the -/// outer parallel loops. All other loops are generated using scf.for -/// operation. -template -class GenerateLoopNest { -public: - using IndexedValueTy = StdIndexedValue; - - static void doit(ConcreteOpTy linalgOp, ArrayRef loopRanges, - MutableArrayRef allIvs) { - // Only generate scf.parallel for outer consecutive "parallel" - // iterator_types. - // TODO(ravishankarm): Generate scf.parallel for all "parallel" iterator - // types, not just the outer most ones. Also handle "reduction" iterator - // types. - auto nOuterPar = linalgOp.iterator_types() - .getValue() - .take_while([](Attribute attr) { - return attr.cast().getValue() == - getParallelIteratorTypeName(); - }) - .size(); - // If there are no outer parallel loops, then number of loop ops is same as - // the number of loops, and they are all scf.for ops. - if (nOuterPar) { - GenericLoopNestRangeBuilder( - allIvs.take_front(nOuterPar), loopRanges.take_front(nOuterPar))([&] { - GenericLoopNestRangeBuilder( - allIvs.drop_front(nOuterPar), - loopRanges.drop_front(nOuterPar))([&] { - SmallVector allIvValues(allIvs.begin(), allIvs.end()); - LinalgScopedEmitter:: - emitScalarImplementation(allIvValues, linalgOp); - }); - }); - } else { - // If there are no parallel loops then fallback to generating all scf.for - // operations. - GenericLoopNestRangeBuilder(allIvs, loopRanges)([&] { - SmallVector allIvValues(allIvs.begin(), allIvs.end()); - LinalgScopedEmitter::emitScalarImplementation(allIvValues, - linalgOp); - }); - } - } -}; -} // namespace - template Optional linalgOpToLoopsImpl(Operation *op, OpBuilder &builder) { - using Impl = GenerateLoopNest; - using IndexedValueTy = - typename GenerateLoopNest::IndexedValueTy; + using IndexedValueTy = typename GenerateLoopNest::IndexedValueTy; ScopedContext scope(builder, op->getLoc()); @@ -591,7 +520,13 @@ emitLoopRanges(scope.getBuilderRef(), scope.getLocation(), invertedMap, getViewSizes(builder, linalgOp)); assert(loopRanges.size() == allIvs.size()); - Impl::doit(linalgOp, loopRanges, allIvs); + GenerateLoopNest::doit( + allIvs, loopRanges, linalgOp.iterator_types().getValue(), [&] { + SmallVector allIvValues(allIvs.begin(), allIvs.end()); + LinalgScopedEmitter::emitScalarImplementation(allIvValues, + linalgOp); + }); // Number of loop ops might be different from the number of ivs since some // loops like affine.parallel and scf.parallel have multiple ivs. llvm::SetVector loopSet; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -376,24 +376,25 @@ // 3. Create the tiled loops. LinalgOp res = op; SmallVector ivs(loopRanges.size()); - GenericLoopNestRangeBuilder(ivs, loopRanges)([&] { - auto &b = ScopedContext::getBuilderRef(); - auto loc = ScopedContext::getLocation(); - SmallVector ivValues(ivs.begin(), ivs.end()); - - // If we have to apply a permutation to the tiled loop nest, we have to - // reorder the induction variables This permutation is the right one - // assuming that loopRanges have previously been permuted by - // (i,j,k)->(k,i,j) So this permutation should be the inversePermutation of - // that one: (d0,d1,d2)->(d2,d0,d1) - if (!options.interchangeVector.empty()) - ivValues = applyMapToValues(b, loc, invPermutationMap, ivValues); - - auto views = makeTiledViews(b, loc, op, ivValues, tileSizes, viewSizes); - auto operands = getAssumedNonViewOperands(op); - views.append(operands.begin(), operands.end()); - res = op.clone(b, loc, views); - }); + GenerateLoopNest::doit( + ivs, loopRanges, op.iterator_types().getValue(), [&] { + auto &b = ScopedContext::getBuilderRef(); + auto loc = ScopedContext::getLocation(); + SmallVector ivValues(ivs.begin(), ivs.end()); + + // If we have to apply a permutation to the tiled loop nest, we have to + // reorder the induction variables This permutation is the right one + // assuming that loopRanges have previously been permuted by + // (i,j,k)->(k,i,j) So this permutation should be the inversePermutation + // of that one: (d0,d1,d2)->(d2,d0,d1) + if (!options.interchangeVector.empty()) + ivValues = applyMapToValues(b, loc, invPermutationMap, ivValues); + + auto views = makeTiledViews(b, loc, op, ivValues, tileSizes, viewSizes); + auto operands = getAssumedNonViewOperands(op); + views.append(operands.begin(), operands.end()); + res = op.clone(b, loc, views); + }); // 4. Transforms index arguments of `linalg.generic` w.r.t. to the tiling. transformIndexedGenericOpIndices(b, res, ivs, loopIndexToRangeIndex); diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" +#include "mlir/Dialect/SCF/EDSC/Builders.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/AffineExpr.h" @@ -101,3 +102,63 @@ } return res; } + +/// Explicit instantiation of loop nest generator for different loop types. +template struct mlir::linalg::GenerateLoopNest; +template struct mlir::linalg::GenerateLoopNest; +template struct mlir::linalg::GenerateLoopNest; + +/// Specialization of loop nest generator for scf.parallel loops to handle +/// iterator types that are not parallel. These are generated as sequential +/// loops. +template <> +void mlir::linalg::GenerateLoopNest::doit( + MutableArrayRef allIvs, ArrayRef loopRanges, + ArrayRef iteratorTypes, std::function fun) { + edsc::GenericLoopNestRangeBuilder(allIvs, loopRanges)(fun); +} + +template <> +void mlir::linalg::GenerateLoopNest::doit( + MutableArrayRef allIvs, ArrayRef loopRanges, + ArrayRef iteratorTypes, std::function fun) { + edsc::GenericLoopNestRangeBuilder(allIvs, loopRanges)(fun); +} + +template <> +void mlir::linalg::GenerateLoopNest::doit( + MutableArrayRef allIvs, ArrayRef loopRanges, + ArrayRef iteratorTypes, std::function fun) { + if (loopRanges.empty()) + return; + size_t nOuterPar = iteratorTypes.take_front(loopRanges.size()) + .take_while([](Attribute attr) { + return attr.cast().getValue() == + getParallelIteratorTypeName(); + }) + .size(); + if (nOuterPar == 0 && loopRanges.size() == 1) + return GenerateLoopNest::doit(allIvs, loopRanges, iteratorTypes, + fun); + if (nOuterPar == 0) { + auto nestedFn = [&]() { + GenerateLoopNest::doit(allIvs.drop_front(), + loopRanges.drop_front(), + iteratorTypes.drop_front(), fun); + }; + return GenerateLoopNest::doit(allIvs[0], loopRanges[0], + iteratorTypes[0], nestedFn); + } + if (nOuterPar == loopRanges.size()) { + return edsc::GenericLoopNestRangeBuilder(allIvs, + loopRanges)(fun); + } + auto nestedFn = [&]() { + GenerateLoopNest::doit( + allIvs.drop_front(nOuterPar), loopRanges.drop_front(nOuterPar), + iteratorTypes.drop_front(nOuterPar), fun); + }; + return GenerateLoopNest::doit( + allIvs.take_front(nOuterPar), loopRanges.take_front(nOuterPar), + iteratorTypes.take_front(nOuterPar), nestedFn); +} diff --git a/mlir/test/Dialect/Linalg/parallel_loops.mlir b/mlir/test/Dialect/Linalg/parallel_loops.mlir --- a/mlir/test/Dialect/Linalg/parallel_loops.mlir +++ b/mlir/test/Dialect/Linalg/parallel_loops.mlir @@ -57,6 +57,42 @@ // CHECK-DAG: %[[D3:.*]] = dim %{{.*}}, 3 // CHECK: scf.parallel (%[[IV0:.*]], %[[IV1:.*]]) = (%[[C0]], %[[C0]]) to (%[[D0]], %[[D1]]) step (%[[C1]], %[[C1]]) // CHECK: scf.for %[[IV2:.*]] = %[[C0]] to %[[D2]] step %[[C1]] -// CHECK: scf.for %[[IV3:.*]] = %[[C0]] to %[[D3]] step %[[C1]] +// CHECK: scf.parallel (%[[IV3:.*]]) = (%[[C0]]) to (%[[D3]]) step (%[[C1]]) // CHECK: load %{{.*}}[%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]] // CHECK: store %{{.*}}, %{{.*}}[%[[IV0]], %[[IV1]], %[[IV3]]] + +// ----- + +#accesses = [ + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d4, d5)> +] +#trait = { + args_in = 1, + args_out = 1, + iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], + indexing_maps = #accesses +} + +func @lower_mixed_parallel(%A: memref, %B: memref) { + linalg.generic #trait %A, %B { + ^bb0(%a: f32, %b: f32): + linalg.yield %a: f32 + } : memref, memref + return +} +// CHECK-LABEL: @lower_mixed_parallel +// CHECK-DAG: %[[C0:.*]] = constant 0 +// CHECK-DAG: %[[C1:.*]] = constant 1 +// CHECK-DAG: %[[D0:.*]] = dim %{{.*}}, 0 +// CHECK-DAG: %[[D1:.*]] = dim %{{.*}}, 1 +// CHECK-DAG: %[[D2:.*]] = dim %{{.*}}, 2 +// CHECK-DAG: %[[D3:.*]] = dim %{{.*}}, 3 +// CHECK-DAG: %[[D4:.*]] = dim %{{.*}}, 4 +// CHECK-DAG: %[[D5:.*]] = dim %{{.*}}, 5 +// CHECK: scf.parallel (%[[IV0:.*]], %[[IV1:.*]]) = (%[[C0]], %[[C0]]) to (%[[D0]], %[[D1]]) step (%[[C1]], %[[C1]]) +// CHECK: scf.for %[[IV2:.*]] = %[[C0]] to %[[D2]] step %[[C1]] +// CHECK: scf.parallel (%[[IV3:.*]], %[[IV4:.*]]) = (%[[C0]], %[[C0]]) to (%[[D3]], %[[D4]]) step (%[[C1]], %[[C1]]) +// CHECK: scf.for %[[IV5:.*]] = %[[C0]] to %[[D5]] step %[[C1]] +// CHECK: load %{{.*}}[%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]], %[[IV5]]] +// CHECK: store %{{.*}}, %{{.*}}[%[[IV0]], %[[IV2]], %[[IV4]], %[[IV5]]] diff --git a/mlir/test/Dialect/Linalg/tile_parallel_reduce.mlir b/mlir/test/Dialect/Linalg/tile_parallel_reduce.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/tile_parallel_reduce.mlir @@ -0,0 +1,108 @@ +// RUN: mlir-opt %s -linalg-tile-to-parallel-loops="linalg-tile-sizes=2,4,8" -split-input-file | FileCheck %s +// RUN: mlir-opt %s -linalg-tile-to-parallel-loops="linalg-tile-sizes=2" -split-input-file | FileCheck %s -check-prefix=TILE1 +// RUN: mlir-opt %s -linalg-tile-to-parallel-loops="linalg-tile-sizes=2,4" -split-input-file | FileCheck %s -check-prefix=TILE2 + +func @gemm(%arg0 : memref, + %arg1 : memref, + %arg2 : memref) +{ + linalg.matmul(%arg0, %arg1, %arg2) + : memref, memref, memref + return +} +// CHECK-LABEL: func @gemm +// CHECK-DAG: %[[C2:.*]] = constant 2 : index +// CHECK-DAG: %[[C4:.*]] = constant 4 : index +// CHECK-DAG: %[[C8:.*]] = constant 8 : index +// CHECK: scf.parallel (%[[ARG3:.*]], %[[ARG4:.*]]) = +// CHECK-SAME: step (%[[C2]], %[[C4]]) +// CHECK: scf.for %[[ARG5:.*]] = +// CHECK-SAME: step %[[C8]] +// CHECK: %[[SV1:.*]] = subview %{{.*}}[%[[ARG3]], %[[ARG5]]] +// CHECK: %[[SV2:.*]] = subview %{{.*}}[%[[ARG5]], %[[ARG4]]] +// CHECK: %[[SV3:.*]] = subview %{{.*}}[%[[ARG3]], %[[ARG4]]] +// CHECK: linalg.matmul(%[[SV1]], %[[SV2]], %[[SV3]]) + +// TILE1-LABEL: func @gemm +// TILE1-DAG: %[[C2:.*]] = constant 2 : index +// TILE1: scf.parallel (%[[ARG3:.*]]) = +// TILE1-SAME: step (%[[C2]]) +// TILE1: %[[SV1:.*]] = subview %{{.*}}[%[[ARG3]], 0] +// TILE1: %[[SV3:.*]] = subview %{{.*}}[%[[ARG3]], 0] +// TILE1-NOT: subview +// TILE1: linalg.matmul(%[[SV1]], %{{.*}}, %[[SV3]]) + +// TILE2-LABEL: func @gemm +// TILE2-DAG: %[[C2:.*]] = constant 2 : index +// TILE2-DAG: %[[C4:.*]] = constant 4 : index +// TILE2: scf.parallel (%[[ARG3:.*]], %[[ARG4:.*]]) = +// TILE2-SAME: step (%[[C2]], %[[C4]]) +// TILE2: %[[SV1:.*]] = subview %{{.*}}[%[[ARG3]], 0] +// TILE2: %[[SV2:.*]] = subview %{{.*}}[0, %[[ARG4]]] +// TILE2: %[[SV3:.*]] = subview %{{.*}}[%[[ARG3]], %[[ARG4]]] +// TILE2: linalg.matmul(%[[SV1]], %[[SV2]], %[[SV3]]) + +// ----- + +#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1)> +#accesses = [#map0, #map1, #map2] +#trait = { + args_in = 2 : i64, + args_out = 1 : i64, + iterator_types = ["reduction", "parallel", "reduction"], + indexing_maps = #accesses +} + +func @reduction(%arg0 : memref, + %arg1 : memref, + %arg2 : memref) +{ + linalg.generic #trait %arg0, %arg1, %arg2 { + ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32): + %0 = addf %arg3, %arg4 : f32 + %1 = addf %0, %arg5 : f32 + linalg.yield %1 : f32 + } : memref, memref, memref + return +} + +// CHECK-LABEL: func @reduction +// CHECK-DAG: %[[C2:.*]] = constant 2 : index +// CHECK-DAG: %[[C4:.*]] = constant 4 : index +// CHECK-DAG: %[[C8:.*]] = constant 8 : index +// CHECK: scf.for %[[ARG3:.*]] = +// CHECK-SAME: step %[[C2]] +// CHECK: scf.parallel (%[[ARG4:.*]]) = +// CHECK-SAME: step (%[[C4]]) +// CHECK: scf.for %[[ARG5:.*]] = +// CHECK-SAME: step %[[C8]] +// CHECK: %[[SV1:.*]] = subview %{{.*}}[%[[ARG3]], %[[ARG4]], %[[ARG5]]] +// CHECK: %[[SV2:.*]] = subview %{{.*}}[%[[ARG3]], %[[ARG5]]] +// CHECK: %[[SV3:.*]] = subview %{{.*}}[%[[ARG4]]] +// CHECK: linalg.generic +// CHECK-SAME: %[[SV1]], %[[SV2]], %[[SV3]] + +// TILE1-LABEL: func @reduction +// TILE1-DAG: %[[C2:.*]] = constant 2 : index +// TILE1: scf.for %[[ARG3:.*]] = +// TILE1-SAME: step %[[C2]] +// TILE1: %[[SV1:.*]] = subview %{{.*}}[%[[ARG3]], 0, 0] +// TILE1: %[[SV2:.*]] = subview %{{.*}}[%[[ARG3]], 0] +// TILE1-NOT: subview +// TILE1: linalg.generic +// TILE1-SAME: %[[SV1]], %[[SV2]], %{{.*}} + +// TILE2-LABEL: func @reduction +// TILE2-DAG: %[[C2:.*]] = constant 2 : index +// TILE2-DAG: %[[C4:.*]] = constant 4 : index +// TILE2: scf.for %[[ARG3:.*]] = +// TILE2-SAME: step %[[C2]] +// TILE2: scf.parallel (%[[ARG4:.*]]) = +// TILE2-SAME: step (%[[C4]]) +// TILE2: %[[SV1:.*]] = subview %{{.*}}[%[[ARG3]], %[[ARG4]], 0] +// TILE2: %[[SV2:.*]] = subview %{{.*}}[%[[ARG3]], 0] +// TILE2: %[[SV3:.*]] = subview %{{.*}}[%[[ARG4]]] +// TILE2: linalg.generic +// TILE2-SAME: %[[SV1]], %[[SV2]], %[[SV3]] diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir --- a/mlir/test/Dialect/Linalg/transform-patterns.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir @@ -44,7 +44,8 @@ // CHECK-DAG: %[[c0:.*]] = constant 0 : index // CHECK-DAG: %[[c5:.*]] = constant 5 : index // CHECK-DAG: %[[c6:.*]] = constant 6 : index -// CHECK: scf.parallel {{.*}} step (%[[c5]], %[[c6]]) +// CHECK: scf.parallel {{.*}} step (%[[c5]]) +// CHECK: scf.for {{.*}} step %[[c6]] // CHECK: linalg.matvec({{.*}}, {{.*}}, {{.*}}) : memref, memref, memref func @matmul(%A: memref,