diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp --- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp +++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" +#include "mlir/IR/BuiltinTypes.h" #include "llvm/Support/Debug.h" #define DEBUG_TYPE "math-to-spirv-pattern" @@ -30,14 +31,74 @@ // normal RewritePattern. namespace { +/// Converts math.copysign to SPIR-V ops. +class CopySignPattern final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(math::CopySignOp copySignOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto type = getTypeConverter()->convertType(copySignOp.getType()); + if (!type) + return failure(); + + FloatType floatType; + if (auto scalarType = copySignOp.getType().dyn_cast()) { + floatType = scalarType; + } else if (auto vectorType = copySignOp.getType().dyn_cast()) { + floatType = vectorType.getElementType().cast(); + } else { + return failure(); + } + + Location loc = copySignOp.getLoc(); + int bitwidth = floatType.getWidth(); + Type intType = rewriter.getIntegerType(bitwidth); + + Value signMask = rewriter.create( + loc, intType, rewriter.getIntegerAttr(intType, (1u << (bitwidth - 1)))); + Value valueMask = rewriter.create( + loc, intType, + rewriter.getIntegerAttr(intType, (1u << (bitwidth - 1)) - 1u)); + + if (auto vectorType = copySignOp.getType().dyn_cast()) { + assert(vectorType.getRank() == 1); + int count = vectorType.getNumElements(); + intType = VectorType::get(count, intType); + + SmallVector signSplat(count, signMask); + signMask = + rewriter.create(loc, intType, signSplat); + + SmallVector valueSplat(count, valueMask); + valueMask = rewriter.create(loc, intType, + valueSplat); + } + + Value lhsCast = + rewriter.create(loc, intType, adaptor.getLhs()); + Value rhsCast = + rewriter.create(loc, intType, adaptor.getRhs()); + + Value value = rewriter.create( + loc, intType, ValueRange{lhsCast, valueMask}); + Value sign = rewriter.create( + loc, intType, ValueRange{rhsCast, signMask}); + + Value result = rewriter.create(loc, intType, + ValueRange{value, sign}); + rewriter.replaceOpWithNewOp(copySignOp, type, result); + return success(); + } +}; + /// Converts math.expm1 to SPIR-V ops. /// /// SPIR-V does not have a direct operations for exp(x)-1. Explicitly lower to /// these operations. template -class ExpM1OpPattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; +struct ExpM1OpPattern final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(math::ExpM1Op operation, OpAdaptor adaptor, @@ -57,9 +118,8 @@ /// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to /// these operations. template -class Log1pOpPattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; +struct Log1pOpPattern final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor, @@ -83,6 +143,8 @@ namespace mlir { void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { + // Core patterns + patterns.add(typeConverter, patterns.getContext()); // GLSL patterns patterns diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-core-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-core-spirv.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/MathToSPIRV/math-to-core-spirv.mlir @@ -0,0 +1,43 @@ +// RUN: mlir-opt -split-input-file -convert-math-to-spirv -verify-diagnostics %s -o - | FileCheck %s + +func @copy_sign_scalar(%value: f32, %sign: f32) -> f32 { + %0 = math.copysign %value, %sign : f32 + return %0: f32 +} + +// CHECK-LABEL: func @copy_sign_scalar +// CHECK-SAME: (%[[VALUE:.+]]: f32, %[[SIGN:.+]]: f32) +// CHECK: %[[SMASK:.+]] = spv.Constant -2147483648 : i32 +// CHECK: %[[VMASK:.+]] = spv.Constant 2147483647 : i32 +// CHECK: %[[VCAST:.+]] = spv.Bitcast %[[VALUE]] : f32 to i32 +// CHECK: %[[SCAST:.+]] = spv.Bitcast %[[SIGN]] : f32 to i32 +// CHECK: %[[VAND:.+]] = spv.BitwiseAnd %[[VCAST]], %[[VMASK]] : i32 +// CHECK: %[[SAND:.+]] = spv.BitwiseAnd %[[SCAST]], %[[SMASK]] : i32 +// CHECK: %[[OR:.+]] = spv.BitwiseOr %[[VAND]], %[[SAND]] : i32 +// CHECK: %[[RESULT:.+]] = spv.Bitcast %[[OR]] : i32 to f32 +// CHECK: return %[[RESULT]] + +// ----- + +module attributes { spv.target_env = #spv.target_env<#spv.vce, {}> } { + +func @copy_sign_vector(%value: vector<3xf16>, %sign: vector<3xf16>) -> vector<3xf16> { + %0 = math.copysign %value, %sign : vector<3xf16> + return %0: vector<3xf16> +} + +} + +// CHECK-LABEL: func @copy_sign_vector +// CHECK-SAME: (%[[VALUE:.+]]: vector<3xf16>, %[[SIGN:.+]]: vector<3xf16>) +// CHECK: %[[SMASK:.+]] = spv.Constant -32768 : i16 +// CHECK: %[[VMASK:.+]] = spv.Constant 32767 : i16 +// CHECK: %[[SVMASK:.+]] = spv.CompositeConstruct %[[SMASK]], %[[SMASK]], %[[SMASK]] : vector<3xi16> +// CHECK: %[[VVMASK:.+]] = spv.CompositeConstruct %[[VMASK]], %[[VMASK]], %[[VMASK]] : vector<3xi16> +// CHECK: %[[VCAST:.+]] = spv.Bitcast %[[VALUE]] : vector<3xf16> to vector<3xi16> +// CHECK: %[[SCAST:.+]] = spv.Bitcast %[[SIGN]] : vector<3xf16> to vector<3xi16> +// CHECK: %[[VAND:.+]] = spv.BitwiseAnd %[[VCAST]], %[[VVMASK]] : vector<3xi16> +// CHECK: %[[SAND:.+]] = spv.BitwiseAnd %[[SCAST]], %[[SVMASK]] : vector<3xi16> +// CHECK: %[[OR:.+]] = spv.BitwiseOr %[[VAND]], %[[SAND]] : vector<3xi16> +// CHECK: %[[RESULT:.+]] = spv.Bitcast %[[OR]] : vector<3xi16> to vector<3xf16> +// CHECK: return %[[RESULT]]