diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -42,6 +42,7 @@ SmallVector fusedProducers; SmallVector fusedLoops; SmallVector, 1> unfusedLoops; + SmallVector promotedBuffers; }; /// Populates patterns for vectorization of all ConvN-D ops. @@ -72,83 +73,6 @@ Optional tileLinalgOp(OpBuilder &b, LinalgOp op, const LinalgTilingOptions &options); -/// Fuse a sequence of linalg operations (`ops`) using tile-and-fuse. This -/// proceeds as follows: -/// - Find outer parallel loops in these ops that can be fused. -/// - Tile fusable outer parallel loops of the last operation in the sequence. -/// - Fuse the remaining operations with the tiled operation -/// - Tile the unfused loops in each of the fused operations id needed. -/// -/// For example, consider the sequence of matmul below -/// -/// linalg.matmul ins(%arg0, %arg1 : memref<256x32xf32>, memref<32x32xf32>) -/// outs(%arg2 : memref<256x32xf32>) -/// linalg.matmul ins(%arg2, %arg3 : memref<256x32xf32>, memref<32x32xf32>) -/// outs(%arg4 : memref<256x32xf32>) -/// -/// It is legal to fuse the RAW dependence (through %arg2) by only fusing the -/// matmuls row-wise. For example, the fused computation for the above is shown -/// below. The outer `scf.parallel` loop is the "fused" loop obtained by tiling -/// along the rows of the matrix. The entire rows of the first matmul operation -/// need to be computed before they can be used for the second matmul. The -/// second matmul is further tiled (similar to normal tiling). -/// -/// #map0 = affine_map<(d0, d1)[s0] -> (d0 * 32 + s0 + d1)> -/// #map1 = affine_map<(d0, d1) -> (d0 * 32 + d1)> -/// scf.parallel (%arg5) = (%c0) to (%c256) step (%c16) { -/// %0 = subview %arg2[%arg5, 0] [16, 32] [1, 1] -/// : memref<256x32xf32> to memref<16x32xf32, #map0> -/// %1 = subview %arg4[%arg5, 0] [16, 32] [1, 1] -/// : memref<256x32xf32> to memref<16x32xf32, #map0> -/// %2 = subview %arg0[%arg5, 0] [16, 32] [1, 1] -/// : memref<256x32xf32> to memref<16x32xf32, #map0> -/// %3 = subview %arg1[0, 0] [32, 32] [1, 1] -/// : memref<32x32xf32> to memref<32x32xf32, #map1> -/// linalg.matmul -/// ins(%2, %3 : memref<16x32xf32, #map0>, memref<32x32xf32, #map1>) -/// outs(%0 : memref<16x32xf32, #map0>) -/// scf.parallel (%arg6) = (%c0) to (%c32) step (%c8) { -/// scf.for %arg7 = %c0 to %c32 step %c4 { -/// %4 = subview %0[0, %arg7] [16, 4] [1, 1] -/// : memref<16x32xf32, #map0> to memref<16x4xf32, #map0> -/// %5 = subview %arg3[%arg7, %arg6] [4, 8] [1, 1] -/// : memref<32x32xf32> to memref<4x8xf32, #map0> -/// %6 = subview %1[0, %arg6] [16, 8] [1, 1] -/// : memref<16x32xf32, #map0> to memref<16x8xf32, #map0> -/// linalg.matmul -/// ins(%4, %5 : memref<16x4xf32, #map0>, memref<4x8xf32, #map0>) -/// outs(%6 : memref<16x8xf32, #map0>) -/// } -/// scf.yield -/// } -/// scf.yield -/// } -/// -/// `tilingOptions` are used to tile the corresponding operation in `ops` (the -/// size of the former should be same as size of the latter. Based on how -/// tile+fuse is implemented, the fused loops are generated based on the last -/// operation in the sequence. For example, the tile sizes for the fused loops -/// is obtained from `tilingOptions.back()`. The following tiling options are -/// handled differently in tile+fuse (compared to tile only) -/// - Interchange of the tiling loops is not supported right now. -/// - Distribution is only done for the fused loops. The tiled loops -/// generated by the second tiling is not distributed. -Optional -tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef ops, - const LinalgDependenceGraph &dependenceGraph, - ArrayRef tilingOptions); - -/// Interchanges the `iterator_types` and `iterator_maps` dimensions of `op`. -/// This is an in-place transformation controlled by `interchangeVector`. -/// An empty vector is interpreted as the identity permutation and the -/// transformation returns early. -/// -/// E.g. the permutation `(i,j,k) -> (j,k,i)` is expressed with -/// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be -/// integers, in the range 0..`op.rank` without duplications -/// (i.e. `[1,1,2]` is an invalid permutation). -LinalgOp interchange(LinalgOp op, ArrayRef interchangeVector); - /// Callback function type used to perform the allocation for the promoted /// `subView`. In `boundingSubViewsize` a best attempt is made to find the /// smallest constant value for the size of the buffer needed for each @@ -242,6 +166,98 @@ } }; +/// Fuse a sequence of linalg operations (`ops`) using tile-and-fuse. This +/// proceeds as follows: +/// - Find outer parallel loops in these ops that can be fused. +/// - Tile fusable outer parallel loops of the last operation in the sequence. +/// - Fuse the remaining operations with the tiled operation +/// - Tile the unfused loops in each of the fused operations id needed. +/// +/// For example, consider the sequence of matmul below +/// +/// linalg.matmul ins(%arg0, %arg1 : memref<256x32xf32>, memref<32x32xf32>) +/// outs(%arg2 : memref<256x32xf32>) +/// linalg.matmul ins(%arg2, %arg3 : memref<256x32xf32>, memref<32x32xf32>) +/// outs(%arg4 : memref<256x32xf32>) +/// +/// It is legal to fuse the RAW dependence (through %arg2) by only fusing the +/// matmuls row-wise. For example, the fused computation for the above is shown +/// below. The outer `scf.parallel` loop is the "fused" loop obtained by tiling +/// along the rows of the matrix. The entire rows of the first matmul operation +/// need to be computed before they can be used for the second matmul. The +/// second matmul is further tiled (similar to normal tiling). +/// +/// #map0 = affine_map<(d0, d1)[s0] -> (d0 * 32 + s0 + d1)> +/// #map1 = affine_map<(d0, d1) -> (d0 * 32 + d1)> +/// scf.parallel (%arg5) = (%c0) to (%c256) step (%c16) { +/// %0 = subview %arg2[%arg5, 0] [16, 32] [1, 1] +/// : memref<256x32xf32> to memref<16x32xf32, #map0> +/// %1 = subview %arg4[%arg5, 0] [16, 32] [1, 1] +/// : memref<256x32xf32> to memref<16x32xf32, #map0> +/// %2 = subview %arg0[%arg5, 0] [16, 32] [1, 1] +/// : memref<256x32xf32> to memref<16x32xf32, #map0> +/// %3 = subview %arg1[0, 0] [32, 32] [1, 1] +/// : memref<32x32xf32> to memref<32x32xf32, #map1> +/// linalg.matmul +/// ins(%2, %3 : memref<16x32xf32, #map0>, memref<32x32xf32, #map1>) +/// outs(%0 : memref<16x32xf32, #map0>) +/// scf.parallel (%arg6) = (%c0) to (%c32) step (%c8) { +/// scf.for %arg7 = %c0 to %c32 step %c4 { +/// %4 = subview %0[0, %arg7] [16, 4] [1, 1] +/// : memref<16x32xf32, #map0> to memref<16x4xf32, #map0> +/// %5 = subview %arg3[%arg7, %arg6] [4, 8] [1, 1] +/// : memref<32x32xf32> to memref<4x8xf32, #map0> +/// %6 = subview %1[0, %arg6] [16, 8] [1, 1] +/// : memref<16x32xf32, #map0> to memref<16x8xf32, #map0> +/// linalg.matmul +/// ins(%4, %5 : memref<16x4xf32, #map0>, memref<4x8xf32, #map0>) +/// outs(%6 : memref<16x8xf32, #map0>) +/// } +/// scf.yield +/// } +/// scf.yield +/// } +/// +/// `tilingOptions` are used to tile the corresponding operation in `ops` (the +/// size of the former should be same as size of the latter. Based on how +/// tile+fuse is implemented, the fused loops are generated based on the last +/// operation in the sequence. For example, the tile sizes for the fused loops +/// is obtained from `tilingOptions.back()`. The following tiling options are +/// handled differently in tile+fuse (compared to tile only) +/// - Interchange of the tiling loops is not supported right now. +/// - Distribution is only done for the fused loops. The tiled loops +/// generated by the second tiling is not distributed. +Optional +tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef ops, + const LinalgDependenceGraph &dependenceGraph, + ArrayRef tilingOptions, + AllocBufferCallbackFn allocationFn = nullptr); + +/// Interchanges the `iterator_types` and `iterator_maps` dimensions of `op`. +/// This is an in-place transformation controlled by `interchangeVector`. +/// An empty vector is interpreted as the identity permutation and the +/// transformation returns early. +/// +/// E.g. the permutation `(i,j,k) -> (j,k,i)` is expressed with +/// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be +/// integers, in the range 0..`op.rank` without duplications +/// (i.e. `[1,1,2]` is an invalid permutation). +LinalgOp interchange(LinalgOp op, ArrayRef interchangeVector); + +/// Creates a new buffer using the `allocationFn` provided. The size of this +/// buffer is the smallest constant bounding size along each dimension that can +/// be computed for the size of the result of `subView`. Returns the allocated +/// buffer as `fullLocalView` and the view that matches the size of the result +/// of subview operation as `partialLocalView`. +struct PromotionInfo { + Value fullLocalView; + Value partialLocalView; +}; +Optional +promoteSubviewAsNewBuffer(OpBuilder &b, Location loc, SubViewOp subView, + AllocBufferCallbackFn allocationFn, + OperationFolder *folder = nullptr); + /// Promotes the `subViews` into a new buffer allocated at the insertion point /// `b`. Promotion occurs in 3 steps: /// 1. Create a new buffer for a full tile (i.e. not clipped at the boundary). diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -25,6 +25,7 @@ #include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/MapVector.h" +#include "llvm/ADT/SetVector.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -710,7 +711,7 @@ return fusableLoops; } -/// Find all dependences that are to be fusable. +/// Find all dependences that are fusable. static FusableOpDependencesTy findAllFusableDependences(ArrayRef ops, const LinalgDependenceGraph &dependenceGraph) { @@ -737,6 +738,7 @@ op.emitRemark("unhandled fusion of ops in different basic blocks"); return FusableOpDependencesTy{}; } + // Make sure that the indexing map of the view used for fusion in the // producer is a projected permutation. unsigned producerIdx = fusableDependence->dependentOpView.operandIndex; @@ -804,27 +806,83 @@ /// Fuse the operations in `fusionCandidates` with `tiledOp`. Latter is expected /// to be a tiled operation such that it is valid to fuse all operations in /// `fusionCandidates`, i.e. move the operation within the inter-tile loops of -/// `tiledOp`. -static SmallVector -fuseOperations(OpBuilder &builder, LinalgOp tiledOp, +/// `tiledOp`. If `allocationFn` is not `nullptr`, the fused view will be +/// promoted. The function will be called to allocate memory for the promoted +/// view. +struct FuseAndPromoteInfo { + SmallVector fusedOps; + SmallVector promotedBuffers; +}; +static Optional +fuseOperations(OpBuilder &builder, LinalgOp rootOp, LinalgOp tiledOp, ArrayRef fusionCandidates, + const std::set &fusedLoops, const FusableOpDependencesTy &fusableDependences, - const std::set &fusedLoops) { + AllocBufferCallbackFn allocationFn = nullptr) { OpBuilder::InsertionGuard guard(builder); builder.setInsertionPoint(tiledOp); DenseMap fusedLoopsAndRanges; - for (unsigned loop : fusedLoops) { - ShapeDimension shapeDim = getShapeDefiningLoopRange(tiledOp, loop); - fusedLoopsAndRanges[loop] = getRangeFromOperandShape( + for (auto loop : llvm::enumerate(fusedLoops)) { + ShapeDimension shapeDim = getShapeDefiningLoopRange(tiledOp, loop.value()); + fusedLoopsAndRanges[loop.value()] = getRangeFromOperandShape( builder, tiledOp.getLoc(), shapeDim.shape, shapeDim.dimension); } - SmallVector fusedOps(fusionCandidates.size()); - for (auto candidate : enumerate(llvm::reverse(fusionCandidates))) { - LinalgOp fusedOp = fuse(builder, candidate.value(), fusedLoopsAndRanges); - fusedOps[fusionCandidates.size() - candidate.index() - 1] = fusedOp; + + FuseAndPromoteInfo info; + DenseSet promotedViews; + DenseMap originalToFusedMap; + originalToFusedMap[rootOp.getOperation()] = tiledOp; + info.fusedOps.resize(fusionCandidates.size()); + for (auto candidateOps : enumerate(llvm::reverse(fusionCandidates))) { + unsigned candidateIndex = + fusionCandidates.size() - 1 - candidateOps.index(); + LinalgOp candidate = candidateOps.value(); + // There is at least one dependence from the candidate for fusion. + auto dependences = fusableDependences.lookup(candidate); + if (dependences.empty()) { + candidate.emitRemark( + "missing dependence for operation that is to be fused"); + return llvm::None; + } + LinalgOp fusedOp = fuse(builder, candidate, fusedLoopsAndRanges); builder.setInsertionPoint(fusedOp); + + originalToFusedMap[candidate.getOperation()] = fusedOp; + unsigned producerIdx = dependences.front().dependentOpView.operandNum; + LinalgOp consumer = cast(dependences.front().indexingOpView.op); + unsigned consumerIdx = dependences.front().indexingOpView.operandNum; + if (allocationFn) { + // Promote the fused view. If the view in the consumer has already been + // promoted (since fusion+promotion happens in reverse), use that instead. + Value promotedView = nullptr; + LinalgOp fusedConsumer = originalToFusedMap[consumer.getOperation()]; + if (!fusedConsumer) { + consumer.emitRemark("unable to find the fused consumer"); + return llvm::None; + } + Value consumerView = fusedConsumer.getShapedOperand(consumerIdx); + if (promotedViews.count(consumerView)) { + promotedView = consumerView; + } else { + Optional promotionInfo = promoteSubviewAsNewBuffer( + builder, fusedOp.getLoc(), + fusedOp.getShapedOperand(producerIdx).getDefiningOp(), + allocationFn); + if (!promotionInfo) { + fusedOp.emitRemark("unable to promote operand ") << producerIdx; + return llvm::None; + } + info.promotedBuffers.push_back(promotionInfo->fullLocalView); + fusedConsumer.getOperation()->setOperand( + consumerIdx, promotionInfo->partialLocalView); + promotedView = promotionInfo->partialLocalView; + promotedViews.insert(promotedView); + } + fusedOp.getOperation()->setOperand(producerIdx, promotedView); + } + info.fusedOps[candidateIndex] = fusedOp; } - return fusedOps; + return info; } /// Post fusion, tile all the unfused loops of the fused operations. @@ -852,7 +910,8 @@ static Optional tileAndFuseLinalgOpsImpl(OpBuilder &builder, ArrayRef ops, const LinalgDependenceGraph &dependenceGraph, - ArrayRef tilingOptions) { + ArrayRef tilingOptions, + AllocBufferCallbackFn allocationFn) { if (ops.empty()) return llvm::None; LinalgOp rootOp = ops.back(); @@ -911,8 +970,16 @@ ret.fusedLoops.assign(tiledRootOp->loops.begin(), tiledRootOp->loops.end()); // Fuse the other operations into the fused inter-tile loops produced above. - ret.fusedProducers = fuseOperations(builder, ret.op, ops.drop_back(), - fusableDependences, tileFuseLoops); + { + Optional fuseAndPromoteInfo = + fuseOperations(builder, rootOp, ret.op, ops.drop_back(), tileFuseLoops, + fusableDependences, allocationFn); + if (!fuseAndPromoteInfo) { + return llvm::None; + } + ret.fusedProducers = std::move(fuseAndPromoteInfo->fusedOps); + ret.promotedBuffers = std::move(fuseAndPromoteInfo->promotedBuffers); + } // Tile the unfused loops for the fused ops. ret.unfusedLoops.resize(ops.size()); @@ -953,17 +1020,18 @@ return ret; } -Optional mlir::linalg::tileAndFuseLinalgOps( - OpBuilder &builder, ArrayRef ops, - const LinalgDependenceGraph &dependenceGraph, - ArrayRef tilingOptions) { +Optional +mlir::linalg::tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef ops, + const LinalgDependenceGraph &dependenceGraph, + ArrayRef tilingOptions, + AllocBufferCallbackFn allocationFn) { switch (tilingOptions.back().loopType) { case LinalgTilingLoopType::Loops: return tileAndFuseLinalgOpsImpl(builder, ops, dependenceGraph, - tilingOptions); + tilingOptions, allocationFn); case LinalgTilingLoopType::ParallelLoops: return tileAndFuseLinalgOpsImpl( - builder, ops, dependenceGraph, tilingOptions); + builder, ops, dependenceGraph, tilingOptions, allocationFn); default:; } return llvm::None; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -166,11 +166,6 @@ /// Alignment of promoted buffer. Optional alignment; }; - -struct PromotionInfo { - Value fullLocalView; - Value partialLocalView; -}; } // namespace LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions( @@ -233,10 +228,9 @@ // To account for general boundary effects, padding must be performed on the // boundary tiles. For now this is done with an unconditional `fill` op followed // by a partial `copy` op. -static Optional -promoteSubviewAsNewBuffer(OpBuilder &b, Location loc, SubViewOp subView, - LinalgOpInstancePromotionOptions const &options, - OperationFolder *folder) { +Optional mlir::linalg::promoteSubviewAsNewBuffer( + OpBuilder &b, Location loc, SubViewOp subView, + AllocBufferCallbackFn allocationFn, OperationFolder *folder) { auto viewType = subView.getType(); auto rank = viewType.getRank(); SmallVector fullSizes, partialSizes; @@ -254,8 +248,7 @@ SmallVector dynSizes(fullSizes.size(), -1); // If a callback is not specified, then use the default implementation for // allocating the promoted buffer. - Optional fullLocalView = - options.allocationFn(b, subView, fullSizes, folder); + Optional fullLocalView = allocationFn(b, subView, fullSizes, folder); if (!fullLocalView) return {}; auto zero = folded_std_constant_index(folder, 0); @@ -279,8 +272,8 @@ for (auto v : options.subViews) { SubViewOp subView = cast(v.second.getDefiningOp()); - Optional promotionInfo = - promoteSubviewAsNewBuffer(b, loc, subView, options, folder); + Optional promotionInfo = promoteSubviewAsNewBuffer( + b, loc, subView, options.allocationFn, folder); if (!promotionInfo) return {}; promotionInfoMap[v.first] = *promotionInfo; diff --git a/mlir/test/Dialect/Linalg/fusion-sequence-promote.mlir b/mlir/test/Dialect/Linalg/fusion-sequence-promote.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/fusion-sequence-promote.mlir @@ -0,0 +1,150 @@ +// RUN: mlir-opt -pass-pipeline="test-linalg-tile-and-fuse{tile-sizes=16,32,64 promote-fused-view},canonicalize,cse" -split-input-file %s | FileCheck %s + +module { + func @three_op_fusion(%arg0: memref, %arg1: memref, + %arg2: memref, %arg3 : memref) { + %cst = constant 0.000000e+00 : f32 + %c0 = constant 0 : index + %c1 = constant 1 : index + %d0 = dim %arg0, %c0 : memref + %d1 = dim %arg1, %c1 : memref + %0 = alloc(%d0, %d1) : memref + linalg.fill(%0, %cst) : memref, f32 + linalg.matmul ins(%arg0, %arg1 : memref, memref) + outs(%0 : memref) + linalg.generic + {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d1)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%0, %arg2 : memref, memref) + outs(%arg3 : memref) { + ^bb0(%arg4 : f32, %arg5 : f32, %arg6 : f32) : + %5 = addf %arg4, %arg5 : f32 + linalg.yield %5 : f32 + } + return + } +} + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (32, -d0 + s0)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (d0 + s0)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> +// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0)[s0] -> (64, -d0 + s0)> +// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[MAP6:.+]] = affine_map<(d0, d1) -> (d1)> +// CHECK: func @three_op_fusion +// CHECK-DAG: %[[ARG0:[a-zA-Z0-9_]+]]: memref +// CHECK-DAG: %[[ARG1:[a-zA-Z0-9_]+]]: memref +// CHECK-DAG: %[[ARG2:[a-zA-Z0-9_]+]]: memref +// CHECK-DAG: %[[ARG3:[a-zA-Z0-9_]+]]: memref +// CHECK-DAG: %[[C0:.+]] = constant 0 : index +// CHECK-DAG: %[[C1:.+]] = constant 1 : index +// CHECK-DAG: %[[C64:.+]] = constant 64 : index +// CHECK-DAG: %[[M:.+]] = dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[N:.+]] = dim %[[ARG1]], %[[C1]] +// CHECK: scf.parallel (%[[IV0:[a-zA-Z0-9_]+]], %[[IV1:[a-zA-Z0-9_]+]]) +// CHECK-SAME: { +// CHECK: %[[TILE_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]] +// CHECK: %[[TILE_N:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[N]]] +// CHECK: %[[K:.+]] = dim %[[ARG0]], %[[C1]] +// CHECK: %[[ALLOC:.+]] = alloc(%[[TILE_M]], %[[TILE_N]]) +// CHECK: %[[SV_ALLOC:.+]] = subview %[[ALLOC]][0, 0] [%[[TILE_M]], %[[TILE_N]]] +// CHECK: linalg.fill(%[[SV_ALLOC]], %{{.+}}) +// CHECK: scf.for %[[IV2:.+]] = %[[C0]] to %[[K]] step %[[C64]] +// CHECK-SAME: { +// CHECK: linalg.matmul +// CHECK-SAME: outs(%[[SV_ALLOC]] : memref) +// CHECK: } +// CHECK: linalg.generic +// CHECK-SAME: ins(%[[SV_ALLOC]], %{{[a-zA-Z0-9_]+}} +// CHECK-SAME: : memref, memref) +// CHECK: } + +// ----- + +module { + func @sequence_of_matmul(%arg0: memref, %arg1: memref, + %arg2: memref, %arg3: memref, + %arg4: memref) { + %cst = constant 0.000000e+00 : f32 + %c0 = constant 0 : index + %c1 = constant 1 : index + %m = dim %arg0, %c0 : memref + %n1 = dim %arg1, %c1 : memref + %n2 = dim %arg2, %c1 : memref + %n3 = dim %arg3, %c1 : memref + %0 = alloc(%m, %n1) : memref + %1 = alloc(%m, %n2) : memref + linalg.fill(%0, %cst) : memref, f32 + linalg.matmul ins(%arg0, %arg1 : memref, memref) + outs(%0 : memref) + linalg.fill(%1, %cst) : memref, f32 + linalg.matmul ins(%0, %arg2 : memref, memref) + outs(%1 : memref) + linalg.fill(%arg4, %cst) : memref, f32 + linalg.matmul ins(%1, %arg3 : memref, memref) + outs(%arg4 : memref) + return + } +} + +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> +// CHECK: func @sequence_of_matmul +// CHECK: scf.parallel (%[[PIV:.+]]) +// CHECK-SAME: { +// CHECK: %[[ALLOC1:.+]] = alloc +// CHECK: %[[SV_ALLOC1:.+]] = subview %[[ALLOC1]][0, 0] +// CHECK: %[[ALLOC2:.+]] = alloc +// CHECK: %[[SV_ALLOC2:.+]] = subview %[[ALLOC2]][0, 0] +// CHECK: %[[ALLOC3:.+]] = alloc +// CHECK: %[[SV_ALLOC3:.+]] = subview %[[ALLOC3]][0, 0] +// CHECK: scf.parallel (%[[IV0:.+]]) = +// CHECK-SAME: { +// CHECK: %[[SV_ALLOC3_2:.+]] = subview %[[SV_ALLOC3]][0, %[[IV0]]] +// CHECK: linalg.fill(%[[SV_ALLOC3_2]], %{{[a-zA-Z0-9_]+}}) +// CHECK: } +// CHECK: scf.parallel (%[[IV1:.+]]) = +// CHECK-SAME: { +// CHECK: scf.for %[[IV2:.+]] = +// CHECK-SAME: { +// CHECK: %[[SV_ALLOC3_3:.+]] = subview %[[SV_ALLOC3]][0, %[[IV1]]] +// CHECK: linalg.matmul +// CHECK-SAME: outs(%[[SV_ALLOC3_3]] : memref) +// CHECK: } +// CHECK: } +// CHECK: scf.parallel (%[[IV3:.+]]) = +// CHECK-SAME: { +// CHECK: %[[SV_ALLOC2_2:.+]] = subview %[[SV_ALLOC2]][0, %[[IV3]]] +// CHECK: linalg.fill(%[[SV_ALLOC2_2]], %{{[a-zA-Z0-9_]+}}) +// CHECK: } +// CHECK: scf.parallel (%[[IV4:.+]]) = +// CHECK-SAME: { +// CHECK: scf.for %[[IV5:.+]] = +// CHECK-SAME: { +// CHECK: %[[SV_ALLOC3_4:.+]] = subview %[[SV_ALLOC3]][0, %[[IV5]]] +// CHECK: %[[SV_ALLOC2_3:.+]] = subview %[[SV_ALLOC2]][0, %[[IV4]]] +// CHECK: linalg.matmul +// CHECK-SAME: ins(%[[SV_ALLOC3_4]], %{{[a-zA-Z0-9_]+}} : +// CHECK-SAME: memref, memref) +// CHECK-SAME: outs(%[[SV_ALLOC2_3]] : memref) +// CHECK: } +// CHECK: } +// CHECK: scf.parallel (%[[IV6:.+]]) = +// CHECK-SAME: { +// CHECK: %[[SV_ALLOC1_2:.+]] = subview %[[SV_ALLOC1]][0, %[[IV6]]] +// CHECK: linalg.fill(%[[SV_ALLOC1_2]], %{{[a-zA-Z0-9_]+}}) +// CHECK: } +// CHECK: scf.parallel (%[[IV7:.+]]) = +// CHECK-SAME: { +// CHECK: scf.for %[[IV8:.+]] = +// CHECK-SAME: { +// CHECK: %[[SV_ALLOC2_4:.+]] = subview %[[SV_ALLOC2]][0, %[[IV8]]] +// CHECK: %[[SV_ALLOC1_3:.+]] = subview %[[SV_ALLOC1]][0, %[[IV7]]] +// CHECK: linalg.matmul +// CHECK-SAME: ins(%[[SV_ALLOC2_4]], %{{[a-zA-Z0-9_]+}} : +// CHECK-SAME: memref, memref) +// CHECK-SAME: outs(%[[SV_ALLOC1_3]] : memref) +// CHECK: } +// CHECK: } diff --git a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp --- a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp @@ -198,6 +198,20 @@ } }; +/// Callback function for allocating memory during promotion in +/// `TestLinalgTileAndFuseSequencePass`. +static Optional +allocCallBackFunction(OpBuilder &b, SubViewOp subView, + ArrayRef boundingSubViewSize, + OperationFolder *folder) { + MemRefType type = + MemRefType::get(SmallVector(boundingSubViewSize.size(), + ShapedType::kDynamicSize), + subView.getType().getElementType()); + return b.create(subView.getLoc(), type, boundingSubViewSize) + .getResult(); +} + /// Pass to test tile and fuse of sequence of operations. Intended only for /// testing. struct TestLinalgTileAndFuseSequencePass @@ -206,6 +220,9 @@ TestLinalgTileAndFuseSequencePass( const TestLinalgTileAndFuseSequencePass &pass){}; + Option promoteFusedView{*this, "promote-fused-view", + llvm::cl::desc("Promoted fused view"), + llvm::cl::init(false)}; ListOption tileSizes{ *this, "tile-sizes", llvm::cl::desc("Tile sizes to use for ops"), llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; @@ -230,7 +247,8 @@ LinalgTilingLoopType::ParallelLoops)); OpBuilder builder(funcOp.getContext()); Optional tileAndFuseOps = tileAndFuseLinalgOps( - builder, linalgOps, dependenceGraph, tilingOptions); + builder, linalgOps, dependenceGraph, tilingOptions, + (promoteFusedView ? allocCallBackFunction : nullptr)); if (!tileAndFuseOps) return signalPassFailure(); for (auto op : linalgOps)