diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -233,8 +233,8 @@ bool hasOtherUses(BlockArgument bbArg, tensor::ExtractSliceOp sliceOp); LinalgOp rootOp; - SmallVector loopOps; - SmallVector loopDims; + SmallVector tileLoopOps; + DenseMap> tiledRootAndFusedOpsLoops; }; /// Tiles `consumerOp` and fuses its dependencies if possible. Uses the diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -52,6 +52,31 @@ return {tiledSliceDims.begin(), tiledSliceDims.end()}; } +/// Returns the tiled producer loop dimensions mapped to the tiled result slice +/// dimensions `tiledSliceDims`. +SmallVector getTiledProducerLoops(OpResult producerResult, + ArrayRef tiledSliceDims) { + // Get the producer. + LinalgOp producerOp = producerResult.getOwner(); + + // Get the producer result indexing map. + AffineMap producerIndexingMap = producerOp.getTiedIndexingMap( + producerOp.getOutputOperand(producerResult.getResultNumber())); + + // Compute the producer loops assuming the indexing map is a permutation. + SmallVector tiledProducerLoops; + transform(tiledSliceDims, std::back_inserter(tiledProducerLoops), + [&](int64_t tiledSliceDim) { + AffineExpr result = + producerIndexingMap.getResults()[tiledSliceDim]; + assert(result.isa() && + "expect producer indexing map is a projected permutation"); + return result.cast().getPosition(); + }); + + return tiledProducerLoops; +} + /// Returns the producer fused in place of `sliceOp`. Tile the producer operands /// along the `tiledSliceDims` and clone the producer. Consider the case of /// fusion of an output tensor: @@ -85,6 +110,7 @@ static LinalgOp getTiledProducer(OpBuilder &b, OpResult producerResult, tensor::ExtractSliceOp sliceOp, ArrayRef tiledSliceDims, + ArrayRef tiledProducerLoops, OpOperand *iterArg) { // Clone the producer after `sliceOp` since the slice may be reused to pass in // the producer result. @@ -102,23 +128,16 @@ [](Range range) { return range.size; }); SmallVector sliceOpRanges = sliceOp.getOrCreateRanges(b, loc); - // Get the producer result indexing map. - AffineMap producerIndexingMap = producerOp.getTiedIndexingMap( - producerOp.getOutputOperand(producerResult.getResultNumber())); - // Tile the producer operands given the `sliceOp` ranges. Iterate the // `tiledSliceDims` and store the tile offset and size for the tiled slice - // dimension. Assumes the mapping from slice dimensions to producer loops is a - // permutation. + // dimension. auto zero = b.create(loc, 0); SmallVector tileIvs(producerOp.getNumLoops(), nullptr); SmallVector tileSizes(producerOp.getNumLoops(), zero); SmallVector allIvs(producerOp.getNumLoops(), nullptr); - for (int64_t tiledSliceDim : tiledSliceDims) { - AffineExpr result = producerIndexingMap.getResults()[tiledSliceDim]; - assert(result.isa() && - "expect producer indexing map is a projected permutation"); - int64_t tiledProducerLoop = result.cast().getPosition(); + for (auto it : zip(tiledSliceDims, tiledProducerLoops)) { + int64_t tiledSliceDim, tiledProducerLoop; + std::tie(tiledSliceDim, tiledProducerLoop) = it; tileIvs[tiledProducerLoop] = sliceOpRanges[tiledSliceDim].offset; tileSizes[tiledProducerLoop] = sliceOpRanges[tiledSliceDim].size; allIvs[tiledProducerLoop] = tileIvs[tiledProducerLoop]; @@ -156,22 +175,26 @@ // TileLoopNest specific helpers. //===----------------------------------------------------------------------===// -bool TileLoopNest::isEmpty() { return loopOps.empty(); } +bool TileLoopNest::isEmpty() { return tileLoopOps.empty(); } bool TileLoopNest::isValid() { - // Check if the number of `tileLoopOps` and `tileLoopDims` match. - if (loopOps.size() != loopDims.size()) + // Check if `rootOp` has been tiled at least once. + if (isEmpty() || tiledRootAndFusedOpsLoops.count(rootOp) == 0) + return false; + + // Check if the number of loop operations and dimensions match. + if (tileLoopOps.size() != tiledRootAndFusedOpsLoops[rootOp].size()) return false; // Check if the innermost tile loop is the parent of `tiledOp`. - if (rootOp->getParentOp() != loopOps.back()) + if (rootOp->getParentOp() != tileLoopOps.back()) return false; // Check if the tile loops are directly nested. - return std::adjacent_find(loopOps.begin(), loopOps.end(), + return std::adjacent_find(tileLoopOps.begin(), tileLoopOps.end(), [](Operation *op1, Operation *op2) { return op1 != op2->getParentOp(); - }) == loopOps.end(); + }) == tileLoopOps.end(); } SmallVector TileLoopNest::getTiedBBArgs(BlockArgument bbArg) { @@ -179,7 +202,7 @@ SmallVector bbArgs; // Search all tile loop block arguments from inner to outer. - for (auto tileLoop : reverse(loopOps)) { + for (auto tileLoop : reverse(tileLoopOps)) { if (bbArg.getOwner()->getParentOp() != tileLoop) return {}; bbArgs.push_back(bbArg); @@ -194,9 +217,9 @@ OpOperand *TileLoopNest::getTiedIterArg(BlockArgument bbArg) { // Search all block arguments and return the matching iteration argument. SmallVector bbArgs = getTiedBBArgs(bbArg); - if (bbArgs.size() != loopOps.size()) + if (bbArgs.size() != tileLoopOps.size()) return nullptr; - return &loopOps.front().getOpOperandForRegionIterArg(bbArgs.front()); + return &tileLoopOps.front().getOpOperandForRegionIterArg(bbArgs.front()); } bool TileLoopNest::hasOtherUses(BlockArgument bbArg, @@ -255,24 +278,29 @@ if (!isEmpty()) rootOp->replaceAllUsesWith(tiledRootOp->tensorResults); + // Transfer the stored `rootOp` loop dimensions if it has been tiled before. + if (tiledRootAndFusedOpsLoops.count(rootOp) != 0) { + tiledRootAndFusedOpsLoops[tiledRootOp->op] = + tiledRootAndFusedOpsLoops[rootOp]; + } + // Update the root operation and append the loops and tile loop dimensions. rootOp = tiledRootOp->op; - loopOps.append(tiledRootOp->loops.begin(), tiledRootOp->loops.end()); + tileLoopOps.append(tiledRootOp->loops.begin(), tiledRootOp->loops.end()); for (auto en : enumerate(tileSizes)) { // Copy only the tiled loop dimensions with non-zero tile size. if (en.value() == 0) continue; - loopDims.push_back(tileInterchange[en.index()]); + tiledRootAndFusedOpsLoops[rootOp].push_back(tileInterchange[en.index()]); } assert(isValid() && "expect tile loop nest to be valid after tiling"); - return success(); } FailureOr TileLoopNest::fuseProducer(OpBuilder &b, - OpOperand *rootOpOperand) { - assert(rootOpOperand->getOwner() == rootOp && - "expect the root op to be the owner of the operand to fuse"); + OpOperand *consumerOpOperand) { + assert(tiledRootAndFusedOpsLoops.count(consumerOpOperand->getOwner()) != 0 && + "expect the operand owner is the root operation or a fused producer"); assert(this->isValid() && "expect the tile loop nest to satisfy all invariants"); @@ -280,13 +308,16 @@ if (isEmpty()) return failure(); - // Check `rootOpOperand` is defined by an ExtractSliceOp. - auto sliceOp = rootOpOperand->get().getDefiningOp(); + // Check `consumerOpOperand` is defined by an ExtractSliceOp. + auto sliceOp = + consumerOpOperand->get().getDefiningOp(); if (!sliceOp) return failure(); - // Check `sliceOp` is tiled by the tile loop nest. - if (sliceOp->getParentOp() != rootOp->getParentOp()) + // Check `sliceOp` and `consumerOp` are tiled by the tile loop nest. + LinalgOp consumerOp = consumerOpOperand->getOwner(); + if (sliceOp->getParentOp() != rootOp->getParentOp() || + consumerOp->getParentOp() != rootOp->getParentOp()) return failure(); // Check if the producer is a LinalgOp possibly passed by iteration argument. @@ -302,19 +333,23 @@ if (!producerResult || !isa(producerResult.getOwner())) return failure(); - // Compute the tiled producer slice dimensions given the tiled root operation - // loop dimensions `loopDims`. - SmallVector tiledSliceDims = - getTiledSliceDims(rootOpOperand, loopDims); + // Compute the tiled producer slice dimensions given the tiled consumer loops. + SmallVector tiledSliceDims = getTiledSliceDims( + consumerOpOperand, tiledRootAndFusedOpsLoops[consumerOp]); if (tiledSliceDims.empty()) return failure(); + // Compute the tiled producer loops. + SmallVector tiledProducerLoops = + getTiledProducerLoops(producerResult, tiledSliceDims); + // Tile the producer operands and clone the producer in place of `sliceOp`. - LinalgOp clonedOp = - getTiledProducer(b, producerResult, sliceOp, tiledSliceDims, iterArg); + LinalgOp clonedOp = getTiledProducer( + b, producerResult, sliceOp, tiledSliceDims, tiledProducerLoops, iterArg); + tiledRootAndFusedOpsLoops[clonedOp] = tiledProducerLoops; // Cast the `clonedOp` result to gap type mismatches before canonicalization. - Type consumerOperandType = rootOpOperand->get().getType(); + Type consumerOperandType = consumerOpOperand->get().getType(); Value newResult = clonedOp->getResult(producerResult.getResultNumber()); if (newResult.getType() != consumerOperandType) { OpBuilder::InsertionGuard guard(b); @@ -330,7 +365,7 @@ ValueRange TileLoopNest::getRootOpReplacementResults() { assert(!isEmpty() && "expect tile loop nest to be non-empty"); - return loopOps.front()->getOpResults(); + return tileLoopOps.front()->getOpResults(); } //===----------------------------------------------------------------------===// @@ -359,14 +394,26 @@ }); int64_t split = std::distance(iterTypes.begin(), it); + // Helper to fuse the producers greedily using a queue of fusion candidates. + auto fuseProducersGreedily = [&](ArrayRef operands) { + SmallVector candidates(operands.begin(), operands.end()); + while (!candidates.empty()) { + FailureOr fusedProducer = + tileLoopNest.fuseProducer(b, candidates.back()); + candidates.pop_back(); + if (failed(fusedProducer)) + continue; + candidates.append(fusedProducer->getInputAndOutputOperands()); + } + }; + // Tile the outer parallel loops and fuse the output operands. SmallVector outerTileSizes; outerTileSizes.append(tileSizes.begin(), tileSizes.begin() + split); outerTileSizes.append(tileSizes.size() - split, 0); if (failed(tileLoopNest.tileRootOp(b, outerTileSizes, tileInterchange))) return failure(); - for (OpOperand *opOperand : tileLoopNest.getRootOp().getOutputOperands()) - (void)tileLoopNest.fuseProducer(b, opOperand); + fuseProducersGreedily(tileLoopNest.getRootOp().getOutputOperands()); // Tile the remaining loops and fuse the input operands. SmallVector innerTileSizes; @@ -374,10 +421,7 @@ innerTileSizes.append(tileSizes.begin() + split, tileSizes.end()); if (failed(tileLoopNest.tileRootOp(b, innerTileSizes, tileInterchange))) return failure(); - SmallVector inputOperands = - tileLoopNest.getRootOp().getInputOperands(); - for (OpOperand *opOperand : tileLoopNest.getRootOp().getInputOperands()) - (void)tileLoopNest.fuseProducer(b, opOperand); + fuseProducersGreedily(tileLoopNest.getRootOp().getInputOperands()); return tileLoopNest; } diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-sequence-on-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-sequence-on-tensors.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/tile-and-fuse-sequence-on-tensors.mlir @@ -0,0 +1,79 @@ +// RUN: mlir-opt %s -linalg-tile-and-fuse-tensor-ops="tile-sizes=4,4,0,0 tile-interchange=0,1,2,3" -cse --canonicalize -split-input-file | FileCheck %s + +// CHECK: fuse_conv_chain +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<2x2xf32> +// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<11x11xf32> +// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]*]]: tensor<10x10xf32> +// CHECK-SAME: %[[ARG3:[0-9a-zA-Z]*]]: tensor<9x9xf32> +// CHECK-SAME: %[[ARG4:[0-9a-zA-Z]*]]: tensor<8x8xf32> +builtin.func @fuse_conv_chain(%arg0: tensor<2x2xf32>, + %arg1: tensor<11x11xf32>, + %arg2: tensor<10x10xf32>, + %arg3: tensor<9x9xf32>, + %arg4: tensor<8x8xf32>) -> tensor<8x8xf32> { + %cst = arith.constant 1.0 : f32 + + // Do not tile the filter fill since the filter dimensions are not tiled. + // CHECK: %[[T0:.*]] = linalg.fill(%{{.*}}, %[[ARG0]]) + %0 = linalg.fill(%cst, %arg0) : f32, tensor<2x2xf32> -> tensor<2x2xf32> + + // Fuse all other operations. + // CHECK: scf.for %[[IV0:.*]] = {{.*}} iter_args(%[[ARG5:.*]] = %[[ARG4]] + // CHECK: scf.for %[[IV1:.*]] = {{.*}} iter_args(%[[ARG6:.*]] = %[[ARG5]] + + // CHECK: %[[T1:.*]] = tensor.extract_slice %[[ARG1]] + // CHECK-SAME: %[[IV0]], %[[IV1]] + // CHECK: %[[T2:.*]] = tensor.extract_slice %[[ARG2]] + // CHECK-SAME: %[[IV0]], %[[IV1]] + // CHECK: %[[T3:.*]] = linalg.fill(%{{.*}}, %[[T2]]) + // CHECK: %[[T4:.*]] = linalg.conv_2d ins(%[[T1]], %[[T0]] : {{.*}} outs(%[[T3]] + %1 = linalg.fill(%cst, %arg2) : f32, tensor<10x10xf32> -> tensor<10x10xf32> + %2 = linalg.conv_2d ins(%arg1, %0 : tensor<11x11xf32>, tensor<2x2xf32>) outs(%1 : tensor<10x10xf32>) -> tensor<10x10xf32> + + // CHECK: %[[T5:.*]] = tensor.extract_slice %[[ARG3]] + // CHECK-SAME: %[[IV0]], %[[IV1]] + // CHECK: %[[T6:.*]] = linalg.fill(%{{.*}}, %[[T5]]) + // CHECK: %[[T7:.*]] = linalg.conv_2d ins(%[[T4]], %[[T0]] : {{.*}} outs(%[[T6]] + %3 = linalg.fill(%cst, %arg3) : f32, tensor<9x9xf32> -> tensor<9x9xf32> + %4 = linalg.conv_2d ins(%2, %0 : tensor<10x10xf32>, tensor<2x2xf32>) outs(%3 : tensor<9x9xf32>) -> tensor<9x9xf32> + + // Use the argument passed in by iteration argument. + // CHECK: %[[T8:.*]] = tensor.extract_slice %[[ARG6]] + // CHECK-SAME: %[[IV0]], %[[IV1]] + // CHECK: %[[T9:.*]] = linalg.fill(%{{.*}}, %[[T8]]) + // CHECK: %[[T5:.*]] = linalg.conv_2d ins(%[[T7]], %[[T0]] {{.*}} outs(%[[T9]] + %5 = linalg.fill(%cst, %arg4) : f32, tensor<8x8xf32> -> tensor<8x8xf32> + %6 = linalg.conv_2d ins(%4, %0 : tensor<9x9xf32>, tensor<2x2xf32>) outs(%5 : tensor<8x8xf32>) -> tensor<8x8xf32> + return %6 : tensor<8x8xf32> +} + +// ----- + +// CHECK: fuse_matmul_chain +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<8x8xf32> +builtin.func @fuse_matmul_chain(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { + %c0 = arith.constant 0 : index + %c12 = arith.constant 12 : index + %c25 = arith.constant 25 : index + %c24 = arith.constant 24 : index + %c4 = arith.constant 4 : index + %cst = arith.constant 0.000000e+00 : f32 + + // Do not fuse rhs fill of the producer matmul since only its outermost loop is tiled. + // CHECK: %[[T0:.*]] = linalg.fill(%{{.*}}, %[[ARG0]]) + %0 = linalg.fill(%cst, %arg0) : f32, tensor<8x8xf32> -> tensor<8x8xf32> + + // CHECK: scf.for %[[IV0:[0-9a-zA-Z]*]] = + // CHECK: scf.for %[[IV1:[0-9a-zA-Z]*]] = + + // Only the outermost loop of the producer matmul is tiled. + // CHECK: %[[T1:.*]] = tensor.extract_slice %[[ARG0]] + // CHECK-SAME: %[[IV0]], 0 + // CHECK: %[[T2:.*]] = linalg.fill(%{{.*}}, %[[T1]]) + // CHECK: %[[T3:.*]] = linalg.matmul ins(%[[T2]], %[[T0]] {{.*}} + %1 = linalg.matmul ins(%0, %0 : tensor<8x8xf32>, tensor<8x8xf32>) outs(%0 : tensor<8x8xf32>) -> tensor<8x8xf32> + + // CHECK: %{{.*}} = linalg.matmul ins(%[[T3]] + %2 = linalg.matmul ins(%1, %0 : tensor<8x8xf32>, tensor<8x8xf32>) outs(%0 : tensor<8x8xf32>) -> tensor<8x8xf32> + return %2 : tensor<8x8xf32> +}