diff --git a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h --- a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h +++ b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h @@ -77,8 +77,9 @@ template class GenericLoopNestRangeBuilder { public: GenericLoopNestRangeBuilder(MutableArrayRef ivs, - ArrayRef ranges); - void operator()(std::function fun = nullptr) { (*builder)(fun); } + ArrayRef ranges, + ArrayRef iteratorTypes); + void operator()(std::function fun = nullptr); private: using LoopOrAffineLoopBuilder = @@ -90,6 +91,8 @@ LoopOrAffineLoopBuilder>; std::unique_ptr builder; + + std::unique_ptr forLoopBuilder; }; inline void defaultRegionBuilder(ArrayRef args) {} diff --git a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp --- a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp +++ b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp @@ -80,13 +80,21 @@ template <> GenericLoopNestRangeBuilder::GenericLoopNestRangeBuilder( - MutableArrayRef ivs, ArrayRef ranges) { + MutableArrayRef ivs, ArrayRef ranges, + ArrayRef iteratorTypes) { builder = std::make_unique(ivs, ranges); } +template <> +void GenericLoopNestRangeBuilder::operator()( + std::function fun) { + (*builder)(fun); +} + template <> GenericLoopNestRangeBuilder::GenericLoopNestRangeBuilder( - MutableArrayRef ivs, ArrayRef ranges) { + MutableArrayRef ivs, ArrayRef ranges, + ArrayRef iteratorTypes) { SmallVector lbs; SmallVector ubs; SmallVector steps; @@ -101,19 +109,49 @@ builder = std::make_unique(ivs, lbs, ubs, steps); } +template <> +void GenericLoopNestRangeBuilder::operator()( + std::function fun) { + (*builder)(fun); +} + template <> GenericLoopNestRangeBuilder::GenericLoopNestRangeBuilder( - MutableArrayRef ivs, ArrayRef ranges) { + MutableArrayRef ivs, ArrayRef ranges, + ArrayRef iteratorTypes) { SmallVector lbs, ubs, steps; - for (Value range : ranges) { - assert(range.getType() && "expected linalg.range type"); - assert(range.getDefiningOp() && "need operations to extract range parts"); - RangeOp rangeOp = cast(range.getDefiningOp()); + size_t nOuterPar = iteratorTypes + .take_while([](Attribute attr) { + return attr.cast().getValue() == + getParallelIteratorTypeName(); + }) + .size(); + nOuterPar = std::min(nOuterPar, ivs.size()); + for (size_t i = 0; i != nOuterPar; ++i) { + assert(ranges[i].getType() && "expected linalg.range type"); + assert(ranges[i].getDefiningOp() && + "need operations to extract range parts"); + RangeOp rangeOp = cast(ranges[i].getDefiningOp()); lbs.emplace_back(rangeOp.min()); ubs.emplace_back(rangeOp.max()); steps.emplace_back(rangeOp.step()); } - builder = std::make_unique(ivs, lbs, ubs, steps); + if (nOuterPar) + builder = std::make_unique( + ivs.take_front(nOuterPar), lbs, ubs, steps); + if (nOuterPar != ivs.size()) + forLoopBuilder = std::make_unique( + ivs.drop_front(nOuterPar), ranges.drop_front(nOuterPar)); +} + +template <> +void GenericLoopNestRangeBuilder::operator()( + std::function fun) { + auto forLoopFn = forLoopBuilder ? [&]() { (*forLoopBuilder)(fun); } : fun; + if (builder) + (*builder)(forLoopFn); + else if (forLoopBuilder) + (*forLoopBuilder)(forLoopFn); } } // namespace edsc diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -500,7 +500,8 @@ AffineIndexedValue, StdIndexedValue>::type; static void doit(ConcreteOpTy linalgOp, ArrayRef loopRanges, MutableArrayRef allIvs) { - GenericLoopNestRangeBuilder(allIvs, loopRanges)([&] { + GenericLoopNestRangeBuilder( + allIvs, loopRanges, linalgOp.iterator_types().getValue())([&] { SmallVector allIvValues(allIvs.begin(), allIvs.end()); LinalgScopedEmitter::emitScalarImplementation(allIvValues, @@ -508,54 +509,6 @@ }); } }; - -/// Generates loop nest using scf.parallel. scf.parallel is only used for the -/// outer parallel loops. All other loops are generated using scf.for -/// operation. -template -class GenerateLoopNest { -public: - using IndexedValueTy = StdIndexedValue; - - static void doit(ConcreteOpTy linalgOp, ArrayRef loopRanges, - MutableArrayRef allIvs) { - // Only generate scf.parallel for outer consecutive "parallel" - // iterator_types. - // TODO(ravishankarm): Generate scf.parallel for all "parallel" iterator - // types, not just the outer most ones. Also handle "reduction" iterator - // types. - auto nOuterPar = linalgOp.iterator_types() - .getValue() - .take_while([](Attribute attr) { - return attr.cast().getValue() == - getParallelIteratorTypeName(); - }) - .size(); - // If there are no outer parallel loops, then number of loop ops is same as - // the number of loops, and they are all scf.for ops. - if (nOuterPar) { - GenericLoopNestRangeBuilder( - allIvs.take_front(nOuterPar), loopRanges.take_front(nOuterPar))([&] { - GenericLoopNestRangeBuilder( - allIvs.drop_front(nOuterPar), - loopRanges.drop_front(nOuterPar))([&] { - SmallVector allIvValues(allIvs.begin(), allIvs.end()); - LinalgScopedEmitter:: - emitScalarImplementation(allIvValues, linalgOp); - }); - }); - } else { - // If there are no parallel loops then fallback to generating all scf.for - // operations. - GenericLoopNestRangeBuilder(allIvs, loopRanges)([&] { - SmallVector allIvValues(allIvs.begin(), allIvs.end()); - LinalgScopedEmitter::emitScalarImplementation(allIvValues, - linalgOp); - }); - } - } -}; } // namespace template diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -383,7 +383,8 @@ linalgRanges.push_back( linalg_range(range.offset, range.size, range.stride)); } - GenericLoopNestRangeBuilder(ivs, linalgRanges)([&] { + GenericLoopNestRangeBuilder(ivs, linalgRanges, + op.iterator_types().getValue())([&] { auto &b = ScopedContext::getBuilderRef(); auto loc = ScopedContext::getLocation(); SmallVector ivValues(ivs.begin(), ivs.end()); diff --git a/mlir/test/Dialect/Linalg/tile_parallel_reduce.mlir b/mlir/test/Dialect/Linalg/tile_parallel_reduce.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/tile_parallel_reduce.mlir @@ -0,0 +1,22 @@ +// RUN: mlir-opt %s -linalg-tile-to-parallel-loops="linalg-tile-sizes=2,4,8" -split-input-file | FileCheck %s + +func @gemm(%arg0 : memref, + %arg1 : memref, + %arg2 : memref) +{ + linalg.matmul(%arg0, %arg1, %arg2) + : memref, memref, memref + return +} +// CHECK-LABEL: func @gemm +// CHECK-DAG: %[[C2:.*]] = constant 2 : index +// CHECK-DAG: %[[C4:.*]] = constant 4 : index +// CHECK-DAG: %[[C8:.*]] = constant 8 : index +// CHECK: scf.parallel (%[[ARG3:.*]], %[[ARG4:.*]]) = +// CHECK-SAME: step (%[[C2]], %[[C4]]) +// CHECK: scf.for %[[ARG5:.*]] = +// CHECK-SAME: step %[[C8]] +// CHECK: %[[SV1:.*]] = subview %{{.*}}[%[[ARG3]], %[[ARG5]]] +// CHECK: %[[SV2:.*]] = subview %{{.*}}[%[[ARG5]], %[[ARG4]]] +// CHECK: %[[SV3:.*]] = subview %{{.*}}[%[[ARG3]], %[[ARG4]]] +// CHECK: linalg.matmul(%[[SV1]], %[[SV2]], %[[SV3]]) \ No newline at end of file diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir --- a/mlir/test/Dialect/Linalg/transform-patterns.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir @@ -44,7 +44,8 @@ // CHECK-DAG: %[[c0:.*]] = constant 0 : index // CHECK-DAG: %[[c5:.*]] = constant 5 : index // CHECK-DAG: %[[c6:.*]] = constant 6 : index -// CHECK: scf.parallel {{.*}} step (%[[c5]], %[[c6]]) +// CHECK: scf.parallel {{.*}} step (%[[c5]]) +// CHECK: scf.for {{.*}} step %[[c6]] // CHECK: linalg.matvec({{.*}}, {{.*}}, {{.*}}) : memref, memref, memref func @matmul(%A: memref,