diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp --- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp @@ -30,9 +30,13 @@ /// scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) /// step (%arg4*tileSize[0], /// %arg5*tileSize[1]) -/// scf.parallel (%j0, %j1) = (0, 0) to (min(tileSize[0], %arg2-%j0) -/// min(tileSize[1], %arg3-%j1)) +/// scf.parallel (%j0, %j1) = (0, 0) to (min(tileSize[0], %arg2-%i0) +/// min(tileSize[1], %arg3-%i1)) /// step (%arg4, %arg5) +/// +/// where the uses of %i0 and %i1 in the loop body are replaced by +/// %i0 + j0 and %i1 + %j1. +// /// The old loop is replaced with the new one. void mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef tileSizes) { OpBuilder b(op); @@ -85,6 +89,18 @@ // Steal the body of the old parallel loop and erase it. innerLoop.region().takeBody(op.region()); + + // Insert computation for new index vectors and replace uses. + b.setInsertionPointToStart(innerLoop.getBody()); + for (auto ivs : + llvm::zip(innerLoop.getInductionVars(), outerLoop.getInductionVars())) { + Value inner_index = std::get<0>(ivs); + AddIOp newIndex = + b.create(op.getLoc(), std::get<0>(ivs), std::get<1>(ivs)); + inner_index.replaceAllUsesExcept( + newIndex, SmallPtrSet{newIndex.getOperation()}); + } + op.erase(); } diff --git a/mlir/test/Dialect/SCF/parallel-loop-tiling.mlir b/mlir/test/Dialect/SCF/parallel-loop-tiling.mlir --- a/mlir/test/Dialect/SCF/parallel-loop-tiling.mlir +++ b/mlir/test/Dialect/SCF/parallel-loop-tiling.mlir @@ -25,10 +25,12 @@ // CHECK: [[VAL_17:%.*]] = affine.min #map0([[VAL_11]], [[VAL_2]], [[VAL_15]]) // CHECK: [[VAL_18:%.*]] = affine.min #map0([[VAL_12]], [[VAL_3]], [[VAL_16]]) // CHECK: scf.parallel ([[VAL_19:%.*]], [[VAL_20:%.*]]) = ([[VAL_10]], [[VAL_10]]) to ([[VAL_17]], [[VAL_18]]) step ([[VAL_4]], [[VAL_5]]) { -// CHECK: [[VAL_21:%.*]] = load [[VAL_7]]{{\[}}[[VAL_19]], [[VAL_20]]] : memref -// CHECK: [[VAL_22:%.*]] = load [[VAL_8]]{{\[}}[[VAL_19]], [[VAL_20]]] : memref -// CHECK: [[VAL_23:%.*]] = addf [[VAL_21]], [[VAL_22]] : f32 -// CHECK: store [[VAL_23]], [[VAL_9]]{{\[}}[[VAL_19]], [[VAL_20]]] : memref +// CHECK: [[VAL_21:%.*]] = addi [[VAL_19]], [[VAL_15]] : index +// CHECK: [[VAL_22:%.*]] = addi [[VAL_20]], [[VAL_16]] : index +// CHECK: [[VAL_23:%.*]] = load [[VAL_7]]{{\[}}[[VAL_21]], [[VAL_22]]] : memref +// CHECK: [[VAL_24:%.*]] = load [[VAL_8]]{{\[}}[[VAL_21]], [[VAL_22]]] : memref +// CHECK: [[VAL_25:%.*]] = addf [[VAL_23]], [[VAL_24]] : f32 +// CHECK: store [[VAL_25]], [[VAL_9]]{{\[}}[[VAL_21]], [[VAL_22]]] : memref // CHECK: } // CHECK: } // CHECK: return