diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -9,14 +9,21 @@ #ifndef MLIR_DIALECT_LINALG_UTILS_H_ #define MLIR_DIALECT_LINALG_UTILS_H_ +#include "mlir/Dialect/Affine/EDSC/Intrinsics.h" +#include "mlir/Dialect/Linalg/EDSC/Builders.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "llvm/ADT/SetVector.h" +using mlir::edsc::intrinsics::AffineIndexedValue; +using mlir::edsc::intrinsics::StdIndexedValue; + namespace mlir { class AffineExpr; +class AffineForOp; class AffineMap; class OperationFolder; class PatternRewriter; @@ -49,6 +56,15 @@ static Optional matchAsScalarBinaryOp(GenericOp op); }; +/// Checks if an iterator_type attribute is parallel. +bool isParallelIteratorType(Attribute attr); + +/// Checks if an iterator_type attribute is parallel. +bool isReductionIteratorType(Attribute attr); + +/// Checks if an iterator_type attribute is parallel. +bool isWindowIteratorType(Attribute attr); + /// Checks whether the specific `producer` is the last write to exactly the /// whole `consumedView`. This checks structural dominance, that the dependence /// is a RAW without any interleaved write to any piece of `consumedView`. @@ -141,6 +157,21 @@ inVec = auxVec; } +/// Utility class used to generate nested loops with ranges described by +/// `loopRanges` and loop type described by the `iteratorTypes`. `allIvs` is +/// populated with induction variables for all generated loops on return, with +/// `fun` used to generate the body of the innermost loop. +template +struct GenerateLoopNest { + using IndexedValueTy = + typename std::conditional::value, + AffineIndexedValue, StdIndexedValue>::type; + static void doit(MutableArrayRef allIvs, + ArrayRef loopRanges, + ArrayRef iteratorTypes, + std::function fun); +}; + } // namespace linalg } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -487,80 +487,9 @@ } }; -namespace { -/// Helper struct to generate the loop nest for the op. This factored out here -/// to be able to partially specialize this for different LoopTy. -template -class GenerateLoopNest { -public: - using IndexedValueTy = - typename std::conditional::value, - AffineIndexedValue, StdIndexedValue>::type; - static void doit(ConcreteOpTy linalgOp, ArrayRef loopRanges, - MutableArrayRef allIvs) { - GenericLoopNestRangeBuilder(allIvs, loopRanges)([&] { - SmallVector allIvValues(allIvs.begin(), allIvs.end()); - LinalgScopedEmitter::emitScalarImplementation(allIvValues, - linalgOp); - }); - } -}; - -/// Generates loop nest using scf.parallel. scf.parallel is only used for the -/// outer parallel loops. All other loops are generated using scf.for -/// operation. -template -class GenerateLoopNest { -public: - using IndexedValueTy = StdIndexedValue; - - static void doit(ConcreteOpTy linalgOp, ArrayRef loopRanges, - MutableArrayRef allIvs) { - // Only generate scf.parallel for outer consecutive "parallel" - // iterator_types. - // TODO(ravishankarm): Generate scf.parallel for all "parallel" iterator - // types, not just the outer most ones. Also handle "reduction" iterator - // types. - auto nOuterPar = linalgOp.iterator_types() - .getValue() - .take_while([](Attribute attr) { - return attr.cast().getValue() == - getParallelIteratorTypeName(); - }) - .size(); - // If there are no outer parallel loops, then number of loop ops is same as - // the number of loops, and they are all scf.for ops. - if (nOuterPar) { - GenericLoopNestRangeBuilder( - allIvs.take_front(nOuterPar), loopRanges.take_front(nOuterPar))([&] { - GenericLoopNestRangeBuilder( - allIvs.drop_front(nOuterPar), - loopRanges.drop_front(nOuterPar))([&] { - SmallVector allIvValues(allIvs.begin(), allIvs.end()); - LinalgScopedEmitter:: - emitScalarImplementation(allIvValues, linalgOp); - }); - }); - } else { - // If there are no parallel loops then fallback to generating all scf.for - // operations. - GenericLoopNestRangeBuilder(allIvs, loopRanges)([&] { - SmallVector allIvValues(allIvs.begin(), allIvs.end()); - LinalgScopedEmitter::emitScalarImplementation(allIvValues, - linalgOp); - }); - } - } -}; -} // namespace - template Optional linalgOpToLoopsImpl(Operation *op, OpBuilder &builder) { - using Impl = GenerateLoopNest; - using IndexedValueTy = - typename GenerateLoopNest::IndexedValueTy; + using IndexedValueTy = typename GenerateLoopNest::IndexedValueTy; ScopedContext scope(builder, op->getLoc()); @@ -591,7 +520,13 @@ emitLoopRanges(scope.getBuilderRef(), scope.getLocation(), invertedMap, getViewSizes(builder, linalgOp)); assert(loopRanges.size() == allIvs.size()); - Impl::doit(linalgOp, loopRanges, allIvs); + GenerateLoopNest::doit( + allIvs, loopRanges, linalgOp.iterator_types().getValue(), [&] { + SmallVector allIvValues(allIvs.begin(), allIvs.end()); + LinalgScopedEmitter::emitScalarImplementation(allIvValues, + linalgOp); + }); // Number of loop ops might be different from the number of ivs since some // loops like affine.parallel and scf.parallel have multiple ivs. llvm::SetVector loopSet; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -376,7 +376,11 @@ // 3. Create the tiled loops. LinalgOp res = op; SmallVector ivs(loopRanges.size()); - GenericLoopNestRangeBuilder(ivs, loopRanges)([&] { + SmallVector iteratorTypes = + llvm::to_vector<4>(op.iterator_types().cast().getValue()); + if (!options.interchangeVector.empty()) + applyPermutationToVector(iteratorTypes, options.interchangeVector); + GenerateLoopNest::doit(ivs, loopRanges, iteratorTypes, [&] { auto &b = ScopedContext::getBuilderRef(); auto loc = ScopedContext::getLocation(); SmallVector ivValues(ivs.begin(), ivs.end()); @@ -384,8 +388,8 @@ // If we have to apply a permutation to the tiled loop nest, we have to // reorder the induction variables This permutation is the right one // assuming that loopRanges have previously been permuted by - // (i,j,k)->(k,i,j) So this permutation should be the inversePermutation of - // that one: (d0,d1,d2)->(d2,d0,d1) + // (i,j,k)->(k,i,j) So this permutation should be the inversePermutation + // of that one: (d0,d1,d2)->(d2,d0,d1) if (!options.interchangeVector.empty()) ivValues = applyMapToValues(b, loc, invPermutationMap, ivValues); diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" +#include "mlir/Dialect/SCF/EDSC/Builders.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/AffineExpr.h" @@ -101,3 +102,91 @@ } return res; } + +bool mlir::linalg::isParallelIteratorType(Attribute attr) { + if (auto strAttr = attr.dyn_cast()) { + return strAttr.getValue() == getParallelIteratorTypeName(); + } + return false; +} + +bool mlir::linalg::isReductionIteratorType(Attribute attr) { + if (auto strAttr = attr.dyn_cast()) { + return strAttr.getValue() == getReductionIteratorTypeName(); + } + return false; +} + +bool mlir::linalg::isWindowIteratorType(Attribute attr) { + if (auto strAttr = attr.dyn_cast()) { + return strAttr.getValue() == getWindowIteratorTypeName(); + } + return false; +} + +/// Explicit instantiation of loop nest generator for different loop types. +template struct mlir::linalg::GenerateLoopNest; +template struct mlir::linalg::GenerateLoopNest; +template struct mlir::linalg::GenerateLoopNest; + +/// Specialization of loop nest generator for scf.parallel loops to handle +/// iterator types that are not parallel. These are generated as sequential +/// loops. +template <> +void mlir::linalg::GenerateLoopNest::doit( + MutableArrayRef allIvs, ArrayRef loopRanges, + ArrayRef iteratorTypes, std::function fun) { + edsc::GenericLoopNestRangeBuilder(allIvs, loopRanges)(fun); +} + +template <> +void mlir::linalg::GenerateLoopNest::doit( + MutableArrayRef allIvs, ArrayRef loopRanges, + ArrayRef iteratorTypes, std::function fun) { + edsc::GenericLoopNestRangeBuilder(allIvs, loopRanges)(fun); +} + +template <> +void mlir::linalg::GenerateLoopNest::doit( + MutableArrayRef allIvs, ArrayRef loopRanges, + ArrayRef iteratorTypes, std::function fun) { + // Check if there is nothing to do here. This is also the recursion + // termination. + if (loopRanges.empty()) + return; + size_t nOuterPar = iteratorTypes.take_front(loopRanges.size()) + .take_while(isParallelIteratorType) + .size(); + if (nOuterPar == 0 && loopRanges.size() == 1) + // Generate the sequential for loop for the remaining non-parallel loop. + return GenerateLoopNest::doit(allIvs, loopRanges, iteratorTypes, + fun); + if (nOuterPar == 0) { + // The immediate outer loop is not parallel. Generate a scf.for op for this + // loop, but there might be subsequent loops that are parallel. Use + // recursion to find those. + auto nestedFn = [&]() { + GenerateLoopNest::doit(allIvs.drop_front(), + loopRanges.drop_front(), + iteratorTypes.drop_front(), fun); + }; + return GenerateLoopNest::doit(allIvs[0], loopRanges[0], + iteratorTypes[0], nestedFn); + } + if (nOuterPar == loopRanges.size()) { + // All loops are parallel, so generate the scf.parallel op. + return edsc::GenericLoopNestRangeBuilder(allIvs, + loopRanges)(fun); + } + // Generate scf.parallel for the outer parallel loops. The next inner loop is + // sequential, but there might be more parallel loops after that. So recurse + // into the same method. + auto nestedFn = [&]() { + GenerateLoopNest::doit( + allIvs.drop_front(nOuterPar), loopRanges.drop_front(nOuterPar), + iteratorTypes.drop_front(nOuterPar), fun); + }; + return GenerateLoopNest::doit( + allIvs.take_front(nOuterPar), loopRanges.take_front(nOuterPar), + iteratorTypes.take_front(nOuterPar), nestedFn); +} diff --git a/mlir/test/Dialect/Linalg/parallel_loops.mlir b/mlir/test/Dialect/Linalg/parallel_loops.mlir --- a/mlir/test/Dialect/Linalg/parallel_loops.mlir +++ b/mlir/test/Dialect/Linalg/parallel_loops.mlir @@ -57,6 +57,42 @@ // CHECK-DAG: %[[D3:.*]] = dim %{{.*}}, 3 // CHECK: scf.parallel (%[[IV0:.*]], %[[IV1:.*]]) = (%[[C0]], %[[C0]]) to (%[[D0]], %[[D1]]) step (%[[C1]], %[[C1]]) // CHECK: scf.for %[[IV2:.*]] = %[[C0]] to %[[D2]] step %[[C1]] -// CHECK: scf.for %[[IV3:.*]] = %[[C0]] to %[[D3]] step %[[C1]] +// CHECK: scf.parallel (%[[IV3:.*]]) = (%[[C0]]) to (%[[D3]]) step (%[[C1]]) // CHECK: load %{{.*}}[%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]] // CHECK: store %{{.*}}, %{{.*}}[%[[IV0]], %[[IV1]], %[[IV3]]] + +// ----- + +#accesses = [ + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d4, d5)> +] +#trait = { + args_in = 1, + args_out = 1, + iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], + indexing_maps = #accesses +} + +func @lower_mixed_parallel(%A: memref, %B: memref) { + linalg.generic #trait %A, %B { + ^bb0(%a: f32, %b: f32): + linalg.yield %a: f32 + } : memref, memref + return +} +// CHECK-LABEL: @lower_mixed_parallel +// CHECK-DAG: %[[C0:.*]] = constant 0 +// CHECK-DAG: %[[C1:.*]] = constant 1 +// CHECK-DAG: %[[D0:.*]] = dim %{{.*}}, 0 +// CHECK-DAG: %[[D1:.*]] = dim %{{.*}}, 1 +// CHECK-DAG: %[[D2:.*]] = dim %{{.*}}, 2 +// CHECK-DAG: %[[D3:.*]] = dim %{{.*}}, 3 +// CHECK-DAG: %[[D4:.*]] = dim %{{.*}}, 4 +// CHECK-DAG: %[[D5:.*]] = dim %{{.*}}, 5 +// CHECK: scf.parallel (%[[IV0:.*]], %[[IV1:.*]]) = (%[[C0]], %[[C0]]) to (%[[D0]], %[[D1]]) step (%[[C1]], %[[C1]]) +// CHECK: scf.for %[[IV2:.*]] = %[[C0]] to %[[D2]] step %[[C1]] +// CHECK: scf.parallel (%[[IV3:.*]], %[[IV4:.*]]) = (%[[C0]], %[[C0]]) to (%[[D3]], %[[D4]]) step (%[[C1]], %[[C1]]) +// CHECK: scf.for %[[IV5:.*]] = %[[C0]] to %[[D5]] step %[[C1]] +// CHECK: load %{{.*}}[%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]], %[[IV5]]] +// CHECK: store %{{.*}}, %{{.*}}[%[[IV0]], %[[IV2]], %[[IV4]], %[[IV5]]] diff --git a/mlir/test/Dialect/Linalg/tile_parallel_reduce.mlir b/mlir/test/Dialect/Linalg/tile_parallel_reduce.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/tile_parallel_reduce.mlir @@ -0,0 +1,108 @@ +// RUN: mlir-opt %s -linalg-tile-to-parallel-loops="linalg-tile-sizes=2,4,8" -split-input-file | FileCheck %s +// RUN: mlir-opt %s -linalg-tile-to-parallel-loops="linalg-tile-sizes=2" -split-input-file | FileCheck %s -check-prefix=TILE1 +// RUN: mlir-opt %s -linalg-tile-to-parallel-loops="linalg-tile-sizes=2,4" -split-input-file | FileCheck %s -check-prefix=TILE2 + +func @gemm(%arg0 : memref, + %arg1 : memref, + %arg2 : memref) +{ + linalg.matmul(%arg0, %arg1, %arg2) + : memref, memref, memref + return +} +// CHECK-LABEL: func @gemm +// CHECK-DAG: %[[C2:.*]] = constant 2 : index +// CHECK-DAG: %[[C4:.*]] = constant 4 : index +// CHECK-DAG: %[[C8:.*]] = constant 8 : index +// CHECK: scf.parallel (%[[ARG3:.*]], %[[ARG4:.*]]) = +// CHECK-SAME: step (%[[C2]], %[[C4]]) +// CHECK: scf.for %[[ARG5:.*]] = +// CHECK-SAME: step %[[C8]] +// CHECK: %[[SV1:.*]] = subview %{{.*}}[%[[ARG3]], %[[ARG5]]] +// CHECK: %[[SV2:.*]] = subview %{{.*}}[%[[ARG5]], %[[ARG4]]] +// CHECK: %[[SV3:.*]] = subview %{{.*}}[%[[ARG3]], %[[ARG4]]] +// CHECK: linalg.matmul(%[[SV1]], %[[SV2]], %[[SV3]]) + +// TILE1-LABEL: func @gemm +// TILE1-DAG: %[[C2:.*]] = constant 2 : index +// TILE1: scf.parallel (%[[ARG3:.*]]) = +// TILE1-SAME: step (%[[C2]]) +// TILE1: %[[SV1:.*]] = subview %{{.*}}[%[[ARG3]], 0] +// TILE1: %[[SV3:.*]] = subview %{{.*}}[%[[ARG3]], 0] +// TILE1-NOT: subview +// TILE1: linalg.matmul(%[[SV1]], %{{.*}}, %[[SV3]]) + +// TILE2-LABEL: func @gemm +// TILE2-DAG: %[[C2:.*]] = constant 2 : index +// TILE2-DAG: %[[C4:.*]] = constant 4 : index +// TILE2: scf.parallel (%[[ARG3:.*]], %[[ARG4:.*]]) = +// TILE2-SAME: step (%[[C2]], %[[C4]]) +// TILE2: %[[SV1:.*]] = subview %{{.*}}[%[[ARG3]], 0] +// TILE2: %[[SV2:.*]] = subview %{{.*}}[0, %[[ARG4]]] +// TILE2: %[[SV3:.*]] = subview %{{.*}}[%[[ARG3]], %[[ARG4]]] +// TILE2: linalg.matmul(%[[SV1]], %[[SV2]], %[[SV3]]) + +// ----- + +#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1)> +#accesses = [#map0, #map1, #map2] +#trait = { + args_in = 2 : i64, + args_out = 1 : i64, + iterator_types = ["reduction", "parallel", "reduction"], + indexing_maps = #accesses +} + +func @reduction(%arg0 : memref, + %arg1 : memref, + %arg2 : memref) +{ + linalg.generic #trait %arg0, %arg1, %arg2 { + ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32): + %0 = addf %arg3, %arg4 : f32 + %1 = addf %0, %arg5 : f32 + linalg.yield %1 : f32 + } : memref, memref, memref + return +} + +// CHECK-LABEL: func @reduction +// CHECK-DAG: %[[C2:.*]] = constant 2 : index +// CHECK-DAG: %[[C4:.*]] = constant 4 : index +// CHECK-DAG: %[[C8:.*]] = constant 8 : index +// CHECK: scf.for %[[ARG3:.*]] = +// CHECK-SAME: step %[[C2]] +// CHECK: scf.parallel (%[[ARG4:.*]]) = +// CHECK-SAME: step (%[[C4]]) +// CHECK: scf.for %[[ARG5:.*]] = +// CHECK-SAME: step %[[C8]] +// CHECK: %[[SV1:.*]] = subview %{{.*}}[%[[ARG3]], %[[ARG4]], %[[ARG5]]] +// CHECK: %[[SV2:.*]] = subview %{{.*}}[%[[ARG3]], %[[ARG5]]] +// CHECK: %[[SV3:.*]] = subview %{{.*}}[%[[ARG4]]] +// CHECK: linalg.generic +// CHECK-SAME: %[[SV1]], %[[SV2]], %[[SV3]] + +// TILE1-LABEL: func @reduction +// TILE1-DAG: %[[C2:.*]] = constant 2 : index +// TILE1: scf.for %[[ARG3:.*]] = +// TILE1-SAME: step %[[C2]] +// TILE1: %[[SV1:.*]] = subview %{{.*}}[%[[ARG3]], 0, 0] +// TILE1: %[[SV2:.*]] = subview %{{.*}}[%[[ARG3]], 0] +// TILE1-NOT: subview +// TILE1: linalg.generic +// TILE1-SAME: %[[SV1]], %[[SV2]], %{{.*}} + +// TILE2-LABEL: func @reduction +// TILE2-DAG: %[[C2:.*]] = constant 2 : index +// TILE2-DAG: %[[C4:.*]] = constant 4 : index +// TILE2: scf.for %[[ARG3:.*]] = +// TILE2-SAME: step %[[C2]] +// TILE2: scf.parallel (%[[ARG4:.*]]) = +// TILE2-SAME: step (%[[C4]]) +// TILE2: %[[SV1:.*]] = subview %{{.*}}[%[[ARG3]], %[[ARG4]], 0] +// TILE2: %[[SV2:.*]] = subview %{{.*}}[%[[ARG3]], 0] +// TILE2: %[[SV3:.*]] = subview %{{.*}}[%[[ARG4]]] +// TILE2: linalg.generic +// TILE2-SAME: %[[SV1]], %[[SV2]], %[[SV3]] diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir --- a/mlir/test/Dialect/Linalg/transform-patterns.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir @@ -44,7 +44,8 @@ // CHECK-DAG: %[[c0:.*]] = constant 0 : index // CHECK-DAG: %[[c5:.*]] = constant 5 : index // CHECK-DAG: %[[c6:.*]] = constant 6 : index -// CHECK: scf.parallel {{.*}} step (%[[c5]], %[[c6]]) +// CHECK: scf.parallel {{.*}} step (%[[c5]]) +// CHECK: scf.for {{.*}} step %[[c6]] // CHECK: linalg.matvec({{.*}}, {{.*}}, {{.*}}) : memref, memref, memref func @matmul(%A: memref, @@ -364,3 +365,25 @@ // CHECK: linalg.fill(%[[v0]], {{%.*}}) : memref, f32 // CHECK: linalg.copy(%[[s0]], %[[l0]]) : memref, memref // CHECK: linalg.fill(%[[v0]], %[[cf]]) : memref, f32 + +func @tile_permute_parallel_loop(%arg0: memref, + %arg1: memref, + %arg2: memref) { + linalg.matmul(%arg0, %arg1, %arg2) {__internal_linalg_transform__ = "par__with_perm__"} + : memref, memref, memref + return +} +// CHECK-LABEL: func @tile_permute_parallel_loop +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref +// CHECK-DAG: %[[C16:.*]] = constant 16 : index +// CHECK-DAG: %[[C8:.*]] = constant 8 : index +// CHECK-DAG: %[[C4:.*]] = constant 4 : index +// CHECK-DAG: %[[C0:.*]] = constant 0 : index +// CHECK-DAG: %[[D0:.*]] = dim %[[ARG0]], 0 +// CHECK-DAG: %[[D1:.*]] = dim %[[ARG0]], 1 +// CHECK-DAG: %[[D2:.*]] = dim %[[ARG1]], 1 +// CHECK: scf.parallel (%{{.*}}) = (%[[C0]]) to (%[[D2]]) step (%[[C8]]) +// CHECK: scf.for %{{.*}} = %[[C0]] to %[[D1]] step %[[C4]] +// CHECK: scf.parallel (%{{.*}}) = (%[[C0]]) to (%[[D0]]) step (%[[C16]]) diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp --- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp @@ -101,6 +101,14 @@ ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}), LinalgMarker({"__with_perm__"}, "L1__with_perm__")); + patterns.insert>( + ctx, + LinalgTilingOptions() + .setTileSizes({16, 8, 4}) + .setInterchange({1, 2, 0}) + .setLoopType(LinalgTilingLoopType::ParallelLoops), + LinalgMarker({"par__with_perm__"}, "after_par__with_perm__")); + //===--------------------------------------------------------------------===// // Linalg to loops patterns. //===--------------------------------------------------------------------===//