diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td @@ -865,6 +865,8 @@ let assemblyFormat = [{ operands attr-dict `:` type($condition) `,` type($result) }]; + + let hasCanonicalizer = 1; } // ----- diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp @@ -236,6 +236,16 @@ ConvertLogicalNotOfLogicalNotEqual>(context); } +void spirv::SelectOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // spv.LogicalOr //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.td b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.td --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.td +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.td @@ -38,3 +38,33 @@ def ConvertLogicalNotOfLogicalNotEqual : Pat< (SPV_LogicalNotOp (SPV_LogicalNotEqualOp $lhs, $rhs)), (SPV_LogicalEqualOp $lhs, $rhs)>; + +//===----------------------------------------------------------------------===// +// Re-write spv.Select + spv. to a suitable variant of +// spv. +//===----------------------------------------------------------------------===// + +def ValuesAreEqual : Constraint>; + +foreach CmpClampPair = [ + [SPV_FOrdLessThanOp, SPV_GLSLFClampOp], + [SPV_FOrdLessThanEqualOp, SPV_GLSLFClampOp], + [SPV_SLessThanOp, SPV_GLSLSClampOp], + [SPV_SLessThanEqualOp, SPV_GLSLSClampOp], + [SPV_ULessThanOp, SPV_GLSLUClampOp], + [SPV_ULessThanEqualOp, SPV_GLSLUClampOp]] in { +def ConvertComparisonIntoClamp#CmpClampPair[0] : Pat< + (SPV_SelectOp + (CmpClampPair[0] + (SPV_SelectOp:$middle0 + (CmpClampPair[0] $min, $input), + $input, + $min + ), + $max + ), + $middle1, + $max), + (CmpClampPair[1] $input, $min, $max), + [(ValuesAreEqual $middle0, $middle1)]>; +} diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir @@ -711,3 +711,117 @@ } spv.Return } + +// ----- + +// CHECK: func @clamp_fordlessthan(%[[INPUT:.*]]: f32) +func @clamp_fordlessthan(%input: f32) -> f32{ + // CHECK: %[[MIN:.*]] = spv.constant + %min = spv.constant 0.5 : f32 + // CHECK: %[[MAX:.*]] = spv.constant + %max = spv.constant 1.0 : f32 + + // CHECK: [[RES:%.*]] = spv.GLSL.FClamp %[[INPUT]], %[[MIN]], %[[MAX]] + %0 = spv.FOrdLessThan %min, %input : f32 + %mid = spv.Select %0, %input, %min : i1, f32 + %1 = spv.FOrdLessThan %mid, %max : f32 + %2 = spv.Select %1, %mid, %max : i1, f32 + + // CHECK-NEXT: spv.ReturnValue [[RES]] + spv.ReturnValue %2 : f32 +} + +// ----- + +// CHECK: func @clamp_fordlessthanequal(%[[INPUT:.*]]: f32) +func @clamp_fordlessthanequal(%input: f32) -> f32 { + // CHECK: %[[MIN:.*]] = spv.constant + %min = spv.constant 0.5 : f32 + // CHECK: %[[MAX:.*]] = spv.constant + %max = spv.constant 1.0 : f32 + + // CHECK: [[RES:%.*]] = spv.GLSL.FClamp %[[INPUT]], %[[MIN]], %[[MAX]] + %0 = spv.FOrdLessThanEqual %min, %input : f32 + %mid = spv.Select %0, %input, %min : i1, f32 + %1 = spv.FOrdLessThanEqual %mid, %max : f32 + %2 = spv.Select %1, %mid, %max : i1, f32 + + // CHECK-NEXT: spv.ReturnValue [[RES]] + spv.ReturnValue %2 : f32 +} + +// ----- + +// CHECK: func @clamp_slessthan(%[[INPUT:.*]]: si32) +func @clamp_slessthan(%input: si32) -> si32 { + // CHECK: %[[MIN:.*]] = spv.constant + %min = spv.constant 0 : si32 + // CHECK: %[[MAX:.*]] = spv.constant + %max = spv.constant 10 : si32 + + // CHECK: [[RES:%.*]] = spv.GLSL.SClamp %[[INPUT]], %[[MIN]], %[[MAX]] + %0 = spv.SLessThan %min, %input : si32 + %mid = spv.Select %0, %input, %min : i1, si32 + %1 = spv.SLessThan %mid, %max : si32 + %2 = spv.Select %1, %mid, %max : i1, si32 + + // CHECK-NEXT: spv.ReturnValue [[RES]] + spv.ReturnValue %2 : si32 +} + +// ----- + +// CHECK: func @clamp_slessthanequal(%[[INPUT:.*]]: si32) +func @clamp_slessthanequal(%input: si32) -> si32 { + // CHECK: %[[MIN:.*]] = spv.constant + %min = spv.constant 0 : si32 + // CHECK: %[[MAX:.*]] = spv.constant + %max = spv.constant 10 : si32 + + // CHECK: [[RES:%.*]] = spv.GLSL.SClamp %[[INPUT]], %[[MIN]], %[[MAX]] + %0 = spv.SLessThanEqual %min, %input : si32 + %mid = spv.Select %0, %input, %min : i1, si32 + %1 = spv.SLessThanEqual %mid, %max : si32 + %2 = spv.Select %1, %mid, %max : i1, si32 + + // CHECK-NEXT: spv.ReturnValue [[RES]] + spv.ReturnValue %2 : si32 +} + +// ----- + +// CHECK: func @clamp_ulessthan(%[[INPUT:.*]]: i32) +func @clamp_ulessthan(%input: i32) -> i32 { + // CHECK: %[[MIN:.*]] = spv.constant + %min = spv.constant 0 : i32 + // CHECK: %[[MAX:.*]] = spv.constant + %max = spv.constant 10 : i32 + + // CHECK: [[RES:%.*]] = spv.GLSL.UClamp %[[INPUT]], %[[MIN]], %[[MAX]] + %0 = spv.ULessThan %min, %input : i32 + %mid = spv.Select %0, %input, %min : i1, i32 + %1 = spv.ULessThan %mid, %max : i32 + %2 = spv.Select %1, %mid, %max : i1, i32 + + // CHECK-NEXT: spv.ReturnValue [[RES]] + spv.ReturnValue %2 : i32 +} + +// ----- + +// CHECK: func @clamp_ulessthanequal(%[[INPUT:.*]]: i32) +func @clamp_ulessthanequal(%input: i32) -> i32 { + // CHECK: %[[MIN:.*]] = spv.constant + %min = spv.constant 0 : i32 + // CHECK: %[[MAX:.*]] = spv.constant + %max = spv.constant 10 : i32 + + // CHECK: [[RES:%.*]] = spv.GLSL.UClamp %[[INPUT]], %[[MIN]], %[[MAX]] + %0 = spv.ULessThanEqual %min, %input : i32 + %mid = spv.Select %0, %input, %min : i1, i32 + %1 = spv.ULessThanEqual %mid, %max : i32 + %2 = spv.Select %1, %mid, %max : i1, i32 + + // CHECK-NEXT: spv.ReturnValue [[RES]] + spv.ReturnValue %2 : i32 +}