diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -97,6 +97,18 @@ operandsToPromote->insert(operands.begin(), operands.end()); return *this; } + // If true the full view should be used for the promoted buffer. If false, use + // the partial view. + Optional> useFullTileBuffers = None; + LinalgPromotionOptions &setUseFullTileBuffers(ArrayRef uses) { + useFullTileBuffers = SmallVector(uses.begin(), uses.end()); + return *this; + } + /// Use the full tile for all operands. + LinalgPromotionOptions &setUseFullTileBuffers(bool use) { + useFullTileBuffers = SmallVector(1, use); + return *this; + } /// Allow the use of dynamicaly-sized buffers. bool dynamicBuffers = false; LinalgPromotionOptions &setDynamicBuffers(unsigned dynamic) { diff --git a/mlir/include/mlir/Dialect/Vector/EDSC/Intrinsics.h b/mlir/include/mlir/Dialect/Vector/EDSC/Intrinsics.h --- a/mlir/include/mlir/Dialect/Vector/EDSC/Intrinsics.h +++ b/mlir/include/mlir/Dialect/Vector/EDSC/Intrinsics.h @@ -16,6 +16,9 @@ using vector_broadcast = ValueBuilder; using vector_contract = ValueBuilder; +using vector_insert = ValueBuilder; +using vector_fma = ValueBuilder; +using vector_extract = ValueBuilder; using vector_matmul = ValueBuilder; using vector_print = OperationBuilder; using vector_transfer_read = ValueBuilder; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -56,6 +56,8 @@ const LinalgPromotionOptions &options); /// SubViews to promote. SetVector subViews; + /// True if the full view should be used for the promoted buffer. + DenseMap useFullTileBuffers; /// Allow the use of dynamicaly-sized buffers. bool dynamicBuffers; /// Alignment of promoted buffer. @@ -65,20 +67,31 @@ LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions( LinalgOp linalgOp, const LinalgPromotionOptions &options) - : subViews(), dynamicBuffers(options.dynamicBuffers), + : subViews(), useFullTileBuffers(), dynamicBuffers(options.dynamicBuffers), alignment(options.alignment) { + unsigned nBuffers = linalgOp.getNumInputsAndOutputBuffers(); + SmallVector vUseFullTileBuffers(nBuffers, false); + if (options.useFullTileBuffers.hasValue()) { + auto v = options.useFullTileBuffers.getValue(); + for (unsigned i = 0; i < nBuffers && i < v.size(); ++i) + vUseFullTileBuffers[i] = v[i]; + } + if (options.operandsToPromote.hasValue()) { for (unsigned idx : options.operandsToPromote.getValue()) { auto *op = linalgOp.getBuffer(idx).getDefiningOp(); - if (auto sv = dyn_cast_or_null(op)) + if (auto sv = dyn_cast_or_null(op)) { subViews.insert(sv); + this->useFullTileBuffers[sv] = vUseFullTileBuffers[idx]; + } } } else { - unsigned nBuffers = linalgOp.getNumInputsAndOutputBuffers(); for (unsigned idx = 0; idx < nBuffers; ++idx) { auto *op = linalgOp.getBuffer(idx).getDefiningOp(); - if (auto sv = dyn_cast_or_null(op)) + if (auto sv = dyn_cast_or_null(op)) { subViews.insert(sv); + this->useFullTileBuffers[sv] = vUseFullTileBuffers[idx]; + } } } } @@ -244,7 +257,10 @@ unsigned promotedIdx = 0; for (auto view : op.getInputsAndOutputBuffers()) { if (options.subViews.count(view) != 0) { - opViews.push_back(promotedBufferAndViews[promotedIdx].fullLocalView); + if (options.useFullTileBuffers[view]) + opViews.push_back(promotedBufferAndViews[promotedIdx].fullLocalView); + else + opViews.push_back(promotedBufferAndViews[promotedIdx].partialLocalView); writebackViews.emplace_back(std::make_pair( view, promotedBufferAndViews[promotedIdx].partialLocalView)); promotedIdx++; diff --git a/mlir/test/Dialect/Linalg/promote.mlir b/mlir/test/Dialect/Linalg/promote.mlir --- a/mlir/test/Dialect/Linalg/promote.mlir +++ b/mlir/test/Dialect/Linalg/promote.mlir @@ -63,7 +63,7 @@ // CHECK: linalg.copy(%[[vB]], %[[partialB]]) : memref, memref // CHECK: linalg.copy(%[[vC]], %[[partialC]]) : memref, memref // -// CHECK: linalg.matmul(%[[fullA]], %[[fullB]], %[[fullC]]) : memref, memref, memref +// CHECK: linalg.matmul(%[[partialA]], %[[partialB]], %[[partialC]]) : memref, memref, memref // // CHECK: linalg.copy(%[[partialC]], %[[vC]]) : memref, memref // @@ -128,7 +128,7 @@ // CHECK: linalg.copy(%[[vB_f64]], %[[partialB_f64]]) : memref, memref // CHECK: linalg.copy(%[[vC_f64]], %[[partialC_f64]]) : memref, memref // -// CHECK: linalg.matmul(%[[fullA_f64]], %[[fullB_f64]], %[[fullC_f64]]) : memref, memref, memref +// CHECK: linalg.matmul(%[[partialA_f64]], %[[partialB_f64]], %[[partialC_f64]]) : memref, memref, memref // // CHECK: linalg.copy(%[[partialC_f64]], %[[vC_f64]]) : memref, memref // @@ -193,7 +193,7 @@ // CHECK: linalg.copy(%[[vB_i32]], %[[partialB_i32]]) : memref, memref // CHECK: linalg.copy(%[[vC_i32]], %[[partialC_i32]]) : memref, memref // -// CHECK: linalg.matmul(%[[fullA_i32]], %[[fullB_i32]], %[[fullC_i32]]) : memref, memref, memref +// CHECK: linalg.matmul(%[[partialA_i32]], %[[partialB_i32]], %[[partialC_i32]]) : memref, memref, memref // // CHECK: linalg.copy(%[[partialC_i32]], %[[vC_i32]]) : memref, memref // diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp --- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp @@ -132,13 +132,19 @@ // Linalg subview operands promotion. //===--------------------------------------------------------------------===// patterns.insert>( - ctx, LinalgPromotionOptions(), + ctx, LinalgPromotionOptions().setUseFullTileBuffers({true, true, true}), LinalgMarker({"_promote_views_"}, "_views_promoted_")); patterns.insert>( - ctx, LinalgPromotionOptions().setOperandsToPromote({0}), + ctx, + LinalgPromotionOptions().setOperandsToPromote({0}).setUseFullTileBuffers( + true), LinalgMarker({"_promote_first_view_"}, "_first_view_promoted_")); patterns.insert>( - ctx, LinalgPromotionOptions().setOperandsToPromote({0}).setAlignment(32), + ctx, + LinalgPromotionOptions() + .setOperandsToPromote({0}) + .setUseFullTileBuffers(true) + .setAlignment(32), LinalgMarker({"_promote_views_aligned_"}, "_views_aligned_promoted_")); applyPatternsAndFoldGreedily(funcOp, patterns); @@ -171,7 +177,8 @@ LinalgMarker({startMarker}, "L1"))); patternsVector.emplace_back(LinalgPromotionPattern( - context, LinalgPromotionOptions(), LinalgMarker({"L1"}, "VEC"))); + context, LinalgPromotionOptions().setUseFullTileBuffers(true), + LinalgMarker({"L1"}, "VEC"))); patternsVector.emplace_back( LinalgVectorizationPattern(context, LinalgMarker({"VEC"})));