diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp --- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp @@ -9,8 +9,8 @@ #include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h" #include "../SPIRVCommon/Pattern.h" -#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" @@ -21,6 +21,7 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/Support/Debug.h" #include +#include namespace mlir { #define GEN_PASS_DEF_CONVERTARITHTOSPIRV @@ -40,7 +41,7 @@ /// Converts composite arith.constant operation to spirv.Constant. struct ConstantCompositeOpPattern final : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, @@ -50,7 +51,7 @@ /// Converts scalar arith.constant operation to spirv.Constant. struct ConstantScalarOpPattern final : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, @@ -62,7 +63,7 @@ /// This cannot be merged into the template unary/binary pattern due to Vulkan /// restrictions over spirv.SRem and spirv.SMod. struct RemSIOpGLPattern final : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, @@ -71,7 +72,7 @@ /// Converts arith.remsi to OpenCL SPIR-V ops. struct RemSIOpCLPattern final : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, @@ -93,7 +94,7 @@ /// Converts arith.xori to SPIR-V operations. struct XOrIOpLogicalPattern final : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor, @@ -103,7 +104,7 @@ /// Converts arith.xori to SPIR-V operations if the type of source is i1 or /// vector of i1. struct XOrIOpBooleanPattern final : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor, @@ -113,7 +114,7 @@ /// Converts arith.uitofp to spirv.Select if the type of source is i1 or vector /// of i1. struct UIToFPI1Pattern final : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor, @@ -123,7 +124,7 @@ /// Converts arith.extsi to spirv.Select if the type of source is i1 or vector /// of i1. struct ExtSII1Pattern final : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor, @@ -133,7 +134,7 @@ /// Converts arith.extui to spirv.Select if the type of source is i1 or vector /// of i1. struct ExtUII1Pattern final : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor, @@ -143,7 +144,7 @@ /// Converts arith.trunci to spirv.Select if the type of result is i1 or vector /// of i1. struct TruncII1Pattern final : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor, @@ -163,7 +164,7 @@ /// Converts integer compare operation on i1 type operands to SPIR-V ops. class CmpIOpBooleanPattern final : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, @@ -173,7 +174,7 @@ /// Converts integer compare operation to SPIR-V ops. class CmpIOpPattern final : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, @@ -183,7 +184,7 @@ /// Converts floating-point comparison operations to SPIR-V ops. class CmpFOpPattern final : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, @@ -194,7 +195,7 @@ /// Kernel capability. class CmpFOpNanKernelPattern final : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, @@ -216,7 +217,7 @@ class AddICarryOpPattern final : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(arith::AddUICarryOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; @@ -225,7 +226,7 @@ /// Converts arith.select to spirv.Select. class SelectOpPattern final : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; @@ -254,7 +255,7 @@ return boolAttr; if (auto intAttr = srcAttr.dyn_cast()) return builder.getBoolAttr(intAttr.getValue().getBoolValue()); - return BoolAttr(); + return {}; } /// Converts the given `srcAttr` to a new attribute of the given `dstType`. @@ -281,7 +282,7 @@ LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' illegal: cannot fit into target type '" << dstType << "'\n"); - return IntegerAttr(); + return {}; } /// Converts the given `srcAttr` to a new attribute of the given `dstType`. @@ -346,7 +347,7 @@ // arith.constant should only have vector or tenor types. assert((srcType.isa())); - auto dstType = getTypeConverter()->convertType(srcType); + Type dstType = getTypeConverter()->convertType(srcType); if (!dstType) return failure(); @@ -473,7 +474,7 @@ // IndexType or IntegerType. Index values are converted to 32-bit integer // values when converting to SPIR-V. auto srcAttr = cstAttr.cast(); - auto dstAttr = + IntegerAttr dstAttr = convertIntegerAttr(srcAttr, dstType.cast(), rewriter); if (!dstAttr) return failure(); @@ -577,7 +578,7 @@ if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) return failure(); - auto dstType = getTypeConverter()->convertType(op.getType()); + Type dstType = getTypeConverter()->convertType(op.getType()); if (!dstType) return failure(); rewriter.replaceOpWithNewOp(op, dstType, @@ -598,7 +599,7 @@ if (!isBoolScalarOrVector(adaptor.getOperands().front().getType())) return failure(); - auto dstType = getTypeConverter()->convertType(op.getType()); + Type dstType = getTypeConverter()->convertType(op.getType()); if (!dstType) return failure(); rewriter.replaceOpWithNewOp(op, dstType, @@ -613,16 +614,15 @@ LogicalResult UIToFPI1Pattern::matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - auto srcType = adaptor.getOperands().front().getType(); + Type srcType = adaptor.getOperands().front().getType(); if (!isBoolScalarOrVector(srcType)) return failure(); - auto dstType = - this->getTypeConverter()->convertType(op.getResult().getType()); + Type dstType = getTypeConverter()->convertType(op.getResult().getType()); Location loc = op.getLoc(); Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); - rewriter.template replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, dstType, adaptor.getOperands().front(), one, zero); return success(); } @@ -670,16 +670,15 @@ LogicalResult ExtUII1Pattern::matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - auto srcType = adaptor.getOperands().front().getType(); + Type srcType = adaptor.getOperands().front().getType(); if (!isBoolScalarOrVector(srcType)) return failure(); - auto dstType = - this->getTypeConverter()->convertType(op.getResult().getType()); + Type dstType = getTypeConverter()->convertType(op.getResult().getType()); Location loc = op.getLoc(); Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); - rewriter.template replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, dstType, adaptor.getOperands().front(), one, zero); return success(); } @@ -691,8 +690,7 @@ LogicalResult TruncII1Pattern::matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - auto dstType = - this->getTypeConverter()->convertType(op.getResult().getType()); + Type dstType = getTypeConverter()->convertType(op.getResult().getType()); if (!isBoolScalarOrVector(dstType)) return failure(); @@ -719,8 +717,8 @@ Op op, typename Op::Adaptor adaptor, ConversionPatternRewriter &rewriter) const { assert(adaptor.getOperands().size() == 1); - auto srcType = adaptor.getOperands().front().getType(); - auto dstType = + Type srcType = adaptor.getOperands().front().getType(); + Type dstType = this->getTypeConverter()->convertType(op.getResult().getType()); if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType)) return failure(); @@ -769,9 +767,9 @@ Type type = rewriter.getI32Type(); if (auto vectorType = dstType.dyn_cast()) type = VectorType::get(vectorType.getShape(), type); - auto extLhs = + Value extLhs = rewriter.create(op.getLoc(), type, adaptor.getLhs()); - auto extRhs = + Value extRhs = rewriter.create(op.getLoc(), type, adaptor.getRhs()); rewriter.replaceOpWithNewOp(op, op.getPredicate(), extLhs, @@ -968,7 +966,7 @@ Op op, typename Op::Adaptor adaptor, ConversionPatternRewriter &rewriter) const { auto *converter = this->template getTypeConverter(); - auto dstType = converter->convertType(op.getType()); + Type dstType = converter->convertType(op.getType()); if (!dstType) return failure(); @@ -1075,8 +1073,9 @@ : public impl::ConvertArithToSPIRVBase { void runOnOperation() override { Operation *op = getOperation(); - auto targetAttr = spirv::lookupTargetEnvOrDefault(op); - auto target = SPIRVConversionTarget::get(targetAttr); + spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op); + std::unique_ptr target = + SPIRVConversionTarget::get(targetAttr); SPIRVConversionOptions options; options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;