diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -177,18 +177,6 @@ return spec; } -/// Given a `subsetExtractOp`, a `source` and a `dest`, create a new -/// `ParallelInsertSlice` op of `source` into `dest` at the same subset location -/// as `subsetExtractOp`. -static void -createMatchingParallelSubsetInsertOp(OpBuilder &b, Location loc, - tensor::ExtractSliceOp subsetExtractOp, - Value source, Value dest) { - b.create( - loc, source, dest, subsetExtractOp.getMixedOffsets(), - subsetExtractOp.getMixedSizes(), subsetExtractOp.getMixedStrides()); -} - /// Returns true if the maximum tile offset `tileSize * numThreads-1` is less /// than `iterationSize`. static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize, @@ -333,16 +321,21 @@ auto tilingInterfaceOp = dyn_cast(tiledOp); assert(tilingInterfaceOp && "Tiled op does not implement TilingInterface"); - - auto tiledDestOperands = tilingInterfaceOp.getDestinationOperands(b); - - // Create terminator with parallel subset insert operations. - b.setInsertionPointToStart(foreachThreadOp.getTerminator().getBody()); - for (auto it : llvm::zip(tiledDestOperands, tilingInterfaceOp->getResults(), - destOperands)) { - createMatchingParallelSubsetInsertOp( - b, loc, cast(std::get<0>(it).getDefiningOp()), - std::get<1>(it), std::get<2>(it)); + OpBuilder::InsertPoint insertPt = b.saveInsertionPoint(); + for (auto it : + llvm::zip(llvm::seq(unsigned(0), unsigned(destOperands.size())), + tilingInterfaceOp->getResults(), destOperands)) { + b.setInsertionPoint(insertPt.getBlock(), insertPt.getPoint()); + SmallVector resultOffsets, resultSizes; + if (failed(op.getResultTilePosition(b, std::get<0>(it), tiledOffsets, + tiledSizes, resultOffsets, + resultSizes))) + return op->emitOpError("output offsets couldn't be calculated"); + SmallVector strides(resultSizes.size(), b.getIndexAttr(1)); + b.setInsertionPointToStart(foreachThreadOp.getTerminator().getBody()); + b.create(loc, std::get<1>(it), + std::get<2>(it), resultOffsets, + resultSizes, strides); } return ForeachThreadTilingResult{foreachThreadOp, tiledOp}; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -161,15 +161,12 @@ })); OpOperand *outOperand = linalgOp.getOutputOperand(resultNumber); - Value sliceOpResult = - makeTiledShape(b, loc, outOperand->get(), sizes, - linalgOp.getTiedIndexingMap(outOperand), offsets, - /*ubs*/ {}, subShapeSizes, true); - auto sliceOp = sliceOpResult.getDefiningOp(); - if (!sliceOp) - return failure(); - resultOffsets = sliceOp.getMixedOffsets(); - resultSizes = sliceOp.getMixedSizes(); + SliceParameters sliceParams = + computeSliceParameters(b, loc, outOperand->get(), sizes, + linalgOp.getTiedIndexingMap(outOperand), offsets, + /*ubs*/ {}, subShapeSizes, true); + resultOffsets = sliceParams.offsets; + resultSizes = sliceParams.sizes; return success(); } diff --git a/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir b/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir --- a/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir +++ b/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir @@ -59,7 +59,7 @@ // CHECK: %[[RESPARTIAL:.+]] = tensor.insert_slice %[[RESSLICE_1]] into %[[ITERARG_2]] // CHECK: scf.yield %[[RESPARTIAL]] - // CHECK: %[[INSERTED:.+]] = tensor.insert_slice %[[LOOPRES]] into %[[OUTSLICE_1]][%[[I1]], 0] [2, 16] [1, 1] + // CHECK: %[[INSERTED:.+]] = tensor.insert_slice %[[LOOPRES]] into %[[OUTSLICE_1]][0, 0] [2, 16] [1, 1] // CHECK: %[[OUTSLICE_3:.+]] = tensor.extract_slice %[[INSERTED]][0, 16] [2, 18] [1, 1] // CHECK: scf.for %{{.*}} iter_args(%{{.*}} = %[[OUTSLICE_3]]) // CHECK-COUNT-2: tensor.extract_slice diff --git a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir --- a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir +++ b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s --test-transform-dialect-interpreter -canonicalize -split-input-file | FileCheck %s +// RUN: mlir-opt %s --test-transform-dialect-interpreter -canonicalize -cse -split-input-file | FileCheck %s // Offset per thread: // CHECK-DAG: affine_map<(d0)[s0] -> (d0 * (s0 ceildiv 10))> @@ -22,7 +22,7 @@ // CHECK: %[[RES:.*]] = linalg.matmul // CHECK-SAME: ins(%[[tA]], %[[tB]] : tensor, tensor) // CHECK-SAME: outs(%[[tC]] : tensor) -> tensor - // CHECK-NEXT: scf.foreach_thread.perform_concurrently { + // CHECK: scf.foreach_thread.perform_concurrently { // CHECK-NEXT: tensor.parallel_insert_slice %[[RES]] into %[[C]]{{.*}} : // CHECK-SAME: tensor into tensor // CHECK-NEXT: } @@ -65,11 +65,9 @@ // CHECK-NOT: affine.max // CHECK: %[[LB0:.+]] = affine.apply #[[$map2]](%[[IV0]]) // CHECK: %[[LB1:.+]] = affine.apply #[[$map3]](%[[IV1]]) - // CHECK: %[[LB0_1:.+]] = affine.apply #[[$map2]](%[[IV0]]) - // CHECK: %[[LB1_1:.+]] = affine.apply #[[$map3]](%[[IV1]]) // CHECK: %[[tA:.+]] = tensor.extract_slice %[[A]][%[[LB0]], 0] [10, 200] [1, 1] : // CHECK: %[[tB:.+]] = tensor.extract_slice %[[B]][0, %[[LB1]]] [200, %[[TS]]] [1, 1] : - // CHECK: %[[tC:.+]] = tensor.extract_slice %[[C]][%[[LB0_1]], %[[LB1_1]]] [10, %[[TS]]] [1, 1] : + // CHECK: %[[tC:.+]] = tensor.extract_slice %[[C]][%[[LB0]], %[[LB1]]] [10, %[[TS]]] [1, 1] : // CHECK: linalg.matmul // CHECK: scf.foreach_thread.perform_concurrently // CHECK-NEXT: tensor.parallel_insert_slice @@ -106,8 +104,6 @@ // CHECK: %[[N:.+]] = tensor.dim %[[B]], %c1 : // CHECK: %[[NT0:.+]] = affine.apply #map0()[%[[M]]] // CHECK: %[[NT1:.+]] = affine.apply #map1()[%[[N]]] - // CHECK: %[[M:.+]] = tensor.dim %[[A]], %c0 : - // CHECK: %[[N:.+]] = tensor.dim %[[B]], %c1 : // CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]]) // CHECK: %[[TS0:.+]] = affine.min #[[$map2]](%[[IV0]])[%[[M]]] // CHECK: %[[TS1:.+]] = affine.min #[[$map4]](%[[IV1]])[%[[N]]] @@ -115,8 +111,6 @@ // CHECK tensor.extract_slice %[[A]] // CHECK: %[[LB1:.+]] = affine.apply #[[$map6]](%[[IV1]]) // CHECK tensor.extract_slice %[[B]] - // CHECK: %[[LB0:.+]] = affine.apply #[[$map5]](%[[IV0]]) - // CHECK: %[[LB1:.+]] = affine.apply #[[$map6]](%[[IV1]]) // CHECK tensor.extract_slice %[[C]] // CHECK: linalg.matmul // CHECK: scf.foreach_thread.perform_concurrently @@ -156,11 +150,9 @@ // CHECK-NOT: affine.min // CHECK: %[[LB0:.+]] = affine.apply #[[$map2]](%[[IV0]]) // CHECK: %[[LB1:.+]] = affine.apply #[[$map3]](%[[IV1]]) - // CHECK: %[[LB0_1:.+]] = affine.apply #[[$map2]](%[[IV0]]) - // CHECK: %[[LB1_1:.+]] = affine.apply #[[$map3]](%[[IV1]]) // CHECK: %[[tA:.+]] = tensor.extract_slice %[[A]][%[[LB0]], 0] [10, 200] [1, 1] : // CHECK: %[[tB:.+]] = tensor.extract_slice %[[B]][0, %[[LB1]]] [200, %[[TS]]] [1, 1] : - // CHECK: %[[tC:.+]] = tensor.extract_slice %[[C]][%[[LB0_1]], %[[LB1_1]]] [10, %[[TS]]] [1, 1] : + // CHECK: %[[tC:.+]] = tensor.extract_slice %[[C]][%[[LB0]], %[[LB1]]] [10, %[[TS]]] [1, 1] : // CHECK: linalg.matmul // CHECK: scf.foreach_thread.perform_concurrently // CHECK-NEXT: tensor.parallel_insert_slice @@ -177,3 +169,37 @@ %1:2 = transform.structured.tile_to_foreach_thread_op %0 tile_sizes [10, 21] } } + +// ----- + +module { + func.func @extract_source(%A: tensor<4xf32>, %B: tensor<16xf32>) -> tensor<4xf32> { + %B1 = tensor.extract_slice %B[10] [4] [1] : tensor<16xf32> to tensor<4xf32> + %result = linalg.generic {indexing_maps = [ + affine_map<(d0) -> (d0)>,affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"]} + ins(%A : tensor<4xf32>) outs(%B1 : tensor<4xf32>) { + ^bb0(%arg3: f32, %arg4: f32): // no predecessors + %2 = arith.addf %arg3, %arg3 : f32 + linalg.yield %2 : f32 + } -> tensor<4xf32> + return %result : tensor<4xf32> + } + + transform.with_pdl_patterns { + ^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1:2 = transform.structured.tile_to_foreach_thread_op %0 num_threads [2] (mapped to dims [0]) + } + } +} +// CHECK-DAG: #[[$map0:.+]] = affine_map<(d0) -> (d0 * 2)> + +// CHECK-LABEL: extract_source( +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: scf.foreach_thread (%[[ARG:.*]]) in (%[[C2]]) -> (tensor<4xf32>) { +// CHECK: %[[OFF:.*]] = affine.apply #[[$map0]](%[[ARG]]) +// CHECK: scf.foreach_thread.perform_concurrently { +// CHECK: tensor.parallel_insert_slice %{{.*}} into %{{.*}}[%[[OFF]]] [2] [1] : tensor<2xf32> into tensor<4xf32>