diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -30,61 +30,26 @@ // StructuredOp specific helpers. //===----------------------------------------------------------------------===// -/// Relate the producer to the consumer loop iterations that access the same -/// producer result element: -/// consumerToProducerLoops = -/// inverse(producerIndexingMap).compose(consumerIndexingMap). -/// Return `consumerToProducerLoops` or none if the inversion fails. -static Optional -getConsumerToProducerLoopsMap(AffineMap producerIndexingMap, - AffineMap consumerIndexingMap) { - assert(consumerIndexingMap.getNumResults() == - producerIndexingMap.getNumResults() && - "expect the number of indexing map results to match"); - // Ensure the producer indexing map is a projected permutation. - if (!producerIndexingMap.isProjectedPermutation()) - return None; - AffineMap inverseIndexingMap = - inverseAndBroadcastProjectedPermuation(producerIndexingMap); - return inverseIndexingMap.compose(consumerIndexingMap); -} - -/// Returns the producer result slice dimensions tiled by the tile loop nest or -/// an empty vector if `getConsumerToProducerLoopsMap` returns none. -// TODO: replace by Fourier-Motzkin and/or compute starting from consumer. -SmallVector getTiledSliceDims(OpResult producerResult, - OpOperand *consumerOperand, +/// Returns the tiled slice dimensions given the tiled consumer loop dimensions. +/// The slice defines a hyper rectangular iteration space and fusing the +/// producer is always possible. However, depending on the consumer indexing +/// map, not all slice elements may be consumed and the tiles may overlap. In +/// these cases, fusion introduces redundant computation. +SmallVector getTiledSliceDims(OpOperand *consumerOperand, ArrayRef tiledLoopDims) { + // Get the consumer operand indexing map. LinalgOp consumerOp = consumerOperand->getOwner(); - LinalgOp producerOp = producerResult.getOwner(); - OpOperand *opOperand = - producerOp.getOutputOperand(producerResult.getResultNumber()); - - // Compute the `consumerToProducerLoopsMap` and exit if the computation fails. - AffineMap producerIndexingMap = producerOp.getTiedIndexingMap(opOperand); - Optional consumerToProducerLoopsMap = - getConsumerToProducerLoopsMap( - producerIndexingMap, consumerOp.getTiedIndexingMap(consumerOperand)); - if (!consumerToProducerLoopsMap.hasValue()) - return {}; - - // Compute the set of tiled producer loops. - DenseSet tiledProducerLoops; - for (auto en : enumerate(consumerToProducerLoopsMap->getResults())) { - for (int64_t dim : tiledLoopDims) { - if (en.value().isFunctionOfDim(dim)) - tiledProducerLoops.insert(en.index()); + AffineMap indexingMap = consumerOp.getTiedIndexingMap(consumerOperand); + + // Search the slice dimensions tiled by a tile loop dimension. + DenseSet tiledSliceDims; + for (auto en : enumerate(indexingMap.getResults())) { + for (auto tiledLoopDim : tiledLoopDims) { + if (en.value().isFunctionOfDim(tiledLoopDim)) + tiledSliceDims.insert(en.index()); } } - - // Compute the slice dimensions for the tiled producer loops. - SmallVector tiledSliceDims; - for (auto en : enumerate(producerIndexingMap.getResults())) { - auto dimExpr = en.value().dyn_cast(); - if (dimExpr && tiledProducerLoops.count(dimExpr.getPosition()) != 0) - tiledSliceDims.push_back(en.index()); - } - return tiledSliceDims; + return {tiledSliceDims.begin(), tiledSliceDims.end()}; } /// Returns the producer fused in place of `sliceOp`. Tile the producer operands @@ -328,9 +293,10 @@ if (!producerResult || !isa(producerResult.getOwner())) return failure(); - // Compute the slice dimensions tiled by `tileLoopNest`. + // Compute the tiled producer slice dimensions given the tiled root operation + // loop dimensions `loopDims`. SmallVector tiledSliceDims = - getTiledSliceDims(producerResult, rootOpOperand, loopDims); + getTiledSliceDims(rootOpOperand, loopDims); if (tiledSliceDims.empty()) return failure(); diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir --- a/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir @@ -230,3 +230,39 @@ return %1 : tensor<24x25xi32> } +// ----- + +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0 + d1)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (8, -d0 - d1 + 18)> +// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, -d1 - d2 + 18)> +#map0 = affine_map<(d0, d1) -> (d0, d0 + d1)> +#map1 = affine_map<(d0, d1) -> (d0, d1)> + +// CHECK: fuse_non_rectangular +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<10x18xf32> +func @fuse_non_rectangular(%arg0: tensor<10x18xf32>, + %arg1: tensor<10x8xf32>) -> tensor<10x8xf32> { + %cst = constant 0.000000e+00 : f32 + %0 = linalg.fill(%cst, %arg0) : f32, tensor<10x18xf32> -> tensor<10x18xf32> + + // CHECK: scf.for %[[IV0:[0-9a-zA-Z]*]] = + // CHECK: scf.for %[[IV1:[0-9a-zA-Z]*]] = + + // Compute producer on a hyper rectangular bounding box. Along the second dimenson, + // the offset is set to the sum of the induction variables and the upper bound + // to either eight (sum of the tile sizes) or eighteen (sum of the domain sizes) + // minus the induction variables. + // CHECK: %[[SUM:.*]] = affine.apply #[[MAP0]](%[[IV1]], %[[IV0]] + // CHECK: %[[TS1:.*]] = affine.min #[[MAP1]](%[[IV1]], %[[IV0]] + // CHECK: %[[UB1:.*]] = affine.min #[[MAP2]](%[[TS1]], %[[IV1]], %[[IV0]] + // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG0]] + // CHECK-SAME: %[[IV1]], %[[SUM]] + // CHECK-SAME: , %[[UB1]] + // CHECK: %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]]) + %1 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel"]} ins(%0 : tensor<10x18xf32>) outs(%arg1 : tensor<10x8xf32>) { + ^bb0(%arg2: f32, %arg3: f32): // no predecessors + %2 = addf %arg2, %arg3 : f32 + linalg.yield %2 : f32 + } -> tensor<10x8xf32> + return %1 : tensor<10x8xf32> +}