diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h --- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h @@ -26,7 +26,7 @@ namespace scf { using SCFTileSizeComputationFunction = - std::function(OpBuilder &, Operation *)>; + std::function(OpBuilder &, Operation *)>; /// Options to use to control tiling. struct SCFTilingOptions { @@ -51,6 +51,13 @@ /// function that computes tile sizes at the point they are needed. Allows /// proper interaction with folding. SCFTilingOptions &setTileSizes(ArrayRef ts); + + /// The interchange vector to reorder the tiled loops. + SmallVector interchangeVector = {}; + SCFTilingOptions &setInterchange(ArrayRef interchange) { + interchangeVector = llvm::to_vector(interchange); + return *this; + } }; struct SCFTilingResult { diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -29,7 +29,7 @@ scf::SCFTilingOptions & scf::SCFTilingOptions::setTileSizes(ArrayRef ts) { assert(!tileSizeComputationFunction && "tile sizes already set"); - SmallVector tileSizes(ts.begin(), ts.end()); + SmallVector tileSizes(ts.begin(), ts.end()); tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { OpBuilder::InsertionGuard guard(b); b.setInsertionPointToStart( @@ -42,6 +42,49 @@ return *this; } +/// Helper method to adjust the interchange vector to match the iteration +/// domain. +static SmallVector +fillInterchangeVector(ArrayRef interchangeVector, + size_t iterationDomainSize) { + SmallVector filledVector = llvm::to_vector(interchangeVector); + if (filledVector.size() < iterationDomainSize) { + auto range = llvm::seq(filledVector.size(), iterationDomainSize); + filledVector.append(range.begin(), range.end()); + } + if (filledVector.size() > iterationDomainSize) + filledVector.resize(iterationDomainSize); + return filledVector; +} + +/// Helper method to apply permutation to a vector +template +static SmallVector applyPermutationToVector(const SmallVector &vector, + ArrayRef interchange) { + assert(interchange.size() == vector.size()); + return llvm::to_vector( + llvm::map_range(interchange, [&](unsigned val) { return vector[val]; })); +} +/// Helper method to apply to invert a permutation. +static SmallVector +invertPermutationVector(ArrayRef interchange) { + SmallVector inversion(interchange.size()); + for (auto pos : llvm::enumerate(interchange)) { + inversion[pos.value()] = pos.index(); + } + return inversion; +} +/// Method to check if an interchange vector is a permutation. +static bool isPermutation(ArrayRef interchange) { + llvm::SmallDenseSet seenVals; + for (auto val : interchange) { + if (seenVals.count(val)) + return false; + seenVals.insert(val); + } + return seenVals.size() == interchange.size(); +} + //===----------------------------------------------------------------------===// // TileUsingSCFForOp pattern implementation. //===----------------------------------------------------------------------===// @@ -137,7 +180,7 @@ // skips tiling a particular dimension. This convention is significantly // simpler to handle instead of adjusting affine maps to account for missing // dimensions. - SmallVector tileSizeVector = + SmallVector tileSizeVector = options.tileSizeComputationFunction(rewriter, op); if (tileSizeVector.size() < iterationDomain.size()) { auto zero = rewriter.create(op.getLoc(), 0); @@ -147,12 +190,38 @@ scf::SCFTilingResult tilingResult; SmallVector offsets, sizes; { + // If there is an interchange specified, permute the iteration domain and + // the tile sizes. + SmallVector interchangeVector; + if (!options.interchangeVector.empty()) { + interchangeVector = fillInterchangeVector(options.interchangeVector, + iterationDomain.size()); + } + if (!interchangeVector.empty()) { + if (!isPermutation(interchangeVector)) { + return rewriter.notifyMatchFailure( + op, "invalid intechange vector, not a permutation of the entire " + "iteration space"); + } + + iterationDomain = + applyPermutationToVector(iterationDomain, interchangeVector); + tileSizeVector = + applyPermutationToVector(tileSizeVector, interchangeVector); + } + // 3. Materialize an empty loop nest that iterates over the tiles. These // loops for now do not return any values even if the original operation has // results. tilingResult.loops = generateTileLoopNest( rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes); + if (!interchangeVector.empty()) { + auto inversePermutation = invertPermutationVector(interchangeVector); + offsets = applyPermutationToVector(offsets, inversePermutation); + sizes = applyPermutationToVector(sizes, inversePermutation); + } + LLVM_DEBUG({ if (!tilingResult.loops.empty()) { llvm::errs() << "LoopNest shell :\n"; diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir --- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir @@ -183,3 +183,50 @@ // CHECK-SAME: outs(%[[OUTS_TILE]] : // CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[GENERIC_TILE]] into %[[ITERARG1]][%[[IV1]], %[[IV0]]] // CHECK scf.yield %[[INSERT]] + +// ----- + +func.func @interchange_matmul_fusion(%arg0 : tensor, %arg1 : tensor) -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %d0 = tensor.dim %arg0, %c0 : tensor + %d1 = tensor.dim %arg1, %c1 : tensor + %cst = arith.constant 0.0 : f32 + %0 = linalg.init_tensor [%d0, %d1] : tensor + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor) -> tensor + %2 = linalg.matmul + ins(%arg0, %arg1 : tensor, tensor) + outs(%1 : tensor) -> tensor + %3 = linalg.generic { + __internal_linalg_transform__ = "gemm_interchange_fusion", + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%2 : tensor) outs(%0 : tensor) { + ^bb0(%b0 : f32, %b1 : f32): + %4 = arith.addf %b0, %b0 : f32 + linalg.yield %4 : f32 + } -> tensor + return %3 : tensor +} +// CHECK: func.func @interchange_matmul_fusion( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor) +// CHECK: %[[INIT:.+]] = linalg.init_tensor +// CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] = +// CHECK-SAME: iter_args(%[[ITERARG0:.+]] = %[[INIT]]) +// CHECK: scf.for %[[IV1:[a-zA-Z0-9]+]] = +// CHECK-SAME: iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]]) +// CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0] +// CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV0]]] +// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]][%[[IV1]], %[[IV0]]] +// CHECK: %[[FILL_TILE:.+]] = linalg.fill +// CHECK-SAME: outs(%[[INIT_TILE]] : +// CHECK: %[[GEMM_TILE:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]] : +// CHECK-SAME: outs(%[[FILL_TILE]] : +// CHECK: %[[INIT_TILE_2:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV1]], %[[IV0]]] +// CHECK: %[[GENERIC_TILE:.+]] = linalg.generic +// CHECK-SAME: ins(%[[GEMM_TILE]] : +// CHECK-SAME: outs(%[[INIT_TILE_2]] : +// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[GENERIC_TILE]] into %[[ITERARG1]][%[[IV1]], %[[IV0]]] +// CHECK scf.yield %[[INSERT]] diff --git a/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir --- a/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir @@ -226,3 +226,52 @@ } -> (tensor) return %0 : tensor } + +// ----- + +func.func @interchange_matmul(%arg0 : tensor, %arg1 : tensor, + %arg2 : tensor) -> tensor { + %0 = linalg.matmul {__internal_linalg_transform__ = "gemm_interchange"} + ins(%arg0, %arg1 : tensor, tensor) + outs(%arg2 : tensor) -> tensor + return %0 : tensor +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (30, -d0 + s1)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)> +// CHECK: func.func @interchange_matmul( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index +// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index +// CHECK-DAG: %[[C30:.+]] = arith.constant 30 : index +// CHECK-DAG: %[[M:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[K:.+]] = tensor.dim %[[ARG0]], %[[C1]] +// CHECK-DAG: %[[N:.+]] = tensor.dim %[[ARG1]], %[[C1]] +// CHECK: %[[OUTER:[a-zA-Z0-9]+]] = scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C20]] +// CHECK-SAME: iter_args(%[[INIT0:.+]] = %[[ARG2]]) +// CHECK: %[[TS_N:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[C20]], %[[N]]] +// CHECK: %[[INNER1:[a-zA-Z0-9]+]] = scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[K]] step %[[C30]] +// CHECK-SAME: iter_args(%[[INIT1:.+]] = %[[INIT0]]) +// CHECK: %[[TS_K:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[C30]], %[[K]]] +// CHECK: %[[INNER2:[a-zA-Z0-9]+]] = scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C10]] +// CHECK-SAME: iter_args(%[[INIT2:.+]] = %[[INIT1]]) +// CHECK-DAG: %[[TS_M:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[C10]], %[[M]]] +// CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]] +// CHECK-SAME: [%[[IV2]], %[[IV1]]] [%[[TS_M]], %[[TS_K]]] [1, 1] +// CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]] +// CHECK-SAME: [%[[IV1]], %[[IV0]]] [%[[TS_K]], %[[TS_N]]] [1, 1] +// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT2]] +// CHECK-SAME: [%[[IV2]], %[[IV0]]] [%[[TS_M]], %[[TS_N]]] [1, 1] +// CHECK: %[[GEMM_TILE:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]] : +// CHECK-SAME: outs(%[[INIT_TILE]] : +// CHECK: %[[UPDATE:.+]] = tensor.insert_slice %[[GEMM_TILE]] into %[[INIT2]] +// CHECK-SAME: [%[[IV2]], %[[IV0]]] [%[[TS_M]], %[[TS_N]]] [1, 1] +// CHECK: scf.yield %[[UPDATE]] +// CHECK: scf.yield %[[INNER2]] +// CHECK: scf.yield %[[INNER1]] +// CHECK: return %[[OUTER]] diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp @@ -147,10 +147,11 @@ template static void -addPatternForTiling(MLIRContext *context, ArrayRef tileSizes, - StringRef filterName, RewritePatternSet &patterns) { +addPatternForTiling(MLIRContext *context, RewritePatternSet &patterns, + StringRef filterName, ArrayRef tileSizes, + ArrayRef interchange = {}) { scf::SCFTilingOptions tilingOptions; - tilingOptions.setTileSizes(tileSizes); + tilingOptions.setTileSizes(tileSizes).setInterchange(interchange); linalg::LinalgTransformationFilter filter( StringAttr::get(context, filterName), StringAttr::get(context, "tiled")); patterns.add(context, tilingOptions, filter); @@ -161,29 +162,35 @@ if (testTiling) { // 1. Tiling M and N dims of `linalg.matmul` on tensors. addPatternForTiling( - context, {10, 20}, "simple_gemm", patterns); + context, patterns, "simple_gemm", {10, 20}); // 2. Tiling M, N and K of `linalg.matmul` on buffers. addPatternForTiling( - context, {10, 20, 30}, "simple_gemm_memref", patterns); + context, patterns, "simple_gemm_memref", {10, 20, 30}); // 3. Tiling 3D parallel generic op which implements a transpose addPatternForTiling( - context, {10, 0, 20}, "parallel_generic_transpose", patterns); + context, patterns, "parallel_generic_transpose", {10, 0, 20}); // 4. Tiling 2D conv op. addPatternForTiling( - context, {0, 0, 0, 0, 10, 20, 30}, "simple_conv", patterns); + context, patterns, "simple_conv", {0, 0, 0, 0, 10, 20, 30}); // 5. Tiling a simple op with `linalg.index` inside. addPatternForTiling( - context, {10, 20}, "indexed_semantics", patterns); + context, patterns, "indexed_semantics", {10, 20}); + // 6. Tiling + interchange of an operation + addPatternForTiling( + context, patterns, "gemm_interchange", {10, 20, 30}, {1, 2, 0}); return; } if (testTileConsumerAndFuseProducer) { // 1. Tile and fuse of gemm with bias-add operation. addPatternForTiling< TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>( - context, {10, 20}, "fusion", patterns); + context, patterns, "fusion", {10, 20}); + addPatternForTiling< + TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>( + context, patterns, "gemm_fusion", {10}); addPatternForTiling< TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>( - context, {10}, "gemm_fusion", patterns); + context, patterns, "gemm_interchange_fusion", {10, 20}, {1, 0}); return; } }