diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -40,10 +40,6 @@ linalg::LinalgTilingLoopType loopType = linalg::LinalgTilingLoopType::Loops); -std::unique_ptr> -createLinalgPromotionPass(bool dynamicBuffers, bool useAlloca); -std::unique_ptr> createLinalgPromotionPass(); - std::unique_ptr> createLinalgInlineScalarOperandsPass(); @@ -101,14 +97,6 @@ const linalg::LinalgTransformationFilter &filter = linalg::LinalgTransformationFilter()); -/// Create a LinalgStrategyPromotePass. -std::unique_ptr> createLinalgStrategyPromotePass( - StringRef opName = "", - const linalg::LinalgPromotionOptions &opt = - linalg::LinalgPromotionOptions(), - const linalg::LinalgTransformationFilter &filter = - linalg::LinalgTransformationFilter()); - /// Create a LinalgStrategyGeneralizePass. std::unique_ptr> createLinalgStrategyGeneralizePass( StringRef opName = "", const linalg::LinalgTransformationFilter &filter = diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -112,18 +112,6 @@ ]; } -def LinalgPromotion : Pass<"linalg-promote-subviews", "func::FuncOp"> { - let summary = "Promote subview ops to local buffers"; - let constructor = "mlir::createLinalgPromotionPass()"; - let options = [ - Option<"dynamicBuffers", "test-promote-dynamic", "bool", - /*default=*/"false", "Test generation of dynamic promoted buffers">, - Option<"useAlloca", "test-use-alloca", "bool", - /*default=*/"false", "Test generation of alloca'ed buffers."> - ]; - let dependentDialects = ["linalg::LinalgDialect"]; -} - def LinalgTiling : Pass<"linalg-tile", "func::FuncOp"> { let summary = "Tile operations in the linalg dialect"; let constructor = "mlir::createLinalgTilingPass()"; @@ -222,19 +210,6 @@ ]; } -def LinalgStrategyPromotePass - : Pass<"linalg-strategy-promote-pass", "func::FuncOp"> { - let summary = "Configurable pass to apply pattern-based linalg promotion."; - let constructor = "mlir::createLinalgStrategyPromotePass()"; - let dependentDialects = ["linalg::LinalgDialect"]; - let options = [ - Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"", - "Which func op is the anchor to latch on.">, - Option<"anchorOpName", "anchor-op", "std::string", /*default=*/"", - "Which linalg op within the func is the anchor to latch on.">, - ]; -} - def LinalgStrategyGeneralizePass : Pass<"linalg-strategy-generalize-pass", "func::FuncOp"> { let summary = "Configurable pass to apply pattern-based generalization."; diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -191,7 +191,6 @@ }]; } - def PadOp : Op { @@ -232,6 +231,45 @@ }]; } +def PromoteOp : Op { + let description = [{ + Promotes the specified operands of the target into a separate memory buffer. + + At this point, this transform does not allow customizing alloc/dealloc + functions nor the behavior on copy in/out operations. + + #### Return modes + + This operation applies to a single Linalg op that satisfies the + `promoteSubviewsPrecondition`, otherwise it fails. + + If the operations referred to by the `target` PDLOperation promote + properly, the transform succeeds. + + When successful, the return handle points to the $target operation that + was modified inplace. + }]; + + let arguments = (ins PDL_Operation:$target, + DefaultValuedAttr:$operands_to_promote, + DefaultValuedAttr:$use_full_tile_buffers, + UnitAttr:$use_full_tiles_by_default, + UnitAttr:$use_alloca, + OptionalAttr:$alignment); + let results = (outs PDL_Operation:$transformed); + + let assemblyFormat = "$target attr-dict"; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::linalg::LinalgOp target, + ::llvm::SmallVectorImpl<::mlir::Operation *> &results, + ::mlir::transform::TransformState &state); + }]; +} + def ScalarizeOp : Op { diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h @@ -81,23 +81,6 @@ linalg::LinalgPaddingOptions options; }; -/// Represent one application of createLinalgStrategyPromotePass. -struct Promote : public Transformation { - Promote(StringRef name, linalg::LinalgPromotionOptions options, - LinalgTransformationFilter::FilterFunction f = nullptr) - : Transformation(std::move(f)), opName(name), - options(std::move(options)) {} - - void addToPassPipeline(OpPassManager &pm, - LinalgTransformationFilter m) const override { - pm.addPass(createLinalgStrategyPromotePass(opName, options, m)); - } - -private: - std::string opName; - linalg::LinalgPromotionOptions options; -}; - /// Represent one application of createLinalgStrategyGeneralizePass. struct Generalize : public Transformation { explicit Generalize(StringRef name, @@ -253,22 +236,6 @@ LinalgTransformationFilter::FilterFunction f = nullptr) { return b ? pad(opName, std::move(options), std::move(f)) : *this; } - /// Append a pattern to add a level of promotion for `LinalgOpType` with - /// promotion `options`. - CodegenStrategy & - promote(StringRef opName, const linalg::LinalgPromotionOptions &options, - const LinalgTransformationFilter::FilterFunction &f = nullptr) { - transformationSequence.emplace_back( - std::make_unique(opName, options, f)); - return *this; - } - /// Conditionally append a pattern to add a level of promotion for - /// `LinalgOpType` with promotion `options`. - CodegenStrategy & - promoteIf(bool b, StringRef opName, linalg::LinalgPromotionOptions options, - LinalgTransformationFilter::FilterFunction f = nullptr) { - return b ? promote(opName, std::move(options), std::move(f)) : *this; - } /// Append a pattern to generalize named operations. CodegenStrategy & generalize(StringRef opName, diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -302,12 +302,6 @@ useFullTileBuffersDefault = use; return *this; } - /// Allow the use of dynamically-sized buffers. - bool dynamicBuffers = false; - LinalgPromotionOptions &setDynamicBuffers(unsigned dynamic) { - dynamicBuffers = dynamic; - return *this; - } /// Alignment of promoted buffer. If `None` do not specify alignment. Optional alignment = None; LinalgPromotionOptions &setAlignment(unsigned align) { @@ -363,8 +357,6 @@ /// 1. Create a new buffer for a full tile (i.e. not clipped at the boundary). /// 2. Take a full view on the buffer. /// 3. Take a partial slice of the full view in step 2. and copy into it. -/// Infers statically sized buffers from subViews unless `dynamicBuffers` is -/// true. /// /// Return the modified linalg op (the modification happens in place) as well /// as all the copy ops created. @@ -992,54 +984,6 @@ LinalgTransformationFilter filter; }; -/// -/// Linalg promotion patterns. -/// -/// Apply the `promoteSubViews` transformation as a pattern. -/// `filter` controls LinalgTransformMarker matching and update when specified. -/// See `promoteSubViews` for more details. -struct LinalgBasePromotionPattern : public RewritePattern { - /// Entry point to match any LinalgOp OpInterface. - /// MatchAnyOpTag-based constructor with a mandatory `filter`. - LinalgBasePromotionPattern( - MLIRContext *context, LinalgTransformationFilter f, - LinalgPromotionOptions options = LinalgPromotionOptions(), - PatternBenefit benefit = 1); - /// Entry point to match a specific Linalg op. - LinalgBasePromotionPattern( - StringRef opName, MLIRContext *context, LinalgPromotionOptions options, - LinalgTransformationFilter f = LinalgTransformationFilter(), - PatternBenefit benefit = 1); - - LogicalResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override; - -private: - /// LinalgTransformMarker handles special attribute manipulations. - LinalgTransformationFilter filter; - /// Promotion options. - LinalgPromotionOptions options; -}; - -template -struct LinalgPromotionPattern : public LinalgBasePromotionPattern { - /// SFINAE: This constructor can only trigger for concrete ops that have a - /// static `getOperationName` method. - template - LinalgPromotionPattern( - MLIRContext *context, LinalgPromotionOptions options, - LinalgTransformationFilter f = LinalgTransformationFilter(), - PatternBenefit benefit = 1) - : LinalgBasePromotionPattern(OpTy::getOperationName(), context, options, - f, benefit) {} - /// This constructor is available to anyone. - LinalgPromotionPattern( - StringRef opName, MLIRContext *context, LinalgPromotionOptions options, - LinalgTransformationFilter f = LinalgTransformationFilter(), - PatternBenefit benefit = 1) - : LinalgBasePromotionPattern(opName, context, options, f, benefit) {} -}; - /// /// Linalg peeling patterns. /// diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -423,6 +423,41 @@ return success(); } +//===----------------------------------------------------------------------===// +// PromoteOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::PromoteOp::applyToOne(linalg::LinalgOp target, + SmallVectorImpl &results, + transform::TransformState &state) { + LinalgPromotionOptions promotionOptions; + if (!getOperandsToPromote().empty()) + promotionOptions = promotionOptions.setOperandsToPromote( + extractFromI64ArrayAttr(getOperandsToPromote())); + if (getUseFullTilesByDefault()) + promotionOptions = promotionOptions.setUseFullTileBuffersByDefault( + getUseFullTilesByDefault()); + if (getUseAlloca()) + promotionOptions = promotionOptions.setUseAlloca(getUseAlloca()); + if (!getUseFullTileBuffers().empty()) + promotionOptions = promotionOptions.setUseFullTileBuffers( + llvm::to_vector(getUseFullTileBuffers().getAsValueRange())); + if (getAlignment().hasValue()) + promotionOptions = promotionOptions.setAlignment(*getAlignment()); + + if (failed(promoteSubviewsPrecondition(target, promotionOptions))) + return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); + + SimpleRewriter rewriter(target->getContext()); + rewriter.setInsertionPoint(target); + FailureOr res = promoteSubViews(rewriter, target, promotionOptions); + if (failed(res)) + return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); + results.push_back(target); + return DiagnosedSilenceableFailure(success()); +} + //===----------------------------------------------------------------------===// // ScalarizeOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp @@ -230,38 +230,6 @@ LinalgTransformationFilter filter; }; -/// Configurable pass to apply pattern-based linalg promotion. -struct LinalgStrategyPromotePass - : public LinalgStrategyPromotePassBase { - - LinalgStrategyPromotePass() = default; - - LinalgStrategyPromotePass(StringRef opName, LinalgPromotionOptions opt, - LinalgTransformationFilter filt) - : options(std::move(opt)), filter(std::move(filt)) { - this->anchorOpName.setValue(opName.str()); - } - - void runOnOperation() override { - auto funcOp = getOperation(); - if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) - return; - - RewritePatternSet promotionPattern(funcOp.getContext()); - if (!anchorOpName.empty()) { - promotionPattern.add( - anchorOpName, funcOp.getContext(), options, filter); - } else { - promotionPattern.add(funcOp.getContext(), - filter, options); - } - (void)applyPatternsAndFoldGreedily(funcOp, std::move(promotionPattern)); - } - - LinalgPromotionOptions options; - LinalgTransformationFilter filter; -}; - /// Configurable pass to apply pattern-based linalg peeling. struct LinalgStrategyPeelPass : public LinalgStrategyPeelPassBase { @@ -508,14 +476,6 @@ return std::make_unique(opName, opt, filter); } -/// Create a LinalgStrategyPromotePass. -std::unique_ptr> -mlir::createLinalgStrategyPromotePass( - StringRef opName, const LinalgPromotionOptions &opt, - const LinalgTransformationFilter &filter) { - return std::make_unique(opName, opt, filter); -} - /// Create a LinalgStrategyGeneralizePass. std::unique_ptr> mlir::createLinalgStrategyGeneralizePass( diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -135,9 +135,6 @@ CopyCallbackFn copyInFn; CopyCallbackFn copyOutFn; - /// Allow the use of dynamically-sized buffers. - bool dynamicBuffers; - /// Alignment of promoted buffer. Optional alignment; }; @@ -145,8 +142,7 @@ LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions( LinalgOp linalgOp, const LinalgPromotionOptions &options) - : subViews(), dynamicBuffers(options.dynamicBuffers), - alignment(options.alignment) { + : subViews(), alignment(options.alignment) { assert(linalgOp.hasBufferSemantics() && "revisit usage of shaped operand"); auto vUseFullTileBuffers = options.useFullTileBuffers.value_or(llvm::SmallBitVector()); @@ -394,36 +390,3 @@ return failure(); return res; } - -namespace { -struct LinalgPromotionPass : public LinalgPromotionBase { - LinalgPromotionPass() = default; - LinalgPromotionPass(bool dynamicBuffers, bool useAlloca) { - this->dynamicBuffers = dynamicBuffers; - this->useAlloca = useAlloca; - } - - void runOnOperation() override { - getOperation().walk([&](LinalgOp op) { - auto options = LinalgPromotionOptions() - .setDynamicBuffers(dynamicBuffers) - .setUseAlloca(useAlloca); - if (failed(promoteSubviewsPrecondition(op, options))) - return; - LLVM_DEBUG(llvm::dbgs() << "Promote: " << *(op.getOperation()) << "\n"); - ImplicitLocOpBuilder b(op.getLoc(), op); - // TODO: signalPassFailure() ? - (void)promoteSubViews(b, op, options); - }); - } -}; -} // namespace - -// TODO: support more transformation options in the pass. -std::unique_ptr> -mlir::createLinalgPromotionPass(bool dynamicBuffers, bool useAlloca) { - return std::make_unique(dynamicBuffers, useAlloca); -} -std::unique_ptr> mlir::createLinalgPromotionPass() { - return std::make_unique(); -} diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -691,40 +691,6 @@ return genericOp; } -mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( - MLIRContext *context, LinalgTransformationFilter f, - LinalgPromotionOptions options, PatternBenefit benefit) - : RewritePattern(MatchAnyOpTypeTag(), benefit, context), - filter(std::move(f)), options(std::move(options)) {} - -mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( - StringRef opName, MLIRContext *context, LinalgPromotionOptions options, - LinalgTransformationFilter f, PatternBenefit benefit) - : RewritePattern(opName, benefit, context, {}), filter(std::move(f)), - options(std::move(options)) {} - -LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite( - Operation *op, PatternRewriter &rewriter) const { - if (failed(filter.checkAndNotify(rewriter, op))) - return failure(); - if (failed(promoteSubviewsPrecondition(op, options))) - return failure(); - - // TODO: We cannot use root update here. This pattern is creating other ops, - // so if the promotion fails, those need to be cleaned up, which doesnt seem - // to be happening here. So to fail properly, we should be cloning the op and - // deleting the previous op. This needs more investigation. - rewriter.startRootUpdate(op); - Optional promotedOp = promoteSubViews(rewriter, op, options); - if (!promotedOp) { - rewriter.cancelRootUpdate(op); - return op->emitError("subview promotion failed"); - } - rewriter.finalizeRootUpdate(op); - filter.replaceLinalgTransformationFilter(rewriter, op); - return success(); -} - mlir::linalg::LinalgPeelingPattern::LinalgPeelingPattern( MLIRContext *context, LinalgTransformationFilter f, LinalgPeelOptions options, PatternBenefit benefit) diff --git a/mlir/test/Dialect/Linalg/promote.mlir b/mlir/test/Dialect/Linalg/promote.mlir --- a/mlir/test/Dialect/Linalg/promote.mlir +++ b/mlir/test/Dialect/Linalg/promote.mlir @@ -1,6 +1,4 @@ -// RUN: mlir-opt %s -linalg-promote-subviews -split-input-file | FileCheck %s -// RUN: mlir-opt %s -linalg-promote-subviews="test-promote-dynamic" -split-input-file | FileCheck %s --check-prefix=DYNAMIC -// RUN: mlir-opt %s -linalg-promote-subviews="test-use-alloca" -split-input-file | FileCheck %s --check-prefix=ALLOCA +// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file | FileCheck %s #map1 = affine_map<(d0) -> (d0 + 2)> #map2 = affine_map<(d0) -> (d0 + 4)> @@ -44,25 +42,19 @@ // CHECK: %[[vB:.*]] = memref.subview {{.*}} : memref // CHECK: %[[vC:.*]] = memref.subview {{.*}} : memref /// -// CHECK: %[[tmpA:.*]] = memref.alloc() : memref<32xi8> -// ALLOCA: %[[tmpA:.*]] = memref.alloca() : memref<32xi8> +// CHECK: %[[tmpA:.*]] = memref.alloca() : memref<32xi8> // CHECK: %[[fullA:.*]] = memref.view %[[tmpA]][{{.*}}][{{.*}}] : memref<32xi8> to memref -// DYNAMIC: memref.view %{{.*}}[{{.*}}][{{.*}}] : memref to memref // CHECK: %[[partialA:.*]] = memref.subview %[[fullA]]{{.*}} : memref to memref /// -// CHECK: %[[tmpB:.*]] = memref.alloc() : memref<48xi8> -// ALLOCA: %[[tmpB:.*]] = memref.alloca() : memref<48xi8> +// CHECK: %[[tmpB:.*]] = memref.alloca() : memref<48xi8> // CHECK: %[[fullB:.*]] = memref.view %[[tmpB]][{{.*}}][{{.*}}] : memref<48xi8> to memref -// DYNAMIC: memref.view %{{.*}}[{{.*}}][{{.*}}] : memref to memref // CHECK: %[[partialB:.*]] = memref.subview %[[fullB]]{{.*}} : memref to memref /// -// CHECK: %[[tmpC:.*]] = memref.alloc() : memref<24xi8> -// ALLOCA: %[[tmpC:.*]] = memref.alloca() : memref<24xi8> +// CHECK: %[[tmpC:.*]] = memref.alloca() : memref<24xi8> // CHECK: %[[fullC:.*]] = memref.view %[[tmpC]][{{.*}}][{{.*}}] : memref<24xi8> to memref -// DYNAMIC: memref.view %{{.*}}[{{.*}}][{{.*}}] : memref to memref // CHECK: %[[partialC:.*]] = memref.subview %[[fullC]]{{.*}} : memref to memref -// CHECK: emref.copy %[[vA]], %[[partialA]] : memref to memref +// CHECK: memref.copy %[[vA]], %[[partialA]] : memref to memref // CHECK: memref.copy %[[vB]], %[[partialB]] : memref to memref // CHECK: memref.copy %[[vC]], %[[partialC]] : memref to memref // @@ -72,12 +64,25 @@ // CHECK: memref to // CHECK: memref // -// CHECK: memref.dealloc %[[tmpA]] : memref<32xi8> -// CHECK: memref.dealloc %[[tmpB]] : memref<48xi8> -// CHECK: memref.dealloc %[[tmpC]] : memref<24xi8> -// ALLOCA-NOT: memref.dealloc %[[tmpA]] : memref<32xi8> -// ALLOCA-NOT: memref.dealloc %[[tmpB]] : memref<48xi8> -// ALLOCA-NOT: memref.dealloc %[[tmpC]] : memref<24xi8> +// CHECK-NOT: memref.dealloc %[[tmpA]] : memref<32xi8> +// CHECK-NOT: memref.dealloc %[[tmpB]] : memref<48xi8> +// CHECK-NOT: memref.dealloc %[[tmpC]] : memref<24xi8> + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + sequence %arg0 { + ^bb0(%arg1: !pdl.operation): + %0 = pdl_match @pdl_target in %arg1 + %1 = transform.structured.promote %0 { use_alloca } + } + + pdl.pattern @pdl_target : benefit(1) { + %args = operands + %results = types + %0 = operation "linalg.matmul"(%args : !pdl.range) -> (%results : !pdl.range) + rewrite %0 with "transform.dialect" + } +} // ----- @@ -119,17 +124,14 @@ /// // CHECK: %[[tmpA_f64:.*]] = memref.alloc() : memref<64xi8> // CHECK: %[[fullA_f64:.*]] = memref.view %[[tmpA_f64]][{{.*}}][{{.*}}] : memref<64xi8> to memref -// DYNAMIC: memref.view %{{.*}}[{{.*}}][{{.*}}] : memref to memref // CHECK: %[[partialA_f64:.*]] = memref.subview %[[fullA_f64]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref to memref /// // CHECK: %[[tmpB_f64:.*]] = memref.alloc() : memref<96xi8> // CHECK: %[[fullB_f64:.*]] = memref.view %[[tmpB_f64]][{{.*}}][{{.*}}] : memref<96xi8> to memref -// DYNAMIC: memref.view %{{.*}}[{{.*}}][{{.*}}] : memref to memref // CHECK: %[[partialB_f64:.*]] = memref.subview %[[fullB_f64]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref to memref /// // CHECK: %[[tmpC_f64:.*]] = memref.alloc() : memref<48xi8> // CHECK: %[[fullC_f64:.*]] = memref.view %[[tmpC_f64]][{{.*}}][{{.*}}] : memref<48xi8> to memref -// DYNAMIC: memref.view %{{.*}}[{{.*}}][{{.*}}] : memref to memref // CHECK: %[[partialC_f64:.*]] = memref.subview %[[fullC_f64]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref to memref // CHECK: memref.copy %[[vA_f64]], %[[partialA_f64]] : memref to memref @@ -146,6 +148,23 @@ // CHECK: memref.dealloc %[[tmpB_f64]] : memref<96xi8> // CHECK: memref.dealloc %[[tmpC_f64]] : memref<48xi8> +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + sequence %arg0 { + ^bb0(%arg1: !pdl.operation): + %0 = pdl_match @pdl_target in %arg1 + %1 = transform.structured.promote %0 + } + + pdl.pattern @pdl_target : benefit(1) { + %args = operands + %results = types + %0 = operation "linalg.matmul"(%args : !pdl.range) -> (%results : !pdl.range) + rewrite %0 with "transform.dialect" + } +} + + // ----- #map0 = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4)> @@ -155,29 +174,29 @@ #map7 = affine_map<(d0, d1, d2) -> (d1, d2)> #map8 = affine_map<(d0, d1, d2) -> (d0, d1)> -// CHECK: promote_rank_reducing_subviews([[arg0:%.+]]: memref<{{.*}}>, [[arg1:%.+]]: memref<{{.*}}>, [[arg2:%.+]]: memref<{{.*}}>, [[lb1:%.+]]: index, [[lb2:%.+]]: index, [[lb3:%.+]]: index, [[lb4:%.+]]: index, [[lb5:%.+]]: index, [[lb6:%.+]]: index, [[ub1:%.+]]: index, [[ub2:%.+]]: index +// CHECK: promote_rank_reducing_subviews(%[[arg0:.+]]: memref<{{.*}}>, %[[arg1:.+]]: memref<{{.*}}>, %[[arg2:.+]]: memref<{{.*}}>, %[[lb1:.+]]: index, %[[lb2:.+]]: index, %[[lb3:.+]]: index, %[[lb4:.+]]: index, %[[lb5:.+]]: index, %[[lb6:.+]]: index, %[[ub1:.+]]: index, %[[ub2:.+]]: index func.func @promote_rank_reducing_subviews(%arg0: memref, %arg1: memref<128x3x3x64xf32, #map0>, %arg2: memref, %arg3: index, %arg4: index, %arg5: index, %arg6: index, %arg7: index, %arg8: index, %ub1: index, %ub2: index) { %13 = memref.subview %arg0[%arg3, 0, %arg4, %arg8] [1, 1, %ub1, 32] [1, 1, 1, 1] : memref to memref %14 = memref.subview %arg1[0, %arg6, %arg7, %arg8] [128, 1, 1, 32] [1, 1, 1, 1] : memref<128x3x3x64xf32, #map0> to memref<128x32xf32, #map5> %9 = memref.subview %arg2[%arg3, %arg4, %arg5, 0] [1, 1, %ub2, 128] [1, 1, 1, 1] : memref to memref - // CHECK: [[a_alloc:%.+]] = memref.alloc - // CHECK: [[a_view:%.+]] = memref.view [[a_alloc]]{{.*}} - // CHECK: [[a_pro_subview:%.+]] = memref.subview [[a_view]][0, 0] [[[ub1]], {{%.+}}] [1, 1] + // CHECK: %[[a_alloc:.+]] = memref.alloc + // CHECK: %[[a_view:.+]] = memref.view %[[a_alloc]]{{.*}} + // CHECK: %[[a_pro_subview:.+]] = memref.subview %[[a_view]][0, 0] [%[[ub1]], {{.+}}] [1, 1] // CHECK: memref.alloc - // CHECK: [[b_view:%.+]] = memref.view - // CHECK: [[b_pro_subview:%.+]] = memref.subview [[b_view]] + // CHECK: %[[b_view:.+]] = memref.view + // CHECK: %[[b_pro_subview:.+]] = memref.subview %[[b_view]] // CHECK: memref.alloc - // CHECK: [[c_view:%.+]] = memref.view - // CHECK: [[c_pro_subview:%.+]] = memref.subview [[c_view]] + // CHECK: %[[c_view:.+]] = memref.view + // CHECK: %[[c_pro_subview:.+]] = memref.subview %[[c_view]] // CHECK-COUNT-3: memref.copy // CHECK: linalg.generic - // CHECK-SAME: ins([[a_pro_subview]], [[b_pro_subview]] - // CHECK-SAME: outs([[c_pro_subview]] + // CHECK-SAME: ins(%[[a_pro_subview]], %[[b_pro_subview]] + // CHECK-SAME: outs(%[[c_pro_subview]] linalg.generic {indexing_maps = [#map6, #map7, #map8], iterator_types = ["parallel", "parallel", "reduction"]} ins(%13, %14 : memref, memref<128x32xf32, #map5>) outs(%9 : memref) { ^bb0(%arg9: f32, %arg10: f32, %arg11: f32): @@ -188,3 +207,19 @@ return } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + sequence %arg0 { + ^bb0(%arg1: !pdl.operation): + %0 = pdl_match @pdl_target in %arg1 + %1 = transform.structured.promote %0 + } + + pdl.pattern @pdl_target : benefit(1) { + %args = operands + %results = types + %0 = operation "linalg.generic"(%args : !pdl.range) -> (%results : !pdl.range) + rewrite %0 with "transform.dialect" + } +} diff --git a/mlir/test/Dialect/Linalg/promotion_options.mlir b/mlir/test/Dialect/Linalg/promotion_options.mlir --- a/mlir/test/Dialect/Linalg/promotion_options.mlir +++ b/mlir/test/Dialect/Linalg/promotion_options.mlir @@ -1,10 +1,9 @@ -// RUN: mlir-opt %s -test-linalg-transform-patterns=test-linalg-promotion-options -split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-transform-dialect-interpreter -canonicalize -split-input-file | FileCheck %s func.func @gemm(%a : memref, %b : memref, %c : memref) { - linalg.matmul {__internal_linalg_transform__ = "START"} - ins(%a, %b: memref, memref) - outs(%c: memref) + linalg.matmul ins(%a, %b: memref, memref) + outs(%c: memref) return } @@ -12,23 +11,39 @@ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref -// CHECK-DAG: %[[C42:.+]] = arith.constant 4.200000e+01 : f32 +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index // CHECK: scf.for // CHECK: scf.for // CHECK: scf.for // CHECK: %[[T7:.+]] = memref.subview %[[ARG0]] // CHECK: %[[T12:.+]] = memref.subview %[[ARG1]] // CHECK: %[[T17:.+]] = memref.subview %[[ARG2]] -// CHECK: %[[T18:.+]] = memref.alloc(%{{.*}}, %{{.*}}) : memref -// CHECK: %[[T19:.+]] = memref.subview %[[T18]] -// CHECK: %[[T20:.+]] = memref.alloc(%{{.*}}, %{{.*}}) : memref -// CHECK: %[[T21:.+]] = memref.subview %[[T20]] -// CHECK: linalg.fill ins(%[[C42]]{{.*}}outs(%[[T19]] +// CHECK: %[[A0:.*]] = memref.alloc() : memref<1024xi8> +// CHECK: %[[V0:.*]] = memref.view %[[A0]][%[[C0]]][] : memref<1024xi8> to memref<16x16xf32> +// CHECK: %[[T19:.+]] = memref.subview %[[V0]] +// CHECK: %[[A1:.*]] = memref.alloc() : memref<1024xi8> +// CHECK: %[[V1:.*]] = memref.view %[[A1]][%[[C0]]][] : memref<1024xi8> to memref<16x16xf32> +// CHECK: %[[T21:.+]] = memref.subview %[[V1]] // CHECK: memref.copy %[[T7]], %[[T19]] -// CHECK: linalg.fill ins(%[[C42]]{{.*}}outs(%[[T21]] // CHECK: memref.copy %[[T17]], %[[T21]] // CHECK: linalg.matmul ins(%[[T19]], %[[T12]]{{.*}} outs(%[[T21]] -// CHECK-NOT: linalg.fill // CHECK: memref.copy %[[T21]], %[[T17]] -// CHECK: memref.dealloc %[[T18]] -// CHECK: memref.dealloc %[[T20]] +// CHECK: memref.dealloc %[[A0]] +// CHECK: memref.dealloc %[[A1]] + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + sequence %arg0 { + ^bb0(%arg1: !pdl.operation): + %0 = pdl_match @pdl_target in %arg1 + %1, %loops:3 = transform.structured.tile %0 [16, 16, 16] + %2 = transform.structured.promote %1 { operands_to_promote = [0, 2], force_full_tiles = [false, false] } + } + + pdl.pattern @pdl_target : benefit(1) { + %args = operands + %results = types + %0 = operation "linalg.matmul"(%args : !pdl.range) -> (%results : !pdl.range) + rewrite %0 with "transform.dialect" + } +} diff --git a/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir deleted file mode 100644 --- a/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir +++ /dev/null @@ -1,46 +0,0 @@ -// RUN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-1d | FileCheck %s -check-prefix=CHECK-1D -// RUN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-2d | FileCheck %s -check-prefix=CHECK-2D - -func.func @matmul(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>, - %B: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>, - %C: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>) { - linalg.matmul {__internal_linalg_transform__ = "START"} - ins(%A, %B: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>, - memref<1584x1584xf32, offset: 0, strides: [1584, 1]>) - outs(%C: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>) - return -} - -// CHECK-1D-LABEL:func @matmul -// CHECK-1D: vector.transfer_write {{.*}} : vector<8x16xf32>, memref<8x16xf32> -// CHECK-1D: vector.transfer_write {{.*}} : vector<16x12xf32>, memref<16x12xf32> -// CHECK-1D: vector.transfer_write {{.*}} : vector<8x12xf32>, memref<8x12xf32> -// -// CHECK-1D: vector.transfer_read {{.*}} : memref<8x16xf32, #{{.*}}>, vector<8x16xf32> -// CHECK-1D: vector.transfer_write {{.*}} : vector<8x16xf32>, memref<8x16xf32> -// CHECK-1D: vector.transfer_read {{.*}} : memref<16x12xf32, #{{.*}}>, vector<16x12xf32> -// CHECK-1D: vector.transfer_write {{.*}} : vector<16x12xf32>, memref<16x12xf32> -// CHECK-1D: vector.transfer_read {{.*}} : memref<8x12xf32, #{{.*}}>, vector<8x12xf32> -// CHECK-1D: vector.transfer_write {{.*}} : vector<8x12xf32>, memref<8x12xf32> -// -// CHECK-1D: vector.contract -// CHECK-1D-SAME: iterator_types = ["parallel", "parallel", "reduction"] -// CHECK-1D-SAME: : vector<8x16xf32>, vector<16x12xf32> into vector<8x12xf32> -// -// CHECK-1D: vector.transfer_read {{.*}} : memref<8x12xf32>, vector<8x12xf32> -// CHECK-1D: vector.transfer_write {{.*}} : vector<8x12xf32>, memref<8x12xf32, #{{.*}}> - -// CHECK-2D-LABEL:func @matmul -// CHECK-2D: vector.transfer_write {{.*}} : vector<8x16xf32>, memref<8x16xf32> -// CHECK-2D: vector.transfer_write {{.*}} : vector<16x12xf32>, memref<16x12xf32> -// CHECK-2D: vector.transfer_write {{.*}} : vector<8x12xf32>, memref<8x12xf32> -// -// CHECK-2D: memref.copy -// CHECK-2D: memref.copy -// CHECK-2D: memref.copy -// -// CHECK-2D: vector.contract -// CHECK-2D-SAME: iterator_types = ["parallel", "parallel", "reduction"] -// CHECK-2D-SAME: : vector<8x16xf32>, vector<16x12xf32> into vector<8x12xf32> -// -// CHECK-2D: memref.copy diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir --- a/mlir/test/Dialect/Linalg/transform-patterns.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir @@ -1,10 +1,8 @@ -// RUN: mlir-opt %s -test-linalg-transform-patterns=test-patterns | FileCheck %s +// RUN: mlir-opt %s -test-linalg-transform-patterns=test-patterns -split-input-file | FileCheck %s // CHECK-DAG: #[[$STRIDED_1D:.*]] = affine_map<(d0)[s0] -> (d0 + s0)> // Map corresponding to a 2D memory access where the stride along the last dim is known to be 1. // CHECK-DAG: #[[$STRIDED_2D_u_1:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> -// Map corresponding to a 2D memory access where the stride along all dims are unknown. -// CHECK-DAG: #[[$STRIDED_2D:.*]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)> // CHECK-DAG: #[[$kn:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)> // CHECK-DAG: #[[$nm:.*]] = affine_map<(d0, d1, d2) -> (d1, d0)> // CHECK-DAG: #[[$km:.*]] = affine_map<(d0, d1, d2) -> (d2, d0)> @@ -177,162 +175,6 @@ // CHECK: ins({{.*}}: memref, memref) // CHECK: outs({{.*}}: memref) -func.func @promote_subview_matmul(%arg0: memref, - %arg1: memref, - %arg2: memref) { - %c2000 = arith.constant 2000 : index - %c3000 = arith.constant 3000 : index - %c4000 = arith.constant 4000 : index - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %0 = memref.dim %arg0, %c0 : memref - %1 = memref.dim %arg0, %c1 : memref - %2 = memref.dim %arg1, %c1 : memref - scf.for %arg3 = %c0 to %0 step %c2000 { - scf.for %arg4 = %c0 to %2 step %c3000 { - scf.for %arg5 = %c0 to %1 step %c4000 { - %3 = memref.subview %arg0[%arg3, %arg5][%c2000, %c4000][%c1, %c1] : - memref to memref - %4 = memref.subview %arg1[%arg5, %arg4][%c4000, %c3000][%c1, %c1] : - memref to memref - %5 = memref.subview %arg2[%arg3, %arg4][%c2000, %c3000][%c1, %c1] : - memref to memref - linalg.matmul {__internal_linalg_transform__ = "_promote_views_"} - ins(%3, %4: memref, - memref) - outs(%5: memref) - } - } - } - return -} -// CHECK-LABEL: func @promote_subview_matmul -// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[c2000:.*]] = arith.constant 2000 : index -// CHECK-DAG: %[[c3000:.*]] = arith.constant 3000 : index -// CHECK-DAG: %[[c4000:.*]] = arith.constant 4000 : index -// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c2000]] { -// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c3000]] { -// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c4000]] { -// CHECK: %[[s0:.*]] = memref.subview {{.*}}: memref to memref -// CHECK: %[[s1:.*]] = memref.subview {{.*}}: memref to memref -// CHECK: %[[s2:.*]] = memref.subview {{.*}}: memref to memref -// CHECK: %[[a0:.*]] = memref.alloc() : memref<32000000xi8> -// CHECK: %[[v0:.*]] = memref.view %[[a0]]{{.*}} : memref<32000000xi8> to memref -// CHECK: %[[l0:.*]] = memref.subview %[[v0]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] -// CHECK-SAME: memref to memref -// CHECK: %[[a1:.*]] = memref.alloc() : memref<48000000xi8> -// CHECK: %[[v1:.*]] = memref.view %[[a1]]{{.*}} : memref<48000000xi8> to memref -// CHECK: %[[l1:.*]] = memref.subview %[[v1]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] -// CHECK-SAME: memref to memref -// CHECK: %[[a2:.*]] = memref.alloc() : memref<24000000xi8> -// CHECK: %[[v2:.*]] = memref.view %[[a2]]{{.*}} : memref<24000000xi8> to memref -// CHECK: %[[l2:.*]] = memref.subview %[[v2]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] -// CHECK-SAME: memref to memref -// CHECK: memref.copy %[[s0]], %[[l0]] : memref to memref -// CHECK: memref.copy %[[s1]], %[[l1]] : memref to memref -// CHECK: memref.copy %[[s2]], %[[l2]] : memref to memref -// CHECK: linalg.matmul -// CHECK-SAME: ins(%[[v0]], %[[v1]] : memref, memref) -// CHECK-SAME: outs(%[[v2]] : memref) - -func.func @promote_first_subview_matmul(%arg0: memref, - %arg1: memref, - %arg2: memref) { - %c2000 = arith.constant 2000 : index - %c3000 = arith.constant 3000 : index - %c4000 = arith.constant 4000 : index - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %0 = memref.dim %arg0, %c0 : memref - %1 = memref.dim %arg0, %c1 : memref - %2 = memref.dim %arg1, %c1 : memref - scf.for %arg3 = %c0 to %0 step %c2000 { - scf.for %arg4 = %c0 to %2 step %c3000 { - scf.for %arg5 = %c0 to %1 step %c4000 { - %3 = memref.subview %arg0[%arg3, %arg5][%c2000, %c4000][%c1, %c1] : - memref to memref - %4 = memref.subview %arg1[%arg5, %arg4][%c4000, %c3000][%c1, %c1] : - memref to memref - %5 = memref.subview %arg2[%arg3, %arg4][%c2000, %c3000][%c1, %c1] : - memref to memref - linalg.matmul {__internal_linalg_transform__ = "_promote_first_view_"} - ins(%3, %4: memref, - memref) - outs(%5: memref) - } - } - } - return -} -// CHECK-LABEL: func @promote_first_subview_matmul -// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[c2000:.*]] = arith.constant 2000 : index -// CHECK-DAG: %[[c3000:.*]] = arith.constant 3000 : index -// CHECK-DAG: %[[c4000:.*]] = arith.constant 4000 : index -// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c2000]] { -// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c3000]] { -// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c4000]] { -// CHECK: %[[s0:.*]] = memref.subview {{.*}}: memref to memref -// CHECK: %[[s1:.*]] = memref.subview {{.*}}: memref to memref -// CHECK: %[[s2:.*]] = memref.subview {{.*}}: memref to memref -// CHECK: %[[a0:.*]] = memref.alloc() : memref<32000000xi8> -// CHECK: %[[v0:.*]] = memref.view %[[a0]]{{.*}} : memref<32000000xi8> to memref -// CHECK: %[[l0:.*]] = memref.subview %[[v0]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref to memref -// CHECK-NOT: memref.alloc -// CHECK-NOT: memref.view -// CHECK-NOT: memref.subview -// CHECK: memref.copy %[[s0]], %[[l0]] : memref to memref -// CHECK-NOT: memref.copy -// CHECK: linalg.matmul -// CHECK-SAME: ins(%[[v0]], %[[s1]] : memref, memref) -// CHECK-SAME: outs(%[[s2]] : memref) - -func.func @aligned_promote_fill(%arg0: memref) { - %c2000 = arith.constant 2000 : index - %c4000 = arith.constant 4000 : index - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %cf = arith.constant 1.0 : f32 - %3 = memref.subview %arg0[%c0, %c0][%c2000, %c4000][%c1, %c1] : - memref to memref - linalg.fill { __internal_linalg_transform__ = "_promote_views_aligned_"} - ins(%cf : f32) outs(%3 : memref) - return -} -// CHECK-LABEL: func @aligned_promote_fill -// CHECK: %[[cf:.*]] = arith.constant 1.{{.*}} : f32 -// CHECK: %[[s0:.*]] = memref.subview {{.*}}: memref to memref -// CHECK: %[[a0:.*]] = memref.alloc() {alignment = 32 : i64} : memref<32000000xi8> -// CHECK: %[[v0:.*]] = memref.view %[[a0]]{{.*}} : memref<32000000xi8> to memref -// CHECK: %[[l0:.*]] = memref.subview %[[v0]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref to memref -// CHECK: linalg.fill ins({{.*}} : f32) outs(%[[v0]] : memref) -// CHECK: memref.copy %[[s0]], %[[l0]] : memref to memref -// CHECK: linalg.fill ins(%[[cf]] : f32) outs(%[[v0]] : memref) - -func.func @aligned_promote_fill_complex(%arg0: memref, offset: ?, strides: [?, 1]>) { - %c2000 = arith.constant 2000 : index - %c4000 = arith.constant 4000 : index - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %cf = arith.constant 1.0 : f32 - %cc = complex.create %cf, %cf : complex - %3 = memref.subview %arg0[%c0, %c0][%c2000, %c4000][%c1, %c1] : - memref, offset: ?, strides: [?, 1]> to memref, offset: ?, strides: [?, ?]> - linalg.fill { __internal_linalg_transform__ = "_promote_views_aligned_"} - ins(%cc : complex) outs(%3 : memref, offset: ?, strides: [?, ?]>) - return -} -// CHECK-LABEL: func @aligned_promote_fill_complex -// CHECK: %[[cc:.*]] = complex.create {{.*}} : complex -// CHECK: %[[s0:.*]] = memref.subview {{.*}}: memref, #map{{.*}}> to memref, #map{{.*}}> -// CHECK: %[[a0:.*]] = memref.alloc() {alignment = 32 : i64} : memref<64000000xi8> -// CHECK: %[[v0:.*]] = memref.view %[[a0]]{{.*}} : memref<64000000xi8> to memref> -// CHECK: %[[l0:.*]] = memref.subview %[[v0]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref> to memref, #[[$STRIDED_2D_u_1]]> -// CHECK: linalg.fill ins({{.*}} : complex) outs(%[[v0]] : memref>) -// CHECK: memref.copy %[[s0]], %[[l0]] : memref, #map{{.*}}> to memref, #map{{.*}}> -// CHECK: linalg.fill ins(%[[cc]] : complex) outs(%[[v0]] : memref>) - func.func @tile_permute_parallel_loop(%arg0: memref, %arg1: memref, %arg2: memref) { diff --git a/mlir/test/Dialect/Linalg/transform-promotion.mlir b/mlir/test/Dialect/Linalg/transform-promotion.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/transform-promotion.mlir @@ -0,0 +1,230 @@ +// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file | FileCheck %s + +// Map corresponding to a 2D memory access where the stride along the last dim is known to be 1. +// CHECK-DAG: #[[$STRIDED_2D_u_1:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> +// Map corresponding to a 2D memory access where the stride along all dims are unknown. +// CHECK-DAG: #[[$STRIDED_2D:.*]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)> +func.func @promote_subview_matmul(%arg0: memref, + %arg1: memref, + %arg2: memref) { + %c2000 = arith.constant 2000 : index + %c3000 = arith.constant 3000 : index + %c4000 = arith.constant 4000 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = memref.dim %arg0, %c0 : memref + %1 = memref.dim %arg0, %c1 : memref + %2 = memref.dim %arg1, %c1 : memref + scf.for %arg3 = %c0 to %0 step %c2000 { + scf.for %arg4 = %c0 to %2 step %c3000 { + scf.for %arg5 = %c0 to %1 step %c4000 { + %3 = memref.subview %arg0[%arg3, %arg5][%c2000, %c4000][%c1, %c1] : + memref to memref + %4 = memref.subview %arg1[%arg5, %arg4][%c4000, %c3000][%c1, %c1] : + memref to memref + %5 = memref.subview %arg2[%arg3, %arg4][%c2000, %c3000][%c1, %c1] : + memref to memref + linalg.matmul ins(%3, %4: memref, + memref) + outs(%5: memref) + } + } + } + return +} +// CHECK-LABEL: func @promote_subview_matmul +// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[c2000:.*]] = arith.constant 2000 : index +// CHECK-DAG: %[[c3000:.*]] = arith.constant 3000 : index +// CHECK-DAG: %[[c4000:.*]] = arith.constant 4000 : index +// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c2000]] { +// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c3000]] { +// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c4000]] { +// CHECK: %[[s0:.*]] = memref.subview {{.*}}: memref to memref +// CHECK: %[[s1:.*]] = memref.subview {{.*}}: memref to memref +// CHECK: %[[s2:.*]] = memref.subview {{.*}}: memref to memref +// CHECK: %[[a0:.*]] = memref.alloc() : memref<32000000xi8> +// CHECK: %[[v0:.*]] = memref.view %[[a0]]{{.*}} : memref<32000000xi8> to memref +// CHECK: %[[l0:.*]] = memref.subview %[[v0]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] +// CHECK-SAME: memref to memref +// CHECK: %[[a1:.*]] = memref.alloc() : memref<48000000xi8> +// CHECK: %[[v1:.*]] = memref.view %[[a1]]{{.*}} : memref<48000000xi8> to memref +// CHECK: %[[l1:.*]] = memref.subview %[[v1]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] +// CHECK-SAME: memref to memref +// CHECK: %[[a2:.*]] = memref.alloc() : memref<24000000xi8> +// CHECK: %[[v2:.*]] = memref.view %[[a2]]{{.*}} : memref<24000000xi8> to memref +// CHECK: %[[l2:.*]] = memref.subview %[[v2]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] +// CHECK-SAME: memref to memref +// CHECK: memref.copy %[[s0]], %[[l0]] : memref to memref +// CHECK: memref.copy %[[s1]], %[[l1]] : memref to memref +// CHECK: memref.copy %[[s2]], %[[l2]] : memref to memref +// CHECK: linalg.matmul +// CHECK-SAME: ins(%[[v0]], %[[v1]] : memref, memref) +// CHECK-SAME: outs(%[[v2]] : memref) + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + sequence %arg0 { + ^bb0(%arg1: !pdl.operation): + %0 = pdl_match @pdl_target in %arg1 + %1 = transform.structured.promote %0 { operands_to_promote = [0, 1, 2], use_full_tiles_by_default } + } + + pdl.pattern @pdl_target : benefit(1) { + %args = operands + %results = types + %0 = operation "linalg.matmul"(%args : !pdl.range) -> (%results : !pdl.range) + rewrite %0 with "transform.dialect" + } +} + +// ----- + +func.func @promote_first_subview_matmul(%arg0: memref, + %arg1: memref, + %arg2: memref) { + %c2000 = arith.constant 2000 : index + %c3000 = arith.constant 3000 : index + %c4000 = arith.constant 4000 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = memref.dim %arg0, %c0 : memref + %1 = memref.dim %arg0, %c1 : memref + %2 = memref.dim %arg1, %c1 : memref + scf.for %arg3 = %c0 to %0 step %c2000 { + scf.for %arg4 = %c0 to %2 step %c3000 { + scf.for %arg5 = %c0 to %1 step %c4000 { + %3 = memref.subview %arg0[%arg3, %arg5][%c2000, %c4000][%c1, %c1] : + memref to memref + %4 = memref.subview %arg1[%arg5, %arg4][%c4000, %c3000][%c1, %c1] : + memref to memref + %5 = memref.subview %arg2[%arg3, %arg4][%c2000, %c3000][%c1, %c1] : + memref to memref + linalg.matmul {__internal_linalg_transform__ = "_promote_first_view_"} + ins(%3, %4: memref, + memref) + outs(%5: memref) + } + } + } + return +} +// CHECK-LABEL: func @promote_first_subview_matmul +// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[c2000:.*]] = arith.constant 2000 : index +// CHECK-DAG: %[[c3000:.*]] = arith.constant 3000 : index +// CHECK-DAG: %[[c4000:.*]] = arith.constant 4000 : index +// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c2000]] { +// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c3000]] { +// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c4000]] { +// CHECK: %[[s0:.*]] = memref.subview {{.*}}: memref to memref +// CHECK: %[[s1:.*]] = memref.subview {{.*}}: memref to memref +// CHECK: %[[s2:.*]] = memref.subview {{.*}}: memref to memref +// CHECK: %[[a0:.*]] = memref.alloc() : memref<32000000xi8> +// CHECK: %[[v0:.*]] = memref.view %[[a0]]{{.*}} : memref<32000000xi8> to memref +// CHECK: %[[l0:.*]] = memref.subview %[[v0]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref to memref +// CHECK-NOT: memref.alloc +// CHECK-NOT: memref.view +// CHECK-NOT: memref.subview +// CHECK: memref.copy %[[s0]], %[[l0]] : memref to memref +// CHECK-NOT: memref.copy +// CHECK: linalg.matmul +// CHECK-SAME: ins(%[[v0]], %[[s1]] : memref, memref) +// CHECK-SAME: outs(%[[s2]] : memref) + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + sequence %arg0 { + ^bb0(%arg1: !pdl.operation): + %0 = pdl_match @pdl_target in %arg1 + %1 = transform.structured.promote %0 { operands_to_promote = [0], use_full_tiles_by_default } + } + + pdl.pattern @pdl_target : benefit(1) { + %args = operands + %results = types + %0 = operation "linalg.matmul"(%args : !pdl.range) -> (%results : !pdl.range) + rewrite %0 with "transform.dialect" + } +} + +// ----- + +func.func @aligned_promote_fill(%arg0: memref) { + %c2000 = arith.constant 2000 : index + %c4000 = arith.constant 4000 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %cf = arith.constant 1.0 : f32 + %3 = memref.subview %arg0[%c0, %c0][%c2000, %c4000][%c1, %c1] : + memref to memref + linalg.fill + ins(%cf : f32) outs(%3 : memref) + return +} +// CHECK-LABEL: func @aligned_promote_fill +// CHECK: %[[cf:.*]] = arith.constant 1.{{.*}} : f32 +// CHECK: %[[s0:.*]] = memref.subview {{.*}}: memref to memref +// CHECK: %[[a0:.*]] = memref.alloc() {alignment = 32 : i64} : memref<32000000xi8> +// CHECK: %[[v0:.*]] = memref.view %[[a0]]{{.*}} : memref<32000000xi8> to memref +// CHECK: %[[l0:.*]] = memref.subview %[[v0]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref to memref +// CHECK: linalg.fill ins({{.*}} : f32) outs(%[[v0]] : memref) +// CHECK: memref.copy %[[s0]], %[[l0]] : memref to memref +// CHECK: linalg.fill ins(%[[cf]] : f32) outs(%[[v0]] : memref) + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + sequence %arg0 { + ^bb0(%arg1: !pdl.operation): + %0 = pdl_match @pdl_target in %arg1 + %1 = transform.structured.promote %0 { operands_to_promote = [1], use_full_tile_buffers = [false, true], alignment = 32} + } + + pdl.pattern @pdl_target : benefit(1) { + %args = operands + %results = types + %0 = operation "linalg.fill"(%args : !pdl.range) -> (%results : !pdl.range) + rewrite %0 with "transform.dialect" + } +} + +// ----- + +func.func @aligned_promote_fill_complex(%arg0: memref, offset: ?, strides: [?, 1]>) { + %c2000 = arith.constant 2000 : index + %c4000 = arith.constant 4000 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %cf = arith.constant 1.0 : f32 + %cc = complex.create %cf, %cf : complex + %3 = memref.subview %arg0[%c0, %c0][%c2000, %c4000][%c1, %c1] : + memref, offset: ?, strides: [?, 1]> to memref, offset: ?, strides: [?, ?]> + linalg.fill ins(%cc : complex) + outs(%3 : memref, offset: ?, strides: [?, ?]>) + return +} +// CHECK-LABEL: func @aligned_promote_fill_complex +// CHECK: %[[cc:.*]] = complex.create {{.*}} : complex +// CHECK: %[[s0:.*]] = memref.subview {{.*}}: memref, #map{{.*}}> to memref, #map{{.*}}> +// CHECK: %[[a0:.*]] = memref.alloc() {alignment = 32 : i64} : memref<64000000xi8> +// CHECK: %[[v0:.*]] = memref.view %[[a0]]{{.*}} : memref<64000000xi8> to memref> +// CHECK: %[[l0:.*]] = memref.subview %[[v0]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref> to memref, #[[$STRIDED_2D_u_1]]> +// CHECK: linalg.fill ins({{.*}} : complex) outs(%[[v0]] : memref>) +// CHECK: memref.copy %[[s0]], %[[l0]] : memref, #map{{.*}}> to memref, #map{{.*}}> +// CHECK: linalg.fill ins(%[[cc]] : complex) outs(%[[v0]] : memref>) + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + sequence %arg0 { + ^bb0(%arg1: !pdl.operation): + %0 = pdl_match @pdl_target in %arg1 + %1 = transform.structured.promote %0 { operands_to_promote = [1], use_full_tile_buffers = [false, true], alignment = 32} + } + + pdl.pattern @pdl_target : benefit(1) { + %args = operands + %results = types + %0 = operation "linalg.fill"(%args : !pdl.range) -> (%results : !pdl.range) + rewrite %0 with "transform.dialect" + } +} diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -62,21 +62,6 @@ Option testPatterns{*this, "test-patterns", llvm::cl::desc("Test a mixed set of patterns"), llvm::cl::init(false)}; - Option testMatmulToVectorPatterns1dTiling{ - *this, "test-matmul-to-vector-patterns-tile-1d", - llvm::cl::desc( - "Test a fused pass that applies patterns from matmul to vectors via " - "1-d tiling"), - llvm::cl::init(false)}; - Option testMatmulToVectorPatterns2dTiling{ - *this, "test-matmul-to-vector-patterns-tile-2d", - llvm::cl::desc( - "Test a fused pass that applies patterns from matmul to vectors via " - "2-d tiling"), - llvm::cl::init(false)}; - Option testPromotionOptions{*this, "test-linalg-promotion-options", - llvm::cl::desc("Test promotion options"), - llvm::cl::init(false)}; Option testTileAndDistributionOptions{ *this, "test-tile-and-distribute-options", llvm::cl::desc("Test tile and distribute options"), @@ -254,31 +239,6 @@ LinalgTransformationFilter(ArrayRef{}, StringAttr::get(ctx, "PERMUTED"))); - //===--------------------------------------------------------------------===// - // Linalg subview operands promotion. - //===--------------------------------------------------------------------===// - patterns.add>( - ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), - LinalgTransformationFilter(StringAttr::get(ctx, "_promote_views_"), - StringAttr::get(ctx, "_views_promoted_"))); - patterns.add>( - ctx, - LinalgPromotionOptions() - .setOperandsToPromote({0}) - .setUseFullTileBuffersByDefault(true), - LinalgTransformationFilter( - StringAttr::get(ctx, "_promote_first_view_"), - StringAttr::get(ctx, "_first_view_promoted_"))); - patterns.add>( - ctx, - LinalgPromotionOptions() - .setOperandsToPromote({1}) - .setUseFullTileBuffers({false, true}) - .setAlignment(32), - LinalgTransformationFilter( - StringAttr::get(ctx, "_promote_views_aligned_"), - StringAttr::get(ctx, "_views_aligned_promoted_"))); - (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); // Drop the marker. @@ -287,97 +247,6 @@ }); } -static void fillL1TilingAndMatmulToVectorPatterns( - func::FuncOp funcOp, StringRef startMarker, - SmallVectorImpl &patternsVector) { - MLIRContext *ctx = funcOp.getContext(); - patternsVector.emplace_back( - ctx, std::make_unique( - MatmulOp::getOperationName(), ctx, - LinalgTilingOptions() - .setTileSizes({8, 12, 16}) - .setInterchange({1, 0, 2}), - LinalgTransformationFilter(StringAttr::get(ctx, startMarker), - StringAttr::get(ctx, "L1")))); - - patternsVector.emplace_back( - ctx, - std::make_unique>( - ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), - LinalgTransformationFilter(StringAttr::get(ctx, "L1"), - StringAttr::get(ctx, "VEC")))); - - patternsVector.emplace_back( - ctx, std::make_unique( - MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(), - LinalgTransformationFilter(StringAttr::get(ctx, "VEC")))); - patternsVector.back().add( - ctx, LinalgTransformationFilter().addOpFilter()); - patternsVector.back().add(ctx); -} - -//===----------------------------------------------------------------------===// -// Test promotion callbacks -//===----------------------------------------------------------------------===// - -// Allocation call back -static Optional allocCallBackFn(OpBuilder &b, memref::SubViewOp subView, - ArrayRef boundingSubViewSize, - DataLayout &layout) { - SmallVector shape(boundingSubViewSize.size(), -1); - return b - .create( - subView.getLoc(), - MemRefType::get(shape, subView.getType().getElementType(), - /*affineMapComposition =*/{}, 3), - boundingSubViewSize) - .getResult(); -} - -// Deallocation callback -static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) { - b.create(buffer.getLoc(), buffer); - return success(); -} - -// Copy in call back -static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst, - bool isOutput) { - auto floatType = src.getType().cast().getElementType(); - if (!floatType.isa()) - return failure(); - if (!isOutput) { - Value cst = b.create(src.getLoc(), - FloatAttr::get(floatType, 42.0)); - b.create(src.getLoc(), cst, dst); - } - b.create(src.getLoc(), src, dst); - return success(); -} - -static void fillPromotionCallBackPatterns(MLIRContext *ctx, - RewritePatternSet &patterns) { - patterns.add( - MatmulOp::getOperationName(), ctx, - LinalgTilingOptions().setTileSizes({16, 16, 16}), - LinalgTransformationFilter(StringAttr::get(ctx, "START"), - StringAttr::get(ctx, "PROMOTE"))); - patterns.add>( - ctx, - LinalgPromotionOptions() - .setOperandsToPromote({0, 2}) - .setUseFullTileBuffers({false, false}) - .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn) - .setCopyInOutFns( - [](OpBuilder &b, Value src, Value dst) -> LogicalResult { - return copyCallBackFn(b, src, dst, false); - }, - [](OpBuilder &b, Value src, Value dst) -> LogicalResult { - return copyCallBackFn(b, src, dst, true); - }), - LinalgTransformationFilter(StringAttr::get(ctx, "PROMOTE"))); -} - template static SmallVector getGpuProcIds(OpBuilder &b, Location loc, ArrayRef parallelLoopRanges) { @@ -530,40 +399,6 @@ StringAttr::get(context, "tensors_after_fuse_distribute1"))); } -static void -applyMatmulToVectorPatterns(func::FuncOp funcOp, - bool testMatmulToVectorPatterns1dTiling, - bool testMatmulToVectorPatterns2dTiling) { - MLIRContext *ctx = funcOp.getContext(); - SmallVector stage1Patterns; - if (testMatmulToVectorPatterns1dTiling) { - fillL1TilingAndMatmulToVectorPatterns(funcOp, "START", stage1Patterns); - } else if (testMatmulToVectorPatterns2dTiling) { - stage1Patterns.emplace_back( - ctx, std::make_unique( - MatmulOp::getOperationName(), ctx, - LinalgTilingOptions() - .setTileSizes({768, 264, 768}) - .setInterchange({1, 2, 0}), - LinalgTransformationFilter(StringAttr::get(ctx, "START"), - StringAttr::get(ctx, "L2")))); - fillL1TilingAndMatmulToVectorPatterns(funcOp, "L2", stage1Patterns); - } - { - // Canonicalization patterns - RewritePatternSet canonicalizationPatterns(funcOp.getContext()); - vector::populateVectorTransferPermutationMapLoweringPatterns( - canonicalizationPatterns); - vector::populateVectorReductionToContractPatterns(canonicalizationPatterns); - stage1Patterns.push_back(std::move(canonicalizationPatterns)); - } - SmallVector frozenStage1Patterns; - llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns)); - FrozenRewritePatternSet stage2Patterns = - getLinalgTilingCanonicalizationPatterns(ctx); - (void)applyStagedPatterns(funcOp, frozenStage1Patterns, stage2Patterns); -} - static void applyVectorTransferForwardingPatterns(func::FuncOp funcOp) { RewritePatternSet forwardPattern(funcOp.getContext()); forwardPattern.add(funcOp.getContext()); @@ -657,12 +492,6 @@ }; std::unique_ptr cleanupGuard{(void *)1, lambda}; - if (testPromotionOptions) { - RewritePatternSet patterns(&getContext()); - fillPromotionCallBackPatterns(&getContext(), patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); - return; - } if (testTileAndDistributionOptions) { RewritePatternSet patterns(&getContext()); fillTileAndDistributePatterns(&getContext(), patterns); @@ -677,10 +506,6 @@ } if (testPatterns) return applyPatterns(getOperation()); - if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling) - return applyMatmulToVectorPatterns(getOperation(), - testMatmulToVectorPatterns1dTiling, - testMatmulToVectorPatterns2dTiling); if (testVectorTransferForwardingPatterns) return applyVectorTransferForwardingPatterns(getOperation()); if (testGenericToVectorPattern)