diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -122,7 +122,11 @@ Option<"emulateNon32BitScalarTypes", "emulate-non-32-bit-scalar-types", "bool", /*default=*/"true", "Emulate non-32-bit scalar types with 32-bit ones if " - "missing native support"> + "missing native support">, + Option<"enableFastMath", "enable-fast-math", + "bool", /*default=*/"false", + "Enable fast math mode (assuming no NaN and infinity for floating " + "point values) when performing conversion"> ]; } diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h --- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h +++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h @@ -26,6 +26,9 @@ //===----------------------------------------------------------------------===// struct SPIRVConversionOptions { + /// The number of bits to store a boolean value. + unsigned boolNumBits{8}; + /// Whether to emulate non-32-bit scalar types with 32-bit scalar types if /// no native support. /// @@ -45,9 +48,10 @@ /// Use 64-bit integers to convert index types. bool use64bitIndex{false}; - /// The number of bits to store a boolean value. It is eight bits by - /// default. - unsigned boolNumBits{8}; + /// Whether to enable fast math mode during conversion. If true, various + /// patterns would assume no NaN/infinity numbers as inputs, and thus there + /// will be no special guards emitted to check and handle such cases. + bool enableFastMathMode{false}; }; /// Type conversion from builtin types to SPIR-V types for shader interface. diff --git a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp --- a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp @@ -219,6 +219,16 @@ ConversionPatternRewriter &rewriter) const override; }; +/// Converts arith.maxf to spv.GL.FMax. +template +class MinMaxFOpPattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(Op op, typename Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + } // namespace //===----------------------------------------------------------------------===// @@ -839,13 +849,25 @@ return failure(); Location loc = op.getLoc(); + auto *converter = getTypeConverter(); - Value lhsIsNan = rewriter.create(loc, adaptor.getLhs()); - Value rhsIsNan = rewriter.create(loc, adaptor.getRhs()); + Value replace; + if (converter->getOptions().enableFastMathMode) { + if (op.getPredicate() == arith::CmpFPredicate::ORD) { + // Ordered comparsion checks if neither operand is NaN. + replace = spirv::ConstantOp::getOne(op.getType(), loc, rewriter); + } else { + // Unordered comparsion checks if either operand is NaN. + replace = spirv::ConstantOp::getZero(op.getType(), loc, rewriter); + } + } else { + Value lhsIsNan = rewriter.create(loc, adaptor.getLhs()); + Value rhsIsNan = rewriter.create(loc, adaptor.getRhs()); - Value replace = rewriter.create(loc, lhsIsNan, rhsIsNan); - if (op.getPredicate() == arith::CmpFPredicate::ORD) - replace = rewriter.create(loc, replace); + replace = rewriter.create(loc, lhsIsNan, rhsIsNan); + if (op.getPredicate() == arith::CmpFPredicate::ORD) + replace = rewriter.create(loc, replace); + } rewriter.replaceOp(op, replace); return success(); @@ -889,6 +911,45 @@ return success(); } +//===----------------------------------------------------------------------===// +// MaxFOpPattern +//===----------------------------------------------------------------------===// + +template +LogicalResult MinMaxFOpPattern::matchAndRewrite( + Op op, typename Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto *converter = this->template getTypeConverter(); + auto dstType = converter->convertType(op.getType()); + if (!dstType) + return failure(); + + // arith.maxf/minf: + // "if one of the arguments is NaN, then the result is also NaN." + // spv.GL.FMax/FMin: + // "which operand is the result is undefined if one of the operands + // is a NaN." + + Location loc = op.getLoc(); + Value cmpF = rewriter.create(loc, dstType, adaptor.getOperands()); + + if (converter->getOptions().enableFastMathMode) { + rewriter.replaceOp(op, cmpF); + return success(); + } + + Value lhsIsNan = rewriter.create(loc, adaptor.getLhs()); + Value rhsIsNan = rewriter.create(loc, adaptor.getRhs()); + + Value select1 = rewriter.create(loc, dstType, lhsIsNan, + adaptor.getLhs(), cmpF); + Value select2 = rewriter.create(loc, dstType, rhsIsNan, + adaptor.getRhs(), select1); + + rewriter.replaceOp(op, select2); + return success(); +} + //===----------------------------------------------------------------------===// // Pattern Population //===----------------------------------------------------------------------===// @@ -932,10 +993,10 @@ CmpFOpNanNonePattern, CmpFOpPattern, AddICarryOpPattern, SelectOpPattern, - spirv::ElementwiseOpPattern, + MinMaxFOpPattern, + MinMaxFOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, - spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern >(typeConverter, patterns.getContext()); @@ -961,6 +1022,7 @@ SPIRVConversionOptions options; options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes; + options.enableFastMathMode = this->enableFastMath; SPIRVTypeConverter typeConverter(targetAttr, options); // Use UnrealizedConversionCast as the bridge so that we don't need to pull diff --git a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir --- a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir +++ b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir @@ -1093,13 +1093,35 @@ %3 = arith.divf %lhs, %rhs: f32 // CHECK: spv.FRem %{{.*}}, %{{.*}}: f32 %4 = arith.remf %lhs, %rhs: f32 - // CHECK: spv.GL.FMax %{{.*}}, %{{.*}}: f32 - %5 = arith.maxf %lhs, %rhs: f32 - // CHECK: spv.GL.FMin %{{.*}}, %{{.*}}: f32 - %6 = arith.minf %lhs, %rhs: f32 return } +// CHECK-LABEL: @float32_minf_scalar +// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32 +func.func @float32_minf_scalar(%arg0 : f32, %arg1 : f32) -> f32 { + // CHECK: %[[MIN:.+]] = spv.GL.FMin %arg0, %arg1 : f32 + // CHECK: %[[LHS_NAN:.+]] = spv.IsNan %[[LHS]] : f32 + // CHECK: %[[RHS_NAN:.+]] = spv.IsNan %[[RHS]] : f32 + // CHECK: %[[SELECT1:.+]] = spv.Select %[[LHS_NAN]], %[[LHS]], %[[MIN]] + // CHECK: %[[SELECT2:.+]] = spv.Select %[[RHS_NAN]], %[[RHS]], %[[SELECT1]] + %0 = arith.minf %arg0, %arg1 : f32 + // CHECK: return %[[SELECT2]] + return %0: f32 +} + +// CHECK-LABEL: @float32_maxf_scalar +// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32 +func.func @float32_maxf_scalar(%arg0 : f32, %arg1 : f32) -> f32 { + // CHECK: %[[MAX:.+]] = spv.GL.FMax %arg0, %arg1 : f32 + // CHECK: %[[LHS_NAN:.+]] = spv.IsNan %[[LHS]] : f32 + // CHECK: %[[RHS_NAN:.+]] = spv.IsNan %[[RHS]] : f32 + // CHECK: %[[SELECT1:.+]] = spv.Select %[[LHS_NAN]], %[[LHS]], %[[MAX]] + // CHECK: %[[SELECT2:.+]] = spv.Select %[[RHS_NAN]], %[[RHS]], %[[SELECT1]] + %0 = arith.maxf %arg0, %arg1 : f32 + // CHECK: return %[[SELECT2]] + return %0: f32 +} + // Check int vector types. // CHECK-LABEL: @int_vector234 func.func @int_vector234(%arg0: vector<2xi8>, %arg1: vector<4xi64>) { diff --git a/mlir/test/Conversion/ArithmeticToSPIRV/fast-math.mlir b/mlir/test/Conversion/ArithmeticToSPIRV/fast-math.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/ArithmeticToSPIRV/fast-math.mlir @@ -0,0 +1,51 @@ +// RUN: mlir-opt -split-input-file -convert-arith-to-spirv=enable-fast-math -verify-diagnostics %s | FileCheck %s + +module attributes { + spv.target_env = #spv.target_env<#spv.vce, #spv.resource_limits<>> +} { + +// CHECK-LABEL: @cmpf_ordered +// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32 +func.func @cmpf_ordered(%arg0 : f32, %arg1 : f32) -> i1 { + // CHECK: %[[T:.+]] = spv.Constant true + %0 = arith.cmpf ord, %arg0, %arg1 : f32 + // CHECK: return %[[T]] + return %0: i1 +} + +// CHECK-LABEL: @cmpf_unordered +// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32 +func.func @cmpf_unordered(%arg0 : f32, %arg1 : f32) -> i1 { + // CHECK: %[[F:.+]] = spv.Constant false + %0 = arith.cmpf uno, %arg0, %arg1 : f32 + // CHECK: return %[[F]] + return %0: i1 +} + +} // end module + +// ----- + +module attributes { + spv.target_env = #spv.target_env<#spv.vce, #spv.resource_limits<>> +} { + +// CHECK-LABEL: @minf +// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32 +func.func @minf(%arg0 : f32, %arg1 : f32) -> f32 { + // CHECK: %[[F:.+]] = spv.GL.FMin %[[LHS]], %[[RHS]] + %0 = arith.minf %arg0, %arg1 : f32 + // CHECK: return %[[F]] + return %0: f32 +} + +// CHECK-LABEL: @maxf +// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32 +func.func @maxf(%arg0 : f32, %arg1 : f32) -> f32 { + // CHECK: %[[F:.+]] = spv.GL.FMax %[[LHS]], %[[RHS]] + %0 = arith.maxf %arg0, %arg1 : f32 + // CHECK: return %[[F]] + return %0: f32 +} + +} // end module