diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -11,6 +11,7 @@ #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/IR/PatternMatch.h" +#include "llvm/ADT/SmallBitVector.h" namespace mlir { namespace linalg { @@ -97,6 +98,28 @@ operandsToPromote->insert(operands.begin(), operands.end()); return *this; } + /// If ith element of `useFullTiles` is true the full view should be used for + /// the promoted buffer of the ith operand in `operandsToPromote`. Otherwise + /// the partial view will be used. + /// The decision is defaulted to `useFullTileBuffersDefault` when + /// `useFullTileBuffers` is None and for operands missing from + /// `useFullTileBuffers`. + Optional useFullTileBuffers = None; + LinalgPromotionOptions &setUseFullTileBuffers(ArrayRef useFullTiles) { + unsigned size = useFullTiles.size(); + llvm::SmallBitVector tmp(size, false); + for (unsigned i = 0; i < size; ++i) + tmp[i] = useFullTiles[i]; + useFullTileBuffers = tmp; + return *this; + } + /// If true all operands unspecified by `useFullTileBuffers` will use the full + /// view, otherwise the partial view. + bool useFullTileBuffersDefault = false; + LinalgPromotionOptions &useFullTileBuffersByDefault() { + useFullTileBuffersDefault = true; + return *this; + } /// Allow the use of dynamicaly-sized buffers. bool dynamicBuffers = false; LinalgPromotionOptions &setDynamicBuffers(unsigned dynamic) { diff --git a/mlir/include/mlir/Dialect/Vector/EDSC/Intrinsics.h b/mlir/include/mlir/Dialect/Vector/EDSC/Intrinsics.h --- a/mlir/include/mlir/Dialect/Vector/EDSC/Intrinsics.h +++ b/mlir/include/mlir/Dialect/Vector/EDSC/Intrinsics.h @@ -16,6 +16,9 @@ using vector_broadcast = ValueBuilder; using vector_contract = ValueBuilder; +using vector_insert = ValueBuilder; +using vector_fma = ValueBuilder; +using vector_extract = ValueBuilder; using vector_matmul = ValueBuilder; using vector_print = OperationBuilder; using vector_transfer_read = ValueBuilder; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -56,6 +56,8 @@ const LinalgPromotionOptions &options); /// SubViews to promote. SetVector subViews; + /// True if the full view should be used for the promoted buffer. + DenseMap useFullTileBuffers; /// Allow the use of dynamicaly-sized buffers. bool dynamicBuffers; /// Alignment of promoted buffer. @@ -65,20 +67,28 @@ LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions( LinalgOp linalgOp, const LinalgPromotionOptions &options) - : subViews(), dynamicBuffers(options.dynamicBuffers), + : subViews(), useFullTileBuffers(), dynamicBuffers(options.dynamicBuffers), alignment(options.alignment) { + unsigned nBuffers = linalgOp.getNumInputsAndOutputBuffers(); + auto vUseFullTileBuffers = + options.useFullTileBuffers.getValueOr(llvm::SmallBitVector()); + vUseFullTileBuffers.resize(nBuffers, options.useFullTileBuffersDefault); + if (options.operandsToPromote.hasValue()) { - for (unsigned idx : options.operandsToPromote.getValue()) { - auto *op = linalgOp.getBuffer(idx).getDefiningOp(); - if (auto sv = dyn_cast_or_null(op)) + for (auto it : llvm::enumerate(options.operandsToPromote.getValue())) { + auto *op = linalgOp.getBuffer(it.value()).getDefiningOp(); + if (auto sv = dyn_cast_or_null(op)) { subViews.insert(sv); + useFullTileBuffers[sv] = vUseFullTileBuffers[it.index()]; + } } } else { - unsigned nBuffers = linalgOp.getNumInputsAndOutputBuffers(); for (unsigned idx = 0; idx < nBuffers; ++idx) { auto *op = linalgOp.getBuffer(idx).getDefiningOp(); - if (auto sv = dyn_cast_or_null(op)) + if (auto sv = dyn_cast_or_null(op)) { subViews.insert(sv); + useFullTileBuffers[sv] = vUseFullTileBuffers[idx]; + } } } } @@ -201,6 +211,9 @@ auto info = promotionInfoMap.find(v); if (info == promotionInfoMap.end()) continue; + // Only fill the buffer if the full local view is used + if (!options.useFullTileBuffers[v]) + continue; Value fillVal; if (auto t = subView.getType().getElementType().dyn_cast()) fillVal = folded_std_constant(folder, FloatAttr::get(t, 0.0)); @@ -244,7 +257,10 @@ unsigned promotedIdx = 0; for (auto view : op.getInputsAndOutputBuffers()) { if (options.subViews.count(view) != 0) { - opViews.push_back(promotedBufferAndViews[promotedIdx].fullLocalView); + if (options.useFullTileBuffers[view]) + opViews.push_back(promotedBufferAndViews[promotedIdx].fullLocalView); + else + opViews.push_back(promotedBufferAndViews[promotedIdx].partialLocalView); writebackViews.emplace_back(std::make_pair( view, promotedBufferAndViews[promotedIdx].partialLocalView)); promotedIdx++; diff --git a/mlir/test/Dialect/Linalg/promote.mlir b/mlir/test/Dialect/Linalg/promote.mlir --- a/mlir/test/Dialect/Linalg/promote.mlir +++ b/mlir/test/Dialect/Linalg/promote.mlir @@ -56,14 +56,11 @@ // DYNAMIC: std.view %{{.*}}[{{.*}}][{{.*}}] : memref to memref // CHECK: %[[partialC:.*]] = subview %[[fullC]]{{.*}} : memref to memref -// CHECK: linalg.fill(%[[fullA]], {{.*}}) : memref, f32 -// CHECK: linalg.fill(%[[fullB]], {{.*}}) : memref, f32 -// CHECK: linalg.fill(%[[fullC]], {{.*}}) : memref, f32 // CHECK: linalg.copy(%[[vA]], %[[partialA]]) : memref, memref // CHECK: linalg.copy(%[[vB]], %[[partialB]]) : memref, memref // CHECK: linalg.copy(%[[vC]], %[[partialC]]) : memref, memref // -// CHECK: linalg.matmul(%[[fullA]], %[[fullB]], %[[fullC]]) : memref, memref, memref +// CHECK: linalg.matmul(%[[partialA]], %[[partialB]], %[[partialC]]) : memref, memref, memref // // CHECK: linalg.copy(%[[partialC]], %[[vC]]) : memref, memref // @@ -121,14 +118,11 @@ // DYNAMIC: std.view %{{.*}}[{{.*}}][{{.*}}] : memref to memref // CHECK: %[[partialC_f64:.*]] = subview %[[fullC_f64]][%{{.*}}, %{{.*}}] : memref to memref -// CHECK: linalg.fill(%[[fullA_f64]], {{.*}}) : memref, f64 -// CHECK: linalg.fill(%[[fullB_f64]], {{.*}}) : memref, f64 -// CHECK: linalg.fill(%[[fullC_f64]], {{.*}}) : memref, f64 // CHECK: linalg.copy(%[[vA_f64]], %[[partialA_f64]]) : memref, memref // CHECK: linalg.copy(%[[vB_f64]], %[[partialB_f64]]) : memref, memref // CHECK: linalg.copy(%[[vC_f64]], %[[partialC_f64]]) : memref, memref // -// CHECK: linalg.matmul(%[[fullA_f64]], %[[fullB_f64]], %[[fullC_f64]]) : memref, memref, memref +// CHECK: linalg.matmul(%[[partialA_f64]], %[[partialB_f64]], %[[partialC_f64]]) : memref, memref, memref // // CHECK: linalg.copy(%[[partialC_f64]], %[[vC_f64]]) : memref, memref // @@ -186,14 +180,11 @@ // DYNAMIC: std.view %{{.*}}[{{.*}}][{{.*}}] : memref to memref // CHECK: %[[partialC_i32:.*]] = subview %[[fullC_i32]][%{{.*}}, %{{.*}}] : memref to memref -// CHECK: linalg.fill(%[[fullA_i32]], {{.*}}) : memref, i32 -// CHECK: linalg.fill(%[[fullB_i32]], {{.*}}) : memref, i32 -// CHECK: linalg.fill(%[[fullC_i32]], {{.*}}) : memref, i32 // CHECK: linalg.copy(%[[vA_i32]], %[[partialA_i32]]) : memref, memref // CHECK: linalg.copy(%[[vB_i32]], %[[partialB_i32]]) : memref, memref // CHECK: linalg.copy(%[[vC_i32]], %[[partialC_i32]]) : memref, memref // -// CHECK: linalg.matmul(%[[fullA_i32]], %[[fullB_i32]], %[[fullC_i32]]) : memref, memref, memref +// CHECK: linalg.matmul(%[[partialA_i32]], %[[partialB_i32]], %[[partialC_i32]]) : memref, memref, memref // // CHECK: linalg.copy(%[[partialC_i32]], %[[vC_i32]]) : memref, memref // diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp --- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp @@ -132,13 +132,20 @@ // Linalg subview operands promotion. //===--------------------------------------------------------------------===// patterns.insert>( - ctx, LinalgPromotionOptions(), + ctx, LinalgPromotionOptions().useFullTileBuffersByDefault(), LinalgMarker({"_promote_views_"}, "_views_promoted_")); patterns.insert>( - ctx, LinalgPromotionOptions().setOperandsToPromote({0}), + ctx, + LinalgPromotionOptions() + .setOperandsToPromote({0}) + .useFullTileBuffersByDefault(), LinalgMarker({"_promote_first_view_"}, "_first_view_promoted_")); patterns.insert>( - ctx, LinalgPromotionOptions().setOperandsToPromote({0}).setAlignment(32), + ctx, + LinalgPromotionOptions() + .setOperandsToPromote({0}) + .setUseFullTileBuffers({true}) + .setAlignment(32), LinalgMarker({"_promote_views_aligned_"}, "_views_aligned_promoted_")); applyPatternsAndFoldGreedily(funcOp, patterns); @@ -171,7 +178,8 @@ LinalgMarker({startMarker}, "L1"))); patternsVector.emplace_back(LinalgPromotionPattern( - context, LinalgPromotionOptions(), LinalgMarker({"L1"}, "VEC"))); + context, LinalgPromotionOptions().useFullTileBuffersByDefault(), + LinalgMarker({"L1"}, "VEC"))); patternsVector.emplace_back( LinalgVectorizationPattern(context, LinalgMarker({"VEC"})));