diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp --- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp +++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp @@ -18,7 +18,9 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/FormatVariadic.h" #define DEBUG_TYPE "math-to-spirv-pattern" @@ -46,6 +48,48 @@ return nullptr; } +/// Check if the type is supported by math-to-spirv conversion. We expect to +/// only see scalars and vectors at this point, with higher-level types already +/// lowered. +static bool isSupportedSourceType(Type originalType) { + if (originalType.isIntOrIndexOrFloat()) + return true; + + if (auto vecTy = originalType.dyn_cast()) { + if (!vecTy.getElementType().isIntOrIndexOrFloat()) + return false; + if (vecTy.isScalable()) + return false; + if (vecTy.getRank() > 1) + return false; + + return true; + } + + return false; +} + +/// Check if all `sourceOp` types are supported by math-to-spirv conversion. +/// Notify of a match failure othwerise and return a `failure` result. +/// This is intended to simplify type checks in `OpConversionPattern`s. +static LogicalResult checkSourceOpTypes(ConversionPatternRewriter &rewriter, + Operation *sourceOp) { + auto allTypes = llvm::to_vector(sourceOp->getOperandTypes()); + llvm::append_range(allTypes, sourceOp->getResultTypes()); + + for (Type ty : allTypes) { + if (!isSupportedSourceType(ty)) { + return rewriter.notifyMatchFailure( + sourceOp, + llvm::formatv( + "unsupported source type for Math to SPIR-V conversion: {0}", + ty)); + } + } + + return success(); +} + //===----------------------------------------------------------------------===// // Operation conversion //===----------------------------------------------------------------------===// @@ -55,14 +99,36 @@ // normal RewritePattern. namespace { +/// Converts elementwise unary, binary, and ternary standard operations to +/// SPIR-V operations. Checks that source `Op` types are supported. +template +struct CheckedElementwiseOpPattern final + : public spirv::ElementwiseOpPattern { + using BasePattern = typename spirv::ElementwiseOpPattern; + using BasePattern::BasePattern; + + LogicalResult + matchAndRewrite(Op op, typename Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (LogicalResult res = checkSourceOpTypes(rewriter, op); failed(res)) + return res; + + return BasePattern::matchAndRewrite(op, adaptor, rewriter); + } +}; + /// Converts math.copysign to SPIR-V ops. -class CopySignPattern final : public OpConversionPattern { +struct CopySignPattern final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(math::CopySignOp copySignOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto type = getTypeConverter()->convertType(copySignOp.getType()); + if (LogicalResult res = checkSourceOpTypes(rewriter, copySignOp); + failed(res)) + return res; + + Type type = getTypeConverter()->convertType(copySignOp.getType()); if (!type) return failure(); @@ -121,14 +187,17 @@ /// SPIR-V does not have a direct operations for counting leading zeros. If /// Shader capability is supported, we can leverage GL FindUMsb to calculate /// it. -class CountLeadingZerosPattern final +struct CountLeadingZerosPattern final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(math::CountLeadingZerosOp countOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto type = getTypeConverter()->convertType(countOp.getType()); + if (LogicalResult res = checkSourceOpTypes(rewriter, countOp); failed(res)) + return res; + + Type type = getTypeConverter()->convertType(countOp.getType()); if (!type) return failure(); @@ -177,9 +246,16 @@ matchAndRewrite(math::ExpM1Op operation, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { assert(adaptor.getOperands().size() == 1); + if (LogicalResult res = checkSourceOpTypes(rewriter, operation); + failed(res)) + return res; + Location loc = operation.getLoc(); - auto type = this->getTypeConverter()->convertType(operation.getType()); - auto exp = rewriter.create(loc, type, adaptor.getOperand()); + Type type = this->getTypeConverter()->convertType(operation.getType()); + if (!type) + return failure(); + + Value exp = rewriter.create(loc, type, adaptor.getOperand()); auto one = spirv::ConstantOp::getOne(type, loc, rewriter); rewriter.replaceOpWithNewOp(operation, exp, one); return success(); @@ -198,10 +274,17 @@ matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { assert(adaptor.getOperands().size() == 1); + if (LogicalResult res = checkSourceOpTypes(rewriter, operation); + failed(res)) + return res; + Location loc = operation.getLoc(); - auto type = this->getTypeConverter()->convertType(operation.getType()); + Type type = this->getTypeConverter()->convertType(operation.getType()); + if (!type) + return failure(); + auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter); - auto onePlus = + Value onePlus = rewriter.create(loc, one, adaptor.getOperand()); rewriter.replaceOpWithNewOp(operation, type, onePlus); return success(); @@ -215,7 +298,10 @@ LogicalResult matchAndRewrite(math::PowFOp powfOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto dstType = getTypeConverter()->convertType(powfOp.getType()); + if (LogicalResult res = checkSourceOpTypes(rewriter, powfOp); failed(res)) + return res; + + Type dstType = getTypeConverter()->convertType(powfOp.getType()); if (!dstType) return failure(); @@ -241,10 +327,13 @@ LogicalResult matchAndRewrite(math::RoundOp roundOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + if (LogicalResult res = checkSourceOpTypes(rewriter, roundOp); failed(res)) + return res; + Location loc = roundOp.getLoc(); - auto operand = roundOp.getOperand(); - auto ty = operand.getType(); - auto ety = getElementTypeOrSelf(ty); + Value operand = roundOp.getOperand(); + Type ty = operand.getType(); + Type ety = getElementTypeOrSelf(ty); auto zero = spirv::ConstantOp::getZero(ty, loc, rewriter); auto one = spirv::ConstantOp::getOne(ty, loc, rewriter); @@ -287,38 +376,38 @@ patterns .add, ExpM1OpPattern, PowFOpPattern, RoundOpPattern, - spirv::ElementwiseOpPattern, - spirv::ElementwiseOpPattern, - spirv::ElementwiseOpPattern, - spirv::ElementwiseOpPattern, - spirv::ElementwiseOpPattern, - spirv::ElementwiseOpPattern, - spirv::ElementwiseOpPattern, - spirv::ElementwiseOpPattern, - spirv::ElementwiseOpPattern, - spirv::ElementwiseOpPattern, - spirv::ElementwiseOpPattern, - spirv::ElementwiseOpPattern, - spirv::ElementwiseOpPattern>( + CheckedElementwiseOpPattern, + CheckedElementwiseOpPattern, + CheckedElementwiseOpPattern, + CheckedElementwiseOpPattern, + CheckedElementwiseOpPattern, + CheckedElementwiseOpPattern, + CheckedElementwiseOpPattern, + CheckedElementwiseOpPattern, + CheckedElementwiseOpPattern, + CheckedElementwiseOpPattern, + CheckedElementwiseOpPattern, + CheckedElementwiseOpPattern, + CheckedElementwiseOpPattern>( typeConverter, patterns.getContext()); // OpenCL patterns patterns.add, ExpM1OpPattern, - spirv::ElementwiseOpPattern, - spirv::ElementwiseOpPattern, - spirv::ElementwiseOpPattern, - spirv::ElementwiseOpPattern, - spirv::ElementwiseOpPattern, - spirv::ElementwiseOpPattern, - spirv::ElementwiseOpPattern, - spirv::ElementwiseOpPattern, - spirv::ElementwiseOpPattern, - spirv::ElementwiseOpPattern, - spirv::ElementwiseOpPattern, - spirv::ElementwiseOpPattern, - spirv::ElementwiseOpPattern, - spirv::ElementwiseOpPattern, - spirv::ElementwiseOpPattern>( + CheckedElementwiseOpPattern, + CheckedElementwiseOpPattern, + CheckedElementwiseOpPattern, + CheckedElementwiseOpPattern, + CheckedElementwiseOpPattern, + CheckedElementwiseOpPattern, + CheckedElementwiseOpPattern, + CheckedElementwiseOpPattern, + CheckedElementwiseOpPattern, + CheckedElementwiseOpPattern, + CheckedElementwiseOpPattern, + CheckedElementwiseOpPattern, + CheckedElementwiseOpPattern, + CheckedElementwiseOpPattern, + CheckedElementwiseOpPattern>( typeConverter, patterns.getContext()); } diff --git a/mlir/lib/Conversion/SPIRVCommon/Pattern.h b/mlir/lib/Conversion/SPIRVCommon/Pattern.h --- a/mlir/lib/Conversion/SPIRVCommon/Pattern.h +++ b/mlir/lib/Conversion/SPIRVCommon/Pattern.h @@ -19,8 +19,7 @@ /// Converts elementwise unary, binary and ternary standard operations to SPIR-V /// operations. template -class ElementwiseOpPattern final : public OpConversionPattern { -public: +struct ElementwiseOpPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-core-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-core-spirv.mlir --- a/mlir/test/Conversion/MathToSPIRV/math-to-core-spirv.mlir +++ b/mlir/test/Conversion/MathToSPIRV/math-to-core-spirv.mlir @@ -41,3 +41,27 @@ // CHECK: %[[OR:.+]] = spirv.BitwiseOr %[[VAND]], %[[SAND]] : vector<3xi16> // CHECK: %[[RESULT:.+]] = spirv.Bitcast %[[OR]] : vector<3xi16> to vector<3xf16> // CHECK: return %[[RESULT]] + +// ----- + +// 2-D vectors are not supported. +func.func @copy_sign_2d_vector(%value: vector<3x3xf32>, %sign: vector<3x3xf32>) -> vector<3x3xf32> { + %0 = math.copysign %value, %sign : vector<3x3xf32> + return %0: vector<3x3xf32> +} + +// CHECK-LABEL: func @copy_sign_2d_vector +// CHECK-NEXT: math.copysign {{%.+}}, {{%.+}} : vector<3x3xf32> +// CHECK-NEXT: return + +// ----- + +// Tensors are not supported. +func.func @copy_sign_tensor(%value: tensor<3x3xf32>, %sign: tensor<3x3xf32>) -> tensor<3x3xf32> { + %0 = math.copysign %value, %sign : tensor<3x3xf32> + return %0: tensor<3x3xf32> +} + +// CHECK-LABEL: func @copy_sign_tensor +// CHECK-NEXT: math.copysign {{%.+}}, {{%.+}} : tensor<3x3xf32> +// CHECK-NEXT: return diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir --- a/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir +++ b/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir @@ -211,3 +211,51 @@ } } // end module + +// ----- + +module attributes { + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> +} { + +// 2-D vectors are not supported. + +// CHECK-LABEL: @vector_2d +func.func @vector_2d(%arg0: vector<2x2xf32>) { + // CHECK-NEXT: math.cos {{.+}} : vector<2x2xf32> + %0 = math.cos %arg0 : vector<2x2xf32> + // CHECK-NEXT: math.exp {{.+}} : vector<2x2xf32> + %1 = math.exp %arg0 : vector<2x2xf32> + // CHECK-NEXT: math.absf {{.+}} : vector<2x2xf32> + %2 = math.absf %arg0 : vector<2x2xf32> + // CHECK-NEXT: math.ceil {{.+}} : vector<2x2xf32> + %3 = math.ceil %arg0 : vector<2x2xf32> + // CHECK-NEXT: math.floor {{.+}} : vector<2x2xf32> + %4 = math.floor %arg0 : vector<2x2xf32> + // CHECK-NEXT: math.powf {{.+}}, {{%.+}} : vector<2x2xf32> + %5 = math.powf %arg0, %arg0 : vector<2x2xf32> + // CHECK-NEXT: return + return +} + +// Tensors are not supported. + +// CHECK-LABEL: @tensor_1d +func.func @tensor_1d(%arg0: tensor<2xf32>) { + // CHECK-NEXT: math.cos {{.+}} : tensor<2xf32> + %0 = math.cos %arg0 : tensor<2xf32> + // CHECK-NEXT: math.exp {{.+}} : tensor<2xf32> + %1 = math.exp %arg0 : tensor<2xf32> + // CHECK-NEXT: math.absf {{.+}} : tensor<2xf32> + %2 = math.absf %arg0 : tensor<2xf32> + // CHECK-NEXT: math.ceil {{.+}} : tensor<2xf32> + %3 = math.ceil %arg0 : tensor<2xf32> + // CHECK-NEXT: math.floor {{.+}} : tensor<2xf32> + %4 = math.floor %arg0 : tensor<2xf32> + // CHECK-NEXT: math.powf {{.+}}, {{%.+}} : tensor<2xf32> + %5 = math.powf %arg0, %arg0 : tensor<2xf32> + // CHECK-NEXT: return + return +} + +} // end module diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir --- a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir +++ b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir @@ -100,3 +100,51 @@ } } // end module + +// ----- + +module attributes { + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> +} { + +// 2-D vectors are not supported. + +// CHECK-LABEL: @vector_2d +func.func @vector_2d(%arg0: vector<2x2xf32>) { + // CHECK-NEXT: math.cos {{.+}} : vector<2x2xf32> + %0 = math.cos %arg0 : vector<2x2xf32> + // CHECK-NEXT: math.exp {{.+}} : vector<2x2xf32> + %1 = math.exp %arg0 : vector<2x2xf32> + // CHECK-NEXT: math.absf {{.+}} : vector<2x2xf32> + %2 = math.absf %arg0 : vector<2x2xf32> + // CHECK-NEXT: math.ceil {{.+}} : vector<2x2xf32> + %3 = math.ceil %arg0 : vector<2x2xf32> + // CHECK-NEXT: math.floor {{.+}} : vector<2x2xf32> + %4 = math.floor %arg0 : vector<2x2xf32> + // CHECK-NEXT: math.powf {{.+}}, {{%.+}} : vector<2x2xf32> + %5 = math.powf %arg0, %arg0 : vector<2x2xf32> + // CHECK-NEXT: return + return +} + +// Tensors are not supported. + +// CHECK-LABEL: @tensor_1d +func.func @tensor_1d(%arg0: tensor<2xf32>) { + // CHECK-NEXT: math.cos {{.+}} : tensor<2xf32> + %0 = math.cos %arg0 : tensor<2xf32> + // CHECK-NEXT: math.exp {{.+}} : tensor<2xf32> + %1 = math.exp %arg0 : tensor<2xf32> + // CHECK-NEXT: math.absf {{.+}} : tensor<2xf32> + %2 = math.absf %arg0 : tensor<2xf32> + // CHECK-NEXT: math.ceil {{.+}} : tensor<2xf32> + %3 = math.ceil %arg0 : tensor<2xf32> + // CHECK-NEXT: math.floor {{.+}} : tensor<2xf32> + %4 = math.floor %arg0 : tensor<2xf32> + // CHECK-NEXT: math.powf {{.+}}, {{%.+}} : tensor<2xf32> + %5 = math.powf %arg0, %arg0 : tensor<2xf32> + // CHECK-NEXT: return + return +} + +} // end module