diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -59,104 +59,6 @@ /// More advanced use cases, analyses as well as profitability heuristics are /// left for future work. -// Fill `offset`, `sizes` and `strides` used to iterate over the shape indexed -// by `permutationMap`. -static void inferShapeComponents(AffineMap permutationMap, - ArrayRef loopRanges, - SmallVectorImpl &offsets, - SmallVectorImpl &sizes, - SmallVectorImpl &strides) { - assert(permutationMap.isProjectedPermutation() && - "expected some subset of a permutation map"); - SmallVector shapeRanges(permutationMap.getNumResults()); - unsigned idx = 0; - for (AffineExpr e : permutationMap.getResults()) { - // loopToOperandRangesMaps are permutations-only, just swap indices. - unsigned loopPos = e.cast().getPosition(); - shapeRanges[idx++] = loopRanges[loopPos]; - } - // Construct a new subshape for the tile. - unsigned rank = shapeRanges.size(); - offsets.reserve(rank); - sizes.reserve(rank); - strides.reserve(rank); - for (auto r : shapeRanges) { - offsets.push_back(r.offset); - sizes.push_back(r.size); - strides.push_back(r.stride); - } -} - -// Return a cloned version of `op` that operates on `loopRanges`, assumed to be -// a subset of the original loop ranges of `op`. -// This is achieved by applying the `loopToOperandRangesMaps` permutation maps -// to the `loopRanges` in order to obtain view ranges. -static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op, - ArrayRef loopRanges) { - SmallVector clonedShapes; - clonedShapes.reserve(op.getNumShapedOperands()); - - // Iterate over the shape operands in order. - // Extract the subranges from the linearized ranges. - for (auto en : llvm::enumerate(op.getShapedOperands())) { - unsigned shapedOperandIdx = en.index(); - AffineMap map = op.getIndexingMap(shapedOperandIdx); - LLVM_DEBUG(llvm::dbgs() << "shapedOperandIdx: " << shapedOperandIdx - << " with indexingMap: " << map << "\n"); - SmallVector offsets, sizes, strides; - inferShapeComponents(map, loopRanges, offsets, sizes, strides); - Value shape = en.value(); - Value sub = - shape.getType().isa() - ? b.create(loc, shape, offsets, sizes, strides) - .getResult() - : b.create(loc, shape, offsets, sizes, strides) - .getResult(); - clonedShapes.push_back(sub); - } - // Append the other operands. - auto operands = op.getAssumedNonShapedOperands(); - clonedShapes.append(operands.begin(), operands.end()); - - // Iterate over the results in order. - // Extract the subtensor type from the linearized range. - // Since we do not enforce any canonicalizations on the fly, this is always - // fully dynamic at construction time. - SmallVector resultTypes; - resultTypes.reserve(op->getNumResults()); - for (RankedTensorType t : op.getOutputTensorTypes()) { - unsigned rank = t.getRank(); - SmallVector staticOffsetsVector( - rank, ShapedType::kDynamicStrideOrOffset); - SmallVector staticSizesVector(rank, ShapedType::kDynamicSize); - SmallVector staticStridesVector( - rank, ShapedType::kDynamicStrideOrOffset); - resultTypes.push_back(SubTensorOp::inferResultType( - t.cast(), staticOffsetsVector, staticSizesVector, - staticStridesVector)); - } - - Operation *clonedOp = op.clone(b, loc, resultTypes, clonedShapes); - // When the producer is an IndexedGenericOp, we have to transform its block - // IV arguments according to the tiling of the consumer, i.e. offset them by - // the values computed in `loopRanges`. - if (auto indexedGenericOp = dyn_cast(clonedOp)) { - auto &block = indexedGenericOp.region().front(); - OpBuilder::InsertionGuard g(b); - b.setInsertionPointToStart(&block); - for (unsigned i = 0, e = indexedGenericOp.getNumLoops(); i < e; ++i) { - Value oldIndex = block.getArgument(i); - // TODO: replace by an affine_apply. - AddIOp newIndex = b.create(indexedGenericOp.getLoc(), oldIndex, - loopRanges[i].offset); - oldIndex.replaceAllUsesExcept(newIndex, - SmallPtrSet{newIndex}); - } - } - - return clonedOp; -} - struct ShapeDimension { Value shape; unsigned dimension; @@ -208,35 +110,86 @@ llvm_unreachable("Expect to be able to extract a shape defining loop range"); } -/// Fuse the producer by cloning the `producer`. The `fusedLoopsAndRanges` +/// Fuses the producer by cloning the `producer`. The `fusedLoopsAndRanges` /// provides the loop range information for the fused loops. The rest are /// obtained from the producer itself, since they are not tiled + fused. -static LinalgOp fuse(OpBuilder &b, LinalgOp producer, +static LinalgOp fuse(OpBuilder &builder, LinalgOp producer, const DenseMap &fusedLoopsAndRanges) { - - unsigned nPar = producer.getNumParallelLoops(); - unsigned nRed = producer.getNumReductionLoops(); - unsigned nWin = producer.getNumWindowLoops(); - SmallVector loopRanges(nPar + nRed + nWin); - for (auto fusedLoops : fusedLoopsAndRanges) - loopRanges[fusedLoops.first] = fusedLoops.second; - - // Iterate over all dimensions. For the dimensions not identified by the - // producer map for `producerIdx`, we need to explicitly compute the shape - // that defines the loop ranges using the `producer`. - for (unsigned i = 0, nLoops = loopRanges.size(); i < nLoops; ++i) { - if (loopRanges[i].offset) - LLVM_DEBUG(llvm::dbgs() - << "existing LoopRange: " << loopRanges[i] << "\n"); - else { + SmallVector ivs, tileSizes, sizeBounds; + SmallVector loopRanges; + auto zero = std_constant_index(0); + auto one = std_constant_index(1); + Location loc = producer.getLoc(); + + for (unsigned i = 0, e = producer.getNumLoops(); i < e; ++i) { + auto it = fusedLoopsAndRanges.find(i); + if (it != fusedLoopsAndRanges.end()) { + ivs.push_back(it->second.offset); + tileSizes.push_back(it->second.size); + sizeBounds.push_back(nullptr); + loopRanges.push_back(it->second); + LLVM_DEBUG(llvm::dbgs() << "tiled loop#" << i << " with LoopRange " + << loopRanges.back() << "\n"); + } else { auto shapeDim = getShapeDefiningLoopRange(producer, i); Value dim = memref_dim(shapeDim.shape, shapeDim.dimension); - loopRanges[i] = Range{std_constant_index(0), dim, std_constant_index(1)}; - LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n"); + tileSizes.push_back(zero); + sizeBounds.push_back(dim); + loopRanges.push_back(Range{zero, dim, one}); + LLVM_DEBUG(llvm::dbgs() << "full loop#" << i << " with LoopRange " + << loopRanges.back() << "\n"); + } + } + + SmallVector clonedShapes; + clonedShapes.reserve(producer.getNumShapedOperands()); + + // Compute subranges for all tensor input/output operands. + auto tiledOperands = llvm::to_vector<4>(producer.getShapedOperands()); + clonedShapes.append(makeTiledShapes(builder, loc, producer, tiledOperands, + ivs, tileSizes, sizeBounds)); + + // Append the other operands. + auto operands = producer.getAssumedNonShapedOperands(); + clonedShapes.append(operands.begin(), operands.end()); + + // Iterate over the results in order. + // Extract the subtensor type from the linearized range. + // Since we do not enforce any canonicalizations on the fly, this is always + // fully dynamic at construction time. + SmallVector resultTypes; + resultTypes.reserve(producer->getNumResults()); + for (RankedTensorType t : producer.getOutputTensorTypes()) { + unsigned rank = t.getRank(); + SmallVector staticOffsetsVector( + rank, ShapedType::kDynamicStrideOrOffset); + SmallVector staticSizesVector(rank, ShapedType::kDynamicSize); + SmallVector staticStridesVector( + rank, ShapedType::kDynamicStrideOrOffset); + resultTypes.push_back(SubTensorOp::inferResultType( + t.cast(), staticOffsetsVector, staticSizesVector, + staticStridesVector)); + } + + Operation *clonedOp = producer.clone(builder, loc, resultTypes, clonedShapes); + // When the producer is an IndexedGenericOp, we have to transform its block + // IV arguments according to the tiling of the consumer, i.e. offset them by + // the values computed in `loopRanges`. + if (auto indexedGenericOp = dyn_cast(clonedOp)) { + auto &block = indexedGenericOp.region().front(); + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPointToStart(&block); + for (unsigned i = 0, e = indexedGenericOp.getNumLoops(); i < e; ++i) { + Value oldIndex = block.getArgument(i); + // TODO: replace by an affine_apply. + AddIOp newIndex = builder.create(indexedGenericOp.getLoc(), + oldIndex, loopRanges[i].offset); + oldIndex.replaceAllUsesExcept(newIndex, + SmallPtrSet{newIndex}); } } - return cloneWithLoopRanges(b, producer.getLoc(), producer, loopRanges); + return clonedOp; } /// Get the loop range for a dimension `dim` based on the `shapedOperand`. It is diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -27,6 +27,9 @@ #include "mlir/IR/OpImplementation.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/LoopUtils.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "linalg-utils" using namespace mlir; using namespace mlir::edsc; @@ -447,11 +450,14 @@ // that define tile subshapes. SmallVector lbs, subShapeSizes; for (unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) { + LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for loop#" << idx << "\n"); bool isTiled = !isZero(tileSizes[idx]); lbs.push_back(isTiled ? ivs[idxIvs++] : (Value)std_constant_index(0)); // Before composing, we need to make range a closed interval. Value size = isTiled ? tileSizes[idx] : sizeBounds[idx]; subShapeSizes.push_back(size - std_constant_index(1)); + LLVM_DEBUG(llvm::dbgs() << "lb: " << lbs.back() << "\n"); + LLVM_DEBUG(llvm::dbgs() << "size: " << subShapeSizes.back() << "\n"); } MLIRContext *context = builder.getContext(); @@ -459,14 +465,18 @@ tiledShapes.reserve(tiledOperands.size()); for (auto en : llvm::enumerate(tiledOperands)) { Value shapedOp = en.value(); + LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for operand " << shapedOp); ShapedType shapedType = shapedOp.getType().cast(); unsigned rank = shapedType.getRank(); AffineMap map = linalgOp.getIndexingMap(en.index()); // If the shape is not tiled, we can use it as is. if (!isTiled(map, tileSizes)) { tiledShapes.push_back(shapedOp); + LLVM_DEBUG(llvm::dbgs() + << ": not tiled: use shape: " << shapedType << "\n"); continue; } + LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subshape...\n"); // Construct a new subview / subtensor for the tile. SmallVector offsets, sizes, strides; @@ -474,22 +484,28 @@ sizes.reserve(rank); strides.reserve(rank); for (unsigned r = 0; r < rank; ++r) { + LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for dim#" << r); if (!isTiled(map.getSubMap({r}), tileSizes)) { offsets.push_back(builder.getIndexAttr(0)); - sizes.push_back(memref_dim(shapedOp, r).value); + Value dim = memref_dim(shapedOp, r).value; + sizes.push_back(dim); strides.push_back(builder.getIndexAttr(1)); + LLVM_DEBUG(llvm::dbgs() << ": not tiled: use size: " << dim << "\n"); continue; } + LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subsize...\n"); // Tiling creates a new slice at the proper index, the slice step is 1 // (i.e. the op does not subsample, stepping occurs in the loop). auto m = map.getSubMap({r}); + LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: submap: " << map << "\n"); auto offset = applyMapToValues(builder, loc, m, lbs).front(); offsets.push_back(offset); auto closedIntSize = applyMapToValues(builder, loc, m, subShapeSizes).front(); // Resulting size needs to be made half open interval again. auto size = closedIntSize + std_constant_index(1); + LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: raw size: " << size << "\n"); // The size of the subview / subtensor should be trimmed to avoid // out-of-bounds accesses, unless we statically know the subshape size @@ -498,6 +514,9 @@ auto sizeCst = size.getDefiningOp(); if (ShapedType::isDynamic(shapeSize) || !sizeCst || (shapeSize % sizeCst.getValue()) != 0) { + LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: shapeSize=" << shapeSize + << ", size: " << size + << ": make sure in bound with affine.min\n"); AffineExpr dim0, dim1, dim2; bindDims(context, dim0, dim1, dim2); // Compute min(size, dim - offset) to avoid out-of-bounds accesses. @@ -510,6 +529,9 @@ } sizes.push_back(size); + LLVM_DEBUG(llvm::dbgs() + << "makeTiledShapes: new offset: " << offset << "\n"); + LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: new size: " << size << "\n"); strides.push_back(builder.getIndexAttr(1)); } diff --git a/mlir/test/Dialect/Linalg/fusion-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-pattern.mlir --- a/mlir/test/Dialect/Linalg/fusion-pattern.mlir +++ b/mlir/test/Dialect/Linalg/fusion-pattern.mlir @@ -16,6 +16,7 @@ // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (64, -d0 + s0)> // CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)> +// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)> // CHECK: func @basic_fusion // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref @@ -47,8 +48,10 @@ // CHECK: %[[TILE_N_2:.+]] = affine.min #[[MAP2]](%[[IV1]])[%[[N_2]]] // CHECK: %[[SV3:.+]] = memref.subview %[[ARG2]][%[[IV0]], %[[IV1]]] // CHECK-SAME: [%[[TILE_M_2]], %[[TILE_N_2]]] +// CHECK: %[[TILE_M_3:.+]] = affine.min #[[MAP4]](%[[IV0]], %[[TILE_M]])[%[[M_2]]] +// CHECK: %[[TILE_N_3:.+]] = affine.min #[[MAP4]](%[[IV1]], %[[TILE_N]])[%[[N_2]]] // CHECK: %[[SV3_2:.+]] = memref.subview %[[ARG2]][%[[IV0]], %[[IV1]]] -// CHECK-SAME: [%[[TILE_M]], %[[TILE_N]]] +// CHECK-SAME: [%[[TILE_M_3]], %[[TILE_N_3]]] // CHECK: linalg.fill(%[[SV3_2]], %[[CST]]) // CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion_producer" // CHECK: scf.for %[[IV2:.+]] = %[[C0]] to %[[K]] step %[[C16]] { @@ -86,6 +89,7 @@ // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (32, -d0 + s0)> // CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)> +// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)> // CHECK: func @rhs_fusion // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref @@ -112,10 +116,13 @@ // CHECK: %[[SV2:.+]] = memref.subview %[[ARG3]][0, %[[IV0]]] // CHECK-SAME: [%[[M]], %[[TILE_N_2]]] // CHECK: %[[K_2:.+]] = memref.dim %[[ARG1]], %[[C0]] +// CHECK: %[[N_3:.+]] = memref.dim %[[ARG1]], %[[C1]] +// CHECK: %[[TILE_N_3:.+]] = affine.min #[[MAP4]](%[[IV0]], %[[TILE_N]])[%[[N_3]]] // CHECK: %[[SV3:.+]] = memref.subview %[[ARG1]][0, %[[IV0]]] -// CHECK-SAME: [%[[K_2]], %[[TILE_N]]] +// CHECK-SAME: [%[[K_2]], %[[TILE_N_3]]] +// CHECK: %[[TILE_N_4:.+]] = affine.min #[[MAP4]](%[[IV0]], %[[TILE_N]])[%[[N]]] // CHECK: %[[SV3_2:.+]] = memref.subview %[[ARG2]][0, %[[IV0]]] -// CHECK-SAME: [%[[K_2]], %[[TILE_N]]] +// CHECK-SAME: [%[[K]], %[[TILE_N_4]]] // CHECK: linalg.copy(%[[SV3]], %[[SV3_2]]) // CHECK-SAME: __internal_linalg_transform__ = "after_rhs_fusion_producer" // CHECK-NOT: linalg.fill @@ -164,6 +171,7 @@ // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)> // CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (64, -d0 + s0)> +// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)> // CHECK: func @two_operand_fusion // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref @@ -191,13 +199,17 @@ // CHECK: %[[N:.+]] = memref.dim %[[ARG3]], %[[C1]] // CHECK: %[[SV2:.+]] = memref.subview %[[ARG3]][%[[IV0]], 0] // CHECK-SAME: [%[[TILE_M_2]], %[[N]]] +// CHECK: %[[TILE_M_3:.+]] = affine.min #[[MAP4]](%[[IV0]], %[[TILE_M]])[%[[M_2]]] // CHECK: %[[SV2_2:.+]] = memref.subview %[[ARG3]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M]], %[[N]]] +// CHECK-SAME: [%[[TILE_M_3]], %[[N]]] +// CHECK: %[[M_3:.+]] = memref.dim %[[ARG0]], %[[C0]] +// CHECK: %[[TILE_M_4:.+]] = affine.min #[[MAP4]](%[[IV0]], %[[TILE_M]])[%[[M_3]]] // CHECK: %[[K_2:.+]] = memref.dim %[[ARG0]], %[[C1]] // CHECK: %[[SV3:.+]] = memref.subview %[[ARG0]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M]], %[[K_2]]] +// CHECK-SAME: [%[[TILE_M_4]], %[[K_2]]] +// CHECK: %[[TILE_M_5:.+]] = affine.min #[[MAP4]](%[[IV0]], %[[TILE_M]])[%[[M]]] // CHECK: %[[SV3_2:.+]] = memref.subview %[[ARG1]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M]], %[[K_2]]] +// CHECK-SAME: [%[[TILE_M_5]], %[[K]]] // CHECK: linalg.copy(%[[SV3]], %[[SV3_2]]) // CHECK-SAME: __internal_linalg_transform__ = "after_two_operand_fusion_producer" // CHECK: linalg.fill(%[[SV2_2]], %[[CST]]) @@ -271,23 +283,24 @@ // CHECK: %[[N:.+]] = memref.dim %[[ARG4]], %[[C1]] // CHECK: %[[SV2:.+]] = memref.subview %[[ARG4]][%[[IV0]], 0] // CHECK-SAME: [%[[TILE_M_2]], %[[N]]] -// CHECK: %[[K2_2:.+]] = memref.dim %[[ARG1]], %[[C1]] +// CHECK: %[[M_3:.+]] = memref.dim %[[ARG0]], %[[C0]] +// CHECK: %[[TILE_M_3:.+]] = affine.min #[[MAP4]](%[[IV0]], %[[TILE_M]])[%[[M_3]]] // CHECK: %[[K1:.+]] = memref.dim %[[ARG0]], %[[C1]] // CHECK: %[[SV3:.+]] = memref.subview %[[ARG0]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M]], %[[K1]]] -// CHECK: %[[SV4:.+]] = memref.subview %[[ARG1]][0, 0] [%[[K1]], %[[K2_2]]] +// CHECK-SAME: [%[[TILE_M_3]], %[[K1]]] +// CHECK: %[[TILE_M_4:.+]] = affine.min #[[MAP4]](%[[IV0]], %[[TILE_M]])[%[[M]]] // CHECK: %[[SV1_2:.+]] = memref.subview %[[ARG2]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M]], %[[K2_2]]] +// CHECK-SAME: [%[[TILE_M_4]], %[[K2]]] // CHECK: linalg.matmul // CHECK-SAME: __internal_linalg_transform__ = "after_lhs_fusion_producer" -// CHECK-SAME: ins(%[[SV3]], %[[SV4]] -// CHECK-SAME: : memref, memref) +// CHECK-SAME: ins(%[[SV3]], %[[ARG1]] +// CHECK-SAME: : memref, memref) // CHECK-SAME: outs(%[[SV1_2]] : memref) -// CHECK-DAG: %[[N_2:.+]] = memref.dim %[[ARG3]], %[[C1]] +// CHECK: %[[N_2:.+]] = memref.dim %[[ARG3]], %[[C1]] // CHECK: scf.parallel (%[[IV1:.+]]) = // CHECK-SAME: (%[[C0]]) to (%[[N_2]]) step (%[[C64]]) { -// CHECK-NEXT: scf.for %[[IV2:.+]] = %[[C0]] to %[[K]] step %[[C16]] { -// CHECK: %[[TILE_K:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[K]]] +// CHECK-NEXT: scf.for %[[IV2:.+]] = %[[C0]] to %[[K2]] step %[[C16]] { +// CHECK: %[[TILE_K:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[K2]]] // CHECK: %[[SV6:.+]] = memref.subview %[[SV1]][0, %[[IV2]]] // CHECK-SAME: [%[[TILE_M]], %[[TILE_K]]] // CHECK: %[[K_2:.+]] = memref.dim %[[ARG3]], %[[C0]] @@ -348,10 +361,11 @@ // CHECK: %[[T6:.+]] = memref.subview %[[ARG2]][%[[ARG3]], %[[ARG4]]] // CHECK: %[[T8:.+]] = memref.subview %[[ARG0]][%[[ARG3]], 0] // CHECK: %[[T9:.+]] = memref.subview %[[ARG1]][0, %[[ARG4]]] +// CHECK: %[[T10:.+]] = memref.subview %[[T2]][%[[ARG3]], %[[ARG4]]] // CHECK: linalg.matmul // CHECK-SAME: after_transpose_fusion_producer // CHECK-SAME: ins(%[[T8]], %[[T9]] -// CHECK-SAME: outs(%[[T5]] +// CHECK-SAME: outs(%[[T10]] // CHECK-NOT: linalg.matmul // CHECK: linalg.generic // CHECK-SAME: ins(%[[T5]], %[[T5]] diff --git a/mlir/test/Dialect/Linalg/fusion-sequence.mlir b/mlir/test/Dialect/Linalg/fusion-sequence.mlir --- a/mlir/test/Dialect/Linalg/fusion-sequence.mlir +++ b/mlir/test/Dialect/Linalg/fusion-sequence.mlir @@ -36,18 +36,19 @@ // CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: memref // CHECK: %[[TEMP:.+]] = memref.alloc(%{{.*}}, %{{.*}}) : memref // CHECK: scf.parallel (%[[IV0:.+]], %[[IV1:.+]]) = {{.*}} { -// CHECK-DAG: %[[SV_TEMP:.+]] = memref.subview %[[TEMP]][%[[IV0]], %[[IV1]]] +// CHECK: %[[SV_TEMP_1:.+]] = memref.subview %[[TEMP]][%[[IV0]], %[[IV1]]] // CHECK-DAG: %[[SV_ARG2:.+]] = memref.subview %[[ARG2]][%[[IV1]]] // CHECK-DAG: %[[SV_ARG3:.+]] = memref.subview %[[ARG3]][%[[IV0]], %[[IV1]]] // CHECK-DAG: %[[SV_ARG0:.+]] = memref.subview %[[ARG0]][%[[IV0]], 0] // CHECK-DAG: %[[SV_ARG1:.+]] = memref.subview %[[ARG1]][0, %[[IV1]]] -// CHECK: linalg.fill(%[[SV_TEMP]], %{{.+}}) +// CHECK: %[[SV_TEMP_2:.+]] = memref.subview %[[TEMP]][%[[IV0]], %[[IV1]]] +// CHECK: linalg.fill(%[[SV_TEMP_2]], %{{.+}}) // CHECK: linalg.matmul // CHECK-SAME: ins(%[[SV_ARG0]], %[[SV_ARG1]] // CHECK-SAME: : memref, memref) -// CHECK-SAME: outs(%[[SV_TEMP]] : memref) +// CHECK-SAME: outs(%[[SV_TEMP_2]] : memref) // CHECK: linalg.generic -// CHECK-SAME: ins(%[[SV_TEMP]], %[[SV_ARG2]] +// CHECK-SAME: ins(%[[SV_TEMP_1]], %[[SV_ARG2]] // CHECK-SAME: : memref, memref) // CHECK-SAME: outs(%[[SV_ARG3]] : memref) // CHECK: scf.yield @@ -83,6 +84,8 @@ // CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)> // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)> + // CHECK: func @sequence_of_matmul // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref @@ -100,37 +103,40 @@ // CHECK: scf.parallel (%[[IV0:.+]]) = (%[[C0]]) to (%[[M]]) // CHECK-SAME: step (%[[C16]]) { // CHECK: %[[TILE_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]] -// CHECK: %[[SV_ALLOC2:.+]] = memref.subview %[[ALLOC2]][%[[IV0]], 0] +// CHECK: %[[SV_ALLOC3:.+]] = memref.subview %[[ALLOC2]][%[[IV0]], 0] // CHECK-SAME: [%[[TILE_M]], %[[N2]]] // CHECK: %[[M_2:.+]] = memref.dim %[[ARG4]], %[[C0]] // CHECK: %[[TILE_M_2:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M_2]]] // CHECK: %[[N3:.+]] = memref.dim %[[ARG4]], %[[C1]] // CHECK: %[[SV_ARG4:.+]] = memref.subview %[[ARG4]][%[[IV0]], 0] // CHECK-SAME: [%[[TILE_M_2]], %[[N3]]] +// CHECK: %[[TILE_M_3:.+]] = affine.min #[[MAP2]](%[[IV0]], %[[TILE_M]])[%[[M_2]]] // CHECK: %[[SV_ARG4_2:.+]] = memref.subview %[[ARG4]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M]], %[[N3]]] +// CHECK-SAME: [%[[TILE_M_3]], %[[N3]]] +// CHECK: %[[TILE_M_4:.+]] = affine.min #[[MAP2]](%[[IV0]], %[[TILE_M]])[%[[M]]] // CHECK: %[[SV_ALLOC1:.+]] = memref.subview %[[ALLOC1]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M]], %[[N1]]] -// CHECK: %[[SV_ARG2:.+]] = memref.subview %[[ARG2]][0, 0] [%[[N1]], %[[N2]]] +// CHECK-SAME: [%[[TILE_M_4]], %[[N1]]] +// CHECK: %[[SV_ALLOC2:.+]] = memref.subview %[[ALLOC2]][%[[IV0]], 0] +// CHECK-SAME: [%[[TILE_M_4]], %[[N2]]] // CHECK: %[[N0:.+]] = memref.dim %[[ARG0]], %[[C1]] // CHECK: %[[SV_ARG0:.+]] = memref.subview %[[ARG0]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M:.+]], %[[N0]]] -// CHECK: %[[SV_ARG1:.+]] = memref.subview %[[ARG1]][0, 0] [%[[N0]], %[[N1]]] +// CHECK-SAME: [%[[TILE_M_4]], %[[N0]]] // CHECK: linalg.fill(%[[SV_ALLOC1]], %{{.+}}) -// CHECK: linalg.matmul ins(%[[SV_ARG0]], %[[SV_ARG1]] -// CHECK-SAME: : memref, memref) +// CHECK: linalg.matmul ins(%[[SV_ARG0]], %[[ARG1]] +// CHECK-SAME: : memref, memref) // CHECK-SAME: outs(%[[SV_ALLOC1]] : memref) // CHECK: linalg.fill(%[[SV_ALLOC2]], %{{.+}}) -// CHECK: linalg.matmul ins(%[[SV_ALLOC1]], %[[SV_ARG2]] -// CHECK-SAME: : memref, memref) +// CHECK: linalg.matmul ins(%[[SV_ALLOC1]], %[[ARG2]] +// CHECK-SAME: : memref, memref) // CHECK-SAME: outs(%[[SV_ALLOC2]] : memref) // CHECK: linalg.fill(%[[SV_ARG4_2]], %{{.+}}) -// CHECK: linalg.matmul ins(%[[SV_ALLOC2]], %[[ARG3]] +// CHECK: linalg.matmul ins(%[[SV_ALLOC3]], %[[ARG3]] // CHECK-SAME: : memref, memref) // CHECK-SAME: outs(%[[SV_ARG4]] : memref) // CHECK: scf.yield // CHECK: } + // ----- module { @@ -189,8 +195,8 @@ module { func @tensor_matmul_fusion(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, - %arg4: tensor, %arg5: tensor, - %arg6: tensor) -> tensor { + %arg4: tensor, %arg5: tensor, + %arg6: tensor) -> tensor { %0 = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) -> tensor // [M, N0] * [N0, N1] %1 = linalg.matmul ins(%0, %arg3 : tensor, tensor) @@ -200,7 +206,12 @@ return %2 : tensor } } -// CHECK-LABEL: func @tensor_matmul_fusion( + +// CHECK: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)> +// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1) -> (16, d0 - d1)> +// CHECK: #[[MAP2:.+]] = affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)> + +// CHECK: func @tensor_matmul_fusion( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor @@ -210,36 +221,39 @@ // CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: tensor) -> tensor { // CHECK-DAG: %[[C0:.+]] = constant 0 : index // CHECK-DAG: %[[C1:.+]] = constant 1 : index +// CHECK: %[[M:.+]] = memref.dim %[[ARG0]], %c0 : tensor // CHECK: %[[R0:.+]] = scf.for %[[IV0:[a-zA-Z0-9_]+]] = // CHECK-SAME: iter_args(%[[ARG8:.+]] = %[[ARG6]]) -> (tensor) { -// CHECK: %[[N3:.+]] = memref.dim %[[ARG8]], %[[C1]] -// CHECK: %[[STARG6:.+]] = subtensor %[[ARG8]][%[[IV0]], 0] -// CHECK-SAME: [%{{[a-zA-Z0-9_]+}}, %[[N3]]] -// CHECK: %[[N2:.+]] = memref.dim %[[ARG3]], %[[C1]] -// CHECK: %[[N1:.+]] = memref.dim %[[ARG1]], %[[C1]] -// CHECK: %[[STARG3:.+]] = subtensor %[[ARG3]][0, 0] -// CHECK-SAME: [%[[N1]], %[[N2]]] -// CHECK: %[[STARG4:.+]] = subtensor %[[ARG4]][%[[IV0]], 0] -// CHECK-SAME: [%{{[a-zA-Z0-9_]+}}, %[[N2]]] -// CHECK: %[[N0:.+]] = memref.dim %[[ARG0]], %[[C1]] -// CHECK: %[[STARG0:.+]] = subtensor %[[ARG0]][%[[IV0]], 0] -// CHECK-SAME: [%{{[a-zA-Z0-9_]+}}, %[[N0]]] -// CHECK: %[[STARG1:.+]] = subtensor %[[ARG1]][0, 0] -// CHECK-SAME: [%[[N0]], %[[N1]]] -// CHECK: %[[STARG2:.+]] = subtensor %[[ARG2]][%[[IV0]], 0] -// CHECK-SAME: [%{{[a-zA-Z0-9_]+}}, %[[N1]]] -// CHECK: %[[T0:.+]] = linalg.matmul -// CHECK-SAME: ins(%[[STARG0]], %[[STARG1]] -// CHECK-SAME: ) outs(%[[STARG2]] : tensor) -// CHECK: %[[T1:.+]] = linalg.matmul -// CHECK-SAME: ins(%[[T0]], %[[STARG3]] -// CHECK-SAME: ) outs(%[[STARG4]] : tensor) -// CHECK: %[[T2:.+]] = linalg.matmul -// CHECK-SAME: ins(%[[T1]], %[[ARG5]] -// CHECK-SAME: ) outs(%[[STARG6]] : tensor) -// CHECK: %[[R1:.+]] = subtensor_insert %[[T2]] -// CHECK-SAME: into %[[ARG8]][%[[IV0]], 0] -// CHECK: scf.yield %[[R1]] -// CHECK: } -// CHECK: return %[[R0]] +// CHECK: %[[TILE_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]] +// CHECK: %[[M_1:.+]] = memref.dim %[[ARG8]], %[[C0]] +// CHECK: %[[TILE_M_1:.+]] = affine.min #[[MAP1]](%[[M_1]], %[[IV0]]) +// CHECK: %[[N3:.+]] = memref.dim %[[ARG8]], %[[C1]] +// CHECK: %[[STARG6:.+]] = subtensor %[[ARG8]][%[[IV0]], 0] +// CHECK-SAME: [%[[TILE_M_1]], %[[N3]]] +// CHECK: %[[M_2:.+]] = memref.dim %[[ARG4]], %[[C0]] +// CHECK: %[[TILE_M_2:.+]] = affine.min #[[MAP2]](%[[IV0]], %[[TILE_M]])[%[[M_2]]] +// CHECK: %[[N2:.+]] = memref.dim %[[ARG4]], %[[C1]] +// CHECK: %[[STARG4:.+]] = subtensor %[[ARG4]][%[[IV0]], 0] +// CHECK-SAME: [%[[TILE_M_2]], %[[N2]]] +// CHECK: %[[TILE_M_3:.+]] = affine.min #[[MAP2]](%[[IV0]], %[[TILE_M]])[%[[M]]] +// CHECK: %[[N0:.+]] = memref.dim %[[ARG0]], %[[C1]] +// CHECK: %[[STARG0:.+]] = subtensor %[[ARG0]][%[[IV0]], 0] +// CHECK-SAME: [%[[TILE_M_3]], %[[N0]]] +// CHECK: %[[M_3:.+]] = memref.dim %[[ARG2]], %[[C0]] +// CHECK: %[[TILE_M_4:.+]] = affine.min #[[MAP2]](%[[IV0]], %[[TILE_M]])[%[[M_3]]] +// CHECK: %[[N1:.+]] = memref.dim %[[ARG2]], %[[C1]] +// CHECK: %[[STARG2:.+]] = subtensor %[[ARG2]][%[[IV0]], 0] +// CHECK-SAME: [%[[TILE_M_4]], %[[N1]]] +// CHECK: %[[T0:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[STARG0]], %[[ARG1]] : tensor, tensor +// CHECK-SAME: ) outs(%[[STARG2]] : tensor) +// CHECK: %[[T1:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[T0]], %arg3 : tensor, tensor +// CHECK-SAME: ) outs(%[[STARG4]] : tensor) +// CHECK: %[[T2:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[T1]], %arg5 : tensor, tensor +// CHECK-SAME: ) outs(%[[STARG6]] : tensor) +// CHECK: %[[R1:.+]] = subtensor_insert %[[T2]] +// CHECK-SAME: into %[[ARG8]][%[[IV0]], 0] [%[[TILE_M_1]], %[[N3]]] +// CHECK: scf.yield %[[R1]] : tensor // CHECK: } diff --git a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir --- a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir +++ b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir @@ -17,12 +17,15 @@ // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)> // CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (64, -d0 + s0)> // CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1) -> (64, d0 - d1)> +// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)> + // CHECK: func @matmul_fusion // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: tensor + // CHECK-DAG: %[[C0:.+]] = constant 0 : index // CHECK-DAG: %[[C1:.+]] = constant 1 : index // CHECK-DAG: %[[C32:.+]] = constant 32 : index @@ -38,18 +41,20 @@ // CHECK: %[[N3:.+]] = memref.dim %[[ARG6]], %[[C1]] // CHECK: %[[ST_ARG6:.+]] = subtensor %[[ARG6]][%[[IV0]], 0] // CHECK-SAME: [%[[TILE_M_2]], %[[N3]]] -// CHECK: %[[N2:.+]] = memref.dim %[[ARG1]], %[[C1]] +// CHECK: %[[TILE_M_3:.+]] = affine.min #[[MAP5]](%[[IV0]], %[[TILE_M]])[%[[M]]] // CHECK: %[[N1:.+]] = memref.dim %[[ARG0]], %[[C1]] // CHECK: %[[ST_ARG0:.+]] = subtensor %[[ARG0]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M]], %[[N1]]] -// CHECK: %[[ST_ARG1:.+]] = subtensor %[[ARG1]][0, 0] -// CHECK-SAME: [%[[N1]], %[[N2]]] +// CHECK-SAME: [%[[TILE_M_3]], %[[N1]]] +// CHECK: %[[M_3:.+]] = memref.dim %[[ARG2]], %[[C0]] +// CHECK: %[[TILE_M_4:.+]] = affine.min #[[MAP5]](%[[IV0]], %[[TILE_M]])[%[[M_3]]] +// CHECK: %[[N2_2:.+]] = memref.dim %[[ARG2]], %[[C1]] // CHECK: %[[ST_ARG2:.+]] = subtensor %[[ARG2]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M]], %[[N2]]] +// CHECK-SAME: [%[[TILE_M_4]], %[[N2_2]]] // CHECK: %[[LHS:.+]] = linalg.matmul // CHECK-SAME: __internal_linalg_transform__ = "after_lhs_fusion_producer" -// CHECK-SAME: ins(%[[ST_ARG0]], %[[ST_ARG1]] : tensor, tensor) +// CHECK-SAME: ins(%[[ST_ARG0]], %[[ARG1]] : tensor, tensor) // CHECK-SAME: outs(%[[ST_ARG2]] : tensor) +// CHECK: %[[N2:.+]] = memref.dim %[[ARG1]], %[[C1]] // CHECK: %[[N3_2:.+]] = memref.dim %[[ARG3]], %[[C1]] // CHECK: %[[YIELD0:.+]] = scf.for %[[IV1:[a-zA-Z0-9]+]] = // CHECK-SAME: %[[C0]] to %[[N3_2]] step %[[C64]] @@ -59,7 +64,7 @@ // CHECK-SAME: iter_args(%[[ARG10:.+]] = %[[ARG8]]) -> (tensor) { // CHECK: %[[TILE_N2:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[N2]]] // CHECK: %[[ST_LHS:.+]] = subtensor %[[LHS]][0, %[[IV2]]] -// CHECK-SAME: [%[[TILE_M]], %[[TILE_N2]]] +// CHECK-SAME: [%[[TILE_M_3]], %[[TILE_N2]]] // CHECK: %[[N2_3:.+]] = memref.dim %[[ARG3]], %[[C0]] // CHECK: %[[TILE_N2_2:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[N2_3]]] // CHECK: %[[TILE_N3:.+]] = affine.min #[[MAP3]](%[[IV1]])[%[[N3_2]]] diff --git a/mlir/test/Dialect/Linalg/fusion.mlir b/mlir/test/Dialect/Linalg/fusion.mlir --- a/mlir/test/Dialect/Linalg/fusion.mlir +++ b/mlir/test/Dialect/Linalg/fusion.mlir @@ -252,25 +252,36 @@ } return %E : memref } -// CHECK-LABEL: func @f5 -// CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}}) +// CHECK: #[[BOUND_2_MAP:.+]] = affine_map<(d0)[s0] -> (2, -d0 + s0)> +// CHECK: #[[BOUND_ID_MAP:.+]] = affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)> +// CHECK: #[[BOUND_4_MAP:.+]] = affine_map<(d0)[s0] -> (4, -d0 + s0)> +// CHECK: func @f5 +// HECK-SAME: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}}) // CHECK-DAG: %[[C0:.*]] = constant 0 : index // CHECK-DAG: %[[C1:.*]] = constant 1 : index -// CHECK-DAG: %[[B_1:.*]] = memref.dim %[[B]], %[[C1:.*]] : memref -// CHECK-DAG: %[[D_0:.*]] = memref.dim %[[D]], %[[C0:.*]] : memref -// CHECK-DAG: %[[D_1:.*]] = memref.dim %[[D]], %[[C1:.*]] : memref +// CHECK-DAG: %[[B_1:.*]] = memref.dim %[[B]], %[[C1]] : memref +// CHECK-DAG: %[[C_0:.*]] = memref.dim %[[C]], %[[C0]] : memref +// CHECK-DAG: %[[D_0:.*]] = memref.dim %[[D]], %[[C0]] : memref +// CHECK-DAG: %[[D_1:.*]] = memref.dim %[[D]], %[[C1]] : memref // CHECK-DAG: %[[B_00:.*]] = memref.subview %[[B]][0, 0]{{.*}} // CHECK: scf.for %[[I:.*]] = %{{.*}} to %[[D_0]] step %{{.*}} { -// CHECK-DAG: %[[A_I0:.*]] = memref.subview %[[A]][%[[I]], 0] -// CHECK-DAG: %[[C_I0:.*]] = memref.subview %[[C]][%[[I]], 0] +// CHECK: %[[BOUND_2_C0:.+]] = affine.min #[[BOUND_2_MAP]](%[[I]])[%[[C_0]]] +// CHECK: %[[C_I0:.*]] = memref.subview %[[C]][%[[I]], 0] [%[[BOUND_2_C0]] +// CHECK: %[[BOUND_2_D0:.+]] = affine.min #[[BOUND_2_MAP]](%[[I]])[%[[D_0]]] +// CHECK: %[[A_I0:.*]] = memref.subview %[[A]][%[[I]], 0] +// Note that %[[BOUND_ID_C0]] is essentially %[[BOUND_2_C0]]. +// CHECK: %[[BOUND_ID_C0:.+]] = affine.min #[[BOUND_ID_MAP]](%[[I]], %[[BOUND_2_C0]])[%[[C_0]]] +// CHECK: %[[C_I0_OUT:.*]] = memref.subview %[[C]][%[[I]], 0] [%[[BOUND_ID_C0]] // CHECK: scf.for %[[J:.*]] = %{{.*}} to %[[B_1]] step %{{.*}} { // CHECK: %[[E_IJ:.*]] = memref.subview %[[E]][%[[I]], %[[J]]] // CHECK: scf.for %[[K:.*]] = %{{.*}} to %[[D_1]] step %{{.*}} { -// CHECK-DAG: %[[D_IK:.*]] = memref.subview %[[D]][%[[I]], %[[K]]] -// CHECK-DAG: %[[B_0K:.*]] = memref.subview %[[B]][0, %[[K]]] -// CHECK-DAG: %[[B_KJ:.*]] = memref.subview %[[B]][%[[K]], %[[J]]] -// CHECK: linalg.matmul ins(%[[A_I0]], %[[B_00]]{{.*}} outs(%[[C_I0]] -// CHECK: linalg.matmul ins(%[[C_I0]], %[[B_0K]]{{.*}} outs(%[[D_IK]] +// CHECK: %[[D_IK:.*]] = memref.subview %[[D]][%[[I]], %[[K]]] [2, 4] +// CHECK: %[[B_KJ:.*]] = memref.subview %[[B]][%[[K]], %[[J]]] +// CHECK: %[[B_0K:.*]] = memref.subview %[[B]][0, %[[K]]] +// CHECK: %[[BOUND_4_D1:.+]] = affine.min #[[BOUND_4_MAP]](%[[K]])[%[[D_1]]] +// CHECK: %[[D_IK_OUT:.+]] = memref.subview %[[D]][%[[I]], %[[K]]] [%[[BOUND_2_D0]], %[[BOUND_4_D1]]] +// CHECK: linalg.matmul ins(%[[A_I0]], %[[B_00]]{{.*}} outs(%[[C_I0_OUT]] +// CHECK: linalg.matmul ins(%[[C_I0]], %[[B_0K]]{{.*}} outs(%[[D_IK_OUT]] // CHECK: linalg.matmul ins(%[[D_IK]], %[[B_KJ]]{{.*}} outs(%[[E_IJ]] // ----- diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir --- a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir @@ -1,11 +1,5 @@ // RUN: mlir-opt %s -test-linalg-greedy-fusion -split-input-file | FileCheck %s -#map0 = affine_map<(d0)[s0] -> (2, -d0 + s0)> -#map1 = affine_map<(d0)[s0] -> (4, -d0 + s0)> -#map2 = affine_map<(d0)[s0] -> (3, -d0 + s0)> -#map3 = affine_map<(d0, d1) -> (2, d0 - d1)> -#map4 = affine_map<(d0, d1) -> (3, d0 - d1)> - func @matmul_tensors(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { %t0 = linalg.matmul ins(%arg0, %arg1: tensor, tensor) outs(%arg2: tensor) @@ -36,23 +30,250 @@ return %3 : tensor } -// CHECK-LABEL: func @matmul_tensors( +// CHECK: #[[BOUND2_MAP:.+]] = affine_map<(d0)[s0] -> (2, -d0 + s0)> +// CHECK: #[[BOUND4_MAP:.+]] = affine_map<(d0)[s0] -> (4, -d0 + s0)> + +// CHECK: func @matmul_tensors( // CHECK-SAME: %[[A:[0-9a-z]*]]: tensor // CHECK-SAME: %[[B:[0-9a-z]*]]: tensor // CHECK-SAME: %[[C:[0-9a-z]*]]: tensor + // CHECK-DAG: %[[C0:.*]] = constant 0 : index // CHECK-DAG: %[[C1:.*]] = constant 1 : index +// CHECK-DAG: %[[dA0:.*]] = memref.dim %[[A]], %[[C0]] : tensor // CHECK-DAG: %[[dA1:.*]] = memref.dim %[[A]], %[[C1]] : tensor +// CHECK-DAG: %[[dB0:.*]] = memref.dim %[[B]], %[[C0]] : tensor +// CHECK-DAG: %[[dB1:.*]] = memref.dim %[[B]], %[[C1]] : tensor +// CHECK-DAG: %[[dC0:.*]] = memref.dim %[[C]], %[[C0]] : tensor +// CHECK-DAG: %[[dC1:.*]] = memref.dim %[[C]], %[[C1]] : tensor // CHECK: scf.for %[[I:[0-9a-z]*]] -// CHECK: %[[stA:.*]] = subtensor %[[A]][%[[I]], 0] [2, %[[dA1]]] [1, 1] : tensor to tensor<2x?xf32> +// CHECK: %[[sizeA0:.*]] = affine.min #[[BOUND2_MAP]](%[[I]])[%[[dA0]]] +// CHECK: %[[stA:.*]] = subtensor %[[A]][%[[I]], 0] [%[[sizeA0]], %[[dA1]]] [1, 1] : tensor to tensor +// CHECK: %[[sizeC0:.*]] = affine.min #[[BOUND2_MAP]](%[[I]])[%[[dC0]]] // CHECK-NEXT: scf.for %[[J:[0-9a-z]*]] // CHECK-NEXT: scf.for %[[K:[0-9a-z]*]] {{.*}} iter_args(%[[RES:[0-9a-z]*]] // CHECK-DAG: %[[stB1:.*]] = subtensor %[[B]][%[[K]], %[[J]]] [4, 3] [1, 1] : tensor to tensor<4x3xf32> // CHECK-DAG: %[[stF:.*]] = subtensor %[[RES]][%[[I]], %[[J]]] [2, 3] [1, 1] : tensor to tensor<2x3xf32> // // subtensors of the producing matmul. -// CHECK-DAG: %[[stB2:.*]] = subtensor %[[B]][0, %[[K]]] [%[[dA1]], 4] [1, 1] : tensor to tensor -// CHECK-DAG: %[[stC:.*]] = subtensor %[[C]][%[[I]], %[[K]]] [2, 4] [1, 1] : tensor to tensor<2x4xf32> -// CHECK: %[[stD:.*]] = linalg.matmul ins(%[[stA]], %[[stB2]] : tensor<2x?xf32>, tensor) outs(%[[stC]] : tensor<2x4xf32>) -> tensor<2x4xf32> -// CHECK-NEXT: %[[stG:.*]] = linalg.matmul ins(%[[stD]], %[[stB1]] : tensor<2x4xf32>, tensor<4x3xf32>) outs(%[[stF]] : tensor<2x3xf32>) -> tensor<2x3xf32> +// CHECK: %[[sizeB1:.*]] = affine.min #[[BOUND4_MAP]](%[[K]])[%[[dB1]]] +// CHECK: %[[stB2:.*]] = subtensor %[[B]][0, %[[K]]] [%[[dB0]], %[[sizeB1]]] [1, 1] : tensor to tensor +// CHECK: %[[sizeC1:.*]] = affine.min #[[BOUND4_MAP]](%[[K]])[%[[dC1]]] +// CHECK: %[[stC:.*]] = subtensor %[[C]][%[[I]], %[[K]]] [%[[sizeC0]], %[[sizeC1]]] [1, 1] : tensor to tensor +// CHECK: %[[stD:.*]] = linalg.matmul ins(%[[stA]], %[[stB2]] : tensor, tensor) outs(%[[stC]] : tensor) -> tensor +// CHECK: %[[CAST:.*]] = tensor.cast %[[stD]] : tensor to tensor +// CHECK-NEXT: %[[stG:.*]] = linalg.matmul ins(%[[CAST]], %[[stB1]] : tensor, tensor<4x3xf32>) outs(%[[stF]] : tensor<2x3xf32>) -> tensor<2x3xf32> // CHECK-NEXT: subtensor_insert %[[stG]] into %[[RES]][%[[I]], %[[J]]] + +// ----- + +func @conv_tensors_static(%input: tensor<1x225x225x32xf32>, %filter: tensor<3x3x3x32xf32>, %elementwise: tensor<1x112x112x32xf32>) -> tensor<1x112x112x32xf32> { + %c112 = constant 112 : index + %c32 = constant 32 : index + %c16 = constant 16 : index + %c8 = constant 8 : index + %c4 = constant 4 : index + %c0 = constant 0 : index + %cst = constant 0.0 : f32 + + %init = linalg.init_tensor [1, 112, 112, 32] : tensor<1x112x112x32xf32> + %fill = linalg.fill(%init, %cst) : tensor<1x112x112x32xf32>, f32 -> tensor<1x112x112x32xf32> + + %conv = linalg.conv_2d_input_nhwc_filter_hwcf + {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} + ins(%input, %filter : tensor<1x225x225x32xf32>, tensor<3x3x3x32xf32>) + outs(%fill : tensor<1x112x112x32xf32>) -> tensor<1x112x112x32xf32> + + %for0 = scf.for %iv0 = %c0 to %c112 step %c8 iter_args(%arg0 = %fill) -> tensor<1x112x112x32xf32> { + %for1 = scf.for %iv1 = %c0 to %c112 step %c16 iter_args(%arg1 = %arg0) -> tensor<1x112x112x32xf32> { + %for2 = scf.for %iv2 = %c0 to %c32 step %c4 iter_args(%arg2 = %arg1) -> tensor<1x112x112x32xf32> { + %0 = subtensor %conv[0, %iv0, %iv1, %iv2][1, 8, 16, 4][1, 1, 1, 1] : tensor<1x112x112x32xf32> to tensor<1x8x16x4xf32> + %1 = subtensor %elementwise[0, %iv0, %iv1, %iv2][1, 8, 16, 4][1, 1, 1, 1] : tensor<1x112x112x32xf32> to tensor<1x8x16x4xf32> + %2 = subtensor %arg2[0, %iv0, %iv1, %iv2][1, 8, 16, 4][1, 1, 1, 1] : tensor<1x112x112x32xf32> to tensor<1x8x16x4xf32> + %add = linalg.generic + { + indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"] + } + ins(%0, %1 : tensor<1x8x16x4xf32>, tensor<1x8x16x4xf32>) outs(%2 : tensor<1x8x16x4xf32>) { + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): + %result = addf %arg3, %arg4 : f32 + linalg.yield %result : f32 + } -> tensor<1x8x16x4xf32> + + %insert = subtensor_insert %add into %arg2[0, %iv0, %iv1, %iv2] [1, 8, 16, 4] [1, 1, 1, 1] : tensor<1x8x16x4xf32> into tensor<1x112x112x32xf32> + scf.yield %insert : tensor<1x112x112x32xf32> + } + scf.yield %for2 : tensor<1x112x112x32xf32> + } + scf.yield %for1 : tensor<1x112x112x32xf32> + } + return %for0 : tensor<1x112x112x32xf32> +} + +// CHECK: #[[MAP0:.+]] = affine_map<(d0) -> (d0 * 2)> +// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> + +// CHECK: func @conv_tensors_static +// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x225x225x32xf32>, %[[FILTER:.+]]: tensor<3x3x3x32xf32>, %[[ELEM:.+]]: tensor<1x112x112x32xf32>) + +// CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 112, 112, 32] : tensor<1x112x112x32xf32> +// CHECK-NEXT: %[[FILL:.+]] = linalg.fill(%[[INIT]], %cst) : tensor<1x112x112x32xf32>, f32 -> tensor<1x112x112x32xf32> + +// CHECK-NEXT: scf.for %[[IV0:.+]] = %{{.+}} to %{{.+}} step %{{.+}} iter_args(%[[ARG0:.+]] = %[[FILL]]) +// CHECK-NEXT: %[[OFFSET_H:.+]] = affine.apply #[[MAP0]](%[[IV0]]) +// CHECK-NEXT: scf.for %[[IV1:.+]] = %{{.+}} to %{{.+}} step %{{.+}} iter_args(%[[ARG1:.+]] = %[[ARG0]]) +// CHECK-NEXT: %[[OFFSET_W:.+]] = affine.apply #[[MAP0]](%[[IV1]]) +// CHECK-NEXT: %[[ST_INPUT:.+]] = subtensor %arg0[0, %[[OFFSET_H]], %[[OFFSET_W]], 0] [1, 17, 33, 32] [1, 1, 1, 1] : tensor<1x225x225x32xf32> to tensor<1x17x33x32xf32> +// CHECK-NEXT: scf.for %[[IV2:.+]] = %{{.+}} to %{{.+}} step %{{.+}} iter_args(%[[ARG2:.+]] = %[[ARG1]]) +// CHECK-NEXT: %[[ST_ELEM:.+]] = subtensor %[[ELEM]][0, %[[IV0]], %[[IV1]], %[[IV2]]] [1, 8, 16, 4] [1, 1, 1, 1] : tensor<1x112x112x32xf32> to tensor<1x8x16x4xf32> +// CHECK-NEXT: %[[ST_ARG2:.+]] = subtensor %[[ARG2]][0, %[[IV0]], %[[IV1]], %[[IV2]]] [1, 8, 16, 4] [1, 1, 1, 1] : tensor<1x112x112x32xf32> to tensor<1x8x16x4xf32> +// CHECK-NEXT: %[[ST_FILTER:.+]] = subtensor %[[FILTER]][0, 0, 0, %[[IV2]]] [3, 3, 3, 4] [1, 1, 1, 1] : tensor<3x3x3x32xf32> to tensor<3x3x3x4xf32> +// CHECK-NEXT: %[[ST_FILL:.+]] = subtensor %[[FILL]][0, %[[IV0]], %[[IV1]], %[[IV2]]] [1, 8, 16, 4] [1, 1, 1, 1] : tensor<1x112x112x32xf32> to tensor<1x8x16x4xf32> +// CHECK-NEXT: %[[ST_CONV:.+]] = linalg.conv_2d_input_nhwc_filter_hwcf +// CHECK-SAME: ins(%[[ST_INPUT]], %[[ST_FILTER]] : tensor<1x17x33x32xf32>, tensor<3x3x3x4xf32>) +// CHECK-SAME: outs(%[[ST_FILL]] : tensor<1x8x16x4xf32>) +// CHECK-NEXT: %[[ADD:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ST_CONV]], %[[ST_ELEM]] : tensor<1x8x16x4xf32>, tensor<1x8x16x4xf32>) +// CHECK-SAME: outs(%[[ST_ARG2]] : tensor<1x8x16x4xf32>) +// CHECK: subtensor_insert %[[ADD]] into %[[ARG2]][0, %[[IV0]], %[[IV1]], %[[IV2]]] [1, 8, 16, 4] + +// ----- + +#bound4_map = affine_map<(d0)[s0] -> (4, -d0 + s0)> +#bound8_map = affine_map<(d0)[s0] -> (8, -d0 + s0)> +#bound16_map = affine_map<(d0)[s0] -> (16, -d0 + s0)> + +func @conv_tensors_dynamic(%input: tensor, %filter: tensor, %elementwise: tensor) -> tensor { + %cst = constant 0.0 : f32 + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + %c3 = constant 3 : index + %c4 = constant 4 : index + %c8 = constant 8 : index + %c16 = constant 16 : index + + %n = memref.dim %elementwise, %c0 : tensor + %oh = memref.dim %elementwise, %c1 : tensor + %ow = memref.dim %elementwise, %c2 : tensor + %oc = memref.dim %elementwise, %c3 : tensor + + %init = linalg.init_tensor [%n, %oh, %ow, %oc] : tensor + %fill = linalg.fill(%init, %cst) : tensor, f32 -> tensor + + %conv = linalg.conv_2d_input_nhwc_filter_hwcf + {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} + ins(%input, %filter : tensor, tensor) + outs(%fill : tensor) -> tensor + + %for0 = scf.for %iv0 = %c0 to %oh step %c8 iter_args(%arg0 = %fill) -> tensor { + %for1 = scf.for %iv1 = %c0 to %ow step %c16 iter_args(%arg1 = %arg0) -> tensor { + %for2 = scf.for %iv2 = %c0 to %oc step %c4 iter_args(%arg2 = %arg1) -> tensor { + %for3 = scf.for %iv3 = %c0 to %oc step %c2 iter_args(%arg3 = %arg2) -> tensor { + %n_size = affine.min #bound8_map(%iv0)[%n] + %oh_size = affine.min #bound16_map(%iv1)[%oh] + %ow_size = affine.min #bound4_map(%iv2)[%ow] + %oc_size = affine.min #bound4_map(%iv2)[%oc] + %0 = subtensor %conv[%iv0, %iv1, %iv2, %iv3][%n_size, %oh_size, %ow_size, %oc_size][1, 1, 1, 1] : tensor to tensor + %1 = subtensor %elementwise[%iv0, %iv1, %iv2, %iv3][%n_size, %oh_size, %ow_size, %oc_size][1, 1, 1, 1] : tensor to tensor + %2 = subtensor %arg3[%iv0, %iv1, %iv2, %iv3][%n_size, %oh_size, %ow_size, %oc_size][1, 1, 1, 1] : tensor to tensor + %add = linalg.generic + { + indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"] + } + ins(%0, %1 : tensor, tensor) outs(%2 : tensor) { + ^bb0(%arg4: f32, %arg5: f32, %arg6: f32): + %result = addf %arg4, %arg5 : f32 + linalg.yield %result : f32 + } -> tensor + + %insert = subtensor_insert %add into %arg3[%iv0, %iv1, %iv2, %iv3] [%n_size, %oh_size, %ow_size, %oc_size] [1, 1, 1, 1] : tensor into tensor + scf.yield %insert : tensor + } + scf.yield %for3 : tensor + } + scf.yield %for2 : tensor + } + scf.yield %for1 : tensor + } + return %for0 : tensor +} + +// ----- + +// CHECK: #[[BOUND8_MAP:.+]] = affine_map<(d0)[s0] -> (8, -d0 + s0)> +// CHECK: #[[BOUND_MAP:.+]] = affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)> +// CHECK: #[[BOUND16_MAP:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)> +// CHECK: #[[X2_MAP:.+]] = affine_map<(d0) -> (d0 * 2)> +// CHECK: #[[INPUT_BOUND:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * 2 + s0 - 2, d1 * -2 + s1)> +// CHECK: #[[BOUND4_MAP:.+]] = affine_map<(d0)[s0] -> (4, -d0 + s0)> + +// CHECK: func @conv_tensors_dynamic +// CHECK-SAME: (%[[INPUT]]: tensor, %[[FILTER]]: tensor, %[[ELEM]]: tensor) + +// CHECK-DAG: %[[C0:.+]] = constant 0 : index +// CHECK-DAG: %[[C1:.+]] = constant 1 : index +// CHECK-DAG: %[[C2:.+]] = constant 2 : index +// CHECK-DAG: %[[C3:.+]] = constant 3 : index + +// CHECK-DAG: %[[ELEM_N:.+]] = memref.dim %[[ELEM]], %[[C0]] : tensor +// CHECK-DAG: %[[ELEM_OH:.+]] = memref.dim %[[ELEM]], %[[C1]] : tensor +// CHECK-DAG: %[[ELEM_OW:.+]] = memref.dim %[[ELEM]], %[[C2]] : tensor +// CHECK-DAG: %[[ELEM_OC:.+]] = memref.dim %[[ELEM]], %[[C3]] : tensor + +// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[ELEM_N]], %[[ELEM_OH]], %[[ELEM_OW]], %[[ELEM_OC]]] : tensor +// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %cst) : tensor, f32 -> tensor + +// CHECK-DAG: %[[FILTER_H:.+]] = memref.dim %[[FILTER]], %[[C0]] : tensor +// CHECK-DAG: %[[FILTER_W:.+]] = memref.dim %[[FILTER]], %[[C1]] : tensor +// CHECK-DAG: %[[INPUT_N:.+]] = memref.dim %[[INPUT]], %[[C0]] : tensor +// CHECK-DAG: %[[INPUT_H:.+]] = memref.dim %[[INPUT]], %[[C1]] : tensor +// CHECK-DAG: %[[INPUT_W:.+]] = memref.dim %[[INPUT]], %[[C2]] : tensor +// CHECK-DAG: %[[INPUT_C:.+]] = memref.dim %[[INPUT]], %[[C3]] : tensor +// CHECK-DAG: %[[FILTER_IC:.+]] = memref.dim %[[FILTER]], %[[C2]] : tensor +// CHECK-DAG: %[[FILTER_OC:.+]] = memref.dim %[[FILTER]], %[[C3]] : tensor + +// CHECK: scf.for %[[IV0:.+]] = %{{.+}} to %[[ELEM_OH]] step %{{.+}} iter_args(%{{.+}} = %[[FILL]]) +// CHECK-NEXT: %[[SIZE_ELEM_N:.+]] = affine.min #[[BOUND8_MAP]](%[[IV0]])[%[[ELEM_N]]] +// CHECK-NEXT: %[[SIZE_INPUT_N:.+]] = affine.min #[[BOUND_MAP]](%[[IV0]], %[[SIZE_ELEM_N]])[%[[INPUT_N]]] +// CHECK-NEXT: %[[SIZE_ELEM_N_2:.+]] = affine.min #[[BOUND_MAP]](%[[IV0]], %[[SIZE_ELEM_N]])[%[[ELEM_N]]] +// CHECK-NEXT: scf.for %[[IV1:.+]] = %{{.+}} to %[[ELEM_OW]] +// CHECK-NEXT: %[[SIZE_ELEM_OH:.+]] = affine.min #[[BOUND16_MAP]](%[[IV1]])[%[[ELEM_OH]]] +// CHECK-NEXT: %[[OFFSET_OH:.+]] = affine.apply #[[X2_MAP]](%[[IV1]]) +// CHECK-NEXT: %[[SIZE_INPUT_H:.+]] = affine.min #[[INPUT_BOUND]](%[[SIZE_ELEM_OH]], %[[IV1]])[%[[FILTER_H]], %[[INPUT_H]]] +// CHECK-NEXT: %[[SIZE_ELEM_OH_2:.+]] = affine.min #[[BOUND_MAP]](%[[IV1]], %[[SIZE_ELEM_OH]])[%[[ELEM_OH]]] +// CHECK-NEXT: scf.for %[[IV2:.+]] = %{{.+}} to %[[ELEM_OC]] +// CHECK-NEXT: %[[SIZE_ELEM_OW:.+]] = affine.min #[[BOUND4_MAP]](%[[IV2]])[%[[ELEM_OW]]] +// CHECK-NEXT: %[[SIZE_ELEM_OC:.+]] = affine.min #[[BOUND4_MAP]](%[[IV2]])[%[[ELEM_OC]]] +// CHECK-NEXT: %[[OFFSET_OW:.+]] = affine.apply #[[X2_MAP]](%[[IV2]]) +// CHECK-NEXT: %[[SIZE_INPUT_W:.+]] = affine.min #[[INPUT_BOUND]](%[[SIZE_ELEM_OW]], %[[IV2]])[%[[FILTER_W]], %[[INPUT_W]]] +// CHECK-NEXT: %[[ST_INPUT:.+]] = subtensor %[[INPUT]][%[[IV0]], %[[OFFSET_OH]], %[[OFFSET_OW]], 0] +// CHECK-SAME: [%[[SIZE_INPUT_N]], %[[SIZE_INPUT_H]], %[[SIZE_INPUT_W]], %[[INPUT_C]]] +// CHECK-NEXT: %[[SIZE_ELEM_OW_2:.+]] = affine.min #[[BOUND_MAP]](%[[IV2]], %[[SIZE_ELEM_OW]])[%[[ELEM_OW]]] +// CHECK-NEXT: scf.for %[[IV3:.+]] = %{{.+}} to %[[ELEM_OC]] step %{{.+}} iter_args(%[[ARG:[a-z0-9]+]] +// CHECK-NEXT: %[[ST_ELEM:.+]] = subtensor %[[ELEM]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]] +// CHECK-SAME: [%[[SIZE_ELEM_N]], %[[SIZE_ELEM_OH]], %[[SIZE_ELEM_OW]], %[[SIZE_ELEM_OC]]] +// CHECK-NEXT: %[[ST_ARG:.+]] = subtensor %[[ARG]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]] +// CHECK-SAME: [%[[SIZE_ELEM_N]], %[[SIZE_ELEM_OH]], %[[SIZE_ELEM_OW]], %[[SIZE_ELEM_OC]]] +// CHECK-NEXT: %[[SIZE_ELEM_OC_2:.+]] = affine.min #[[BOUND_MAP]](%[[IV3]], %[[SIZE_ELEM_OC]])[%[[FILTER_OC]]] +// CHECK-NEXT: %[[ST_FILTER:.+]] = subtensor %[[FILTER]][0, 0, 0, %[[IV3]]] +// CHECK-SAME: [%[[FILTER_H]], %[[FILTER_W]], %[[FILTER_IC]], %[[SIZE_ELEM_OC_2]]] +// CHECK-NEXT: %[[SIZE_ELEM_OC_3:.+]] = affine.min #[[BOUND_MAP]](%[[IV3]], %[[SIZE_ELEM_OC]])[%[[ELEM_OC]]] +// CHECK-NEXT: %[[ST_FILL:.+]] = subtensor %[[FILL]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]] +// CHECK-SAME: [%[[SIZE_ELEM_N_2]], %[[SIZE_ELEM_OH_2]], %[[SIZE_ELEM_OW_2]], %[[SIZE_ELEM_OC_3]]] +// CHECK-NEXT: %[[ST_CONV:.+]] = linalg.conv_2d_input_nhwc_filter_hwcf +// CHECK-SAME: ins(%[[ST_INPUT]], %[[ST_FILTER]] : tensor, tensor) +// CHECK-SAME: outs(%[[ST_FILL]] : tensor) -> tensor +// CHECK-NEXT: %[[ST_ADD:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ST_CONV]], %[[ST_ELEM]] : tensor, tensor) +// CHECK-SAME: outs(%[[ST_ARG]] : tensor) +// CHECK: subtensor_insert %[[ST_ADD]] into %[[ARG]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]] +// CHECK-SAME: [%[[SIZE_ELEM_N]], %[[SIZE_ELEM_OH]], %[[SIZE_ELEM_OW]], %[[SIZE_ELEM_OC]]] diff --git a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp --- a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp @@ -179,6 +179,10 @@ namespace { struct TestLinalgGreedyFusion : public PassWrapper { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } void runOnFunction() override { MLIRContext *context = &getContext(); RewritePatternSet patterns =