diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h @@ -38,9 +38,8 @@ void sfinae_enqueue(OwningRewritePatternList &patterList, OptionsType options, MLIRContext *context, StringRef opName, linalg::LinalgTransformationFilter m) { - assert(opName.empty() || - opName == ConcreteOpType::getOperationName() && - "explicit name must match ConcreteOpType::getOperationName"); + assert(opName == ConcreteOpType::getOperationName() && + "explicit name must match ConcreteOpType::getOperationName"); patterList.insert>(context, options, m); } @@ -61,7 +60,8 @@ struct Tile : public Transformation { explicit Tile(linalg::LinalgTilingOptions options, linalg::LinalgTransformationFilter::FilterFunction f = nullptr) - : Transformation(f), opName(""), options(options) {} + : Transformation(f), opName(LinalgOpType::getOperationName()), + options(options) {} Tile(StringRef name, linalg::LinalgTilingOptions options, linalg::LinalgTransformationFilter::FilterFunction f = nullptr) @@ -88,7 +88,8 @@ explicit Promote( linalg::LinalgPromotionOptions options, linalg::LinalgTransformationFilter::FilterFunction f = nullptr) - : Transformation(f), opName(""), options(options) {} + : Transformation(f), opName(LinalgOpType::getOperationName()), + options(options) {} Promote(StringRef name, linalg::LinalgPromotionOptions options, linalg::LinalgTransformationFilter::FilterFunction f = nullptr) @@ -116,7 +117,8 @@ explicit Vectorize( linalg::LinalgVectorizationOptions options, linalg::LinalgTransformationFilter::FilterFunction f = nullptr) - : Transformation(f), opName(""), options(options) {} + : Transformation(f), opName(LinalgOpType::getOperationName()), + options(options) {} Vectorize(StringRef name, linalg::LinalgVectorizationOptions options, linalg::LinalgTransformationFilter::FilterFunction f = nullptr) diff --git a/mlir/test/Dialect/Linalg/codegen-strategy.mlir b/mlir/test/Dialect/Linalg/codegen-strategy.mlir --- a/mlir/test/Dialect/Linalg/codegen-strategy.mlir +++ b/mlir/test/Dialect/Linalg/codegen-strategy.mlir @@ -1,3 +1,6 @@ +// Test that both anchor-op name and MatmulOp-based codegen strategy produce the same result. +// RUN: mlir-opt %s -test-linalg-codegen-strategy="tile-sizes=2,4,8 vectorize vectorize-contraction-to=matrixintrinsics unroll-vector-transfers=true" | FileCheck %s +// RUN: mlir-opt %s -test-linalg-codegen-strategy="tile-sizes=16,32,64 promote promote-full-tile-pad register-tile-sizes=2,4,8 vectorize vectorize-contraction-to=outerproduct split-transfers=true unroll-vector-transfers=false" | FileCheck %s --check-prefix=OUTER // RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.matmul tile-sizes=2,4,8 vectorize vectorize-contraction-to=matrixintrinsics unroll-vector-transfers=true" | FileCheck %s // RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.matmul tile-sizes=16,32,64 promote promote-full-tile-pad register-tile-sizes=2,4,8 vectorize vectorize-contraction-to=outerproduct split-transfers=true unroll-vector-transfers=false" | FileCheck %s --check-prefix=OUTER diff --git a/mlir/test/lib/Transforms/TestLinalgCodegenStrategy.cpp b/mlir/test/lib/Transforms/TestLinalgCodegenStrategy.cpp --- a/mlir/test/lib/Transforms/TestLinalgCodegenStrategy.cpp +++ b/mlir/test/lib/Transforms/TestLinalgCodegenStrategy.cpp @@ -47,6 +47,12 @@ void runOnFunction() override; + template + void runStrategy(LinalgTilingOptions tilingOptions, + LinalgTilingOptions registerTilingOptions, + vector::VectorContractLowering vectorContractLowering, + vector::VectorTransferSplit vectorTransferSplit); + ListOption tileSizes{*this, "tile-sizes", llvm::cl::MiscFlags::CommaSeparated, llvm::cl::desc("Specifies the tile sizes.")}; @@ -107,12 +113,67 @@ }; } // end anonymous namespace +template <> +void TestLinalgCodegenStrategy::runStrategy( + LinalgTilingOptions tilingOptions, + LinalgTilingOptions registerTilingOptions, + vector::VectorContractLowering vectorContractLowering, + vector::VectorTransferSplit vectorTransferSplit) { + assert(!anchorOpName.empty()); + CodegenStrategy strategy; + strategy.tileIf(!tileSizes.empty(), anchorOpName, tilingOptions) + .promoteIf(promote, anchorOpName, + LinalgPromotionOptions() + .setAlignment(16) + .setUseFullTileBuffersByDefault(promoteFullTile)) + .tileIf(!registerTileSizes.empty(), anchorOpName, + registerTilingOptions) + .promoteIf( + registerPromote, anchorOpName, + LinalgPromotionOptions() + .setAlignment(16) + .setUseFullTileBuffersByDefault(registerPromoteFullTile)) + .vectorizeIf(vectorize, anchorOpName) + .setVectorTransformsOptions( + vector::VectorTransformsOptions() + .setVectorTransformsOptions(vectorContractLowering) + .setVectorTransferSplit(vectorTransferSplit)) + .setVectorTransferToSCFOptions( + VectorTransferToSCFOptions().setUnroll(unrollVectorTransfers)); + strategy.transform(getFunction()); +} + +template +void TestLinalgCodegenStrategy::runStrategy( + LinalgTilingOptions tilingOptions, + LinalgTilingOptions registerTilingOptions, + vector::VectorContractLowering vectorContractLowering, + vector::VectorTransferSplit vectorTransferSplit) { + CodegenStrategy strategy; + strategy.tileIf(!tileSizes.empty(), tilingOptions) + .template promoteIf( + promote, LinalgPromotionOptions() + .setAlignment(16) + .setUseFullTileBuffersByDefault(promoteFullTile)) + .template tileIf(!registerTileSizes.empty(), + registerTilingOptions) + .template promoteIf( + registerPromote, + LinalgPromotionOptions() + .setAlignment(16) + .setUseFullTileBuffersByDefault(registerPromoteFullTile)) + .template vectorizeIf(vectorize) + .setVectorTransformsOptions( + vector::VectorTransformsOptions() + .setVectorTransformsOptions(vectorContractLowering) + .setVectorTransferSplit(vectorTransferSplit)) + .setVectorTransferToSCFOptions( + VectorTransferToSCFOptions().setUnroll(unrollVectorTransfers)); + strategy.transform(getFunction()); +} + /// Apply transformations specified as patterns. void TestLinalgCodegenStrategy::runOnFunction() { - linalg::LinalgTransformationFilter::FilterFunction filterOpName = - [&](Operation *op) -> LogicalResult { - return success(op->getName().getStringRef() == anchorOpName); - }; LinalgTilingOptions tilingOptions; if (!tileSizes.empty()) tilingOptions = tilingOptions.setTileSizes(tileSizes); @@ -137,28 +198,14 @@ .Case("vector-transfers", vector::VectorTransferSplit::VectorTransfer) .Default(vector::VectorTransferSplit::None); - CodegenStrategy strategy; - strategy.tileIf(!tileSizes.empty(), anchorOpName, tilingOptions) - .promoteIf(promote, anchorOpName, - LinalgPromotionOptions() - .setAlignment(16) - .setUseFullTileBuffersByDefault(promoteFullTile), - filterOpName) - .tileIf(!registerTileSizes.empty(), anchorOpName, - registerTilingOptions) - .promoteIf( - registerPromote, anchorOpName, - LinalgPromotionOptions() - .setAlignment(16) - .setUseFullTileBuffersByDefault(registerPromoteFullTile)) - .vectorizeIf(vectorize, anchorOpName) - .setVectorTransformsOptions( - vector::VectorTransformsOptions() - .setVectorTransformsOptions(vectorContractLowering) - .setVectorTransferSplit(vectorTransferSplit)) - .setVectorTransferToSCFOptions( - VectorTransferToSCFOptions().setUnroll(unrollVectorTransfers)); - strategy.transform(getFunction()); + // If no anchorOpNameis specified, just test that strategy applies properly to + // linalg::MatmulOp. + if (anchorOpName.empty()) + runStrategy(tilingOptions, registerTilingOptions, + vectorContractLowering, vectorTransferSplit); + else + runStrategy(tilingOptions, registerTilingOptions, + vectorContractLowering, vectorTransferSplit); } namespace mlir {