diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -481,6 +481,75 @@ using TileSizeComputationFunction = std::function(OpBuilder &, Operation *)>; +/// Creates a number of ranges equal to the number of non-zero in `tileSizes`. +/// One for each loop of the LinalgOp that is tiled. The `tileSizes` argument +/// has one entry per surrounding loop. It uses zero as the convention that a +/// particular loop is not tiled. This convention simplifies implementations by +/// avoiding affine map manipulations. +/// The returned ranges correspond to the loop ranges, in the proper order, that +/// are tiled and for which new loops will be created. Also the function returns +/// a map from loop indices of the LinalgOp to the corresponding non-empty range +/// indices of newly created loops. +using LoopIndexToRangeIndexMap = DenseMap; +std::tuple, LoopIndexToRangeIndexMap> +makeTiledLoopRanges(RewriterBase &b, Location loc, AffineMap map, + ValueRange allShapeSizes, ValueRange allTileSizes); + +/// All indices returned by IndexOp should be invariant with respect to tiling. +/// Therefore, if an operation is tiled, we have to transform the indices +/// accordingly, i.e. offset them by the values of the corresponding induction +/// variables that are captured implicitly in the body of the op. +/// +/// Example. `linalg.generic` before tiling: +/// +/// #id_2d = (i, j) -> (i, j) +/// #pointwise_2d_trait = { +/// indexing_maps = [#id_2d, #id_2d], +/// iterator_types = ["parallel", "parallel"] +/// } +/// linalg.generic #pointwise_2d_trait %operand, %result { +/// ^bb0(%operand_in: f32, %result_in: f32): +/// %i = linalg.index 0 : index +/// %j = linalg.index 1 : index +/// +/// }: memref<50x100xf32>, memref<50x100xf32> +/// +/// After tiling pass with tiles sizes 10 and 25: +/// +/// #strided = (i, j)[s0, s1, s2] -> (i * s1 + s0 + j * s2) +/// +/// %c1 = arith.constant 1 : index +/// %c0 = arith.constant 0 : index +/// %c25 = arith.constant 25 : index +/// %c10 = arith.constant 10 : index +/// operand_dim_0 = dim %operand, 0 : memref<50x100xf32> +/// operand_dim_1 = dim %operand, 1 : memref<50x100xf32> +/// scf.for %k = %c0 to operand_dim_0 step %c10 { +/// scf.for %l = %c0 to operand_dim_1 step %c25 { +/// %4 = std.subview %operand[%k, %l][%c10, %c25][%c1, %c1] +/// : memref<50x100xf32> to memref +/// %5 = std.subview %result[%k, %l][%c10, %c25][%c1, %c1] +/// : memref<50x100xf32> to memref +/// linalg.generic pointwise_2d_trait %4, %5 { +/// ^bb0(%operand_in: f32, %result_in: f32): +/// %i = linalg.index 0 : index +/// %j = linalg.index 1 : index +/// // Indices `k` and `l` are implicitly captured in the body. +/// %transformed_i = arith.addi %i, %k : index // index `i` is offset by +/// %k %transformed_j = arith.addi %j, %l : index // index `j` is offset +/// by %l +/// // Every use of %i, %j is replaced with %transformed_i, %transformed_j +/// +/// }: memref, memref +/// } +/// } +/// +/// TODO: Investigate whether mixing implicit and explicit indices +/// does not lead to losing information. +void transformIndexOps(RewriterBase &b, LinalgOp op, + SmallVectorImpl &ivs, + const LoopIndexToRangeIndexMap &loopIndexToRangeIndex); + /// Callback returning the padding value to use for a given OpOperand or failure /// for no padding. This should be a function of both the operation and the /// operand type. diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -39,20 +39,10 @@ return false; } -using LoopIndexToRangeIndexMap = DenseMap; - -// Creates a number of ranges equal to the number of non-zero in `tileSizes`. -// One for each loop of the LinalgOp that is tiled. The `tileSizes` argument has -// one entry per surrounding loop. It uses zero as the convention that a -// particular loop is not tiled. This convention simplifies implementations by -// avoiding affine map manipulations. -// The returned ranges correspond to the loop ranges, in the proper order, that -// are tiled and for which new loops will be created. Also the function returns -// a map from loop indices of the LinalgOp to the corresponding non-empty range -// indices of newly created loops. -static std::tuple, LoopIndexToRangeIndexMap> -makeTiledLoopRanges(RewriterBase &b, Location loc, AffineMap map, - ValueRange allShapeSizes, ValueRange allTileSizes) { +std::tuple, LoopIndexToRangeIndexMap> +mlir::linalg::makeTiledLoopRanges(RewriterBase &b, Location loc, AffineMap map, + ValueRange allShapeSizes, + ValueRange allTileSizes) { assert(allTileSizes.size() == map.getNumResults()); // Apply `map` to get shape sizes in loop order. auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes); @@ -78,59 +68,9 @@ return std::make_tuple(res, loopIndexToRangeIndex); } -// All indices returned by IndexOp should be invariant with respect to tiling. -// Therefore, if an operation is tiled, we have to transform the indices -// accordingly, i.e. offset them by the values of the corresponding induction -// variables that are captured implicitly in the body of the op. -// -// Example. `linalg.generic` before tiling: -// -// #id_2d = (i, j) -> (i, j) -// #pointwise_2d_trait = { -// indexing_maps = [#id_2d, #id_2d], -// iterator_types = ["parallel", "parallel"] -// } -// linalg.generic #pointwise_2d_trait %operand, %result { -// ^bb0(%operand_in: f32, %result_in: f32): -// %i = linalg.index 0 : index -// %j = linalg.index 1 : index -// -// }: memref<50x100xf32>, memref<50x100xf32> -// -// After tiling pass with tiles sizes 10 and 25: -// -// #strided = (i, j)[s0, s1, s2] -> (i * s1 + s0 + j * s2) -// -// %c1 = arith.constant 1 : index -// %c0 = arith.constant 0 : index -// %c25 = arith.constant 25 : index -// %c10 = arith.constant 10 : index -// operand_dim_0 = dim %operand, 0 : memref<50x100xf32> -// operand_dim_1 = dim %operand, 1 : memref<50x100xf32> -// scf.for %k = %c0 to operand_dim_0 step %c10 { -// scf.for %l = %c0 to operand_dim_1 step %c25 { -// %4 = std.subview %operand[%k, %l][%c10, %c25][%c1, %c1] -// : memref<50x100xf32> to memref -// %5 = std.subview %result[%k, %l][%c10, %c25][%c1, %c1] -// : memref<50x100xf32> to memref -// linalg.generic pointwise_2d_trait %4, %5 { -// ^bb0(%operand_in: f32, %result_in: f32): -// %i = linalg.index 0 : index -// %j = linalg.index 1 : index -// // Indices `k` and `l` are implicitly captured in the body. -// %transformed_i = arith.addi %i, %k : index // index `i` is offset by %k -// %transformed_j = arith.addi %j, %l : index // index `j` is offset by %l -// // Every use of %i, %j is replaced with %transformed_i, %transformed_j -// -// }: memref, memref -// } -// } -// -// TODO: Investigate whether mixing implicit and explicit indices -// does not lead to losing information. -static void -transformIndexOps(RewriterBase &b, LinalgOp op, SmallVectorImpl &ivs, - const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) { +void mlir::linalg::transformIndexOps( + RewriterBase &b, LinalgOp op, SmallVectorImpl &ivs, + const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) { SmallVector allIvs(op.getNumLoops(), nullptr); for (auto &en : enumerate(allIvs)) { auto rangeIndex = loopIndexToRangeIndex.find(en.index());