diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -243,10 +243,11 @@ ArrayRef sizeBounds, bool omitPartialTileCheck); -/// Add the tile loop induction variables `ivs` to the IndexOp results found in -/// the body of the `tiledOp` to account for the tile offset. -void addTileLoopIvsToIndexOpResults(OpBuilder &b, LinalgOp tiledOp, - ArrayRef ivs); +/// Add the specified offsets to any `linalg.index` ops contained in the given +/// `linalgOp`. The offsets are provided in the same order as iteration space +/// dimensions. Null offests are assumed to be zero. +void offsetIndices(OpBuilder &b, LinalgOp linalgOp, ArrayRef offests); +void offsetIndices(RewriterBase &b, LinalgOp linalgOp, ArrayRef offests); using FusableOpDependencesTy = llvm::MapVector< Operation *, diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -170,7 +170,7 @@ SmallVector allIvs; llvm::transform(loopRanges, std::back_inserter(allIvs), [](Range range) { return range.offset; }); - addTileLoopIvsToIndexOpResults(b, clonedOp, allIvs); + offsetIndices(b, clonedOp, allIvs); return clonedOp; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -186,7 +186,7 @@ LinalgOp clonedOp = producerOp.clone(b, loc, resultTypes, tiledOperands); // Shift all IndexOp results by the tile offset. - addTileLoopIvsToIndexOpResults(b, clonedOp, allIvs); + offsetIndices(b, clonedOp, allIvs); return clonedOp; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp @@ -139,8 +139,7 @@ SmallVector ivAdditions; ivAdditions.resize(splitIterationSpace.size()); ivAdditions[dimension] = splitPointValue; - linalg::addTileLoopIvsToIndexOpResults(builder, cast(second), - ivAdditions); + linalg::offsetIndices(rewriter, cast(second), ivAdditions); // Replace the original op with the results of the two newly created ops. rewriter.replaceOp(op, secondResults); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -80,7 +80,7 @@ continue; en.value() = ivs[rangeIndex->second]; } - addTileLoopIvsToIndexOpResults(b, op, allIvs); + offsetIndices(b, op, allIvs); } /// Asserts that the given index-typed value is strictly positive. If the value diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -71,9 +71,10 @@ Location loc = op->getLoc(); LinalgOp linalgOp = cast(op); SmallVector valuesToTile = linalgOp.getInputAndOutputOperands(); + SmallVector offsetValues = + getValueOrCreateConstantIndexOp(b, loc, offsets); SmallVector tiledOperands = makeTiledShapes( - b, loc, linalgOp, valuesToTile, - getValueOrCreateConstantIndexOp(b, loc, offsets), + b, loc, linalgOp, valuesToTile, offsetValues, getValueOrCreateConstantIndexOp(b, loc, sizes), {}, true); SmallVector resultTensorTypes = llvm::to_vector(llvm::map_range( @@ -83,6 +84,7 @@ Operation *tiledOp = linalgOp.clone(b, loc, resultTensorTypes, tiledOperands); + offsetIndices(b, cast(tiledOp), offsetValues); return {tiledOp}; } diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -1048,21 +1048,29 @@ return tiledShapes; } -void addTileLoopIvsToIndexOpResults(OpBuilder &b, LinalgOp tiledOp, - ArrayRef ivs) { - if (tiledOp.hasIndexSemantics()) { - for (IndexOp indexOp : tiledOp.getBlock()->getOps()) { - if (ivs[indexOp.dim()] == nullptr) - continue; - OpBuilder::InsertionGuard guard(b); - b.setInsertionPointAfter(indexOp); - AffineExpr index, offset; - bindDims(b.getContext(), index, offset); - AffineApplyOp applyOp = makeComposedAffineApply( - b, indexOp.getLoc(), index + offset, - ValueRange{indexOp.getResult(), ivs[indexOp.dim()]}); - indexOp.getResult().replaceAllUsesExcept(applyOp, applyOp); - } +void offsetIndices(OpBuilder &b, LinalgOp linalgOp, ArrayRef offsets) { + IRRewriter rewriter(b); + offsetIndices(rewriter, linalgOp, offsets); +} + +void offsetIndices(RewriterBase &b, LinalgOp linalgOp, + ArrayRef offsets) { + if (!linalgOp.hasIndexSemantics()) + return; + + for (IndexOp indexOp : linalgOp.getBlock()->getOps()) { + if (indexOp.dim() >= offsets.size() || offsets[indexOp.dim()] == nullptr) + continue; + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointAfter(indexOp); + AffineExpr index, offset; + bindDims(b.getContext(), index, offset); + AffineApplyOp applyOp = makeComposedAffineApply( + b, indexOp.getLoc(), index + offset, + ValueRange{indexOp.getResult(), offsets[indexOp.dim()]}); + b.replaceOpWithIf(indexOp, applyOp.getResult(), [&](OpOperand &use) { + return use.getOwner() != applyOp; + }); } } diff --git a/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir --- a/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir @@ -192,3 +192,37 @@ // CHECK-SAME: outs(%[[INIT_TILE]] : // CHECK: tensor.insert_slice %[[CONV_TILE]] into %[[INIT2]] // CHECK-SAME: [0, 0, 0, 0] [%[[N]], %[[R]], %[[S]], %[[F]]] + +// ----- + +// CHECK: #[[$MAP_ADD:.+]] = affine_map<(d0, d1) -> (d0 + d1)> + +// CHECK-LABEL: @indexed_semantics +func.func @indexed_semantics(%arg0: tensor, %arg1: tensor) -> tensor { + // Check that we correctly amend "linalg.index" results. + + // CHECK: scf.for %[[I0:.+]] = %{{.*}} to %{{.*}} step %{{.*}} + // CHECK: scf.for %[[I1:.+]] = %{{.*}} to %{{.*}} step %{{.*}} + %0 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + {__internal_linalg_transform__ = "indexed_semantics"} + ins(%arg0: tensor) + outs(%arg1: tensor) { + ^bb0(%arg2: f32, %arg3: f32): + // CHECK: %[[INDEX0:.+]] = linalg.index 0 + // CHECK: %[[INDEX0_AMENDED:.+]] = affine.apply #[[$MAP_ADD]](%[[INDEX0]], %[[I0]]) + %1 = linalg.index 0 : index + // CHECK: %[[INDEX1:.+]] = linalg.index 1 + // CHECK: %[[INDEX1_AMENDED:.+]] = affine.apply #[[$MAP_ADD]](%[[INDEX1]], %[[I1]]) + %2 = linalg.index 1 : index + // CHECK: arith.addi %[[INDEX0_AMENDED]], %[[INDEX1_AMENDED]] + %3 = arith.addi %1, %2 : index + %4 = arith.index_cast %3 : index to i64 + %5 = arith.uitofp %4 : i64 to f32 + %6 = arith.addf %5, %arg2 : f32 + linalg.yield %6 : f32 + } -> (tensor) + return %0 : tensor +} diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp @@ -171,6 +171,9 @@ // 4. Tiling 2D conv op. addPatternForTiling( context, {0, 0, 0, 0, 10, 20, 30}, "simple_conv", patterns); + // 5. Tiling a simple op with `linalg.index` inside. + addPatternForTiling( + context, {10, 20}, "indexed_semantics", patterns); return; } if (testTileConsumerAndFuseProducer) {