diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp --- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp @@ -32,228 +32,6 @@ using namespace mlir; -//===----------------------------------------------------------------------===// -// Operation Conversion -//===----------------------------------------------------------------------===// - -namespace { - -/// Converts composite arith.constant operation to spirv.Constant. -struct ConstantCompositeOpPattern final - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -/// Converts scalar arith.constant operation to spirv.Constant. -struct ConstantScalarOpPattern final - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -/// Converts arith.remsi to GLSL SPIR-V ops. -/// -/// This cannot be merged into the template unary/binary pattern due to Vulkan -/// restrictions over spirv.SRem and spirv.SMod. -struct RemSIOpGLPattern final : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -/// Converts arith.remsi to OpenCL SPIR-V ops. -struct RemSIOpCLPattern final : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -/// Converts bitwise operations to SPIR-V operations. This is a special pattern -/// other than the BinaryOpPatternPattern because if the operands are boolean -/// values, SPIR-V uses different operations (`SPIRVLogicalOp`). For -/// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`. -template -struct BitwiseOpPattern final : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(Op op, typename Op::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -/// Converts arith.xori to SPIR-V operations. -struct XOrIOpLogicalPattern final : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -/// Converts arith.xori to SPIR-V operations if the type of source is i1 or -/// vector of i1. -struct XOrIOpBooleanPattern final : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -/// Converts arith.uitofp to spirv.Select if the type of source is i1 or vector -/// of i1. -struct UIToFPI1Pattern final : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -/// Converts arith.extsi to spirv.Select if the type of source is i1 or vector -/// of i1. -struct ExtSII1Pattern final : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -/// Converts arith.extui to spirv.Select if the type of source is i1 or vector -/// of i1. -struct ExtUII1Pattern final : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -/// Converts arith.trunci to spirv.Select if the type of result is i1 or vector -/// of i1. -struct TruncII1Pattern final : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -/// Converts type-casting standard operations to SPIR-V operations. -template -struct TypeCastingOpPattern final : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(Op op, typename Op::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -/// Converts integer compare operation on i1 type operands to SPIR-V ops. -class CmpIOpBooleanPattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -/// Converts integer compare operation to SPIR-V ops. -class CmpIOpPattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -/// Converts floating-point comparison operations to SPIR-V ops. -class CmpFOpPattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -/// Converts floating point NaN check to SPIR-V ops. This pattern requires -/// Kernel capability. -class CmpFOpNanKernelPattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -/// Converts floating point NaN check to SPIR-V ops. This pattern does not -/// require additional capability. -class CmpFOpNanNonePattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -/// Converts arith.addui_extended to spirv.IAddCarry. -class AddUIExtendedOpPattern final - : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -/// Converts arith.mul*i_extended to spirv.*MulExtended. -template -class MulIExtendedOpPattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(ArithMulOp op, typename ArithMulOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -/// Converts arith.select to spirv.Select. -class SelectOpPattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -/// Converts arith.maxf to spirv.GL.FMax or spirv.CL.fmax. -template -class MinMaxFOpPattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(Op op, typename Op::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -} // namespace - //===----------------------------------------------------------------------===// // Conversion Helpers //===----------------------------------------------------------------------===// @@ -362,157 +140,169 @@ return getTypeConversionFailure(rewriter, op, op->getResultTypes().front()); } +namespace { + //===----------------------------------------------------------------------===// -// ConstantOp with composite type +// ConstantOp //===----------------------------------------------------------------------===// -LogicalResult ConstantCompositeOpPattern::matchAndRewrite( - arith::ConstantOp constOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - auto srcType = dyn_cast(constOp.getType()); - if (!srcType || srcType.getNumElements() == 1) - return failure(); - - // arith.constant should only have vector or tenor types. - assert((isa(srcType))); - - Type dstType = getTypeConverter()->convertType(srcType); - if (!dstType) - return failure(); +/// Converts composite arith.constant operation to spirv.Constant. +struct ConstantCompositeOpPattern final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; - auto dstElementsAttr = dyn_cast(constOp.getValue()); - if (!dstElementsAttr) - return failure(); + LogicalResult + matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto srcType = dyn_cast(constOp.getType()); + if (!srcType || srcType.getNumElements() == 1) + return failure(); - ShapedType dstAttrType = dstElementsAttr.getType(); + // arith.constant should only have vector or tenor types. + assert((isa(srcType))); - // If the composite type has more than one dimensions, perform linearization. - if (srcType.getRank() > 1) { - if (isa(srcType)) { - dstAttrType = RankedTensorType::get(srcType.getNumElements(), - srcType.getElementType()); - dstElementsAttr = dstElementsAttr.reshape(dstAttrType); - } else { - // TODO: add support for large vectors. + Type dstType = getTypeConverter()->convertType(srcType); + if (!dstType) return failure(); - } - } - Type srcElemType = srcType.getElementType(); - Type dstElemType; - // Tensor types are converted to SPIR-V array types; vector types are - // converted to SPIR-V vector/array types. - if (auto arrayType = dyn_cast(dstType)) - dstElemType = arrayType.getElementType(); - else - dstElemType = cast(dstType).getElementType(); - - // If the source and destination element types are different, perform - // attribute conversion. - if (srcElemType != dstElemType) { - SmallVector elements; - if (isa(srcElemType)) { - for (FloatAttr srcAttr : dstElementsAttr.getValues()) { - FloatAttr dstAttr = - convertFloatAttr(srcAttr, cast(dstElemType), rewriter); - if (!dstAttr) - return failure(); - elements.push_back(dstAttr); - } - } else if (srcElemType.isInteger(1)) { + auto dstElementsAttr = dyn_cast(constOp.getValue()); + if (!dstElementsAttr) return failure(); - } else { - for (IntegerAttr srcAttr : dstElementsAttr.getValues()) { - IntegerAttr dstAttr = convertIntegerAttr( - srcAttr, cast(dstElemType), rewriter); - if (!dstAttr) - return failure(); - elements.push_back(dstAttr); + + ShapedType dstAttrType = dstElementsAttr.getType(); + + // If the composite type has more than one dimensions, perform + // linearization. + if (srcType.getRank() > 1) { + if (isa(srcType)) { + dstAttrType = RankedTensorType::get(srcType.getNumElements(), + srcType.getElementType()); + dstElementsAttr = dstElementsAttr.reshape(dstAttrType); + } else { + // TODO: add support for large vectors. + return failure(); } } - // Unfortunately, we cannot use dialect-specific types for element - // attributes; element attributes only works with builtin types. So we need - // to prepare another converted builtin types for the destination elements - // attribute. - if (isa(dstAttrType)) - dstAttrType = RankedTensorType::get(dstAttrType.getShape(), dstElemType); + Type srcElemType = srcType.getElementType(); + Type dstElemType; + // Tensor types are converted to SPIR-V array types; vector types are + // converted to SPIR-V vector/array types. + if (auto arrayType = dyn_cast(dstType)) + dstElemType = arrayType.getElementType(); else - dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType); + dstElemType = cast(dstType).getElementType(); + + // If the source and destination element types are different, perform + // attribute conversion. + if (srcElemType != dstElemType) { + SmallVector elements; + if (isa(srcElemType)) { + for (FloatAttr srcAttr : dstElementsAttr.getValues()) { + FloatAttr dstAttr = + convertFloatAttr(srcAttr, cast(dstElemType), rewriter); + if (!dstAttr) + return failure(); + elements.push_back(dstAttr); + } + } else if (srcElemType.isInteger(1)) { + return failure(); + } else { + for (IntegerAttr srcAttr : dstElementsAttr.getValues()) { + IntegerAttr dstAttr = convertIntegerAttr( + srcAttr, cast(dstElemType), rewriter); + if (!dstAttr) + return failure(); + elements.push_back(dstAttr); + } + } + + // Unfortunately, we cannot use dialect-specific types for element + // attributes; element attributes only works with builtin types. So we + // need to prepare another converted builtin types for the destination + // elements attribute. + if (isa(dstAttrType)) + dstAttrType = + RankedTensorType::get(dstAttrType.getShape(), dstElemType); + else + dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType); + + dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements); + } - dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements); + rewriter.replaceOpWithNewOp(constOp, dstType, + dstElementsAttr); + return success(); } +}; - rewriter.replaceOpWithNewOp(constOp, dstType, - dstElementsAttr); - return success(); -} +/// Converts scalar arith.constant operation to spirv.Constant. +struct ConstantScalarOpPattern final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; -//===----------------------------------------------------------------------===// -// ConstantOp with scalar type -//===----------------------------------------------------------------------===// + LogicalResult + matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type srcType = constOp.getType(); + if (auto shapedType = dyn_cast(srcType)) { + if (shapedType.getNumElements() != 1) + return failure(); + srcType = shapedType.getElementType(); + } + if (!srcType.isIntOrIndexOrFloat()) + return failure(); -LogicalResult ConstantScalarOpPattern::matchAndRewrite( - arith::ConstantOp constOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Type srcType = constOp.getType(); - if (auto shapedType = dyn_cast(srcType)) { - if (shapedType.getNumElements() != 1) + Attribute cstAttr = constOp.getValue(); + if (auto elementsAttr = dyn_cast(cstAttr)) + cstAttr = elementsAttr.getSplatValue(); + + Type dstType = getTypeConverter()->convertType(srcType); + if (!dstType) return failure(); - srcType = shapedType.getElementType(); - } - if (!srcType.isIntOrIndexOrFloat()) - return failure(); - Attribute cstAttr = constOp.getValue(); - if (auto elementsAttr = dyn_cast(cstAttr)) - cstAttr = elementsAttr.getSplatValue(); + // Floating-point types. + if (isa(srcType)) { + auto srcAttr = cast(cstAttr); + auto dstAttr = srcAttr; - Type dstType = getTypeConverter()->convertType(srcType); - if (!dstType) - return failure(); + // Floating-point types not supported in the target environment are all + // converted to float type. + if (srcType != dstType) { + dstAttr = convertFloatAttr(srcAttr, cast(dstType), rewriter); + if (!dstAttr) + return failure(); + } - // Floating-point types. - if (isa(srcType)) { - auto srcAttr = cast(cstAttr); - auto dstAttr = srcAttr; + rewriter.replaceOpWithNewOp(constOp, dstType, dstAttr); + return success(); + } - // Floating-point types not supported in the target environment are all - // converted to float type. - if (srcType != dstType) { - dstAttr = convertFloatAttr(srcAttr, cast(dstType), rewriter); + // Bool type. + if (srcType.isInteger(1)) { + // arith.constant can use 0/1 instead of true/false for i1 values. We need + // to handle that here. + auto dstAttr = convertBoolAttr(cstAttr, rewriter); if (!dstAttr) return failure(); + rewriter.replaceOpWithNewOp(constOp, dstType, dstAttr); + return success(); } - rewriter.replaceOpWithNewOp(constOp, dstType, dstAttr); - return success(); - } - - // Bool type. - if (srcType.isInteger(1)) { - // arith.constant can use 0/1 instead of true/false for i1 values. We need - // to handle that here. - auto dstAttr = convertBoolAttr(cstAttr, rewriter); + // IndexType or IntegerType. Index values are converted to 32-bit integer + // values when converting to SPIR-V. + auto srcAttr = cast(cstAttr); + IntegerAttr dstAttr = + convertIntegerAttr(srcAttr, cast(dstType), rewriter); if (!dstAttr) return failure(); rewriter.replaceOpWithNewOp(constOp, dstType, dstAttr); return success(); } - - // IndexType or IntegerType. Index values are converted to 32-bit integer - // values when converting to SPIR-V. - auto srcAttr = cast(cstAttr); - IntegerAttr dstAttr = - convertIntegerAttr(srcAttr, cast(dstType), rewriter); - if (!dstAttr) - return failure(); - rewriter.replaceOpWithNewOp(constOp, dstType, dstAttr); - return success(); -} +}; //===----------------------------------------------------------------------===// -// RemSIOpGLPattern +// RemSIOp //===----------------------------------------------------------------------===// /// Returns signed remainder for `lhs` and `rhs` and lets the result follow @@ -545,303 +335,363 @@ return builder.create(loc, type, isPositive, abs, absNegate); } -LogicalResult -RemSIOpGLPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Value result = emulateSignedRemainder( - op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1], - adaptor.getOperands()[0], rewriter); - rewriter.replaceOp(op, result); +/// Converts arith.remsi to GLSL SPIR-V ops. +/// +/// This cannot be merged into the template unary/binary pattern due to Vulkan +/// restrictions over spirv.SRem and spirv.SMod. +struct RemSIOpGLPattern final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; - return success(); -} + LogicalResult + matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value result = emulateSignedRemainder( + op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1], + adaptor.getOperands()[0], rewriter); + rewriter.replaceOp(op, result); -//===----------------------------------------------------------------------===// -// RemSIOpCLPattern -//===----------------------------------------------------------------------===// + return success(); + } +}; -LogicalResult -RemSIOpCLPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Value result = emulateSignedRemainder( - op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1], - adaptor.getOperands()[0], rewriter); - rewriter.replaceOp(op, result); +/// Converts arith.remsi to OpenCL SPIR-V ops. +struct RemSIOpCLPattern final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; - return success(); -} + LogicalResult + matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value result = emulateSignedRemainder( + op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1], + adaptor.getOperands()[0], rewriter); + rewriter.replaceOp(op, result); + + return success(); + } +}; //===----------------------------------------------------------------------===// -// BitwiseOpPattern +// BitwiseOp //===----------------------------------------------------------------------===// +/// Converts bitwise operations to SPIR-V operations. This is a special pattern +/// other than the BinaryOpPatternPattern because if the operands are boolean +/// values, SPIR-V uses different operations (`SPIRVLogicalOp`). For +/// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`. template -LogicalResult -BitwiseOpPattern::matchAndRewrite( - Op op, typename Op::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const { - assert(adaptor.getOperands().size() == 2); - Type dstType = this->getTypeConverter()->convertType(op.getType()); - if (!dstType) - return getTypeConversionFailure(rewriter, op); - - if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) { - rewriter.template replaceOpWithNewOp(op, dstType, - adaptor.getOperands()); - } else { - rewriter.template replaceOpWithNewOp(op, dstType, - adaptor.getOperands()); +struct BitwiseOpPattern final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(Op op, typename Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + assert(adaptor.getOperands().size() == 2); + Type dstType = this->getTypeConverter()->convertType(op.getType()); + if (!dstType) + return getTypeConversionFailure(rewriter, op); + + if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) { + rewriter.template replaceOpWithNewOp( + op, dstType, adaptor.getOperands()); + } else { + rewriter.template replaceOpWithNewOp( + op, dstType, adaptor.getOperands()); + } + return success(); } - return success(); -} +}; //===----------------------------------------------------------------------===// -// XOrIOpLogicalPattern +// XOrIOp //===----------------------------------------------------------------------===// -LogicalResult XOrIOpLogicalPattern::matchAndRewrite( - arith::XOrIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - assert(adaptor.getOperands().size() == 2); +/// Converts arith.xori to SPIR-V operations. +struct XOrIOpLogicalPattern final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; - if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) - return failure(); + LogicalResult + matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + assert(adaptor.getOperands().size() == 2); - Type dstType = getTypeConverter()->convertType(op.getType()); - if (!dstType) - return getTypeConversionFailure(rewriter, op); + if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) + return failure(); - rewriter.replaceOpWithNewOp(op, dstType, - adaptor.getOperands()); + Type dstType = getTypeConverter()->convertType(op.getType()); + if (!dstType) + return getTypeConversionFailure(rewriter, op); - return success(); -} + rewriter.replaceOpWithNewOp(op, dstType, + adaptor.getOperands()); -//===----------------------------------------------------------------------===// -// XOrIOpBooleanPattern -//===----------------------------------------------------------------------===// + return success(); + } +}; + +/// Converts arith.xori to SPIR-V operations if the type of source is i1 or +/// vector of i1. +struct XOrIOpBooleanPattern final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; -LogicalResult XOrIOpBooleanPattern::matchAndRewrite( - arith::XOrIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - assert(adaptor.getOperands().size() == 2); + LogicalResult + matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + assert(adaptor.getOperands().size() == 2); - if (!isBoolScalarOrVector(adaptor.getOperands().front().getType())) - return failure(); + if (!isBoolScalarOrVector(adaptor.getOperands().front().getType())) + return failure(); - Type dstType = getTypeConverter()->convertType(op.getType()); - if (!dstType) - return getTypeConversionFailure(rewriter, op); + Type dstType = getTypeConverter()->convertType(op.getType()); + if (!dstType) + return getTypeConversionFailure(rewriter, op); - rewriter.replaceOpWithNewOp(op, dstType, - adaptor.getOperands()); - return success(); -} + rewriter.replaceOpWithNewOp( + op, dstType, adaptor.getOperands()); + return success(); + } +}; //===----------------------------------------------------------------------===// -// UIToFPI1Pattern +// UIToFPOp //===----------------------------------------------------------------------===// -LogicalResult -UIToFPI1Pattern::matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Type srcType = adaptor.getOperands().front().getType(); - if (!isBoolScalarOrVector(srcType)) - return failure(); +/// Converts arith.uitofp to spirv.Select if the type of source is i1 or vector +/// of i1. +struct UIToFPI1Pattern final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; - Type dstType = getTypeConverter()->convertType(op.getType()); - if (!dstType) - return getTypeConversionFailure(rewriter, op); + LogicalResult + matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type srcType = adaptor.getOperands().front().getType(); + if (!isBoolScalarOrVector(srcType)) + return failure(); - Location loc = op.getLoc(); - Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); - Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); - rewriter.replaceOpWithNewOp( - op, dstType, adaptor.getOperands().front(), one, zero); - return success(); -} + Type dstType = getTypeConverter()->convertType(op.getType()); + if (!dstType) + return getTypeConversionFailure(rewriter, op); + + Location loc = op.getLoc(); + Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); + Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); + rewriter.replaceOpWithNewOp( + op, dstType, adaptor.getOperands().front(), one, zero); + return success(); + } +}; //===----------------------------------------------------------------------===// -// ExtSII1Pattern +// ExtSIOp //===----------------------------------------------------------------------===// -LogicalResult -ExtSII1Pattern::matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Value operand = adaptor.getIn(); - if (!isBoolScalarOrVector(operand.getType())) - return failure(); +/// Converts arith.extsi to spirv.Select if the type of source is i1 or vector +/// of i1. +struct ExtSII1Pattern final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; - Location loc = op.getLoc(); - Type dstType = getTypeConverter()->convertType(op.getType()); - if (!dstType) - return getTypeConversionFailure(rewriter, op); - - Value allOnes; - if (auto intTy = dyn_cast(dstType)) { - unsigned componentBitwidth = intTy.getWidth(); - allOnes = rewriter.create( - loc, intTy, - rewriter.getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth))); - } else if (auto vectorTy = dyn_cast(dstType)) { - unsigned componentBitwidth = vectorTy.getElementTypeBitWidth(); - allOnes = rewriter.create( - loc, vectorTy, - SplatElementsAttr::get(vectorTy, APInt::getAllOnes(componentBitwidth))); - } else { - return rewriter.notifyMatchFailure( - loc, llvm::formatv("unhandled type: {0}", dstType)); - } + LogicalResult + matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value operand = adaptor.getIn(); + if (!isBoolScalarOrVector(operand.getType())) + return failure(); - Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); - rewriter.replaceOpWithNewOp(op, dstType, operand, allOnes, - zero); - return success(); -} + Location loc = op.getLoc(); + Type dstType = getTypeConverter()->convertType(op.getType()); + if (!dstType) + return getTypeConversionFailure(rewriter, op); + + Value allOnes; + if (auto intTy = dyn_cast(dstType)) { + unsigned componentBitwidth = intTy.getWidth(); + allOnes = rewriter.create( + loc, intTy, + rewriter.getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth))); + } else if (auto vectorTy = dyn_cast(dstType)) { + unsigned componentBitwidth = vectorTy.getElementTypeBitWidth(); + allOnes = rewriter.create( + loc, vectorTy, + SplatElementsAttr::get(vectorTy, + APInt::getAllOnes(componentBitwidth))); + } else { + return rewriter.notifyMatchFailure( + loc, llvm::formatv("unhandled type: {0}", dstType)); + } + + Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); + rewriter.replaceOpWithNewOp(op, dstType, operand, allOnes, + zero); + return success(); + } +}; //===----------------------------------------------------------------------===// -// ExtUII1Pattern +// ExtUIOp //===----------------------------------------------------------------------===// -LogicalResult -ExtUII1Pattern::matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Type srcType = adaptor.getOperands().front().getType(); - if (!isBoolScalarOrVector(srcType)) - return failure(); +/// Converts arith.extui to spirv.Select if the type of source is i1 or vector +/// of i1. +struct ExtUII1Pattern final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type srcType = adaptor.getOperands().front().getType(); + if (!isBoolScalarOrVector(srcType)) + return failure(); - Type dstType = getTypeConverter()->convertType(op.getType()); - if (!dstType) - return getTypeConversionFailure(rewriter, op); + Type dstType = getTypeConverter()->convertType(op.getType()); + if (!dstType) + return getTypeConversionFailure(rewriter, op); - Location loc = op.getLoc(); - Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); - Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); - rewriter.replaceOpWithNewOp( - op, dstType, adaptor.getOperands().front(), one, zero); - return success(); -} + Location loc = op.getLoc(); + Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); + Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); + rewriter.replaceOpWithNewOp( + op, dstType, adaptor.getOperands().front(), one, zero); + return success(); + } +}; //===----------------------------------------------------------------------===// -// TruncII1Pattern +// TruncIOp //===----------------------------------------------------------------------===// -LogicalResult -TruncII1Pattern::matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Type dstType = getTypeConverter()->convertType(op.getType()); - if (!dstType) - return getTypeConversionFailure(rewriter, op); +/// Converts arith.trunci to spirv.Select if the type of result is i1 or vector +/// of i1. +struct TruncII1Pattern final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; - if (!isBoolScalarOrVector(dstType)) - return failure(); + LogicalResult + matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type dstType = getTypeConverter()->convertType(op.getType()); + if (!dstType) + return getTypeConversionFailure(rewriter, op); - Location loc = op.getLoc(); - auto srcType = adaptor.getOperands().front().getType(); - // Check if (x & 1) == 1. - Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter); - Value maskedSrc = rewriter.create( - loc, srcType, adaptor.getOperands()[0], mask); - Value isOne = rewriter.create(loc, maskedSrc, mask); - - Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); - Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); - rewriter.replaceOpWithNewOp(op, dstType, isOne, one, zero); - return success(); -} + if (!isBoolScalarOrVector(dstType)) + return failure(); + + Location loc = op.getLoc(); + auto srcType = adaptor.getOperands().front().getType(); + // Check if (x & 1) == 1. + Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter); + Value maskedSrc = rewriter.create( + loc, srcType, adaptor.getOperands()[0], mask); + Value isOne = rewriter.create(loc, maskedSrc, mask); + + Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); + Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); + rewriter.replaceOpWithNewOp(op, dstType, isOne, one, zero); + return success(); + } +}; //===----------------------------------------------------------------------===// -// TypeCastingOpPattern +// TypeCastingOp //===----------------------------------------------------------------------===// +/// Converts type-casting standard operations to SPIR-V operations. template -LogicalResult TypeCastingOpPattern::matchAndRewrite( - Op op, typename Op::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const { - assert(adaptor.getOperands().size() == 1); - Type srcType = adaptor.getOperands().front().getType(); - Type dstType = this->getTypeConverter()->convertType(op.getType()); - if (!dstType) - return getTypeConversionFailure(rewriter, op); - - if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType)) - return failure(); +struct TypeCastingOpPattern final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; - if (dstType == srcType) { - // Due to type conversion, we are seeing the same source and target type. - // Then we can just erase this operation by forwarding its operand. - rewriter.replaceOp(op, adaptor.getOperands().front()); - } else { - rewriter.template replaceOpWithNewOp(op, dstType, - adaptor.getOperands()); + LogicalResult + matchAndRewrite(Op op, typename Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + assert(adaptor.getOperands().size() == 1); + Type srcType = adaptor.getOperands().front().getType(); + Type dstType = this->getTypeConverter()->convertType(op.getType()); + if (!dstType) + return getTypeConversionFailure(rewriter, op); + + if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType)) + return failure(); + + if (dstType == srcType) { + // Due to type conversion, we are seeing the same source and target type. + // Then we can just erase this operation by forwarding its operand. + rewriter.replaceOp(op, adaptor.getOperands().front()); + } else { + rewriter.template replaceOpWithNewOp(op, dstType, + adaptor.getOperands()); + } + return success(); } - return success(); -} +}; //===----------------------------------------------------------------------===// -// CmpIOpBooleanPattern +// CmpIOp //===----------------------------------------------------------------------===// -LogicalResult CmpIOpBooleanPattern::matchAndRewrite( - arith::CmpIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Type srcType = op.getLhs().getType(); - if (!isBoolScalarOrVector(srcType)) - return failure(); - Type dstType = getTypeConverter()->convertType(srcType); - if (!dstType) - return getTypeConversionFailure(rewriter, op, srcType); +/// Converts integer compare operation on i1 type operands to SPIR-V ops. +class CmpIOpBooleanPattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; - switch (op.getPredicate()) { - case arith::CmpIPredicate::eq: { - rewriter.replaceOpWithNewOp(op, adaptor.getLhs(), - adaptor.getRhs()); - return success(); - } - case arith::CmpIPredicate::ne: { - rewriter.replaceOpWithNewOp(op, adaptor.getLhs(), - adaptor.getRhs()); - return success(); - } - case arith::CmpIPredicate::uge: - case arith::CmpIPredicate::ugt: - case arith::CmpIPredicate::ule: - case arith::CmpIPredicate::ult: { - // There are no direct corresponding instructions in SPIR-V for such cases. - // Extend them to 32-bit and do comparision then. - Type type = rewriter.getI32Type(); - if (auto vectorType = dyn_cast(dstType)) - type = VectorType::get(vectorType.getShape(), type); - Value extLhs = - rewriter.create(op.getLoc(), type, adaptor.getLhs()); - Value extRhs = - rewriter.create(op.getLoc(), type, adaptor.getRhs()); - - rewriter.replaceOpWithNewOp(op, op.getPredicate(), extLhs, - extRhs); - return success(); - } - default: - break; + LogicalResult + matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type srcType = op.getLhs().getType(); + if (!isBoolScalarOrVector(srcType)) + return failure(); + Type dstType = getTypeConverter()->convertType(srcType); + if (!dstType) + return getTypeConversionFailure(rewriter, op, srcType); + + switch (op.getPredicate()) { + case arith::CmpIPredicate::eq: { + rewriter.replaceOpWithNewOp(op, adaptor.getLhs(), + adaptor.getRhs()); + return success(); + } + case arith::CmpIPredicate::ne: { + rewriter.replaceOpWithNewOp( + op, adaptor.getLhs(), adaptor.getRhs()); + return success(); + } + case arith::CmpIPredicate::uge: + case arith::CmpIPredicate::ugt: + case arith::CmpIPredicate::ule: + case arith::CmpIPredicate::ult: { + // There are no direct corresponding instructions in SPIR-V for such + // cases. Extend them to 32-bit and do comparision then. + Type type = rewriter.getI32Type(); + if (auto vectorType = dyn_cast(dstType)) + type = VectorType::get(vectorType.getShape(), type); + Value extLhs = + rewriter.create(op.getLoc(), type, adaptor.getLhs()); + Value extRhs = + rewriter.create(op.getLoc(), type, adaptor.getRhs()); + + rewriter.replaceOpWithNewOp(op, op.getPredicate(), extLhs, + extRhs); + return success(); + } + default: + break; + } + return failure(); } - return failure(); -} +}; -//===----------------------------------------------------------------------===// -// CmpIOpPattern -//===----------------------------------------------------------------------===// +/// Converts integer compare operation to SPIR-V ops. +class CmpIOpPattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; -LogicalResult -CmpIOpPattern::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Type srcType = op.getLhs().getType(); - if (isBoolScalarOrVector(srcType)) - return failure(); - Type dstType = getTypeConverter()->convertType(srcType); - if (!dstType) - return getTypeConversionFailure(rewriter, op, srcType); + LogicalResult + matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type srcType = op.getLhs().getType(); + if (isBoolScalarOrVector(srcType)) + return failure(); + Type dstType = getTypeConverter()->convertType(srcType); + if (!dstType) + return getTypeConversionFailure(rewriter, op, srcType); - switch (op.getPredicate()) { + switch (op.getPredicate()) { #define DISPATCH(cmpPredicate, spirvOp) \ case cmpPredicate: \ if (spirvOp::template hasTrait() && \ @@ -854,216 +704,253 @@ adaptor.getRhs()); \ return success(); - DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp); - DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp); - DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp); - DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp); - DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp); - DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp); - DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp); - DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp); - DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp); - DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp); + DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp); + DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp); + DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp); + DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp); + DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp); + DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp); + DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp); + DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp); + DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp); + DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp); #undef DISPATCH + } + return failure(); } - return failure(); -} +}; //===----------------------------------------------------------------------===// // CmpFOpPattern //===----------------------------------------------------------------------===// -LogicalResult -CmpFOpPattern::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - switch (op.getPredicate()) { +/// Converts floating-point comparison operations to SPIR-V ops. +class CmpFOpPattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + switch (op.getPredicate()) { #define DISPATCH(cmpPredicate, spirvOp) \ case cmpPredicate: \ rewriter.replaceOpWithNewOp(op, adaptor.getLhs(), \ adaptor.getRhs()); \ return success(); - // Ordered. - DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp); - DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp); - DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp); - DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp); - DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp); - DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp); - // Unordered. - DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp); - DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp); - DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp); - DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp); - DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp); - DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp); + // Ordered. + DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp); + DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp); + DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp); + DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp); + DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp); + DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp); + // Unordered. + DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp); + DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp); + DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp); + DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp); + DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp); + DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp); #undef DISPATCH - default: - break; + default: + break; + } + return failure(); } - return failure(); -} - -//===----------------------------------------------------------------------===// -// CmpFOpNanKernelPattern -//===----------------------------------------------------------------------===// +}; -LogicalResult CmpFOpNanKernelPattern::matchAndRewrite( - arith::CmpFOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - if (op.getPredicate() == arith::CmpFPredicate::ORD) { - rewriter.replaceOpWithNewOp(op, adaptor.getLhs(), - adaptor.getRhs()); - return success(); - } +/// Converts floating point NaN check to SPIR-V ops. This pattern requires +/// Kernel capability. +class CmpFOpNanKernelPattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; - if (op.getPredicate() == arith::CmpFPredicate::UNO) { - rewriter.replaceOpWithNewOp(op, adaptor.getLhs(), + LogicalResult + matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (op.getPredicate() == arith::CmpFPredicate::ORD) { + rewriter.replaceOpWithNewOp(op, adaptor.getLhs(), adaptor.getRhs()); - return success(); - } - - return failure(); -} + return success(); + } -//===----------------------------------------------------------------------===// -// CmpFOpNanNonePattern -//===----------------------------------------------------------------------===// + if (op.getPredicate() == arith::CmpFPredicate::UNO) { + rewriter.replaceOpWithNewOp(op, adaptor.getLhs(), + adaptor.getRhs()); + return success(); + } -LogicalResult CmpFOpNanNonePattern::matchAndRewrite( - arith::CmpFOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - if (op.getPredicate() != arith::CmpFPredicate::ORD && - op.getPredicate() != arith::CmpFPredicate::UNO) return failure(); + } +}; - Location loc = op.getLoc(); - auto *converter = getTypeConverter(); +/// Converts floating point NaN check to SPIR-V ops. This pattern does not +/// require additional capability. +class CmpFOpNanNonePattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; - Value replace; - if (converter->getOptions().enableFastMathMode) { - if (op.getPredicate() == arith::CmpFPredicate::ORD) { - // Ordered comparsion checks if neither operand is NaN. - replace = spirv::ConstantOp::getOne(op.getType(), loc, rewriter); + LogicalResult + matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (op.getPredicate() != arith::CmpFPredicate::ORD && + op.getPredicate() != arith::CmpFPredicate::UNO) + return failure(); + + Location loc = op.getLoc(); + auto *converter = getTypeConverter(); + + Value replace; + if (converter->getOptions().enableFastMathMode) { + if (op.getPredicate() == arith::CmpFPredicate::ORD) { + // Ordered comparsion checks if neither operand is NaN. + replace = spirv::ConstantOp::getOne(op.getType(), loc, rewriter); + } else { + // Unordered comparsion checks if either operand is NaN. + replace = spirv::ConstantOp::getZero(op.getType(), loc, rewriter); + } } else { - // Unordered comparsion checks if either operand is NaN. - replace = spirv::ConstantOp::getZero(op.getType(), loc, rewriter); + Value lhsIsNan = rewriter.create(loc, adaptor.getLhs()); + Value rhsIsNan = rewriter.create(loc, adaptor.getRhs()); + + replace = rewriter.create(loc, lhsIsNan, rhsIsNan); + if (op.getPredicate() == arith::CmpFPredicate::ORD) + replace = rewriter.create(loc, replace); } - } else { - Value lhsIsNan = rewriter.create(loc, adaptor.getLhs()); - Value rhsIsNan = rewriter.create(loc, adaptor.getRhs()); - replace = rewriter.create(loc, lhsIsNan, rhsIsNan); - if (op.getPredicate() == arith::CmpFPredicate::ORD) - replace = rewriter.create(loc, replace); + rewriter.replaceOp(op, replace); + return success(); } - - rewriter.replaceOp(op, replace); - return success(); -} +}; //===----------------------------------------------------------------------===// -// AddUIExtendedOpPattern +// AddUIExtendedOp //===----------------------------------------------------------------------===// -LogicalResult AddUIExtendedOpPattern::matchAndRewrite( - arith::AddUIExtendedOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Type dstElemTy = adaptor.getLhs().getType(); - Location loc = op->getLoc(); - Value result = rewriter.create(loc, adaptor.getLhs(), - adaptor.getRhs()); - - Value sumResult = rewriter.create( - loc, result, llvm::ArrayRef(0)); - Value carryValue = rewriter.create( - loc, result, llvm::ArrayRef(1)); - - // Convert the carry value to boolean. - Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter); - Value carryResult = rewriter.create(loc, carryValue, one); - - rewriter.replaceOp(op, {sumResult, carryResult}); - return success(); -} +/// Converts arith.addui_extended to spirv.IAddCarry. +class AddUIExtendedOpPattern final + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type dstElemTy = adaptor.getLhs().getType(); + Location loc = op->getLoc(); + Value result = rewriter.create(loc, adaptor.getLhs(), + adaptor.getRhs()); + + Value sumResult = rewriter.create( + loc, result, llvm::ArrayRef(0)); + Value carryValue = rewriter.create( + loc, result, llvm::ArrayRef(1)); + + // Convert the carry value to boolean. + Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter); + Value carryResult = rewriter.create(loc, carryValue, one); + + rewriter.replaceOp(op, {sumResult, carryResult}); + return success(); + } +}; //===----------------------------------------------------------------------===// -// MulIExtendedOpPattern +// MulIExtendedOp //===----------------------------------------------------------------------===// +/// Converts arith.mul*i_extended to spirv.*MulExtended. template -LogicalResult MulIExtendedOpPattern::matchAndRewrite( - ArithMulOp op, typename ArithMulOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Location loc = op->getLoc(); - Value result = - rewriter.create(loc, adaptor.getLhs(), adaptor.getRhs()); - - Value low = rewriter.create(loc, result, - llvm::ArrayRef(0)); - Value high = rewriter.create(loc, result, - llvm::ArrayRef(1)); - - rewriter.replaceOp(op, {low, high}); - return success(); -} +class MulIExtendedOpPattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(ArithMulOp op, typename ArithMulOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + Value result = + rewriter.create(loc, adaptor.getLhs(), adaptor.getRhs()); + + Value low = rewriter.create(loc, result, + llvm::ArrayRef(0)); + Value high = rewriter.create(loc, result, + llvm::ArrayRef(1)); + + rewriter.replaceOp(op, {low, high}); + return success(); + } +}; //===----------------------------------------------------------------------===// -// SelectOpPattern +// SelectOp //===----------------------------------------------------------------------===// -LogicalResult -SelectOpPattern::matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - rewriter.replaceOpWithNewOp(op, adaptor.getCondition(), - adaptor.getTrueValue(), - adaptor.getFalseValue()); - return success(); -} +/// Converts arith.select to spirv.Select. +class SelectOpPattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, adaptor.getCondition(), + adaptor.getTrueValue(), + adaptor.getFalseValue()); + return success(); + } +}; //===----------------------------------------------------------------------===// -// MaxFOpPattern +// MaxFOp //===----------------------------------------------------------------------===// +/// Converts arith.maxf to spirv.GL.FMax or spirv.CL.fmax. template -LogicalResult MinMaxFOpPattern::matchAndRewrite( - Op op, typename Op::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const { - auto *converter = this->template getTypeConverter(); - Type dstType = converter->convertType(op.getType()); - if (!dstType) - return getTypeConversionFailure(rewriter, op); - - // arith.maxf/minf: - // "if one of the arguments is NaN, then the result is also NaN." - // spirv.GL.FMax/FMin - // "which operand is the result is undefined if one of the operands - // is a NaN." - // spirv.CL.fmax/fmin: - // "If one argument is a NaN, Fmin returns the other argument." - - Location loc = op.getLoc(); - Value spirvOp = rewriter.create(loc, dstType, adaptor.getOperands()); - - if (converter->getOptions().enableFastMathMode) { - rewriter.replaceOp(op, spirvOp); - return success(); - } +class MinMaxFOpPattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(Op op, typename Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto *converter = this->template getTypeConverter(); + Type dstType = converter->convertType(op.getType()); + if (!dstType) + return getTypeConversionFailure(rewriter, op); + + // arith.maxf/minf: + // "if one of the arguments is NaN, then the result is also NaN." + // spirv.GL.FMax/FMin + // "which operand is the result is undefined if one of the operands + // is a NaN." + // spirv.CL.fmax/fmin: + // "If one argument is a NaN, Fmin returns the other argument." + + Location loc = op.getLoc(); + Value spirvOp = + rewriter.create(loc, dstType, adaptor.getOperands()); + + if (converter->getOptions().enableFastMathMode) { + rewriter.replaceOp(op, spirvOp); + return success(); + } - Value lhsIsNan = rewriter.create(loc, adaptor.getLhs()); - Value rhsIsNan = rewriter.create(loc, adaptor.getRhs()); + Value lhsIsNan = rewriter.create(loc, adaptor.getLhs()); + Value rhsIsNan = rewriter.create(loc, adaptor.getRhs()); - Value select1 = rewriter.create(loc, dstType, lhsIsNan, - adaptor.getLhs(), spirvOp); - Value select2 = rewriter.create(loc, dstType, rhsIsNan, - adaptor.getRhs(), select1); + Value select1 = rewriter.create(loc, dstType, lhsIsNan, + adaptor.getLhs(), spirvOp); + Value select2 = rewriter.create(loc, dstType, rhsIsNan, + adaptor.getRhs(), select1); - rewriter.replaceOp(op, select2); - return success(); -} + rewriter.replaceOp(op, select2); + return success(); + } +}; + +} // namespace //===----------------------------------------------------------------------===// // Pattern Population