diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -172,6 +172,82 @@ } } +// All indices returned by IndexOp should be invariant with respect to tiling. +// Therefore, if an operation is tiled, we have to transform the indices +// accordingly, i.e. offset them by the values of the corresponding induction +// variables that are captured implicitly in the body of the op. +// +// Example. `linalg.generic` before tiling: +// +// #id_2d = (i, j) -> (i, j) +// #pointwise_2d_trait = { +// indexing_maps = [#id_2d, #id_2d], +// iterator_types = ["parallel", "parallel"] +// } +// linalg.generic #pointwise_2d_trait %operand, %result { +// ^bb0(%operand_in: f32, %result_in: f32): +// %i = linalg.index 0 : index +// %j = linalg.index 1 : index +// +// }: memref<50x100xf32>, memref<50x100xf32> +// +// After tiling pass with tiles sizes 10 and 25: +// +// #strided = (i, j)[s0, s1, s2] -> (i * s1 + s0 + j * s2) +// +// %c1 = constant 1 : index +// %c0 = constant 0 : index +// %c25 = constant 25 : index +// %c10 = constant 10 : index +// operand_dim_0 = dim %operand, 0 : memref<50x100xf32> +// operand_dim_1 = dim %operand, 1 : memref<50x100xf32> +// scf.for %k = %c0 to operand_dim_0 step %c10 { +// scf.for %l = %c0 to operand_dim_1 step %c25 { +// %4 = std.subview %operand[%k, %l][%c10, %c25][%c1, %c1] +// : memref<50x100xf32> to memref +// %5 = std.subview %result[%k, %l][%c10, %c25][%c1, %c1] +// : memref<50x100xf32> to memref +// linalg.generic pointwise_2d_trait %4, %5 { +// ^bb0(%operand_in: f32, %result_in: f32): +// %i = linalg.index 0 : index +// %j = linalg.index 1 : index +// // Indices `k` and `l` are implicitly captured in the body. +// %transformed_i = addi %i, %k : index // index `i` is offset by %k +// %transformed_j = addi %j, %l : index // index `j` is offset by %l +// // Every use of %i, %j is replaced with %transformed_i, %transformed_j +// +// }: memref, memref +// } +// } +// +// TODO: Investigate whether mixing implicit and explicit indices +// does not lead to losing information. +static void +transformIndexOps(OpBuilder &b, LinalgOp op, SmallVectorImpl &ivs, + const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) { + // Skip operations that have no region attached. + if (op->getNumRegions() == 0) + return; + assert(op->getNumRegions() == 1 && op->getRegion(0).getBlocks().size() == 1 && + "expected linalg operation to have one block."); + Block &block = op->getRegion(0).front(); + + for (IndexOp indexOp : + llvm::make_early_inc_range(block.getOps())) { + auto rangeIndex = loopIndexToRangeIndex.find(indexOp.dim()); + if (rangeIndex == loopIndexToRangeIndex.end()) + continue; + // Offset the index by the value of the corresponding induction variable and + // replace all uses of the previous value. + OpBuilder::InsertionGuard g(b); + b.setInsertionPointAfter(indexOp); + AddIOp addOp = b.create(indexOp.getLoc(), indexOp.getResult(), + ivs[rangeIndex->second]); + indexOp.getResult().replaceAllUsesExcept( + addOp.getResult(), SmallPtrSet{addOp}); + } +} + template static Optional tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes, @@ -299,8 +375,10 @@ }, options.distribution); - // 3. Transforms index arguments of `linalg.generic` w.r.t. to the tiling. + // 3a. Transforms index arguments of `linalg.generic` w.r.t. to the tiling. transformIndexedGenericOpIndices(b, res, ivs, loopIndexToRangeIndex); + // 3b. Transform IndexOp results w.r.t. the tiling. + transformIndexOps(b, res, ivs, loopIndexToRangeIndex); // 4. Gather the newly created loops and return them with the new op. SmallVector loops; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -246,8 +246,7 @@ LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase( Operation *op, PatternRewriter &rewriter, TiledLinalgOp &result) const { LinalgOp linalgOp = dyn_cast(op); - // TODO: remove hasIndexSemantics check once index ops are supported. - if (!linalgOp || linalgOp.hasIndexSemantics()) + if (!linalgOp) return failure(); if (failed(filter.checkAndNotify(rewriter, linalgOp))) return failure(); diff --git a/mlir/test/Dialect/Linalg/tile-indexed.mlir b/mlir/test/Dialect/Linalg/tile-indexed.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/tile-indexed.mlir @@ -0,0 +1,78 @@ +// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=10,25" | FileCheck %s -check-prefix=TILE-10n25 +// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=25,0" | FileCheck %s -check-prefix=TILE-25n0 +// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=0,25" | FileCheck %s -check-prefix=TILE-0n25 + +func @indexed_vector(%arg0: memref<50xindex>) { + linalg.generic {indexing_maps = [affine_map<(i) -> (i)>], + iterator_types = ["parallel"]} + outs(%arg0 : memref<50xindex>) { + ^bb0(%a: index): + %i = linalg.index 0 : index + linalg.yield %i : index + } + return +} +// TILE-10n25-LABEL: func @indexed_vector +// TILE-10n25: %[[C10:.*]] = constant 10 : index +// TILE-10n25: scf.for %[[J:.*]] = {{.*}} step %[[C10]] +// TILE-10n25: linalg.generic +// TILE-10n25: %[[I:.*]] = linalg.index 0 : index +// TILE-10n25: %[[NEW_I:.*]] = addi %[[I]], %[[J]] : index +// TILE-10n25: linalg.yield %[[NEW_I]] : index + +// TILE-25n0-LABEL: func @indexed_vector +// TILE-25n0: %[[C25:.*]] = constant 25 : index +// TILE-25n0: scf.for %[[J:.*]] = {{.*}} step %[[C25]] +// TILE-25n0: linalg.generic +// TILE-25n0: %[[I:.*]] = linalg.index 0 : index +// TILE-25n0: %[[NEW_I:.*]] = addi %[[I]], %[[J]] : index +// TILE-25n0: linalg.yield %[[NEW_I]] : index + +// TILE-0n25-LABEL: func @indexed_vector +// TILE-0n25-NOT: scf.for %[[J:.*]] = {{.*}} step % +// TILE-0n25: linalg.generic + +func @indexed_matrix(%arg0: memref<50x50xindex>) { + linalg.generic {indexing_maps = [affine_map<(i, j) -> (i, j)>], + iterator_types = ["parallel", "parallel"]} + outs(%arg0 : memref<50x50xindex>) { + ^bb0(%a: index): + %i = linalg.index 0 : index + %j = linalg.index 1 : index + %sum = addi %i, %j : index + linalg.yield %sum : index + } + return +} +// TILE-10n25-LABEL: func @indexed_matrix +// TILE-10n25-DAG: %[[C25:.*]] = constant 25 : index +// TILE-10n25-DAG: %[[C10:.*]] = constant 10 : index +// TILE-10n25: scf.for %[[K:.*]] = {{.*}} step %[[C10]] +// TILE-10n25: scf.for %[[L:.*]] = {{.*}} step %[[C25]] +// TILE-10n25: linalg.generic +// TILE-10n25: %[[I:.*]] = linalg.index 0 : index +// TILE-10n25: %[[NEW_I:.*]] = addi %[[I]], %[[K]] : index +// TILE-10n25: %[[J:.*]] = linalg.index 1 : index +// TILE-10n25: %[[NEW_J:.*]] = addi %[[J]], %[[L]] : index +// TILE-10n25: %[[SUM:.*]] = addi %[[NEW_I]], %[[NEW_J]] : index +// TILE-10n25: linalg.yield %[[SUM]] : index + +// TILE-25n0-LABEL: func @indexed_matrix +// TILE-25n0: %[[C25:.*]] = constant 25 : index +// TILE-25n0: scf.for %[[L:.*]] = {{.*}} step %[[C25]] +// TILE-25n0: linalg.generic +// TILE-25n0: %[[I:.*]] = linalg.index 0 : index +// TILE-25n0: %[[NEW_I:.*]] = addi %[[I]], %[[L]] : index +// TILE-25n0: %[[J:.*]] = linalg.index 1 : index +// TILE-25n0: %[[SUM:.*]] = addi %[[NEW_I]], %[[J]] : index +// TILE-25n0: linalg.yield %[[SUM]] : index + +// TILE-0n25-LABEL: func @indexed_matrix +// TILE-0n25: %[[C25:.*]] = constant 25 : index +// TILE-0n25: scf.for %[[L:.*]] = {{.*}} step %[[C25]] +// TILE-0n25: linalg.generic +// TILE-0n25: %[[I:.*]] = linalg.index 0 : index +// TILE-0n25: %[[J:.*]] = linalg.index 1 : index +// TILE-0n25: %[[NEW_J:.*]] = addi %[[J]], %[[L]] : index +// TILE-0n25: %[[SUM:.*]] = addi %[[I]], %[[NEW_J]] : index +// TILE-0n25: linalg.yield %[[SUM]] : index diff --git a/mlir/test/Dialect/Linalg/tile.mlir b/mlir/test/Dialect/Linalg/tile.mlir --- a/mlir/test/Dialect/Linalg/tile.mlir +++ b/mlir/test/Dialect/Linalg/tile.mlir @@ -377,18 +377,3 @@ // TILE-234: for // TILE-234-NOT: for // TILE-234: linalg.generic - -// TILE-2-LABEL: func @index_op -// TILE-2-NOT: for -// TILE-2: linalg.generic -func @index_op(%arg0: memref) { - linalg.generic { - indexing_maps = [affine_map<(i, j) -> (i, j)>], - iterator_types = ["parallel", "parallel"]} - outs(%arg0 : memref) { - ^bb0(%arg1: index): // no predecessors - %0 = linalg.index 1 : index - linalg.yield %0 : index - } - return -}