diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td --- a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td @@ -114,4 +114,9 @@ "succeeded(promoteSubviewsLinalgOpPrecondition(op))">; def PromoteSubviewsLinalgOp : NativeCodeCall< "promoteSubviewsLinalgOp($_builder, op)">; + +class PromoteSelectedSubviewsLinalgOp operands, string marker=""> : + NativeCodeCall<"promoteSelectedSubviewsLinalgOpAndSetMarker($_builder, op, {" # + StrJoinInt.result # "}, \"" # marker # "\")">; + #endif // LINALG_TRANSFORMS diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h @@ -121,6 +121,9 @@ SmallVector promoteSubviewsLinalgOp(PatternRewriter &rewriter, Operation *op); +SmallVector promoteSelectedSubviewsLinalgOpAndSetMarker( + PatternRewriter &rewriter, Operation *op, + ArrayRef operandIndicesToPromote, StringRef linalgMarker = ""); } // namespace linalg } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp @@ -338,6 +338,23 @@ assert(succeeded(promoteSubviewsLinalgOpPrecondition(op)) && "DRR failure case must be a precondition"); + LinalgOp linOp = cast(op); + SmallVector toPromote; + for (int64_t i = 0; i < linOp.getNumInputsAndOutputBuffers(); ++i) + toPromote.push_back(i); + return promoteSelectedSubviewsLinalgOpAndSetMarker(rewriter, op, toPromote); +} + +SmallVector +mlir::linalg::promoteSelectedSubviewsLinalgOpAndSetMarker( + PatternRewriter &rewriter, Operation *op, + ArrayRef operandIndicesToPromote, StringRef linalgMarker) { + LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: Promote subviews for linalg op: " + << *op << ":\n"); + + assert(succeeded(promoteSubviewsLinalgOpPrecondition(op)) && + "DRR failure case must be a precondition"); + if (auto convOp = dyn_cast(op)) { // TODO(ntv): add a level of indirection to linalg.generic. if (convOp.padding()) @@ -348,11 +365,17 @@ assert(linOp.hasBufferSemantics() && "expected linalg op with buffer semantics"); SetVector subViews; - for (auto it : linOp.getInputsAndOutputBuffers()) - if (auto sv = dyn_cast_or_null(it.getDefiningOp())) + for (int64_t index : operandIndicesToPromote) + if (auto sv = + dyn_cast_or_null(linOp.getBuffer(index).getDefiningOp())) subViews.insert(sv); + if (!subViews.empty()) { - promoteSubViewOperands(rewriter, linOp, subViews); + auto newOp = + promoteSubViewOperands(rewriter, linOp, subViews); + if (linalgMarker != "") + newOp.setAttr(LinalgTransforms::kLinalgTransformMarker, + rewriter.getStringAttr(linalgMarker)); return {}; } llvm_unreachable("DRR failure case must be a precondition"); diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir --- a/mlir/test/Dialect/Linalg/transform-patterns.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir @@ -395,3 +395,45 @@ // CHECK : linalg.copy(%[[s1]], %[[l1]]) : memref, memref // CHECK : linalg.copy(%[[s2]], %[[l2]]) : memref, memref // CHECK : linalg.matmul(%[[v0]], %[[v1]], %[[v2]]) : memref, memref, memref + +func @promote_first_subview_matmul(%arg0: memref, + %arg1: memref, + %arg2: memref) { + %c2000 = constant 2000 : index + %c3000 = constant 3000 : index + %c4000 = constant 4000 : index + %c0 = constant 0 : index + %c1 = constant 1 : index + %0 = dim %arg0, 0 : memref + %1 = dim %arg0, 1 : memref + %2 = dim %arg1, 1 : memref + loop.for %arg3 = %c0 to %0 step %c2000 { + loop.for %arg4 = %c0 to %2 step %c3000 { + loop.for %arg5 = %c0 to %1 step %c4000 { + %3 = std.subview %arg0[%arg3, %arg5][%c2000, %c4000][%c1, %c1] : + memref to memref + %4 = std.subview %arg1[%arg5, %arg4][%c4000, %c3000][%c1, %c1] : + memref to memref + %5 = std.subview %arg2[%arg3, %arg4][%c2000, %c3000][%c1, %c1] : + memref to memref + linalg.matmul(%3, %4, %5) {__internal_linalg_transform__ = "_promote_first_view_"} : + memref, + memref, + memref + } + } + } + return +} +// CHECK-LABEL: func @promote_first_subview_matmul +// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c2000]] { +// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c3000]] { +// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c4000]] { +// CHECK : %[[s0:.*]] = std.subview {{%.*}}[{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] : memref to memref +// CHECK : %[[s1:.*]] = std.subview {{%.*}}[{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] : memref to memref +// CHECK : %[[s2:.*]] = std.subview {{%.*}}[{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] : memref to memref +// CHECK : %[[a0:.*]] = alloc({{%.*}}) : memref +// CHECK : %[[v0:.*]] = std.view %[[a0]][][{{%.*}}, {{%.*}}]: memref to memref +// CHECK : %[[l0:.*]] = linalg.slice %[[v0]][{{%.*}}, {{%.*}}] : memref, !linalg.range, !linalg.range, memref +// CHECK : linalg.copy(%[[s0]], %[[l0]]) : memref, memref +// CHECK : linalg.matmul(%[[v0]], %[[s1]], %[[s2]]) : memref, memref, memref diff --git a/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td b/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td --- a/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td +++ b/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td @@ -149,4 +149,12 @@ HasLinalgTransformMarker<"_promote_views_">]>> )]>; +def : Pat<(MatmulOp:$op $_, $_, $_), + (PromoteSelectedSubviewsLinalgOp<[0, 1], "first_view_promotion">), + [(Constraint, + HasLinalgTransformMarker<"_promote_first_view_">]>> + )]>; + #endif // TEST_LINALG_TRANSFORMS_PATTERNS