diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -40,10 +40,6 @@ linalg::LinalgTilingLoopType loopType = linalg::LinalgTilingLoopType::Loops); -std::unique_ptr> -createLinalgPromotionPass(bool dynamicBuffers, bool useAlloca); -std::unique_ptr> createLinalgPromotionPass(); - std::unique_ptr> createLinalgInlineScalarOperandsPass(); @@ -101,14 +97,6 @@ const linalg::LinalgTransformationFilter &filter = linalg::LinalgTransformationFilter()); -/// Create a LinalgStrategyPromotePass. -std::unique_ptr> createLinalgStrategyPromotePass( - StringRef opName = "", - const linalg::LinalgPromotionOptions &opt = - linalg::LinalgPromotionOptions(), - const linalg::LinalgTransformationFilter &filter = - linalg::LinalgTransformationFilter()); - /// Create a LinalgStrategyGeneralizePass. std::unique_ptr> createLinalgStrategyGeneralizePass( StringRef opName = "", const linalg::LinalgTransformationFilter &filter = diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -112,18 +112,6 @@ ]; } -def LinalgPromotion : Pass<"linalg-promote-subviews", "func::FuncOp"> { - let summary = "Promote subview ops to local buffers"; - let constructor = "mlir::createLinalgPromotionPass()"; - let options = [ - Option<"dynamicBuffers", "test-promote-dynamic", "bool", - /*default=*/"false", "Test generation of dynamic promoted buffers">, - Option<"useAlloca", "test-use-alloca", "bool", - /*default=*/"false", "Test generation of alloca'ed buffers."> - ]; - let dependentDialects = ["linalg::LinalgDialect"]; -} - def LinalgTiling : Pass<"linalg-tile", "func::FuncOp"> { let summary = "Tile operations in the linalg dialect"; let constructor = "mlir::createLinalgTilingPass()"; @@ -222,19 +210,6 @@ ]; } -def LinalgStrategyPromotePass - : Pass<"linalg-strategy-promote-pass", "func::FuncOp"> { - let summary = "Configurable pass to apply pattern-based linalg promotion."; - let constructor = "mlir::createLinalgStrategyPromotePass()"; - let dependentDialects = ["linalg::LinalgDialect"]; - let options = [ - Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"", - "Which func op is the anchor to latch on.">, - Option<"anchorOpName", "anchor-op", "std::string", /*default=*/"", - "Which linalg op within the func is the anchor to latch on.">, - ]; -} - def LinalgStrategyGeneralizePass : Pass<"linalg-strategy-generalize-pass", "func::FuncOp"> { let summary = "Configurable pass to apply pattern-based generalization."; diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -191,7 +191,6 @@ }]; } - def PadOp : Op { @@ -232,6 +231,46 @@ }]; } +def PromoteOp : Op { + let description = [{ + Promotes the specified operands of the target into a separate memory buffer. + + At this point, this transform does not allow customizing alloc/dealloc + functions nor the behavior on copy in/out operations. + + #### Return modes + + This operation applies to a single Linalg op that satisfies the + `promoteSubviewsPrecondition`, otherwise it fails. + + If the operation referred to by the `target` PDLOperation promote + properly, the transform succeeds. + + When successful, the return handles point to the $target operation that was + modified inplace. + }]; + + let arguments = (ins PDL_Operation:$target, + DefaultValuedAttr:$operands_to_promote, + DefaultValuedAttr:$force_full_tiles, + UnitAttr:$use_full_tiles_by_default, + UnitAttr:$use_dynamic_buffers, + UnitAttr:$use_alloca, + OptionalAttr:$alignment); + let results = (outs PDL_Operation:$transformed); + + let assemblyFormat = "$target attr-dict"; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::linalg::LinalgOp target, + ::llvm::SmallVectorImpl<::mlir::Operation *> &results, + ::mlir::transform::TransformState &state); + }]; +} + def ScalarizeOp : Op { diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h @@ -81,23 +81,6 @@ linalg::LinalgPaddingOptions options; }; -/// Represent one application of createLinalgStrategyPromotePass. -struct Promote : public Transformation { - Promote(StringRef name, linalg::LinalgPromotionOptions options, - LinalgTransformationFilter::FilterFunction f = nullptr) - : Transformation(std::move(f)), opName(name), - options(std::move(options)) {} - - void addToPassPipeline(OpPassManager &pm, - LinalgTransformationFilter m) const override { - pm.addPass(createLinalgStrategyPromotePass(opName, options, m)); - } - -private: - std::string opName; - linalg::LinalgPromotionOptions options; -}; - /// Represent one application of createLinalgStrategyGeneralizePass. struct Generalize : public Transformation { explicit Generalize(StringRef name, @@ -253,22 +236,6 @@ LinalgTransformationFilter::FilterFunction f = nullptr) { return b ? pad(opName, std::move(options), std::move(f)) : *this; } - /// Append a pattern to add a level of promotion for `LinalgOpType` with - /// promotion `options`. - CodegenStrategy & - promote(StringRef opName, const linalg::LinalgPromotionOptions &options, - const LinalgTransformationFilter::FilterFunction &f = nullptr) { - transformationSequence.emplace_back( - std::make_unique(opName, options, f)); - return *this; - } - /// Conditionally append a pattern to add a level of promotion for - /// `LinalgOpType` with promotion `options`. - CodegenStrategy & - promoteIf(bool b, StringRef opName, linalg::LinalgPromotionOptions options, - LinalgTransformationFilter::FilterFunction f = nullptr) { - return b ? promote(opName, std::move(options), std::move(f)) : *this; - } /// Append a pattern to generalize named operations. CodegenStrategy & generalize(StringRef opName, diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -992,54 +992,6 @@ LinalgTransformationFilter filter; }; -/// -/// Linalg promotion patterns. -/// -/// Apply the `promoteSubViews` transformation as a pattern. -/// `filter` controls LinalgTransformMarker matching and update when specified. -/// See `promoteSubViews` for more details. -struct LinalgBasePromotionPattern : public RewritePattern { - /// Entry point to match any LinalgOp OpInterface. - /// MatchAnyOpTag-based constructor with a mandatory `filter`. - LinalgBasePromotionPattern( - MLIRContext *context, LinalgTransformationFilter f, - LinalgPromotionOptions options = LinalgPromotionOptions(), - PatternBenefit benefit = 1); - /// Entry point to match a specific Linalg op. - LinalgBasePromotionPattern( - StringRef opName, MLIRContext *context, LinalgPromotionOptions options, - LinalgTransformationFilter f = LinalgTransformationFilter(), - PatternBenefit benefit = 1); - - LogicalResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override; - -private: - /// LinalgTransformMarker handles special attribute manipulations. - LinalgTransformationFilter filter; - /// Promotion options. - LinalgPromotionOptions options; -}; - -template -struct LinalgPromotionPattern : public LinalgBasePromotionPattern { - /// SFINAE: This constructor can only trigger for concrete ops that have a - /// static `getOperationName` method. - template - LinalgPromotionPattern( - MLIRContext *context, LinalgPromotionOptions options, - LinalgTransformationFilter f = LinalgTransformationFilter(), - PatternBenefit benefit = 1) - : LinalgBasePromotionPattern(OpTy::getOperationName(), context, options, - f, benefit) {} - /// This constructor is available to anyone. - LinalgPromotionPattern( - StringRef opName, MLIRContext *context, LinalgPromotionOptions options, - LinalgTransformationFilter f = LinalgTransformationFilter(), - PatternBenefit benefit = 1) - : LinalgBasePromotionPattern(opName, context, options, f, benefit) {} -}; - /// /// Linalg peeling patterns. /// diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -423,6 +423,37 @@ return success(); } +//===----------------------------------------------------------------------===// +// PromoteOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::PromoteOp::applyToOne(linalg::LinalgOp target, + SmallVectorImpl &results, + transform::TransformState &state) { + auto promotionOptions = LinalgPromotionOptions() + .setOperandsToPromote(extractFromI64ArrayAttr(getOperandsToPromote())) + .setUseFullTileBuffersByDefault(getUseFullTilesByDefault()) + .setDynamicBuffers(getUseDynamicBuffers()) + .setUseAlloca(getUseAlloca()); + if (!getForceFullTiles().empty()) + promotionOptions = promotionOptions.setUseFullTileBuffers( + llvm::to_vector(getForceFullTiles().getAsValueRange())); + if (getAlignment().hasValue()) + promotionOptions = promotionOptions.setAlignment(*getAlignment()); + + if (failed(promoteSubviewsPrecondition(target, promotionOptions))) + return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); + + SimpleRewriter rewriter(target->getContext()); + rewriter.setInsertionPoint(target); + FailureOr res = promoteSubViews(rewriter, target, promotionOptions); + if (failed(res)) + return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); + results.push_back(target); + return DiagnosedSilenceableFailure(success()); +} + //===----------------------------------------------------------------------===// // ScalarizeOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp @@ -230,38 +230,6 @@ LinalgTransformationFilter filter; }; -/// Configurable pass to apply pattern-based linalg promotion. -struct LinalgStrategyPromotePass - : public LinalgStrategyPromotePassBase { - - LinalgStrategyPromotePass() = default; - - LinalgStrategyPromotePass(StringRef opName, LinalgPromotionOptions opt, - LinalgTransformationFilter filt) - : options(std::move(opt)), filter(std::move(filt)) { - this->anchorOpName.setValue(opName.str()); - } - - void runOnOperation() override { - auto funcOp = getOperation(); - if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) - return; - - RewritePatternSet promotionPattern(funcOp.getContext()); - if (!anchorOpName.empty()) { - promotionPattern.add( - anchorOpName, funcOp.getContext(), options, filter); - } else { - promotionPattern.add(funcOp.getContext(), - filter, options); - } - (void)applyPatternsAndFoldGreedily(funcOp, std::move(promotionPattern)); - } - - LinalgPromotionOptions options; - LinalgTransformationFilter filter; -}; - /// Configurable pass to apply pattern-based linalg peeling. struct LinalgStrategyPeelPass : public LinalgStrategyPeelPassBase { @@ -508,14 +476,6 @@ return std::make_unique(opName, opt, filter); } -/// Create a LinalgStrategyPromotePass. -std::unique_ptr> -mlir::createLinalgStrategyPromotePass( - StringRef opName, const LinalgPromotionOptions &opt, - const LinalgTransformationFilter &filter) { - return std::make_unique(opName, opt, filter); -} - /// Create a LinalgStrategyGeneralizePass. std::unique_ptr> mlir::createLinalgStrategyGeneralizePass( diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -394,36 +394,3 @@ return failure(); return res; } - -namespace { -struct LinalgPromotionPass : public LinalgPromotionBase { - LinalgPromotionPass() = default; - LinalgPromotionPass(bool dynamicBuffers, bool useAlloca) { - this->dynamicBuffers = dynamicBuffers; - this->useAlloca = useAlloca; - } - - void runOnOperation() override { - getOperation().walk([&](LinalgOp op) { - auto options = LinalgPromotionOptions() - .setDynamicBuffers(dynamicBuffers) - .setUseAlloca(useAlloca); - if (failed(promoteSubviewsPrecondition(op, options))) - return; - LLVM_DEBUG(llvm::dbgs() << "Promote: " << *(op.getOperation()) << "\n"); - ImplicitLocOpBuilder b(op.getLoc(), op); - // TODO: signalPassFailure() ? - (void)promoteSubViews(b, op, options); - }); - } -}; -} // namespace - -// TODO: support more transformation options in the pass. -std::unique_ptr> -mlir::createLinalgPromotionPass(bool dynamicBuffers, bool useAlloca) { - return std::make_unique(dynamicBuffers, useAlloca); -} -std::unique_ptr> mlir::createLinalgPromotionPass() { - return std::make_unique(); -} diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -691,40 +691,6 @@ return genericOp; } -mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( - MLIRContext *context, LinalgTransformationFilter f, - LinalgPromotionOptions options, PatternBenefit benefit) - : RewritePattern(MatchAnyOpTypeTag(), benefit, context), - filter(std::move(f)), options(std::move(options)) {} - -mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( - StringRef opName, MLIRContext *context, LinalgPromotionOptions options, - LinalgTransformationFilter f, PatternBenefit benefit) - : RewritePattern(opName, benefit, context, {}), filter(std::move(f)), - options(std::move(options)) {} - -LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite( - Operation *op, PatternRewriter &rewriter) const { - if (failed(filter.checkAndNotify(rewriter, op))) - return failure(); - if (failed(promoteSubviewsPrecondition(op, options))) - return failure(); - - // TODO: We cannot use root update here. This pattern is creating other ops, - // so if the promotion fails, those need to be cleaned up, which doesnt seem - // to be happening here. So to fail properly, we should be cloning the op and - // deleting the previous op. This needs more investigation. - rewriter.startRootUpdate(op); - Optional promotedOp = promoteSubViews(rewriter, op, options); - if (!promotedOp) { - rewriter.cancelRootUpdate(op); - return op->emitError("subview promotion failed"); - } - rewriter.finalizeRootUpdate(op); - filter.replaceLinalgTransformationFilter(rewriter, op); - return success(); -} - mlir::linalg::LinalgPeelingPattern::LinalgPeelingPattern( MLIRContext *context, LinalgTransformationFilter f, LinalgPeelOptions options, PatternBenefit benefit) diff --git a/mlir/test/Dialect/Linalg/promotion_options.mlir b/mlir/test/Dialect/Linalg/promotion_options.mlir --- a/mlir/test/Dialect/Linalg/promotion_options.mlir +++ b/mlir/test/Dialect/Linalg/promotion_options.mlir @@ -1,10 +1,9 @@ -// RUN: mlir-opt %s -test-linalg-transform-patterns=test-linalg-promotion-options -split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-transform-dialect-interpreter -canonicalize -split-input-file | FileCheck %s func.func @gemm(%a : memref, %b : memref, %c : memref) { - linalg.matmul {__internal_linalg_transform__ = "START"} - ins(%a, %b: memref, memref) - outs(%c: memref) + linalg.matmul ins(%a, %b: memref, memref) + outs(%c: memref) return } @@ -12,23 +11,39 @@ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref -// CHECK-DAG: %[[C42:.+]] = arith.constant 4.200000e+01 : f32 +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index // CHECK: scf.for // CHECK: scf.for // CHECK: scf.for // CHECK: %[[T7:.+]] = memref.subview %[[ARG0]] // CHECK: %[[T12:.+]] = memref.subview %[[ARG1]] // CHECK: %[[T17:.+]] = memref.subview %[[ARG2]] -// CHECK: %[[T18:.+]] = memref.alloc(%{{.*}}, %{{.*}}) : memref -// CHECK: %[[T19:.+]] = memref.subview %[[T18]] -// CHECK: %[[T20:.+]] = memref.alloc(%{{.*}}, %{{.*}}) : memref -// CHECK: %[[T21:.+]] = memref.subview %[[T20]] -// CHECK: linalg.fill ins(%[[C42]]{{.*}}outs(%[[T19]] +// CHECK: %[[A0:.*]] = memref.alloc() : memref<1024xi8> +// CHECK: %[[V0:.*]] = memref.view %[[A0]][%[[C0]]][] : memref<1024xi8> to memref<16x16xf32> +// CHECK: %[[T19:.+]] = memref.subview %[[V0]] +// CHECK: %[[A1:.*]] = memref.alloc() : memref<1024xi8> +// CHECK: %[[V1:.*]] = memref.view %[[A1]][%[[C0]]][] : memref<1024xi8> to memref<16x16xf32> +// CHECK: %[[T21:.+]] = memref.subview %[[V1]] // CHECK: memref.copy %[[T7]], %[[T19]] -// CHECK: linalg.fill ins(%[[C42]]{{.*}}outs(%[[T21]] // CHECK: memref.copy %[[T17]], %[[T21]] // CHECK: linalg.matmul ins(%[[T19]], %[[T12]]{{.*}} outs(%[[T21]] -// CHECK-NOT: linalg.fill // CHECK: memref.copy %[[T21]], %[[T17]] -// CHECK: memref.dealloc %[[T18]] -// CHECK: memref.dealloc %[[T20]] +// CHECK: memref.dealloc %[[A0]] +// CHECK: memref.dealloc %[[A1]] + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + sequence %arg0 { + ^bb0(%arg1: !pdl.operation): + %0 = pdl_match @pdl_target in %arg1 + %1, %loops:3 = transform.structured.tile %0 [16, 16, 16] + %2 = transform.structured.promote %1 { operands_to_promote = [0, 2], force_full_tiles = [false, false] } + } + + pdl.pattern @pdl_target : benefit(1) { + %args = operands + %results = types + %0 = operation "linalg.matmul"(%args : !pdl.range) -> (%results : !pdl.range) + rewrite %0 with "transform.dialect" + } +} diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -74,9 +74,6 @@ "Test a fused pass that applies patterns from matmul to vectors via " "2-d tiling"), llvm::cl::init(false)}; - Option testPromotionOptions{*this, "test-linalg-promotion-options", - llvm::cl::desc("Test promotion options"), - llvm::cl::init(false)}; Option testTileAndDistributionOptions{ *this, "test-tile-and-distribute-options", llvm::cl::desc("Test tile and distribute options"), @@ -257,27 +254,27 @@ //===--------------------------------------------------------------------===// // Linalg subview operands promotion. //===--------------------------------------------------------------------===// - patterns.add>( - ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), - LinalgTransformationFilter(StringAttr::get(ctx, "_promote_views_"), - StringAttr::get(ctx, "_views_promoted_"))); - patterns.add>( - ctx, - LinalgPromotionOptions() - .setOperandsToPromote({0}) - .setUseFullTileBuffersByDefault(true), - LinalgTransformationFilter( - StringAttr::get(ctx, "_promote_first_view_"), - StringAttr::get(ctx, "_first_view_promoted_"))); - patterns.add>( - ctx, - LinalgPromotionOptions() - .setOperandsToPromote({1}) - .setUseFullTileBuffers({false, true}) - .setAlignment(32), - LinalgTransformationFilter( - StringAttr::get(ctx, "_promote_views_aligned_"), - StringAttr::get(ctx, "_views_aligned_promoted_"))); + // patterns.add>( + // ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), + // LinalgTransformationFilter(StringAttr::get(ctx, "_promote_views_"), + // StringAttr::get(ctx, "_views_promoted_"))); + // patterns.add>( + // ctx, + // LinalgPromotionOptions() + // .setOperandsToPromote({0}) + // .setUseFullTileBuffersByDefault(true), + // LinalgTransformationFilter( + // StringAttr::get(ctx, "_promote_first_view_"), + // StringAttr::get(ctx, "_first_view_promoted_"))); + // patterns.add>( + // ctx, + // LinalgPromotionOptions() + // .setOperandsToPromote({1}) + // .setUseFullTileBuffers({false, true}) + // .setAlignment(32), + // LinalgTransformationFilter( + // StringAttr::get(ctx, "_promote_views_aligned_"), + // StringAttr::get(ctx, "_views_aligned_promoted_"))); (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); @@ -300,12 +297,12 @@ LinalgTransformationFilter(StringAttr::get(ctx, startMarker), StringAttr::get(ctx, "L1")))); - patternsVector.emplace_back( - ctx, - std::make_unique>( - ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), - LinalgTransformationFilter(StringAttr::get(ctx, "L1"), - StringAttr::get(ctx, "VEC")))); + // patternsVector.emplace_back( + // ctx, + // std::make_unique>( + // ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), + // LinalgTransformationFilter(StringAttr::get(ctx, "L1"), + // StringAttr::get(ctx, "VEC")))); patternsVector.emplace_back( ctx, std::make_unique( @@ -316,68 +313,6 @@ patternsVector.back().add(ctx); } -//===----------------------------------------------------------------------===// -// Test promotion callbacks -//===----------------------------------------------------------------------===// - -// Allocation call back -static Optional allocCallBackFn(OpBuilder &b, memref::SubViewOp subView, - ArrayRef boundingSubViewSize, - DataLayout &layout) { - SmallVector shape(boundingSubViewSize.size(), -1); - return b - .create( - subView.getLoc(), - MemRefType::get(shape, subView.getType().getElementType(), - /*affineMapComposition =*/{}, 3), - boundingSubViewSize) - .getResult(); -} - -// Deallocation callback -static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) { - b.create(buffer.getLoc(), buffer); - return success(); -} - -// Copy in call back -static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst, - bool isOutput) { - auto floatType = src.getType().cast().getElementType(); - if (!floatType.isa()) - return failure(); - if (!isOutput) { - Value cst = b.create(src.getLoc(), - FloatAttr::get(floatType, 42.0)); - b.create(src.getLoc(), cst, dst); - } - b.create(src.getLoc(), src, dst); - return success(); -} - -static void fillPromotionCallBackPatterns(MLIRContext *ctx, - RewritePatternSet &patterns) { - patterns.add( - MatmulOp::getOperationName(), ctx, - LinalgTilingOptions().setTileSizes({16, 16, 16}), - LinalgTransformationFilter(StringAttr::get(ctx, "START"), - StringAttr::get(ctx, "PROMOTE"))); - patterns.add>( - ctx, - LinalgPromotionOptions() - .setOperandsToPromote({0, 2}) - .setUseFullTileBuffers({false, false}) - .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn) - .setCopyInOutFns( - [](OpBuilder &b, Value src, Value dst) -> LogicalResult { - return copyCallBackFn(b, src, dst, false); - }, - [](OpBuilder &b, Value src, Value dst) -> LogicalResult { - return copyCallBackFn(b, src, dst, true); - }), - LinalgTransformationFilter(StringAttr::get(ctx, "PROMOTE"))); -} - template static SmallVector getGpuProcIds(OpBuilder &b, Location loc, ArrayRef parallelLoopRanges) { @@ -657,12 +592,6 @@ }; std::unique_ptr cleanupGuard{(void *)1, lambda}; - if (testPromotionOptions) { - RewritePatternSet patterns(&getContext()); - fillPromotionCallBackPatterns(&getContext(), patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); - return; - } if (testTileAndDistributionOptions) { RewritePatternSet patterns(&getContext()); fillTileAndDistributePatterns(&getContext(), patterns);