diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -88,6 +88,13 @@ linalg::LinalgTransformationFilter filter = linalg::LinalgTransformationFilter()); +/// Create a LinalgStrategyPadPass. +std::unique_ptr> createLinalgStrategyPadPass( + StringRef opName = "", + linalg::LinalgPaddingOptions opt = linalg::LinalgPaddingOptions(), + linalg::LinalgTransformationFilter filter = + linalg::LinalgTransformationFilter()); + /// Create a LinalgStrategyPromotePass. std::unique_ptr> createLinalgStrategyPromotePass( StringRef opName = "", diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -248,6 +248,19 @@ ]; } +def LinalgStrategyPadPass + : FunctionPass<"linalg-strategy-pad-pass"> { + let summary = "Configurable pass to apply padding and hoisting."; + let constructor = "mlir::createLinalgStrategyPadPass()"; + let dependentDialects = ["linalg::LinalgDialect"]; + let options = [ + Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"", + "Which func op is the anchor to latch on.">, + Option<"anchorOpName", "anchor-op", "std::string", /*default=*/"", + "Which linalg op within the func is the anchor to latch on.">, + ]; +} + def LinalgStrategyPromotePass : FunctionPass<"linalg-strategy-promote-pass"> { let summary = "Configurable pass to apply pattern-based linalg promotion."; diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h @@ -46,6 +46,22 @@ linalg::LinalgTilingOptions options; }; +/// Represent one application of LinalgStrategyPadPass. +struct Pad : public Transformation { + Pad(StringRef name, linalg::LinalgPaddingOptions options, + LinalgTransformationFilter::FilterFunction f = nullptr) + : Transformation(f), opName(name), options(options) {} + + void addToPassPipeline(OpPassManager &pm, + LinalgTransformationFilter m) const override { + pm.addPass(createLinalgStrategyPadPass(opName, options, m)); + } + +private: + std::string opName; + linalg::LinalgPaddingOptions options; +}; + /// Represent one application of createLinalgStrategyPromotePass. struct Promote : public Transformation { Promote(StringRef name, linalg::LinalgPromotionOptions options, @@ -147,6 +163,21 @@ LinalgTransformationFilter::FilterFunction f = nullptr) { return b ? tile(opName, options) : *this; } + /// Append a pattern to pad and hoist the operands of Op `opName` with padding + /// `options`. + CodegenStrategy &pad(StringRef opName, linalg::LinalgPaddingOptions options, + LinalgTransformationFilter::FilterFunction f = nullptr) { + transformationSequence.emplace_back( + std::make_unique(opName, options, f)); + return *this; + } + /// Conditionally append a pattern to pad and hoist the operands of Op + /// `opName` with padding `options`. + CodegenStrategy & + padIf(bool b, StringRef opName, linalg::LinalgPaddingOptions options, + LinalgTransformationFilter::FilterFunction f = nullptr) { + return b ? pad(opName, options, f) : *this; + } /// Append a pattern to add a level of promotion for `LinalgOpType` with /// promotion `options`. CodegenStrategy & diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp @@ -68,6 +68,39 @@ LinalgTransformationFilter filter; }; +/// Configurable pass to apply hoisting and padding. +struct LinalgStrategyPadPass + : public LinalgStrategyPadPassBase { + + LinalgStrategyPadPass() = default; + + LinalgStrategyPadPass(StringRef opName, LinalgPaddingOptions opt, + LinalgTransformationFilter filt) + : options(opt), filter(filt) { + this->anchorOpName.setValue(opName.str()); + } + + void runOnFunction() override { + auto funcOp = getFunction(); + if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) + return; + + RewritePatternSet paddingPattern(funcOp.getContext()); + if (!anchorOpName.empty()) { + paddingPattern.add( + anchorOpName, funcOp.getContext(), options, filter); + } else { + paddingPattern.add(funcOp.getContext(), options, + filter); + } + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(paddingPattern)))) + signalPassFailure(); + } + + LinalgPaddingOptions options; + LinalgTransformationFilter filter; +}; + /// Configurable pass to apply pattern-based linalg generalization. struct LinalgStrategyGeneralizePass : public LinalgStrategyGeneralizePassBase { @@ -332,6 +365,13 @@ return std::make_unique(opName, opt, filter); } +/// Create a LinalgStrategyPadPass. +std::unique_ptr> +mlir::createLinalgStrategyPadPass(StringRef opName, LinalgPaddingOptions opt, + LinalgTransformationFilter filter) { + return std::make_unique(opName, opt, filter); +} + /// Create a LinalgStrategyPromotePass. std::unique_ptr> mlir::createLinalgStrategyPromotePass(StringRef opName, diff --git a/mlir/test/Dialect/Linalg/codegen-strategy.mlir b/mlir/test/Dialect/Linalg/codegen-strategy.mlir --- a/mlir/test/Dialect/Linalg/codegen-strategy.mlir +++ b/mlir/test/Dialect/Linalg/codegen-strategy.mlir @@ -1,28 +1,53 @@ -// Test that both anchor-op name and MatmulOp-based codegen strategy produce the same result. -// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=2,4,8 vectorize vectorize-contraction-to=matrixintrinsics unroll-vector-transfers=true" | FileCheck %s -// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=2,4,8 tile-interchange=1,2,0 vectorize vectorize-contraction-to=matrixintrinsics unroll-vector-transfers=true" | FileCheck %s -// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 promote promote-full-tile-pad register-tile-sizes=2,4,8 vectorize vectorize-contraction-to=outerproduct split-transfers=true unroll-vector-transfers=false" | FileCheck %s --check-prefix=OUTER -// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=2,4,8 vectorize vectorize-contraction-to=matrixintrinsics unroll-vector-transfers=true" | FileCheck %s -// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 promote promote-full-tile-pad register-tile-sizes=2,4,8 vectorize vectorize-contraction-to=outerproduct split-transfers=true unroll-vector-transfers=false" | FileCheck %s --check-prefix=OUTER -// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 generalize iterator-interchange=0,2,1" | FileCheck %s --check-prefix=GENER - - -// CHECK-LABEL: func @matmul( -// OUTER-LABEL: func @matmul( -// GENER-LABEL: func @matmul( -func @matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %C: memref<1584x1584xf32>) { - linalg.matmul - ins(%A, %B: memref<1584x1584xf32>, memref<1584x1584xf32>) - outs(%C: memref<1584x1584xf32>) - - // CHECK: vector.matrix_multiply - // CHECK-SAME: {lhs_columns = 8 : i32, lhs_rows = 2 : i32, rhs_columns = 4 : i32} - // CHECK-SAME: (vector<16xf32>, vector<32xf32>) -> vector<8xf32> - - // OUTER: vector.outerproduct {{.*}} : vector<2xf32>, vector<4xf32> - - // GENER: linalg.generic - // GENER-SAME: iterator_types = ["parallel", "reduction", "parallel"] +// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=2,4,8 vectorize vectorize-contraction-to=matrixintrinsics unroll-vector-transfers=true" -split-input-file | FileCheck %s --check-prefix=CHECK-INTRINSIC +// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 promote promote-full-tile-pad register-tile-sizes=2,4,8 vectorize vectorize-contraction-to=outerproduct split-transfers=true unroll-vector-transfers=false" -split-input-file | FileCheck %s --check-prefix=CHECK-OUTER +// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 tile-interchange=1,2,0 generalize iterator-interchange=0,2,1" -split-input-file | FileCheck %s --check-prefix=CHECK-INTERCHANGE +// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 pad pack-paddings=1,1,0 hoist-paddings=3,3,0" -split-input-file | FileCheck %s --check-prefix=CHECK-PAD + +// CHECK-INTRINSIC: func @matmul( +// CHECK-OUTER: func @matmul( +func @matmul(%arg0: memref<72x72xf32>, %arg1: memref<72x72xf32>, %arg2: memref<72x72xf32>) { + + // Check the matrix intrinsic lowering is triggered. + // CHECK-INTRINSIC: vector.matrix_multiply + // CHECK-INTRINSIC-SAME: {lhs_columns = 8 : i32, lhs_rows = 2 : i32, rhs_columns = 4 : i32} + // CHECK-INTRINSIC-SAME: (vector<16xf32>, vector<32xf32>) -> vector<8xf32> + + // Check the outer product lowering is triggered. + // CHECK-OUTER: vector.outerproduct {{.*}} : vector<2xf32>, vector<4xf32> + linalg.matmul ins(%arg0, %arg1: memref<72x72xf32>, memref<72x72xf32>) outs(%arg2: memref<72x72xf32>) return } +// ----- + +// CHECK-INTERCHANGE: func @matmul( +func @matmul(%arg0: tensor<72x72xf32>, %arg1: tensor<72x72xf32>, %arg2: tensor<72x72xf32>) -> tensor<72x72xf32> { + // CHECK-INTERCHANGE-DAG: %[[C16:.*]] = arith.constant 16 + // CHECK-INTERCHANGE-DAG: %[[C32:.*]] = arith.constant 32 + // CHECK-INTERCHANGE-DAG: %[[C64:.*]] = arith.constant 64 + + // Check the tile loops are interchanged. + // CHECK-INTERCHANGE: scf.for {{.*}} step %[[C32]] + // CHECK-INTERCHANGE: scf.for {{.*}} step %[[C64]] + // CHECK-INTERCHANGE: scf.for {{.*}} step %[[C16]] + + // Check the operation has been generalized and interchanged. + // CHECK-INTERCHANGE: linalg.generic + // CHECK-INTERCHANGE-SAME: iterator_types = ["parallel", "reduction", "parallel"] + %0 = linalg.matmul ins(%arg0, %arg1: tensor<72x72xf32>, tensor<72x72xf32>) outs(%arg2: tensor<72x72xf32>) -> tensor<72x72xf32> + return %0 : tensor<72x72xf32> +} + +// ----- + +// CHECK-PAD: func @matmul( +func @matmul(%arg0: tensor<72x72xf32>, %arg1: tensor<72x72xf32>, %arg2: tensor<72x72xf32>) -> tensor<72x72xf32> { + + // Check the padding of the input operands has been hoisted out of the tile loop nest. + // CHECK-PAD-COUNT=2: linalg.pad_tensor %{{.*}} nofold + // CHECK-PAD-COUNT=3: scf.for + // CHECK-PAD: linalg.matmul + %0 = linalg.matmul ins(%arg0, %arg1: tensor<72x72xf32>, tensor<72x72xf32>) outs(%arg2: tensor<72x72xf32>) -> tensor<72x72xf32> + return %0 : tensor<72x72xf32> +} + diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp @@ -54,6 +54,7 @@ void runStrategy(LinalgTilingOptions tilingOptions, LinalgTilingOptions registerTilingOptions, + LinalgPaddingOptions paddingOptions, vector::VectorContractLowering vectorContractLowering, vector::VectorTransferSplit vectorTransferSplit); @@ -86,6 +87,16 @@ *this, "register-promote-full-tile-pad", llvm::cl::desc("Pad the small aligned memory buffer to the tile sizes."), llvm::cl::init(false)}; + Option pad{*this, "pad", llvm::cl::desc("Pad the operands."), + llvm::cl::init(false)}; + ListOption packPaddings{ + *this, "pack-paddings", + llvm::cl::desc("Operand packing flags when test-pad-pattern"), + llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; + ListOption hoistPaddings{ + *this, "hoist-paddings", + llvm::cl::desc("Operand hoisting depths when test-pad-pattern"), + llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; Option generalize{*this, "generalize", llvm::cl::desc("Generalize named operations."), llvm::cl::init(false)}; @@ -132,9 +143,18 @@ llvm::cl::init("")}; }; +// For now, just assume it is the zero of type. +// In the future, it should be the zero of type + op. +static Value getNeutralOfLinalgOp(OpBuilder &b, OpOperand &op) { + auto t = getElementTypeOrSelf(op.get()); + return b.create(op.getOwner()->getLoc(), t, + b.getZeroAttr(t)); +} + void TestLinalgCodegenStrategy::runStrategy( LinalgTilingOptions tilingOptions, LinalgTilingOptions registerTilingOptions, + LinalgPaddingOptions paddingOptions, vector::VectorContractLowering vectorContractLowering, vector::VectorTransferSplit vectorTransferSplit) { assert(!anchorOpName.empty()); @@ -150,6 +170,7 @@ LinalgPromotionOptions() .setAlignment(16) .setUseFullTileBuffersByDefault(registerPromoteFullTile)) + .padIf(pad, anchorOpName, paddingOptions) .generalizeIf(generalize, anchorOpName) .interchangeIf(!iteratorInterchange.empty(), iteratorInterchange) .vectorizeIf(vectorize, generalize ? genericOpName : anchorOpName) @@ -191,6 +212,21 @@ registerTilingOptions = registerTilingOptions.setTileSizes(registerTileSizes); + LinalgPaddingOptions paddingOptions; + auto packFunc = [&](OpOperand &opOperand) { + return opOperand.getOperandNumber() < packPaddings.size() + ? packPaddings[opOperand.getOperandNumber()] + : false; + }; + auto hoistingFunc = [&](OpOperand &opOperand) { + return opOperand.getOperandNumber() < hoistPaddings.size() + ? hoistPaddings[opOperand.getOperandNumber()] + : 0; + }; + paddingOptions.setPaddingValueComputationFunction(getNeutralOfLinalgOp); + paddingOptions.setPaddingNoFoldComputationFunction(packFunc); + paddingOptions.setPaddingHoistComputationFunction(hoistingFunc); + vector::VectorContractLowering vectorContractLowering = llvm::StringSwitch( vectorizeContractionTo.getValue()) @@ -206,8 +242,8 @@ .Case("vector-transfers", vector::VectorTransferSplit::VectorTransfer) .Default(vector::VectorTransferSplit::None); - runStrategy(tilingOptions, registerTilingOptions, vectorContractLowering, - vectorTransferSplit); + runStrategy(tilingOptions, registerTilingOptions, paddingOptions, + vectorContractLowering, vectorTransferSplit); } namespace mlir {