diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -48,7 +48,8 @@ /// Populate patterns for splitting a `LinalgOp` with multiple statements within /// its payload into multiple `GenericOp` that have a single statement. -void populateDecomposeLinalgOpsPattern(RewritePatternSet &patterns); +void populateDecomposeLinalgOpsPattern(RewritePatternSet &patterns, + bool canonicalize = true); /// Populate patterns for vectorizing low-D convolution ops. This is a step in /// progressive lowering for convolution ops, it assume high-D convolution ops diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp @@ -376,6 +376,8 @@ } void mlir::linalg::populateDecomposeLinalgOpsPattern( - RewritePatternSet &patterns) { + RewritePatternSet &patterns, bool canonicalize) { patterns.insert(patterns.getContext()); + if (canonicalize) + GenericOp::getCanonicalizationPatterns(patterns, patterns.getContext()); } diff --git a/mlir/test/Dialect/Linalg/decompose-ops.mlir b/mlir/test/Dialect/Linalg/decompose-ops.mlir --- a/mlir/test/Dialect/Linalg/decompose-ops.mlir +++ b/mlir/test/Dialect/Linalg/decompose-ops.mlir @@ -1,5 +1,5 @@ -// RUN: mlir-opt -test-linalg-decompose-ops -cse -split-input-file %s | FileCheck %s -// RUN: mlir-opt -test-linalg-decompose-ops -cse -canonicalize -split-input-file %s | FileCheck %s --check-prefix=CANONICALIZECHECK +// RUN: mlir-opt -pass-pipeline="test-linalg-decompose-ops{canonicalize-decomposed-ops=false}, cse" -split-input-file %s | FileCheck %s +// RUN: mlir-opt -pass-pipeline="test-linalg-decompose-ops, cse" -test-linalg-decompose-ops -cse -canonicalize -split-input-file %s | FileCheck %s --check-prefix=CANONICALIZECHECK func.func @simple_op(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> (tensor, tensor) { diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgDecomposeOps.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgDecomposeOps.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgDecomposeOps.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgDecomposeOps.cpp @@ -33,10 +33,16 @@ return "Test Linalg decomposition patterns"; } + Option canonicalizeDecomposedOps{ + *this, "canonicalize-decomposed-ops", + llvm::cl::desc("Canonicalize the decomposed ops."), + llvm::cl::init(false)}; + void runOnOperation() override { MLIRContext *context = &this->getContext(); RewritePatternSet decompositionPatterns(context); - linalg::populateDecomposeLinalgOpsPattern(decompositionPatterns); + linalg::populateDecomposeLinalgOpsPattern(decompositionPatterns, + canonicalizeDecomposedOps); if (failed(applyPatternsAndFoldGreedily( getOperation(), std::move(decompositionPatterns)))) { return signalPassFailure();