diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.h @@ -0,0 +1,31 @@ +//===- SPIRVGLSLCanonicalization.h - GLSL-specific patterns -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares a function to register SPIR-V GLSL-specific +// canonicalization patterns. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SPIRV_IR_SPIRVGLSLCANONICALIZATION_H_ +#define MLIR_DIALECT_SPIRV_IR_SPIRVGLSLCANONICALIZATION_H_ + +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" + +//===----------------------------------------------------------------------===// +// GLSL canonicalization patterns +//===----------------------------------------------------------------------===// + +namespace mlir { +namespace spirv { +void populateSPIRVGLSLCanonicalizationPatterns( + mlir::OwningRewritePatternList &results, mlir::MLIRContext *context); +} // namespace spirv +} // namespace mlir + +#endif // MLIR_DIALECT_SPIRV_IR_SPIRVGLSLCANONICALIZATION_H_ diff --git a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt --- a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt @@ -5,6 +5,7 @@ add_mlir_dialect_library(MLIRSPIRV SPIRVAttributes.cpp SPIRVCanonicalization.cpp + SPIRVGLSLCanonicalization.cpp SPIRVDialect.cpp SPIRVEnums.cpp SPIRVOps.cpp diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.td b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.td --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.td +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.td @@ -38,3 +38,33 @@ def ConvertLogicalNotOfLogicalNotEqual : Pat< (SPV_LogicalNotOp (SPV_LogicalNotEqualOp $lhs, $rhs)), (SPV_LogicalEqualOp $lhs, $rhs)>; + +//===----------------------------------------------------------------------===// +// Re-write spv.Select + spv. to a suitable variant of +// spv. +//===----------------------------------------------------------------------===// + +def ValuesAreEqual : Constraint>; + +foreach CmpClampPair = [ + [SPV_FOrdLessThanOp, SPV_GLSLFClampOp], + [SPV_FOrdLessThanEqualOp, SPV_GLSLFClampOp], + [SPV_SLessThanOp, SPV_GLSLSClampOp], + [SPV_SLessThanEqualOp, SPV_GLSLSClampOp], + [SPV_ULessThanOp, SPV_GLSLUClampOp], + [SPV_ULessThanEqualOp, SPV_GLSLUClampOp]] in { +def ConvertComparisonIntoClamp#CmpClampPair[0] : Pat< + (SPV_SelectOp + (CmpClampPair[0] + (SPV_SelectOp:$middle0 + (CmpClampPair[0] $min, $input), + $input, + $min + ), + $max + ), + $middle1, + $max), + (CmpClampPair[1] $input, $min, $max), + [(ValuesAreEqual $middle0, $middle1)]>; +} diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.cpp @@ -0,0 +1,35 @@ +//===- SPIRVGLSLCanonicalization.cpp - SPIR-V GLSL canonicalization patterns =// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the canonicalization patterns for SPIR-V GLSL-specific ops. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.h" + +#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" + +using namespace mlir; + +namespace { +#include "SPIRVCanonicalization.inc" +} // end anonymous namespace + +namespace mlir { +namespace spirv { +void populateSPIRVGLSLCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} +} // namespace spirv +} // namespace mlir diff --git a/mlir/test/Dialect/SPIRV/Transforms/glsl_canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/glsl_canonicalize.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SPIRV/Transforms/glsl_canonicalize.mlir @@ -0,0 +1,113 @@ +// RUN: mlir-opt -test-spirv-glsl-canonicalization -split-input-file -verify-diagnostics %s | FileCheck %s + +// CHECK: func @clamp_fordlessthan(%[[INPUT:.*]]: f32) +func @clamp_fordlessthan(%input: f32) -> f32 { + // CHECK: %[[MIN:.*]] = spv.constant + %min = spv.constant 0.5 : f32 + // CHECK: %[[MAX:.*]] = spv.constant + %max = spv.constant 1.0 : f32 + + // CHECK: [[RES:%.*]] = spv.GLSL.FClamp %[[INPUT]], %[[MIN]], %[[MAX]] + %0 = spv.FOrdLessThan %min, %input : f32 + %mid = spv.Select %0, %input, %min : i1, f32 + %1 = spv.FOrdLessThan %mid, %max : f32 + %2 = spv.Select %1, %mid, %max : i1, f32 + + // CHECK-NEXT: spv.ReturnValue [[RES]] + spv.ReturnValue %2 : f32 +} + +// ----- + +// CHECK: func @clamp_fordlessthanequal(%[[INPUT:.*]]: f32) +func @clamp_fordlessthanequal(%input: f32) -> f32 { + // CHECK: %[[MIN:.*]] = spv.constant + %min = spv.constant 0.5 : f32 + // CHECK: %[[MAX:.*]] = spv.constant + %max = spv.constant 1.0 : f32 + + // CHECK: [[RES:%.*]] = spv.GLSL.FClamp %[[INPUT]], %[[MIN]], %[[MAX]] + %0 = spv.FOrdLessThanEqual %min, %input : f32 + %mid = spv.Select %0, %input, %min : i1, f32 + %1 = spv.FOrdLessThanEqual %mid, %max : f32 + %2 = spv.Select %1, %mid, %max : i1, f32 + + // CHECK-NEXT: spv.ReturnValue [[RES]] + spv.ReturnValue %2 : f32 +} + +// ----- + +// CHECK: func @clamp_slessthan(%[[INPUT:.*]]: si32) +func @clamp_slessthan(%input: si32) -> si32 { + // CHECK: %[[MIN:.*]] = spv.constant + %min = spv.constant 0 : si32 + // CHECK: %[[MAX:.*]] = spv.constant + %max = spv.constant 10 : si32 + + // CHECK: [[RES:%.*]] = spv.GLSL.SClamp %[[INPUT]], %[[MIN]], %[[MAX]] + %0 = spv.SLessThan %min, %input : si32 + %mid = spv.Select %0, %input, %min : i1, si32 + %1 = spv.SLessThan %mid, %max : si32 + %2 = spv.Select %1, %mid, %max : i1, si32 + + // CHECK-NEXT: spv.ReturnValue [[RES]] + spv.ReturnValue %2 : si32 +} + +// ----- + +// CHECK: func @clamp_slessthanequal(%[[INPUT:.*]]: si32) +func @clamp_slessthanequal(%input: si32) -> si32 { + // CHECK: %[[MIN:.*]] = spv.constant + %min = spv.constant 0 : si32 + // CHECK: %[[MAX:.*]] = spv.constant + %max = spv.constant 10 : si32 + + // CHECK: [[RES:%.*]] = spv.GLSL.SClamp %[[INPUT]], %[[MIN]], %[[MAX]] + %0 = spv.SLessThanEqual %min, %input : si32 + %mid = spv.Select %0, %input, %min : i1, si32 + %1 = spv.SLessThanEqual %mid, %max : si32 + %2 = spv.Select %1, %mid, %max : i1, si32 + + // CHECK-NEXT: spv.ReturnValue [[RES]] + spv.ReturnValue %2 : si32 +} + +// ----- + +// CHECK: func @clamp_ulessthan(%[[INPUT:.*]]: i32) +func @clamp_ulessthan(%input: i32) -> i32 { + // CHECK: %[[MIN:.*]] = spv.constant + %min = spv.constant 0 : i32 + // CHECK: %[[MAX:.*]] = spv.constant + %max = spv.constant 10 : i32 + + // CHECK: [[RES:%.*]] = spv.GLSL.UClamp %[[INPUT]], %[[MIN]], %[[MAX]] + %0 = spv.ULessThan %min, %input : i32 + %mid = spv.Select %0, %input, %min : i1, i32 + %1 = spv.ULessThan %mid, %max : i32 + %2 = spv.Select %1, %mid, %max : i1, i32 + + // CHECK-NEXT: spv.ReturnValue [[RES]] + spv.ReturnValue %2 : i32 +} + +// ----- + +// CHECK: func @clamp_ulessthanequal(%[[INPUT:.*]]: i32) +func @clamp_ulessthanequal(%input: i32) -> i32 { + // CHECK: %[[MIN:.*]] = spv.constant + %min = spv.constant 0 : i32 + // CHECK: %[[MAX:.*]] = spv.constant + %max = spv.constant 10 : i32 + + // CHECK: [[RES:%.*]] = spv.GLSL.UClamp %[[INPUT]], %[[MIN]], %[[MAX]] + %0 = spv.ULessThanEqual %min, %input : i32 + %mid = spv.Select %0, %input, %min : i1, i32 + %1 = spv.ULessThanEqual %mid, %max : i32 + %2 = spv.Select %1, %mid, %max : i1, i32 + + // CHECK-NEXT: spv.ReturnValue [[RES]] + spv.ReturnValue %2 : i32 +} diff --git a/mlir/test/lib/Dialect/SPIRV/CMakeLists.txt b/mlir/test/lib/Dialect/SPIRV/CMakeLists.txt --- a/mlir/test/lib/Dialect/SPIRV/CMakeLists.txt +++ b/mlir/test/lib/Dialect/SPIRV/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_library(MLIRSPIRVTestPasses TestAvailability.cpp TestEntryPointAbi.cpp + TestGLSLCanonicalization.cpp TestModuleCombiner.cpp EXCLUDE_FROM_LIBMLIR diff --git a/mlir/test/lib/Dialect/SPIRV/TestGLSLCanonicalization.cpp b/mlir/test/lib/Dialect/SPIRV/TestGLSLCanonicalization.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/SPIRV/TestGLSLCanonicalization.cpp @@ -0,0 +1,39 @@ +//===- TestGLSLCanonicalization.cpp - Pass to test GLSL-specific pattterns ===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVModule.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +namespace { +class TestGLSLCanonicalizationPass + : public PassWrapper> { +public: + TestGLSLCanonicalizationPass() = default; + TestGLSLCanonicalizationPass(const TestGLSLCanonicalizationPass &) {} + void runOnOperation() override; +}; +} // namespace + +void TestGLSLCanonicalizationPass::runOnOperation() { + OwningRewritePatternList patterns; + spirv::populateSPIRVGLSLCanonicalizationPatterns(patterns, &getContext()); + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); +} + +namespace mlir { +void registerTestSpirvGLSLCanonicalizationPass() { + PassRegistration registration( + "test-spirv-glsl-canonicalization", + "Tests SPIR-V canonicalization patterns for GLSL extension."); +} +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -47,6 +47,7 @@ void registerTestPrintNestingPass(); void registerTestReducer(); void registerTestSpirvEntryPointABIPass(); +void registerTestSpirvGLSLCanonicalizationPass(); void registerTestSpirvModuleCombinerPass(); void registerTestTraitsPass(); void registerTosaTestQuantUtilAPIPass(); @@ -115,6 +116,7 @@ registerTestPrintNestingPass(); registerTestReducer(); registerTestSpirvEntryPointABIPass(); + registerTestSpirvGLSLCanonicalizationPass(); registerTestSpirvModuleCombinerPass(); registerTestTraitsPass(); registerVectorizerTestPass();