diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -2355,6 +2355,102 @@ } }; +/// Merge an affine min/max op to its consumers if its consumer is also an +/// affine min/max op. +/// +/// This pattern requires the producer affine min/max op is bound to a +/// dimension/symbol that is used as a standalone expression in the consumer +/// affine op's map. +/// +/// For example, a pattern like the following: +/// +/// %0 = affine.min affine_map<()[s0] -> (s0 + 16, s0 * 8)> ()[%sym1] +/// %1 = affine.min affine_map<(d0)[s0] -> (s0 + 4, d0)> (%0)[%sym2] +/// +/// Can be turned into: +/// +/// %1 = affine.min affine_map< +/// ()[s0, s1] -> (s0 + 4, s1 + 16, s1 * 8)> ()[%sym2, %sym1] +template +struct MergeAffineMinMaxOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(T affineOp, + PatternRewriter &rewriter) const override { + AffineMap oldMap = affineOp.getAffineMap(); + ValueRange dimOperands = + affineOp.getMapOperands().take_front(oldMap.getNumDims()); + ValueRange symOperands = + affineOp.getMapOperands().take_back(oldMap.getNumSymbols()); + + auto newDimOperands = llvm::to_vector<8>(dimOperands); + auto newSymOperands = llvm::to_vector<8>(symOperands); + SmallVector newExprs; + SmallVector producerOps; + + // Go over each expression to see whether it's a single dimension/symbol + // with the corresponding operand which is the result of another affine + // min/max op. If So it can be merged into this affine op. + for (AffineExpr expr : oldMap.getResults()) { + if (auto symExpr = expr.dyn_cast()) { + Value symValue = symOperands[symExpr.getPosition()]; + if (auto producerOp = symValue.getDefiningOp()) { + producerOps.push_back(producerOp); + continue; + } + } else if (auto dimExpr = expr.dyn_cast()) { + Value dimValue = dimOperands[dimExpr.getPosition()]; + if (auto producerOp = dimValue.getDefiningOp()) { + producerOps.push_back(producerOp); + continue; + } + } + // For the above cases we will remove the expression by merging the + // producer affine min/max's affine expressions. Otherwise we need to + // keep the existing expression. + newExprs.push_back(expr); + } + + if (producerOps.empty()) + return failure(); + + unsigned numUsedDims = oldMap.getNumDims(); + unsigned numUsedSyms = oldMap.getNumSymbols(); + + // Now go over all producer affine ops and merge their expressions. + for (T producerOp : producerOps) { + AffineMap producerMap = producerOp.getAffineMap(); + unsigned numProducerDims = producerMap.getNumDims(); + unsigned numProducerSyms = producerMap.getNumSymbols(); + + // Collect all dimension/symbol values. + ValueRange dimValues = + producerOp.getMapOperands().take_front(numProducerDims); + ValueRange symValues = + producerOp.getMapOperands().take_back(numProducerSyms); + newDimOperands.append(dimValues.begin(), dimValues.end()); + newSymOperands.append(symValues.begin(), symValues.end()); + + // For expressions we need to shift to avoid overlap. + for (AffineExpr expr : producerMap.getResults()) { + newExprs.push_back(expr.shiftDims(numProducerDims, numUsedDims) + .shiftSymbols(numProducerSyms, numUsedSyms)); + } + + numUsedDims += numProducerDims; + numUsedSyms += numProducerSyms; + } + + auto newMap = AffineMap::get(numUsedDims, numUsedSyms, newExprs, + rewriter.getContext()); + auto newOperands = + llvm::to_vector<8>(llvm::concat(newDimOperands, newSymOperands)); + rewriter.replaceOpWithNewOp(affineOp, newMap, newOperands); + + return success(); + } +}; + //===----------------------------------------------------------------------===// // AffineMinOp //===----------------------------------------------------------------------===// @@ -2368,8 +2464,10 @@ void AffineMinOp::getCanonicalizationPatterns( OwningRewritePatternList &patterns, MLIRContext *context) { - patterns.insert, - SimplifyAffineOp>(context); + patterns + .insert, + MergeAffineMinMaxOp, SimplifyAffineOp>( + context); } //===----------------------------------------------------------------------===// @@ -2385,8 +2483,10 @@ void AffineMaxOp::getCanonicalizationPatterns( OwningRewritePatternList &patterns, MLIRContext *context) { - patterns.insert, - SimplifyAffineOp>(context); + patterns + .insert, + MergeAffineMinMaxOp, SimplifyAffineOp>( + context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir --- a/mlir/test/Dialect/Affine/canonicalize.mlir +++ b/mlir/test/Dialect/Affine/canonicalize.mlir @@ -718,3 +718,169 @@ %0 = affine.max affine_map<()[s0, s1] -> (s0 + s1, s0 * s1, s1 + s0, s0 * s1)> ()[%i0, %i1] return %0: index } + +// ----- + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> (s0 * 3, 16, -s1 + s2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> (-s1 + 5, 16, -s0 + s2)> + +// CHECK: func @merge_affine_min_ops +// CHECK-SAME: (%[[I0:.+]]: index, %[[I1:.+]]: index, %[[I2:.+]]: index, %[[I3:.+]]: index) +func @merge_affine_min_ops(%i0: index, %i1: index, %i2: index, %i3: index) -> (index, index) { + %0 = affine.min affine_map<(d0)[s0] -> (16, d0 - s0)> (%i0)[%i1] + + // CHECK: affine.min #[[MAP0]]()[%[[I2]], %[[I1]], %[[I0]]] + %1 = affine.min affine_map<(d0)[s0] -> (3 * s0, d0)> (%0)[%i2] // Use as dim + // CHECK: affine.min #[[MAP1]]()[%[[I1]], %[[I3]], %[[I0]]] + %2 = affine.min affine_map<(d0)[s0] -> (s0, 5 - d0)> (%i3)[%0] // Use as symbol + + return %1, %2: index, index +} + +// ----- + +// CHECK: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s0 + 7, s1 + 16, s1 * 8, s2 + 8, s2 * 4)> + +// CHECK: func @merge_multiple_affine_min_ops +// CHECK-SAME: (%[[I0:.+]]: index, %[[I1:.+]]: index, %[[I2:.+]]: index) +func @merge_multiple_affine_min_ops(%i0: index, %i1: index, %i2: index) -> index { + %0 = affine.min affine_map<()[s0] -> (s0 + 16, s0 * 8)> ()[%i0] + %1 = affine.min affine_map<()[s0] -> (s0 + 8, s0 * 4)> ()[%i1] + // CHECK: affine.min #[[MAP]]()[%[[I2]], %[[I0]], %[[I1]]] + %2 = affine.min affine_map<()[s0, s1, s2] -> (s0, 7 + s1, s2)> ()[%0, %i2, %1] + return %2: index +} + +// ----- + +// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 2, s1 + 16, s1 * 8)> + +// CHECK: func @merge_multiple_uses_of_affine_min_ops +// CHECK-SAME: (%[[I0:.+]]: index, %[[I1:.+]]: index) +func @merge_multiple_uses_of_affine_min_ops(%i0: index, %i1: index) -> index { + %0 = affine.min affine_map<()[s0] -> (s0 + 16, s0 * 8)> ()[%i0] + // CHECK: affine.min #[[MAP]]()[%[[I1]], %[[I0]]] + %2 = affine.min affine_map<()[s0, s1, s2] -> (s0, s1, s2 * 2)> ()[%0, %0, %i1] + return %2: index +} + +// ----- + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 + 16, s0 * 8)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> (s0 + 1, s1 * 2, s2 + 16, s2 * 8)> + +// CHECK: func @merge_mixed_uses_of_affine_min_ops +// CHECK-SAME: (%[[I0:.+]]: index, %[[I1:.+]]: index) +func @merge_mixed_uses_of_affine_min_ops(%i0: index, %i1: index) -> index { + // CHECK: %[[AFFINE:.+]] = affine.min #[[MAP0]]()[%[[I0]]] + %0 = affine.min affine_map<()[s0] -> (s0 + 16, s0 * 8)> ()[%i0] + // %0 is bound to a symbol that is both a standalone expression and a part + // of other expressions. + // CHECK: affine.min #[[MAP1]]()[%[[AFFINE]], %[[I1]], %[[I0]]] + %2 = affine.min affine_map<()[s0, s1, s2] -> (s0, s1 + 1, s2 * 2)> ()[%0, %0, %i1] + return %2: index +} + +// ----- + +// CHECK-LABEL: func @dont_merge_affine_min_if_not_single_dim +func @dont_merge_affine_min_if_not_single_dim(%i0: index, %i1: index, %i2: index) -> index { + // CHECK-COUNT-2: affine.min + %0 = affine.min affine_map<()[s0] -> (s0 + 16, s0 * 8)> ()[%i0] + %1 = affine.min affine_map<(d0)[s0] -> (s0 + 4, 7 + d0)> (%0)[%i2] + return %1: index +} + + +// ----- + +// CHECK-LABEL: func @dont_merge_affine_min_if_not_single_sym +func @dont_merge_affine_min_if_not_single_sym(%i0: index, %i1: index, %i2: index) -> index { + // CHECK-COUNT-2: affine.min + %0 = affine.min affine_map<()[s0] -> (s0 + 16, s0 * 8)> ()[%i0] + %1 = affine.min affine_map<()[s0, s1] -> (s0 + 4, 7 + s1)> ()[%0, %i2] + return %1: index +} + +// ----- + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> (s0 * 3, 16, -s1 + s2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> (-s1 + 5, 16, -s0 + s2)> + +// CHECK: func @merge_affine_max_ops +// CHECK-SAME: (%[[I0:.+]]: index, %[[I1:.+]]: index, %[[I2:.+]]: index, %[[I3:.+]]: index) +func @merge_affine_max_ops(%i0: index, %i1: index, %i2: index, %i3: index) -> (index, index) { + %0 = affine.max affine_map<(d0)[s0] -> (16, d0 - s0)> (%i0)[%i1] + + // CHECK: affine.max #[[MAP0]]()[%[[I2]], %[[I1]], %[[I0]]] + %1 = affine.max affine_map<(d0)[s0] -> (3 * s0, d0)> (%0)[%i2] // Use as dim + // CHECK: affine.max #[[MAP1]]()[%[[I1]], %[[I3]], %[[I0]]] + %2 = affine.max affine_map<(d0)[s0] -> (s0, 5 - d0)> (%i3)[%0] // Use as symbol + + return %1, %2: index, index +} + +// ----- + +// CHECK: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s0 + 7, s1 + 16, s1 * 8, s2 + 8, s2 * 4)> + +// CHECK: func @merge_multiple_affine_max_ops +// CHECK-SAME: (%[[I0:.+]]: index, %[[I1:.+]]: index, %[[I2:.+]]: index) +func @merge_multiple_affine_max_ops(%i0: index, %i1: index, %i2: index) -> index { + %0 = affine.max affine_map<()[s0] -> (s0 + 16, s0 * 8)> ()[%i0] + %1 = affine.max affine_map<()[s0] -> (s0 + 8, s0 * 4)> ()[%i1] + // CHECK: affine.max #[[MAP]]()[%[[I2]], %[[I0]], %[[I1]]] + %2 = affine.max affine_map<()[s0, s1, s2] -> (s0, 7 + s1, s2)> ()[%0, %i2, %1] + return %2: index +} + +// ----- + +// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 2, s1 + 16, s1 * 8)> + +// CHECK: func @merge_multiple_uses_of_affine_max_ops +// CHECK-SAME: (%[[I0:.+]]: index, %[[I1:.+]]: index) +func @merge_multiple_uses_of_affine_max_ops(%i0: index, %i1: index) -> index { + %0 = affine.max affine_map<()[s0] -> (s0 + 16, s0 * 8)> ()[%i0] + // CHECK: affine.max #[[MAP]]()[%[[I1]], %[[I0]]] + %2 = affine.max affine_map<()[s0, s1, s2] -> (s0, s1, s2 * 2)> ()[%0, %0, %i1] + return %2: index +} + +// ----- + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 + 16, s0 * 8)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> (s0 + 1, s1 * 2, s2 + 16, s2 * 8)> + +// CHECK: func @merge_mixed_uses_of_affine_max_ops +// CHECK-SAME: (%[[I0:.+]]: index, %[[I1:.+]]: index) +func @merge_mixed_uses_of_affine_max_ops(%i0: index, %i1: index) -> index { + // CHECK: %[[AFFINE:.+]] = affine.max #[[MAP0]]()[%[[I0]]] + %0 = affine.max affine_map<()[s0] -> (s0 + 16, s0 * 8)> ()[%i0] + // %0 is bound to a symbol that is both a standalone expression and a part + // of other expressions. + // CHECK: affine.max #[[MAP1]]()[%[[AFFINE]], %[[I1]], %[[I0]]] + %2 = affine.max affine_map<()[s0, s1, s2] -> (s0, s1 + 1, s2 * 2)> ()[%0, %0, %i1] + return %2: index +} + +// ----- + +// CHECK-LABEL: func @dont_merge_affine_max_if_not_single_dim +func @dont_merge_affine_max_if_not_single_dim(%i0: index, %i1: index, %i2: index) -> index { + // CHECK-COUNT-2: affine.max + %0 = affine.max affine_map<()[s0] -> (s0 + 16, s0 * 8)> ()[%i0] + %1 = affine.max affine_map<(d0)[s0] -> (s0 + 4, 7 + d0)> (%0)[%i2] + return %1: index +} + + +// ----- + +// CHECK-LABEL: func @dont_merge_affine_max_if_not_single_sym +func @dont_merge_affine_max_if_not_single_sym(%i0: index, %i1: index, %i2: index) -> index { + // CHECK-COUNT-2: affine.max + %0 = affine.max affine_map<()[s0] -> (s0 + 16, s0 * 8)> ()[%i0] + %1 = affine.max affine_map<()[s0, s1] -> (s0 + 4, 7 + s1)> ()[%0, %i2] + return %1: index +} diff --git a/mlir/test/Dialect/Linalg/fusion-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-pattern.mlir --- a/mlir/test/Dialect/Linalg/fusion-pattern.mlir +++ b/mlir/test/Dialect/Linalg/fusion-pattern.mlir @@ -16,7 +16,8 @@ // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (64, -d0 + s0)> // CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)> -// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)> +// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s0, 32, -d0 + s1)> +// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s0, 64, -d0 + s1)> // CHECK: func @basic_fusion // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref @@ -48,8 +49,8 @@ // CHECK: %[[TILE_N_2:.+]] = affine.min #[[MAP2]](%[[IV1]])[%[[N_2]]] // CHECK: %[[SV3:.+]] = memref.subview %[[ARG2]][%[[IV0]], %[[IV1]]] // CHECK-SAME: [%[[TILE_M_2]], %[[TILE_N_2]]] -// CHECK: %[[TILE_M_3:.+]] = affine.min #[[MAP4]](%[[IV0]], %[[TILE_M]])[%[[M_2]]] -// CHECK: %[[TILE_N_3:.+]] = affine.min #[[MAP4]](%[[IV1]], %[[TILE_N]])[%[[N_2]]] +// CHECK: %[[TILE_M_3:.+]] = affine.min #[[MAP4]](%[[IV0]])[%[[M_2]], %[[M]]] +// CHECK: %[[TILE_N_3:.+]] = affine.min #[[MAP5]](%[[IV1]])[%[[N_2]], %[[N]]] // CHECK: %[[SV3_2:.+]] = memref.subview %[[ARG2]][%[[IV0]], %[[IV1]]] // CHECK-SAME: [%[[TILE_M_3]], %[[TILE_N_3]]] // CHECK: linalg.fill(%[[SV3_2]], %[[CST]]) @@ -89,7 +90,7 @@ // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (32, -d0 + s0)> // CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)> -// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)> +// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s0, 64, -d0 + s1)> // CHECK: func @rhs_fusion // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref @@ -117,10 +118,10 @@ // CHECK-SAME: [%[[M]], %[[TILE_N_2]]] // CHECK: %[[K_2:.+]] = memref.dim %[[ARG1]], %[[C0]] // CHECK: %[[N_3:.+]] = memref.dim %[[ARG1]], %[[C1]] -// CHECK: %[[TILE_N_3:.+]] = affine.min #[[MAP4]](%[[IV0]], %[[TILE_N]])[%[[N_3]]] +// CHECK: %[[TILE_N_3:.+]] = affine.min #[[MAP4]](%[[IV0]])[%[[N_3]], %[[N]]] // CHECK: %[[SV3:.+]] = memref.subview %[[ARG1]][0, %[[IV0]]] // CHECK-SAME: [%[[K_2]], %[[TILE_N_3]]] -// CHECK: %[[TILE_N_4:.+]] = affine.min #[[MAP4]](%[[IV0]], %[[TILE_N]])[%[[N]]] +// CHECK: %[[TILE_N_4:.+]] = affine.min #[[MAP4]](%[[IV0]])[%[[N]], %[[N]]] // CHECK: %[[SV3_2:.+]] = memref.subview %[[ARG2]][0, %[[IV0]]] // CHECK-SAME: [%[[K]], %[[TILE_N_4]]] // CHECK: linalg.copy(%[[SV3]], %[[SV3_2]]) @@ -171,7 +172,7 @@ // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)> // CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (64, -d0 + s0)> -// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)> +// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s0, 32, -d0 + s1)> // CHECK: func @two_operand_fusion // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref @@ -199,15 +200,15 @@ // CHECK: %[[N:.+]] = memref.dim %[[ARG3]], %[[C1]] // CHECK: %[[SV2:.+]] = memref.subview %[[ARG3]][%[[IV0]], 0] // CHECK-SAME: [%[[TILE_M_2]], %[[N]]] -// CHECK: %[[TILE_M_3:.+]] = affine.min #[[MAP4]](%[[IV0]], %[[TILE_M]])[%[[M_2]]] +// CHECK: %[[TILE_M_3:.+]] = affine.min #[[MAP4]](%[[IV0]])[%[[M_2]], %[[M]]] // CHECK: %[[SV2_2:.+]] = memref.subview %[[ARG3]][%[[IV0]], 0] // CHECK-SAME: [%[[TILE_M_3]], %[[N]]] // CHECK: %[[M_3:.+]] = memref.dim %[[ARG0]], %[[C0]] -// CHECK: %[[TILE_M_4:.+]] = affine.min #[[MAP4]](%[[IV0]], %[[TILE_M]])[%[[M_3]]] -// CHECK: %[[K_2:.+]] = memref.dim %[[ARG0]], %[[C1]] +// CHECK: %[[TILE_M_4:.+]] = affine.min #[[MAP4]](%[[IV0]])[%[[M_3]], %[[M]]] +// CHECK: %[[K_3:.+]] = memref.dim %[[ARG0]], %[[C1]] // CHECK: %[[SV3:.+]] = memref.subview %[[ARG0]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M_4]], %[[K_2]]] -// CHECK: %[[TILE_M_5:.+]] = affine.min #[[MAP4]](%[[IV0]], %[[TILE_M]])[%[[M]]] +// CHECK-SAME: [%[[TILE_M_4]], %[[K_3]]] +// CHECK: %[[TILE_M_5:.+]] = affine.min #[[MAP4]](%[[IV0]])[%[[M]], %[[M]]] // CHECK: %[[SV3_2:.+]] = memref.subview %[[ARG1]][%[[IV0]], 0] // CHECK-SAME: [%[[TILE_M_5]], %[[K]]] // CHECK: linalg.copy(%[[SV3]], %[[SV3_2]]) @@ -258,6 +259,7 @@ // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)> // CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (64, -d0 + s0)> +// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s0, 32, -d0 + s1)> // CHECK: func @matmul_fusion // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref @@ -284,11 +286,11 @@ // CHECK: %[[SV2:.+]] = memref.subview %[[ARG4]][%[[IV0]], 0] // CHECK-SAME: [%[[TILE_M_2]], %[[N]]] // CHECK: %[[M_3:.+]] = memref.dim %[[ARG0]], %[[C0]] -// CHECK: %[[TILE_M_3:.+]] = affine.min #[[MAP4]](%[[IV0]], %[[TILE_M]])[%[[M_3]]] +// CHECK: %[[TILE_M_3:.+]] = affine.min #[[MAP4]](%[[IV0]])[%[[M_3]], %[[M]]] // CHECK: %[[K1:.+]] = memref.dim %[[ARG0]], %[[C1]] // CHECK: %[[SV3:.+]] = memref.subview %[[ARG0]][%[[IV0]], 0] // CHECK-SAME: [%[[TILE_M_3]], %[[K1]]] -// CHECK: %[[TILE_M_4:.+]] = affine.min #[[MAP4]](%[[IV0]], %[[TILE_M]])[%[[M]]] +// CHECK: %[[TILE_M_4:.+]] = affine.min #[[MAP4]](%[[IV0]])[%[[M]], %[[M]]] // CHECK: %[[SV1_2:.+]] = memref.subview %[[ARG2]][%[[IV0]], 0] // CHECK-SAME: [%[[TILE_M_4]], %[[K2]]] // CHECK: linalg.matmul diff --git a/mlir/test/Dialect/Linalg/fusion-sequence.mlir b/mlir/test/Dialect/Linalg/fusion-sequence.mlir --- a/mlir/test/Dialect/Linalg/fusion-sequence.mlir +++ b/mlir/test/Dialect/Linalg/fusion-sequence.mlir @@ -84,7 +84,9 @@ // CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)> // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s0, 16, -d0 + s1)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 16)> + // CHECK: func @sequence_of_matmul // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref @@ -110,17 +112,18 @@ // CHECK: %[[N3:.+]] = memref.dim %[[ARG4]], %[[C1]] // CHECK: %[[SV_ARG4:.+]] = memref.subview %[[ARG4]][%[[IV0]], 0] // CHECK-SAME: [%[[TILE_M_2]], %[[N3]]] -// CHECK: %[[TILE_M_3:.+]] = affine.min #[[MAP2]](%[[IV0]], %[[TILE_M]])[%[[M_2]]] +// CHECK: %[[TILE_M_3:.+]] = affine.min #[[MAP2]](%[[IV0]])[%[[M_2]], %[[M]]] // CHECK: %[[SV_ARG4_2:.+]] = memref.subview %[[ARG4]][%[[IV0]], 0] // CHECK-SAME: [%[[TILE_M_3]], %[[N3]]] -// CHECK: %[[TILE_M_4:.+]] = affine.min #[[MAP2]](%[[IV0]], %[[TILE_M]])[%[[M]]] +// CHECK: %[[TILE_M_4:.+]] = affine.min #[[MAP3]](%[[IV0]])[%[[M]]] // CHECK: %[[SV_ALLOC1:.+]] = memref.subview %[[ALLOC1]][%[[IV0]], 0] // CHECK-SAME: [%[[TILE_M_4]], %[[N1]]] // CHECK: %[[SV_ALLOC2:.+]] = memref.subview %[[ALLOC2]][%[[IV0]], 0] // CHECK-SAME: [%[[TILE_M_4]], %[[N2]]] +// CHECK: %[[TILE_M_5:.+]] = affine.min #[[MAP2]](%[[IV0]])[%[[M]], %[[M]]] // CHECK: %[[N0:.+]] = memref.dim %[[ARG0]], %[[C1]] // CHECK: %[[SV_ARG0:.+]] = memref.subview %[[ARG0]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M_4]], %[[N0]]] +// CHECK-SAME: [%[[TILE_M_5]], %[[N0]]] // CHECK: linalg.fill(%[[SV_ALLOC1]], %{{.+}}) // CHECK: linalg.matmul ins(%[[SV_ARG0]], %[[ARG1]] // CHECK-SAME: : memref, memref) @@ -207,9 +210,8 @@ } } -// CHECK: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)> -// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1) -> (16, d0 - d1)> -// CHECK: #[[MAP2:.+]] = affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)> +// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1) -> (16, d0 - d1)> +// CHECK: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s0, 16, -d0 + s1)> // CHECK: func @tensor_matmul_fusion( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor @@ -224,23 +226,22 @@ // CHECK: %[[M:.+]] = memref.dim %[[ARG0]], %c0 : tensor // CHECK: %[[R0:.+]] = scf.for %[[IV0:[a-zA-Z0-9_]+]] = // CHECK-SAME: iter_args(%[[ARG8:.+]] = %[[ARG6]]) -> (tensor) { -// CHECK: %[[TILE_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]] // CHECK: %[[M_1:.+]] = memref.dim %[[ARG8]], %[[C0]] -// CHECK: %[[TILE_M_1:.+]] = affine.min #[[MAP1]](%[[M_1]], %[[IV0]]) +// CHECK: %[[TILE_M_1:.+]] = affine.min #[[MAP0]](%[[M_1]], %[[IV0]]) // CHECK: %[[N3:.+]] = memref.dim %[[ARG8]], %[[C1]] // CHECK: %[[STARG6:.+]] = subtensor %[[ARG8]][%[[IV0]], 0] // CHECK-SAME: [%[[TILE_M_1]], %[[N3]]] // CHECK: %[[M_2:.+]] = memref.dim %[[ARG4]], %[[C0]] -// CHECK: %[[TILE_M_2:.+]] = affine.min #[[MAP2]](%[[IV0]], %[[TILE_M]])[%[[M_2]]] +// CHECK: %[[TILE_M_2:.+]] = affine.min #[[MAP1]](%[[IV0]])[%[[M_2]], %[[M]]] // CHECK: %[[N2:.+]] = memref.dim %[[ARG4]], %[[C1]] // CHECK: %[[STARG4:.+]] = subtensor %[[ARG4]][%[[IV0]], 0] // CHECK-SAME: [%[[TILE_M_2]], %[[N2]]] -// CHECK: %[[TILE_M_3:.+]] = affine.min #[[MAP2]](%[[IV0]], %[[TILE_M]])[%[[M]]] +// CHECK: %[[TILE_M_3:.+]] = affine.min #[[MAP1]](%[[IV0]])[%[[M]], %[[M]]] // CHECK: %[[N0:.+]] = memref.dim %[[ARG0]], %[[C1]] // CHECK: %[[STARG0:.+]] = subtensor %[[ARG0]][%[[IV0]], 0] // CHECK-SAME: [%[[TILE_M_3]], %[[N0]]] // CHECK: %[[M_3:.+]] = memref.dim %[[ARG2]], %[[C0]] -// CHECK: %[[TILE_M_4:.+]] = affine.min #[[MAP2]](%[[IV0]], %[[TILE_M]])[%[[M_3]]] +// CHECK: %[[TILE_M_4:.+]] = affine.min #[[MAP1]](%[[IV0]])[%[[M_3]], %[[M]]] // CHECK: %[[N1:.+]] = memref.dim %[[ARG2]], %[[C1]] // CHECK: %[[STARG2:.+]] = subtensor %[[ARG2]][%[[IV0]], 0] // CHECK-SAME: [%[[TILE_M_4]], %[[N1]]] diff --git a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir --- a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir +++ b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir @@ -12,12 +12,11 @@ return %1 : tensor } } -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (32, -d0 + s0)> // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (32, d0 - d1)> // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)> // CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (64, -d0 + s0)> // CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1) -> (64, d0 - d1)> -// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)> +// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s0, 32, -d0 + s1)> // CHECK: func @matmul_fusion // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor @@ -35,18 +34,17 @@ // CHECK: %[[RESULT:.+]] = scf.for %[[IV0:[a-zA-Z0-9]+]] = // CHECK-SAME: %[[C0]] to %[[M]] step %[[C32]] // CHECK-SAME: iter_args(%[[ARG6:.+]] = %[[ARG4]]) -> (tensor) { -// CHECK: %[[TILE_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]] // CHECK: %[[M_2:.+]] = memref.dim %[[ARG6]], %[[C0]] // CHECK: %[[TILE_M_2:.+]] = affine.min #[[MAP1]](%[[M_2]], %[[IV0]]) // CHECK: %[[N3:.+]] = memref.dim %[[ARG6]], %[[C1]] // CHECK: %[[ST_ARG6:.+]] = subtensor %[[ARG6]][%[[IV0]], 0] // CHECK-SAME: [%[[TILE_M_2]], %[[N3]]] -// CHECK: %[[TILE_M_3:.+]] = affine.min #[[MAP5]](%[[IV0]], %[[TILE_M]])[%[[M]]] +// CHECK: %[[TILE_M_3:.+]] = affine.min #[[MAP5]](%[[IV0]])[%[[M]], %[[M]]] // CHECK: %[[N1:.+]] = memref.dim %[[ARG0]], %[[C1]] // CHECK: %[[ST_ARG0:.+]] = subtensor %[[ARG0]][%[[IV0]], 0] // CHECK-SAME: [%[[TILE_M_3]], %[[N1]]] // CHECK: %[[M_3:.+]] = memref.dim %[[ARG2]], %[[C0]] -// CHECK: %[[TILE_M_4:.+]] = affine.min #[[MAP5]](%[[IV0]], %[[TILE_M]])[%[[M_3]]] +// CHECK: %[[TILE_M_4:.+]] = affine.min #[[MAP5]](%[[IV0]])[%[[M_3]], %[[M]]] // CHECK: %[[N2_2:.+]] = memref.dim %[[ARG2]], %[[C1]] // CHECK: %[[ST_ARG2:.+]] = subtensor %[[ARG2]][%[[IV0]], 0] // CHECK-SAME: [%[[TILE_M_4]], %[[N2_2]]] diff --git a/mlir/test/Dialect/Linalg/fusion.mlir b/mlir/test/Dialect/Linalg/fusion.mlir --- a/mlir/test/Dialect/Linalg/fusion.mlir +++ b/mlir/test/Dialect/Linalg/fusion.mlir @@ -252,10 +252,11 @@ } return %E : memref } -// CHECK: #[[BOUND_2_MAP:.+]] = affine_map<(d0)[s0] -> (2, -d0 + s0)> -// CHECK: #[[BOUND_ID_MAP:.+]] = affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)> -// CHECK: #[[BOUND_4_MAP:.+]] = affine_map<(d0)[s0] -> (4, -d0 + s0)> -// CHECK: func @f5 + +// CHECK-DAG: #[[BOUND_2_MAP:.+]] = affine_map<(d0)[s0] -> (2, -d0 + s0)> +// CHECK-DAG: #[[BOUND_2_MAP_2:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s0, 2, -d0 + s1)> +// CHECK-DAG: #[[BOUND_4_MAP:.+]] = affine_map<(d0)[s0] -> (4, -d0 + s0)> +// CHECK: func @f5 // HECK-SAME: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}}) // CHECK-DAG: %[[C0:.*]] = constant 0 : index // CHECK-DAG: %[[C1:.*]] = constant 1 : index @@ -269,8 +270,7 @@ // CHECK: %[[C_I0:.*]] = memref.subview %[[C]][%[[I]], 0] [%[[BOUND_2_C0]] // CHECK: %[[BOUND_2_D0:.+]] = affine.min #[[BOUND_2_MAP]](%[[I]])[%[[D_0]]] // CHECK: %[[A_I0:.*]] = memref.subview %[[A]][%[[I]], 0] -// Note that %[[BOUND_ID_C0]] is essentially %[[BOUND_2_C0]]. -// CHECK: %[[BOUND_ID_C0:.+]] = affine.min #[[BOUND_ID_MAP]](%[[I]], %[[BOUND_2_C0]])[%[[C_0]]] +// CHECK: %[[BOUND_ID_C0:.+]] = affine.min #[[BOUND_2_MAP_2]](%[[I]])[%[[C_0]], %[[C_0]]] // CHECK: %[[C_I0_OUT:.*]] = memref.subview %[[C]][%[[I]], 0] [%[[BOUND_ID_C0]] // CHECK: scf.for %[[J:.*]] = %{{.*}} to %[[B_1]] step %{{.*}} { // CHECK: %[[E_IJ:.*]] = memref.subview %[[E]][%[[I]], %[[J]]] diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir --- a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir @@ -210,11 +210,16 @@ // ----- // CHECK: #[[BOUND8_MAP:.+]] = affine_map<(d0)[s0] -> (8, -d0 + s0)> -// CHECK: #[[BOUND_MAP:.+]] = affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)> +// CHECK: #[[BOUND8_MAP_2:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s0, 8, -d0 + s1)> +// CHECK: #[[BOUND8_MAP_3:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 8)> // CHECK: #[[BOUND16_MAP:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)> // CHECK: #[[X2_MAP:.+]] = affine_map<(d0) -> (d0 * 2)> // CHECK: #[[INPUT_BOUND:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * 2 + s0 - 2, d1 * -2 + s1)> +// CHECK: #[[BOUND16_MAP_2:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 16)> // CHECK: #[[BOUND4_MAP:.+]] = affine_map<(d0)[s0] -> (4, -d0 + s0)> +// CHECK: #[[BOUND4_MAP_2:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 4)> +// CHECK: #[[BOUND4_MAP_3:.+]] = affine_map<(d0, d1)[s0, s1] -> (-d0 + s0, 4, -d1 + s1)> +// CHECK: #[[BOUND4_MAP_4:.+]] = affine_map<(d0, d1)[s0] -> (-d0 + s0, 4, -d1 + s0)> // CHECK: func @conv_tensors_dynamic // CHECK-SAME: (%[[INPUT]]: tensor, %[[FILTER]]: tensor, %[[ELEM]]: tensor) @@ -243,13 +248,13 @@ // CHECK: scf.for %[[IV0:.+]] = %{{.+}} to %[[ELEM_OH]] step %{{.+}} iter_args(%{{.+}} = %[[FILL]]) // CHECK-NEXT: %[[SIZE_ELEM_N:.+]] = affine.min #[[BOUND8_MAP]](%[[IV0]])[%[[ELEM_N]]] -// CHECK-NEXT: %[[SIZE_INPUT_N:.+]] = affine.min #[[BOUND_MAP]](%[[IV0]], %[[SIZE_ELEM_N]])[%[[INPUT_N]]] -// CHECK-NEXT: %[[SIZE_ELEM_N_2:.+]] = affine.min #[[BOUND_MAP]](%[[IV0]], %[[SIZE_ELEM_N]])[%[[ELEM_N]]] +// CHECK-NEXT: %[[SIZE_INPUT_N:.+]] = affine.min #[[BOUND8_MAP_2]](%[[IV0]])[%[[INPUT_N]], %[[ELEM_N]]] +// CHECK-NEXT: %[[SIZE_ELEM_N_2:.+]] = affine.min #[[BOUND8_MAP_3]](%[[IV0]])[%[[ELEM_N]]] // CHECK-NEXT: scf.for %[[IV1:.+]] = %{{.+}} to %[[ELEM_OW]] // CHECK-NEXT: %[[SIZE_ELEM_OH:.+]] = affine.min #[[BOUND16_MAP]](%[[IV1]])[%[[ELEM_OH]]] // CHECK-NEXT: %[[OFFSET_OH:.+]] = affine.apply #[[X2_MAP]](%[[IV1]]) // CHECK-NEXT: %[[SIZE_INPUT_H:.+]] = affine.min #[[INPUT_BOUND]](%[[SIZE_ELEM_OH]], %[[IV1]])[%[[FILTER_H]], %[[INPUT_H]]] -// CHECK-NEXT: %[[SIZE_ELEM_OH_2:.+]] = affine.min #[[BOUND_MAP]](%[[IV1]], %[[SIZE_ELEM_OH]])[%[[ELEM_OH]]] +// CHECK-NEXT: %[[SIZE_ELEM_OH_2:.+]] = affine.min #[[BOUND16_MAP_2]](%[[IV1]])[%[[ELEM_OH]]] // CHECK-NEXT: scf.for %[[IV2:.+]] = %{{.+}} to %[[ELEM_OC]] // CHECK-NEXT: %[[SIZE_ELEM_OW:.+]] = affine.min #[[BOUND4_MAP]](%[[IV2]])[%[[ELEM_OW]]] // CHECK-NEXT: %[[SIZE_ELEM_OC:.+]] = affine.min #[[BOUND4_MAP]](%[[IV2]])[%[[ELEM_OC]]] @@ -257,16 +262,16 @@ // CHECK-NEXT: %[[SIZE_INPUT_W:.+]] = affine.min #[[INPUT_BOUND]](%[[SIZE_ELEM_OW]], %[[IV2]])[%[[FILTER_W]], %[[INPUT_W]]] // CHECK-NEXT: %[[ST_INPUT:.+]] = subtensor %[[INPUT]][%[[IV0]], %[[OFFSET_OH]], %[[OFFSET_OW]], 0] // CHECK-SAME: [%[[SIZE_INPUT_N]], %[[SIZE_INPUT_H]], %[[SIZE_INPUT_W]], %[[INPUT_C]]] -// CHECK-NEXT: %[[SIZE_ELEM_OW_2:.+]] = affine.min #[[BOUND_MAP]](%[[IV2]], %[[SIZE_ELEM_OW]])[%[[ELEM_OW]]] +// CHECK-NEXT: %[[SIZE_ELEM_OW_2:.+]] = affine.min #[[BOUND4_MAP_2]](%[[IV2]])[%[[ELEM_OW]]] // CHECK-NEXT: scf.for %[[IV3:.+]] = %{{.+}} to %[[ELEM_OC]] step %{{.+}} iter_args(%[[ARG:[a-z0-9]+]] // CHECK-NEXT: %[[ST_ELEM:.+]] = subtensor %[[ELEM]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]] // CHECK-SAME: [%[[SIZE_ELEM_N]], %[[SIZE_ELEM_OH]], %[[SIZE_ELEM_OW]], %[[SIZE_ELEM_OC]]] // CHECK-NEXT: %[[ST_ARG:.+]] = subtensor %[[ARG]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]] // CHECK-SAME: [%[[SIZE_ELEM_N]], %[[SIZE_ELEM_OH]], %[[SIZE_ELEM_OW]], %[[SIZE_ELEM_OC]]] -// CHECK-NEXT: %[[SIZE_ELEM_OC_2:.+]] = affine.min #[[BOUND_MAP]](%[[IV3]], %[[SIZE_ELEM_OC]])[%[[FILTER_OC]]] +// CHECK-NEXT: %[[SIZE_ELEM_OC_2:.+]] = affine.min #[[BOUND4_MAP_3]](%[[IV3]], %[[IV2]])[%[[FILTER_OC]], %[[ELEM_OC]]] // CHECK-NEXT: %[[ST_FILTER:.+]] = subtensor %[[FILTER]][0, 0, 0, %[[IV3]]] // CHECK-SAME: [%[[FILTER_H]], %[[FILTER_W]], %[[FILTER_IC]], %[[SIZE_ELEM_OC_2]]] -// CHECK-NEXT: %[[SIZE_ELEM_OC_3:.+]] = affine.min #[[BOUND_MAP]](%[[IV3]], %[[SIZE_ELEM_OC]])[%[[ELEM_OC]]] +// CHECK-NEXT: %[[SIZE_ELEM_OC_3:.+]] = affine.min #[[BOUND4_MAP_4]](%[[IV3]], %[[IV2]])[%[[ELEM_OC]]] // CHECK-NEXT: %[[ST_FILL:.+]] = subtensor %[[FILL]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]] // CHECK-SAME: [%[[SIZE_ELEM_N_2]], %[[SIZE_ELEM_OH_2]], %[[SIZE_ELEM_OW_2]], %[[SIZE_ELEM_OC_3]]] // CHECK-NEXT: %[[ST_CONV:.+]] = linalg.conv_2d_input_nhwc_filter_hwcf