diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.td b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.td --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.td +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.td @@ -40,8 +40,7 @@ (SPV_LogicalEqualOp $lhs, $rhs)>; //===----------------------------------------------------------------------===// -// Re-write spv.Select + spv. to a suitable variant of -// spv. +// spv.Select -> spv.GLSL.*Clamp //===----------------------------------------------------------------------===// def ValuesAreEqual : Constraint>; @@ -53,7 +52,9 @@ [SPV_SLessThanEqualOp, SPV_GLSLSClampOp], [SPV_ULessThanOp, SPV_GLSLUClampOp], [SPV_ULessThanEqualOp, SPV_GLSLUClampOp]] in { -def ConvertComparisonIntoClamp#CmpClampPair[0] : Pat< + +// Detect: $min < $input, $input < $max +def ConvertComparisonIntoClamp1_#CmpClampPair[0] : Pat< (SPV_SelectOp (CmpClampPair[0] (SPV_SelectOp:$middle0 @@ -67,4 +68,16 @@ $max), (CmpClampPair[1] $input, $min, $max), [(ValuesAreEqual $middle0, $middle1)]>; + +// Detect: $input < $min, $max < $input +def ConvertComparisonIntoClamp2_#CmpClampPair[0] : Pat< + (SPV_SelectOp + (CmpClampPair[0] $max, $input), + $max, + (SPV_SelectOp + (CmpClampPair[0] $input, $min), + $min, + $input + )), + (CmpClampPair[1] $input, $min, $max)>; } diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.cpp @@ -23,12 +23,18 @@ namespace mlir { namespace spirv { void populateSPIRVGLSLCanonicalizationPatterns(RewritePatternSet &results) { - results.add( + results.add( results.getContext()); } } // namespace spirv diff --git a/mlir/test/Dialect/SPIRV/Transforms/glsl-canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/glsl-canonicalize.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/glsl-canonicalize.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/glsl-canonicalize.mlir @@ -1,12 +1,8 @@ // RUN: mlir-opt -split-input-file -spirv-canonicalize-glsl %s | FileCheck %s -// CHECK: func @clamp_fordlessthan(%[[INPUT:.*]]: f32) -func @clamp_fordlessthan(%input: f32) -> f32 { - // CHECK: %[[MIN:.*]] = spv.Constant - %min = spv.Constant 0.5 : f32 - // CHECK: %[[MAX:.*]] = spv.Constant - %max = spv.Constant 1.0 : f32 - +// CHECK-LABEL: func @clamp_fordlessthan +// CHECK-SAME: (%[[INPUT:.*]]: f32, %[[MIN:.*]]: f32, %[[MAX:.*]]: f32) +func @clamp_fordlessthan(%input: f32, %min: f32, %max: f32) -> f32 { // CHECK: [[RES:%.*]] = spv.GLSL.FClamp %[[INPUT]], %[[MIN]], %[[MAX]] %0 = spv.FOrdLessThan %min, %input : f32 %mid = spv.Select %0, %input, %min : i1, f32 @@ -19,13 +15,24 @@ // ----- -// CHECK: func @clamp_fordlessthanequal(%[[INPUT:.*]]: f32) -func @clamp_fordlessthanequal(%input: f32) -> f32 { - // CHECK: %[[MIN:.*]] = spv.Constant - %min = spv.Constant 0.5 : f32 - // CHECK: %[[MAX:.*]] = spv.Constant - %max = spv.Constant 1.0 : f32 +// CHECK-LABEL: func @clamp_fordlessthan +// CHECK-SAME: (%[[INPUT:.*]]: f32, %[[MIN:.*]]: f32, %[[MAX:.*]]: f32) +func @clamp_fordlessthan(%input: f32, %min: f32, %max: f32) -> f32 { + // CHECK: [[RES:%.*]] = spv.GLSL.FClamp %[[INPUT]], %[[MIN]], %[[MAX]] + %0 = spv.FOrdLessThan %input, %min : f32 + %mid = spv.Select %0, %min, %input : i1, f32 + %1 = spv.FOrdLessThan %max, %input : f32 + %2 = spv.Select %1, %max, %mid : i1, f32 + + // CHECK-NEXT: spv.ReturnValue [[RES]] + spv.ReturnValue %2 : f32 +} +// ----- + +// CHECK-LABEL: func @clamp_fordlessthanequal +// CHECK-SAME: (%[[INPUT:.*]]: f32, %[[MIN:.*]]: f32, %[[MAX:.*]]: f32) +func @clamp_fordlessthanequal(%input: f32, %min: f32, %max: f32) -> f32 { // CHECK: [[RES:%.*]] = spv.GLSL.FClamp %[[INPUT]], %[[MIN]], %[[MAX]] %0 = spv.FOrdLessThanEqual %min, %input : f32 %mid = spv.Select %0, %input, %min : i1, f32 @@ -38,13 +45,24 @@ // ----- -// CHECK: func @clamp_slessthan(%[[INPUT:.*]]: si32) -func @clamp_slessthan(%input: si32) -> si32 { - // CHECK: %[[MIN:.*]] = spv.Constant - %min = spv.Constant 0 : si32 - // CHECK: %[[MAX:.*]] = spv.Constant - %max = spv.Constant 10 : si32 +// CHECK-LABEL: func @clamp_fordlessthanequal +// CHECK-SAME: (%[[INPUT:.*]]: f32, %[[MIN:.*]]: f32, %[[MAX:.*]]: f32) +func @clamp_fordlessthanequal(%input: f32, %min: f32, %max: f32) -> f32 { + // CHECK: [[RES:%.*]] = spv.GLSL.FClamp %[[INPUT]], %[[MIN]], %[[MAX]] + %0 = spv.FOrdLessThanEqual %input, %min : f32 + %mid = spv.Select %0, %min, %input : i1, f32 + %1 = spv.FOrdLessThanEqual %max, %input : f32 + %2 = spv.Select %1, %max, %mid : i1, f32 + + // CHECK-NEXT: spv.ReturnValue [[RES]] + spv.ReturnValue %2 : f32 +} + +// ----- +// CHECK-LABEL: func @clamp_slessthan +// CHECK-SAME: (%[[INPUT:.*]]: si32, %[[MIN:.*]]: si32, %[[MAX:.*]]: si32) +func @clamp_slessthan(%input: si32, %min: si32, %max: si32) -> si32 { // CHECK: [[RES:%.*]] = spv.GLSL.SClamp %[[INPUT]], %[[MIN]], %[[MAX]] %0 = spv.SLessThan %min, %input : si32 %mid = spv.Select %0, %input, %min : i1, si32 @@ -57,13 +75,24 @@ // ----- -// CHECK: func @clamp_slessthanequal(%[[INPUT:.*]]: si32) -func @clamp_slessthanequal(%input: si32) -> si32 { - // CHECK: %[[MIN:.*]] = spv.Constant - %min = spv.Constant 0 : si32 - // CHECK: %[[MAX:.*]] = spv.Constant - %max = spv.Constant 10 : si32 +// CHECK-LABEL: func @clamp_slessthan +// CHECK-SAME: (%[[INPUT:.*]]: si32, %[[MIN:.*]]: si32, %[[MAX:.*]]: si32) +func @clamp_slessthan(%input: si32, %min: si32, %max: si32) -> si32 { + // CHECK: [[RES:%.*]] = spv.GLSL.SClamp %[[INPUT]], %[[MIN]], %[[MAX]] + %0 = spv.SLessThan %input, %min : si32 + %mid = spv.Select %0, %min, %input : i1, si32 + %1 = spv.SLessThan %max, %input : si32 + %2 = spv.Select %1, %max, %mid : i1, si32 + + // CHECK-NEXT: spv.ReturnValue [[RES]] + spv.ReturnValue %2 : si32 +} + +// ----- +// CHECK-LABEL: func @clamp_slessthanequal +// CHECK-SAME: (%[[INPUT:.*]]: si32, %[[MIN:.*]]: si32, %[[MAX:.*]]: si32) +func @clamp_slessthanequal(%input: si32, %min: si32, %max: si32) -> si32 { // CHECK: [[RES:%.*]] = spv.GLSL.SClamp %[[INPUT]], %[[MIN]], %[[MAX]] %0 = spv.SLessThanEqual %min, %input : si32 %mid = spv.Select %0, %input, %min : i1, si32 @@ -76,13 +105,24 @@ // ----- -// CHECK: func @clamp_ulessthan(%[[INPUT:.*]]: i32) -func @clamp_ulessthan(%input: i32) -> i32 { - // CHECK: %[[MIN:.*]] = spv.Constant - %min = spv.Constant 0 : i32 - // CHECK: %[[MAX:.*]] = spv.Constant - %max = spv.Constant 10 : i32 +// CHECK-LABEL: func @clamp_slessthanequal +// CHECK-SAME: (%[[INPUT:.*]]: si32, %[[MIN:.*]]: si32, %[[MAX:.*]]: si32) +func @clamp_slessthanequal(%input: si32, %min: si32, %max: si32) -> si32 { + // CHECK: [[RES:%.*]] = spv.GLSL.SClamp %[[INPUT]], %[[MIN]], %[[MAX]] + %0 = spv.SLessThanEqual %input, %min : si32 + %mid = spv.Select %0, %min, %input : i1, si32 + %1 = spv.SLessThanEqual %max, %input : si32 + %2 = spv.Select %1, %max, %mid : i1, si32 + + // CHECK-NEXT: spv.ReturnValue [[RES]] + spv.ReturnValue %2 : si32 +} +// ----- + +// CHECK-LABEL: func @clamp_ulessthan +// CHECK-SAME: (%[[INPUT:.*]]: i32, %[[MIN:.*]]: i32, %[[MAX:.*]]: i32) +func @clamp_ulessthan(%input: i32, %min: i32, %max: i32) -> i32 { // CHECK: [[RES:%.*]] = spv.GLSL.UClamp %[[INPUT]], %[[MIN]], %[[MAX]] %0 = spv.ULessThan %min, %input : i32 %mid = spv.Select %0, %input, %min : i1, i32 @@ -95,13 +135,24 @@ // ----- -// CHECK: func @clamp_ulessthanequal(%[[INPUT:.*]]: i32) -func @clamp_ulessthanequal(%input: i32) -> i32 { - // CHECK: %[[MIN:.*]] = spv.Constant - %min = spv.Constant 0 : i32 - // CHECK: %[[MAX:.*]] = spv.Constant - %max = spv.Constant 10 : i32 +// CHECK-LABEL: func @clamp_ulessthan +// CHECK-SAME: (%[[INPUT:.*]]: i32, %[[MIN:.*]]: i32, %[[MAX:.*]]: i32) +func @clamp_ulessthan(%input: i32, %min: i32, %max: i32) -> i32 { + // CHECK: [[RES:%.*]] = spv.GLSL.UClamp %[[INPUT]], %[[MIN]], %[[MAX]] + %0 = spv.ULessThan %input, %min : i32 + %mid = spv.Select %0, %min, %input : i1, i32 + %1 = spv.ULessThan %max, %input : i32 + %2 = spv.Select %1, %max, %mid : i1, i32 + + // CHECK-NEXT: spv.ReturnValue [[RES]] + spv.ReturnValue %2 : i32 +} + +// ----- +// CHECK-LABEL: func @clamp_ulessthanequal +// CHECK-SAME: (%[[INPUT:.*]]: i32, %[[MIN:.*]]: i32, %[[MAX:.*]]: i32) +func @clamp_ulessthanequal(%input: i32, %min: i32, %max: i32) -> i32 { // CHECK: [[RES:%.*]] = spv.GLSL.UClamp %[[INPUT]], %[[MIN]], %[[MAX]] %0 = spv.ULessThanEqual %min, %input : i32 %mid = spv.Select %0, %input, %min : i1, i32 @@ -111,3 +162,18 @@ // CHECK-NEXT: spv.ReturnValue [[RES]] spv.ReturnValue %2 : i32 } + +// ----- + +// CHECK-LABEL: func @clamp_ulessthanequal +// CHECK-SAME: (%[[INPUT:.*]]: i32, %[[MIN:.*]]: i32, %[[MAX:.*]]: i32) +func @clamp_ulessthanequal(%input: i32, %min: i32, %max: i32) -> i32 { + // CHECK: [[RES:%.*]] = spv.GLSL.UClamp %[[INPUT]], %[[MIN]], %[[MAX]] + %0 = spv.ULessThanEqual %input, %min : i32 + %mid = spv.Select %0, %min, %input : i1, i32 + %1 = spv.ULessThanEqual %max, %input : i32 + %2 = spv.Select %1, %max, %mid : i1, i32 + + // CHECK-NEXT: spv.ReturnValue [[RES]] + spv.ReturnValue %2 : i32 +}