diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td --- a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td @@ -114,4 +114,9 @@ "succeeded(promoteSubviewsLinalgOpPrecondition(op))">; def PromoteSubviewsLinalgOp : NativeCodeCall< "promoteSubviewsLinalgOp($_builder, op)">; + +class PromoteSelectedSubviewsLinalgOp operands, string marker=""> : + NativeCodeCall<"promoteSelectedSubviewsLinalgOpAndSetMarker($_builder, op, {" # + StrJoinInt.result # "}, \"" # marker # "\")">; + #endif // LINALG_TRANSFORMS diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h @@ -121,6 +121,14 @@ SmallVector promoteSubviewsLinalgOp(PatternRewriter &rewriter, Operation *op); +/// Similar to `promoteSubviewsLinalgOp` but only tries to promote +/// the views corresponding to the operands specified in +/// `operandIndicesToPromote`. +/// If linalgMarker is specified and the transformation is successfull +/// sets the attribute `kLinalgTransformMarker` to `linalgMarker`. +SmallVector promoteSelectedSubviewsLinalgOpAndSetMarker( + PatternRewriter &rewriter, Operation *op, + ArrayRef operandIndicesToPromote, StringRef linalgMarker = ""); } // namespace linalg } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp @@ -338,6 +338,24 @@ assert(succeeded(promoteSubviewsLinalgOpPrecondition(op)) && "DRR failure case must be a precondition"); + LinalgOp linOp = cast(op); + SmallVector toPromote; + int64_t nBuffers = linOp.getNumInputsAndOutputBuffers(); + toPromote.reserve(nBuffers); + for (int64_t i = 0; i < nBuffers; ++i) + toPromote.push_back(i); + return promoteSelectedSubviewsLinalgOpAndSetMarker(rewriter, op, toPromote); +} + +SmallVector mlir::linalg::promoteSelectedSubviewsLinalgOpAndSetMarker( + PatternRewriter &rewriter, Operation *op, + ArrayRef operandIndicesToPromote, StringRef linalgMarker) { + LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: Promote subviews for linalg op: " + << *op << ":\n"); + + assert(succeeded(promoteSubviewsLinalgOpPrecondition(op)) && + "DRR failure case must be a precondition"); + if (auto convOp = dyn_cast(op)) { // TODO(ntv): add a level of indirection to linalg.generic. if (convOp.padding()) @@ -348,11 +366,16 @@ assert(linOp.hasBufferSemantics() && "expected linalg op with buffer semantics"); SetVector subViews; - for (auto it : linOp.getInputsAndOutputBuffers()) - if (auto sv = dyn_cast_or_null(it.getDefiningOp())) + for (int64_t index : operandIndicesToPromote) + if (auto sv = + dyn_cast_or_null(linOp.getBuffer(index).getDefiningOp())) subViews.insert(sv); + if (!subViews.empty()) { - promoteSubViewOperands(rewriter, linOp, subViews); + auto newOp = promoteSubViewOperands(rewriter, linOp, subViews); + if (!linalgMarker.empty()) + newOp.setAttr(LinalgTransforms::kLinalgTransformMarker, + rewriter.getStringAttr(linalgMarker)); return {}; } llvm_unreachable("DRR failure case must be a precondition"); diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir --- a/mlir/test/Dialect/Linalg/transform-patterns.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir @@ -395,3 +395,53 @@ // CHECK : linalg.copy(%[[s1]], %[[l1]]) : memref, memref // CHECK : linalg.copy(%[[s2]], %[[l2]]) : memref, memref // CHECK : linalg.matmul(%[[v0]], %[[v1]], %[[v2]]) : memref, memref, memref + +func @promote_first_subview_matmul(%arg0: memref, + %arg1: memref, + %arg2: memref) { + %c2000 = constant 2000 : index + %c3000 = constant 3000 : index + %c4000 = constant 4000 : index + %c0 = constant 0 : index + %c1 = constant 1 : index + %0 = dim %arg0, 0 : memref + %1 = dim %arg0, 1 : memref + %2 = dim %arg1, 1 : memref + loop.for %arg3 = %c0 to %0 step %c2000 { + loop.for %arg4 = %c0 to %2 step %c3000 { + loop.for %arg5 = %c0 to %1 step %c4000 { + %3 = std.subview %arg0[%arg3, %arg5][%c2000, %c4000][%c1, %c1] : + memref to memref + %4 = std.subview %arg1[%arg5, %arg4][%c4000, %c3000][%c1, %c1] : + memref to memref + %5 = std.subview %arg2[%arg3, %arg4][%c2000, %c3000][%c1, %c1] : + memref to memref + linalg.matmul(%3, %4, %5) {__internal_linalg_transform__ = "_promote_first_view_"} : + memref, + memref, + memref + } + } + } + return +} +// CHECK-LABEL: func @promote_first_subview_matmul +// CHECK: loop.for {{.*}} = %c0 to {{.*}} step %c2000 { +// CHECK: loop.for {{.*}} = %c0 to {{.*}} step %c3000 { +// CHECK: loop.for {{.*}} = %c0 to {{.*}} step %c4000 { +// CHECK: %[[s0:.*]] = subview {{%.*}}[{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] : memref to memref +// CHECK: %[[s1:.*]] = subview {{%.*}}[{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] : memref to memref +// CHECK: %[[s2:.*]] = subview {{%.*}}[{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] : memref to memref +// CHECK: %[[a0:.*]] = alloc({{%.*}}) : memref +// CHECK: %[[v0:.*]] = std.view %[[a0]][][{{%.*}}, {{%.*}}] : memref to memref +// CHECK: %[[l0:.*]] = subview %[[v0]][{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] : memref to memref +// CHECK-NOT: %[[a1:.*]] = alloc({{%.*}}) : memref +// CHECK-NOT: %[[v1:.*]] = std.view %[[a1]][][{{%.*}}, {{%.*}}] : memref to memref +// CHECK-NOT: %[[l0:.*]] = subview %[[v1]][{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] : memref to memref +// CHECK-NOT: %[[a2:.*]] = alloc({{%.*}}) : memref +// CHECK-NOT: %[[v2:.*]] = std.view %[[a2]][][{{%.*}}, {{%.*}}] : memref to memref +// CHECK-NOT: %[[l0:.*]] = subview %[[v2]][{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] : memref to memref +// CHECK: linalg.copy(%[[s0]], %[[l0]]) : memref, memref +// CHECK-NOT: linalg.copy(%[[s1]], %[[l1]]) : memref, memref +// CHECK-NOT: linalg.copy(%[[s2]], %[[l2]]) : memref, memref^ +// CHECK: linalg.matmul(%[[v0]], %[[s1]], %[[s2]]) : memref, memref, memref diff --git a/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td b/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td --- a/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td +++ b/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td @@ -149,4 +149,12 @@ HasLinalgTransformMarker<"_promote_views_">]>> )]>; +def : Pat<(MatmulOp:$op $_, $_, $_), + (PromoteSelectedSubviewsLinalgOp<[0], "first_view_promotion">), + [(Constraint, + HasLinalgTransformMarker<"_promote_first_view_">]>> + )]>; + #endif // TEST_LINALG_TRANSFORMS_PATTERNS