diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.h --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.h +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.h @@ -23,8 +23,11 @@ namespace mlir { namespace spirv { -void populateSPIRVGLSLCanonicalizationPatterns( - mlir::RewritePatternSet &results); +/// Populates patterns to run canoncalization that involves GLSL ops. +/// +/// These patterns cannot be run in default canonicalization because GLSL ops +/// aren't always available. So they should be involed specifically when needed. +void populateSPIRVGLSLCanonicalizationPatterns(RewritePatternSet &results); } // namespace spirv } // namespace mlir diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.h @@ -24,6 +24,11 @@ // Passes //===----------------------------------------------------------------------===// +/// Creates a pass to run canoncalization patterns that involve GLSL ops. +/// These patterns cannot be run in default canonicalization because GLSL ops +/// aren't always available. So they should be involed specifically when needed. +std::unique_ptr> createCanonicalizeGLSLPass(); + /// Creates a module pass that converts composite types used by objects in the /// StorageBuffer, PhysicalStorageBuffer, Uniform, and PushConstant storage /// classes with layout information. diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td b/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td @@ -17,6 +17,11 @@ let constructor = "mlir::spirv::createDecorateSPIRVCompositeTypeLayoutPass()"; } +def SPIRVCanonicalizeGLSL : Pass<"spirv-canonicalize-glsl", ""> { + let summary = "Run canonicalization involving GLSL ops"; + let constructor = "mlir::spirv::createCanonicalizeGLSLPass()"; +} + def SPIRVLowerABIAttributes : Pass<"spirv-lower-abi-attrs", "spirv::ModuleOp"> { let summary = "Decorate SPIR-V composite type with layout info"; let constructor = "mlir::spirv::createLowerABIAttributesPass()"; diff --git a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt @@ -1,4 +1,5 @@ set(LLVM_OPTIONAL_SOURCES + CanonicalizeGLSLPass.cpp DecorateCompositeTypeLayoutPass.cpp LowerABIAttributesPass.cpp RewriteInsertsPass.cpp @@ -19,6 +20,7 @@ ) add_mlir_dialect_library(MLIRSPIRVTransforms + CanonicalizeGLSLPass.cpp DecorateCompositeTypeLayoutPass.cpp LowerABIAttributesPass.cpp RewriteInsertsPass.cpp diff --git a/mlir/lib/Dialect/SPIRV/Transforms/CanonicalizeGLSLPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/CanonicalizeGLSLPass.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SPIRV/Transforms/CanonicalizeGLSLPass.cpp @@ -0,0 +1,34 @@ +//===- CanonicalizeGLSLPass.cpp - GLSL Related Canonicalization Pass ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" +#include "mlir/Dialect/SPIRV/Transforms/Passes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +namespace { +class CanonicalizeGLSLPass final + : public SPIRVCanonicalizeGLSLBase { +public: + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + spirv::populateSPIRVGLSLCanonicalizationPatterns(patterns); + if (failed( + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + return signalPassFailure(); + } +}; +} // namespace + +std::unique_ptr> spirv::createCanonicalizeGLSLPass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/SPIRV/Transforms/glsl_canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/glsl-canonicalize.mlir rename from mlir/test/Dialect/SPIRV/Transforms/glsl_canonicalize.mlir rename to mlir/test/Dialect/SPIRV/Transforms/glsl-canonicalize.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/glsl_canonicalize.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/glsl-canonicalize.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -test-spirv-glsl-canonicalization -split-input-file -verify-diagnostics %s | FileCheck %s +// RUN: mlir-opt -split-input-file -spirv-canonicalize-glsl %s | FileCheck %s // CHECK: func @clamp_fordlessthan(%[[INPUT:.*]]: f32) func @clamp_fordlessthan(%input: f32) -> f32 { diff --git a/mlir/test/lib/Dialect/SPIRV/CMakeLists.txt b/mlir/test/lib/Dialect/SPIRV/CMakeLists.txt --- a/mlir/test/lib/Dialect/SPIRV/CMakeLists.txt +++ b/mlir/test/lib/Dialect/SPIRV/CMakeLists.txt @@ -2,7 +2,6 @@ add_mlir_library(MLIRSPIRVTestPasses TestAvailability.cpp TestEntryPointAbi.cpp - TestGLSLCanonicalization.cpp TestModuleCombiner.cpp EXCLUDE_FROM_LIBMLIR diff --git a/mlir/test/lib/Dialect/SPIRV/TestGLSLCanonicalization.cpp b/mlir/test/lib/Dialect/SPIRV/TestGLSLCanonicalization.cpp deleted file mode 100644 --- a/mlir/test/lib/Dialect/SPIRV/TestGLSLCanonicalization.cpp +++ /dev/null @@ -1,42 +0,0 @@ -//===- TestGLSLCanonicalization.cpp - Pass to test GLSL-specific pattterns ===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -using namespace mlir; - -namespace { -class TestGLSLCanonicalizationPass - : public PassWrapper> { -public: - TestGLSLCanonicalizationPass() = default; - TestGLSLCanonicalizationPass(const TestGLSLCanonicalizationPass &) {} - void runOnOperation() override; - StringRef getArgument() const final { - return "test-spirv-glsl-canonicalization"; - } - StringRef getDescription() const final { - return "Tests SPIR-V canonicalization patterns for GLSL extension."; - } -}; -} // namespace - -void TestGLSLCanonicalizationPass::runOnOperation() { - RewritePatternSet patterns(&getContext()); - spirv::populateSPIRVGLSLCanonicalizationPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); -} - -namespace mlir { -void registerTestSpirvGLSLCanonicalizationPass() { - PassRegistration(); -} -} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -49,7 +49,6 @@ void registerTestPrintNestingPass(); void registerTestReducer(); void registerTestSpirvEntryPointABIPass(); -void registerTestSpirvGLSLCanonicalizationPass(); void registerTestSpirvModuleCombinerPass(); void registerTestTraitsPass(); void registerTosaTestQuantUtilAPIPass(); @@ -137,7 +136,6 @@ registerTestPrintNestingPass(); registerTestReducer(); registerTestSpirvEntryPointABIPass(); - registerTestSpirvGLSLCanonicalizationPass(); registerTestSpirvModuleCombinerPass(); registerTestTraitsPass(); registerVectorizerTestPass();