diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h --- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h +++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h @@ -63,19 +63,6 @@ spirv::TargetEnv targetEnv; }; -/// Base class to define a conversion pattern to lower `SourceOp` into SPIR-V. -template -class SPIRVOpLowering : public OpConversionPattern { -public: - SPIRVOpLowering(MLIRContext *context, SPIRVTypeConverter &typeConverter, - PatternBenefit benefit = 1) - : OpConversionPattern(context, benefit), - typeConverter(typeConverter) {} - -protected: - SPIRVTypeConverter &typeConverter; -}; - /// Appends to a pattern list additional patterns for translating the builtin /// `func` op to the SPIR-V dialect. These patterns do not handle shader /// interface/ABI; they convert function parameters to be of SPIR-V allowed diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -341,6 +341,13 @@ /// does not require type conversion. TypeConverter *getTypeConverter() const { return typeConverter; } + template + std::enable_if_t::value, + ConverterTy *> + getTypeConverter() const { + return static_cast(typeConverter); + } + protected: /// See `RewritePattern::RewritePattern` for information on the other /// available constructors. diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -17,6 +17,8 @@ #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/StringSwitch.h" using namespace mlir; @@ -26,9 +28,9 @@ /// Pattern lowering GPU block/thread size/id to loading SPIR-V invocation /// builtin variables. template -class LaunchConfigConversion : public SPIRVOpLowering { +class LaunchConfigConversion : public OpConversionPattern { public: - using SPIRVOpLowering::SPIRVOpLowering; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(SourceOp op, ArrayRef operands, @@ -38,9 +40,9 @@ /// Pattern lowering subgroup size/id to loading SPIR-V invocation /// builtin variables. template -class SingleDimLaunchConfigConversion : public SPIRVOpLowering { +class SingleDimLaunchConfigConversion : public OpConversionPattern { public: - using SPIRVOpLowering::SPIRVOpLowering; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(SourceOp op, ArrayRef operands, @@ -51,9 +53,9 @@ /// a constant with WorkgroupSize decoration. So here we cannot generate a /// builtin variable; instead the information in the `spv.entry_point_abi` /// attribute on the surrounding FuncOp is used to replace the gpu::BlockDimOp. -class WorkGroupSizeConversion : public SPIRVOpLowering { +class WorkGroupSizeConversion : public OpConversionPattern { public: - using SPIRVOpLowering::SPIRVOpLowering; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(gpu::BlockDimOp op, ArrayRef operands, @@ -61,9 +63,9 @@ }; /// Pattern to convert a kernel function in GPU dialect within a spv.module. -class GPUFuncOpConversion final : public SPIRVOpLowering { +class GPUFuncOpConversion final : public OpConversionPattern { public: - using SPIRVOpLowering::SPIRVOpLowering; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(gpu::GPUFuncOp funcOp, ArrayRef operands, @@ -74,9 +76,9 @@ }; /// Pattern to convert a gpu.module to a spv.module. -class GPUModuleConversion final : public SPIRVOpLowering { +class GPUModuleConversion final : public OpConversionPattern { public: - using SPIRVOpLowering::SPIRVOpLowering; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(gpu::GPUModuleOp moduleOp, ArrayRef operands, @@ -85,9 +87,9 @@ /// Pattern to convert a gpu.return into a SPIR-V return. // TODO: This can go to DRR when GPU return has operands. -class GPUReturnOpConversion final : public SPIRVOpLowering { +class GPUReturnOpConversion final : public OpConversionPattern { public: - using SPIRVOpLowering::SPIRVOpLowering; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(gpu::ReturnOp returnOp, ArrayRef operands, @@ -102,17 +104,14 @@ static Optional getLaunchConfigIndex(Operation *op) { auto dimAttr = op->getAttrOfType("dimension"); - if (!dimAttr) { - return {}; - } - if (dimAttr.getValue() == "x") { - return 0; - } else if (dimAttr.getValue() == "y") { - return 1; - } else if (dimAttr.getValue() == "z") { - return 2; - } - return {}; + if (!dimAttr) + return llvm::None; + + return llvm::StringSwitch>(dimAttr.getValue()) + .Case("x", 0) + .Case("y", 1) + .Case("z", 2) + .Default(llvm::None); } template @@ -150,7 +149,8 @@ auto workGroupSizeAttr = spirv::lookupLocalWorkGroupSize(op); auto val = workGroupSizeAttr.getValue(index.getValue()); - auto convertedType = typeConverter.convertType(op.getResult().getType()); + auto convertedType = + getTypeConverter()->convertType(op.getResult().getType()); if (!convertedType) return failure(); rewriter.replaceOpWithNewOp( @@ -164,7 +164,7 @@ // Legalizes a GPU function as an entry SPIR-V function. static spirv::FuncOp -lowerAsEntryFunction(gpu::GPUFuncOp funcOp, SPIRVTypeConverter &typeConverter, +lowerAsEntryFunction(gpu::GPUFuncOp funcOp, TypeConverter &typeConverter, ConversionPatternRewriter &rewriter, spirv::EntryPointABIAttr entryPointInfo, ArrayRef argABIInfo) { @@ -266,7 +266,7 @@ return failure(); } spirv::FuncOp newFuncOp = lowerAsEntryFunction( - funcOp, typeConverter, rewriter, entryPointAttr, argABI); + funcOp, *getTypeConverter(), rewriter, entryPointAttr, argABI); if (!newFuncOp) return failure(); newFuncOp.removeAttr(Identifier::get(gpu::GPUDialect::getKernelFuncAttrName(), @@ -344,5 +344,5 @@ spirv::BuiltIn::NumSubgroups>, SingleDimLaunchConfigConversion, - WorkGroupSizeConversion>(context, typeConverter); + WorkGroupSizeConversion>(typeConverter, context); } diff --git a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp --- a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp +++ b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/IR/AffineExpr.h" +#include "mlir/Transforms/DialectConversion.h" using namespace mlir; @@ -44,10 +45,9 @@ /// A pattern to convert a linalg.generic op to SPIR-V ops under the condition /// that the linalg.generic op is performing reduction with a workload size that /// can fit in one workgroup. -class SingleWorkgroupReduction final - : public SPIRVOpLowering { -public: - using SPIRVOpLowering::SPIRVOpLowering; +struct SingleWorkgroupReduction final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; /// Matches the given linalg.generic op as performing reduction and returns /// the binary op kind if successful. @@ -142,9 +142,11 @@ // TODO: Load to Workgroup storage class first. + auto *typeConverter = getTypeConverter(); + // Get the input element accessed by this invocation. Value inputElementPtr = spirv::getElementPtr( - typeConverter, originalInputType, convertedInput, {x}, loc, rewriter); + *typeConverter, originalInputType, convertedInput, {x}, loc, rewriter); Value inputElement = rewriter.create(loc, inputElementPtr); // Perform the group reduction operation. @@ -163,10 +165,10 @@ // Get the output element accessed by this reduction. Value zero = spirv::ConstantOp::getZero( - typeConverter.getIndexType(rewriter.getContext()), loc, rewriter); + typeConverter->getIndexType(rewriter.getContext()), loc, rewriter); SmallVector zeroIndices(originalOutputType.getRank(), zero); Value outputElementPtr = - spirv::getElementPtr(typeConverter, originalOutputType, convertedOutput, + spirv::getElementPtr(*typeConverter, originalOutputType, convertedOutput, zeroIndices, loc, rewriter); // Write out the final reduction result. This should be only conducted by one @@ -204,5 +206,5 @@ void mlir::populateLinalgToSPIRVPatterns(MLIRContext *context, SPIRVTypeConverter &typeConverter, OwningRewritePatternList &patterns) { - patterns.insert(context, typeConverter); + patterns.insert(typeConverter, context); } diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp --- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp +++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp @@ -16,9 +16,14 @@ #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/Transforms/DialectConversion.h" using namespace mlir; +//===----------------------------------------------------------------------===// +// Context +//===----------------------------------------------------------------------===// + namespace mlir { struct ScfToSPIRVContextImpl { // Map between the spirv region control flow operation (spv.loop or @@ -37,20 +42,40 @@ ScfToSPIRVContext::ScfToSPIRVContext() { impl = std::make_unique(); } + ScfToSPIRVContext::~ScfToSPIRVContext() = default; +//===----------------------------------------------------------------------===// +// Pattern Declarations +//===----------------------------------------------------------------------===// + namespace { /// Common class for all vector to GPU patterns. template -class SCFToSPIRVPattern : public SPIRVOpLowering { +class SCFToSPIRVPattern : public OpConversionPattern { public: SCFToSPIRVPattern(MLIRContext *context, SPIRVTypeConverter &converter, ScfToSPIRVContextImpl *scfToSPIRVContext) - : SPIRVOpLowering::SPIRVOpLowering(context, converter), - scfToSPIRVContext(scfToSPIRVContext) {} + : OpConversionPattern::OpConversionPattern(context), + scfToSPIRVContext(scfToSPIRVContext), typeConverter(converter) {} protected: ScfToSPIRVContextImpl *scfToSPIRVContext; + // FIXME: We explicitly keep a reference of the type converter here instead of + // passing it to OpConversionPattern during construction. This effectively + // bypasses the conversion framework's automation on type conversion. This is + // needed right now because the conversion framework will unconditionally + // legalize all types used by SCF ops upon discovering them, for example, the + // types of loop carried values. We use SPIR-V variables for those loop + // carried values. Depending on the available capabilities, the SPIR-V + // variable can be different, for example, cooperative matrix or normal + // variable. We'd like to detach the conversion of the loop carried values + // from the SCF ops (which is mainly a region). So we need to "mark" types + // used by SCF ops as legal, if to use the conversion framework for type + // conversion. There isn't a straightforward way to do that yet, as when + // converting types, ops aren't taken into consideration. Therefore, we just + // bypass the framework's type conversion for now. + SPIRVTypeConverter &typeConverter; }; /// Pattern to convert a scf::ForOp within kernel functions into spirv::LoopOp. @@ -90,7 +115,6 @@ /// we load the value from the allocation and use it as the SCF op result. template static void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp, - SPIRVTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, ScfToSPIRVContextImpl *scfToSPIRVContext, ArrayRef returnTypes) { @@ -117,7 +141,7 @@ } //===----------------------------------------------------------------------===// -// scf::ForOp. +// scf::ForOp //===----------------------------------------------------------------------===// LogicalResult @@ -196,13 +220,12 @@ SmallVector initTypes; for (auto arg : forOperands.initArgs()) initTypes.push_back(arg.getType()); - replaceSCFOutputValue(forOp, loopOp, typeConverter, rewriter, - scfToSPIRVContext, initTypes); + replaceSCFOutputValue(forOp, loopOp, rewriter, scfToSPIRVContext, initTypes); return success(); } //===----------------------------------------------------------------------===// -// scf::IfOp. +// scf::IfOp //===----------------------------------------------------------------------===// LogicalResult @@ -255,11 +278,15 @@ auto convertedType = typeConverter.convertType(result.getType()); returnTypes.push_back(convertedType); } - replaceSCFOutputValue(ifOp, selectionOp, typeConverter, rewriter, - scfToSPIRVContext, returnTypes); + replaceSCFOutputValue(ifOp, selectionOp, rewriter, scfToSPIRVContext, + returnTypes); return success(); } +//===----------------------------------------------------------------------===// +// scf::YieldOp +//===----------------------------------------------------------------------===// + /// Yield is lowered to stores to the VariableOp created during lowering of the /// parent region. For loops we also need to update the branch looping back to /// the header with the loop carried values. @@ -290,6 +317,10 @@ return success(); } +//===----------------------------------------------------------------------===// +// Hooks +//===----------------------------------------------------------------------===// + void mlir::populateSCFToSPIRVPatterns(MLIRContext *context, SPIRVTypeConverter &typeConverter, ScfToSPIRVContext &scfToSPIRVContext, diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp @@ -159,7 +159,7 @@ /// 1D array (spv.array or spv.rt_array), the last index is modified to load the /// bits needed. The extraction of the actual bits needed are handled /// separately. Note that this only works for a 1-D tensor. -static Value adjustAccessChainForBitwidth(SPIRVTypeConverter &typeConverter, +static Value adjustAccessChainForBitwidth(TypeConverter &typeConverter, spirv::AccessChainOp op, int sourceBits, int targetBits, OpBuilder &builder) { @@ -264,9 +264,9 @@ /// to Workgroup memory when the size is constant. Note that this pattern needs /// to be applied in a pass that runs at least at spv.module scope since it wil /// ladd global variables into the spv.module. -class AllocOpPattern final : public SPIRVOpLowering { +class AllocOpPattern final : public OpConversionPattern { public: - using SPIRVOpLowering::SPIRVOpLowering; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AllocOp operation, ArrayRef operands, @@ -276,7 +276,7 @@ return operation.emitError("unhandled allocation type"); // Get the SPIR-V type for the allocation. - Type spirvType = typeConverter.convertType(allocType); + Type spirvType = getTypeConverter()->convertType(allocType); // Insert spv.globalVariable for this allocation. Operation *parent = @@ -306,9 +306,9 @@ /// Removed a deallocation if it is a supported allocation. Currently only /// removes deallocation if the memory space is workgroup memory. -class DeallocOpPattern final : public SPIRVOpLowering { +class DeallocOpPattern final : public OpConversionPattern { public: - using SPIRVOpLowering::SPIRVOpLowering; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(DeallocOp operation, ArrayRef operands, @@ -323,15 +323,15 @@ /// Converts unary and binary standard operations to SPIR-V operations. template -class UnaryAndBinaryOpPattern final : public SPIRVOpLowering { +class UnaryAndBinaryOpPattern final : public OpConversionPattern { public: - using SPIRVOpLowering::SPIRVOpLowering; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(StdOp operation, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { assert(operands.size() <= 2); - auto dstType = this->typeConverter.convertType(operation.getType()); + auto dstType = this->getTypeConverter()->convertType(operation.getType()); if (!dstType) return failure(); if (isUnsignedOp() && dstType != operation.getType()) { @@ -347,9 +347,9 @@ /// /// This cannot be merged into the template unary/binary pattern due to /// Vulkan restrictions over spv.SRem and spv.SMod. -class SignedRemIOpPattern final : public SPIRVOpLowering { +class SignedRemIOpPattern final : public OpConversionPattern { public: - using SPIRVOpLowering::SPIRVOpLowering; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(SignedRemIOp remOp, ArrayRef operands, @@ -361,16 +361,16 @@ /// boolean values, SPIR-V uses different operations (`SPIRVLogicalOp`). For /// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`. template -class BitwiseOpPattern final : public SPIRVOpLowering { +class BitwiseOpPattern final : public OpConversionPattern { public: - using SPIRVOpLowering::SPIRVOpLowering; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(StdOp operation, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { assert(operands.size() == 2); auto dstType = - this->typeConverter.convertType(operation.getResult().getType()); + this->getTypeConverter()->convertType(operation.getResult().getType()); if (!dstType) return failure(); if (isBoolScalarOrVector(operands.front().getType())) { @@ -385,9 +385,10 @@ }; /// Converts composite std.constant operation to spv.constant. -class ConstantCompositeOpPattern final : public SPIRVOpLowering { +class ConstantCompositeOpPattern final + : public OpConversionPattern { public: - using SPIRVOpLowering::SPIRVOpLowering; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ConstantOp constOp, ArrayRef operands, @@ -395,9 +396,9 @@ }; /// Converts scalar std.constant operation to spv.constant. -class ConstantScalarOpPattern final : public SPIRVOpLowering { +class ConstantScalarOpPattern final : public OpConversionPattern { public: - using SPIRVOpLowering::SPIRVOpLowering; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ConstantOp constOp, ArrayRef operands, @@ -405,9 +406,9 @@ }; /// Converts floating-point comparison operations to SPIR-V ops. -class CmpFOpPattern final : public SPIRVOpLowering { +class CmpFOpPattern final : public OpConversionPattern { public: - using SPIRVOpLowering::SPIRVOpLowering; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(CmpFOp cmpFOp, ArrayRef operands, @@ -415,9 +416,9 @@ }; /// Converts integer compare operation on i1 type operands to SPIR-V ops. -class BoolCmpIOpPattern final : public SPIRVOpLowering { +class BoolCmpIOpPattern final : public OpConversionPattern { public: - using SPIRVOpLowering::SPIRVOpLowering; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(CmpIOp cmpIOp, ArrayRef operands, @@ -425,9 +426,9 @@ }; /// Converts integer compare operation to SPIR-V ops. -class CmpIOpPattern final : public SPIRVOpLowering { +class CmpIOpPattern final : public OpConversionPattern { public: - using SPIRVOpLowering::SPIRVOpLowering; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(CmpIOp cmpIOp, ArrayRef operands, @@ -435,9 +436,9 @@ }; /// Converts std.load to spv.Load. -class IntLoadOpPattern final : public SPIRVOpLowering { +class IntLoadOpPattern final : public OpConversionPattern { public: - using SPIRVOpLowering::SPIRVOpLowering; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(LoadOp loadOp, ArrayRef operands, @@ -445,9 +446,9 @@ }; /// Converts std.load to spv.Load. -class LoadOpPattern final : public SPIRVOpLowering { +class LoadOpPattern final : public OpConversionPattern { public: - using SPIRVOpLowering::SPIRVOpLowering; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(LoadOp loadOp, ArrayRef operands, @@ -455,9 +456,9 @@ }; /// Converts std.return to spv.Return. -class ReturnOpPattern final : public SPIRVOpLowering { +class ReturnOpPattern final : public OpConversionPattern { public: - using SPIRVOpLowering::SPIRVOpLowering; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ReturnOp returnOp, ArrayRef operands, @@ -465,18 +466,18 @@ }; /// Converts std.select to spv.Select. -class SelectOpPattern final : public SPIRVOpLowering { +class SelectOpPattern final : public OpConversionPattern { public: - using SPIRVOpLowering::SPIRVOpLowering; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(SelectOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; /// Converts std.store to spv.Store on integers. -class IntStoreOpPattern final : public SPIRVOpLowering { +class IntStoreOpPattern final : public OpConversionPattern { public: - using SPIRVOpLowering::SPIRVOpLowering; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(StoreOp storeOp, ArrayRef operands, @@ -484,9 +485,9 @@ }; /// Converts std.store to spv.Store. -class StoreOpPattern final : public SPIRVOpLowering { +class StoreOpPattern final : public OpConversionPattern { public: - using SPIRVOpLowering::SPIRVOpLowering; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(StoreOp storeOp, ArrayRef operands, @@ -495,9 +496,9 @@ /// Converts std.zexti to spv.Select if the type of source is i1 or vector of /// i1. -class ZeroExtendI1Pattern final : public SPIRVOpLowering { +class ZeroExtendI1Pattern final : public OpConversionPattern { public: - using SPIRVOpLowering::SPIRVOpLowering; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ZeroExtendIOp op, ArrayRef operands, @@ -506,7 +507,8 @@ if (!isBoolScalarOrVector(srcType)) return failure(); - auto dstType = this->typeConverter.convertType(op.getResult().getType()); + auto dstType = + this->getTypeConverter()->convertType(op.getResult().getType()); Location loc = op.getLoc(); Attribute zeroAttr, oneAttr; if (auto vectorType = dstType.dyn_cast()) { @@ -526,9 +528,9 @@ /// Converts type-casting standard operations to SPIR-V operations. template -class TypeCastingOpPattern final : public SPIRVOpLowering { +class TypeCastingOpPattern final : public OpConversionPattern { public: - using SPIRVOpLowering::SPIRVOpLowering; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(StdOp operation, ArrayRef operands, @@ -538,7 +540,7 @@ if (isBoolScalarOrVector(srcType)) return failure(); auto dstType = - this->typeConverter.convertType(operation.getResult().getType()); + this->getTypeConverter()->convertType(operation.getResult().getType()); if (dstType == srcType) { // Due to type conversion, we are seeing the same source and target type. // Then we can just erase this operation by forwarding its operand. @@ -552,9 +554,9 @@ }; /// Converts std.xor to SPIR-V operations. -class XOrOpPattern final : public SPIRVOpLowering { +class XOrOpPattern final : public OpConversionPattern { public: - using SPIRVOpLowering::SPIRVOpLowering; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(XOrOp xorOp, ArrayRef operands, @@ -591,7 +593,7 @@ // std.constant should only have vector or tenor types. assert((srcType.isa())); - auto dstType = typeConverter.convertType(srcType); + auto dstType = getTypeConverter()->convertType(srcType); if (!dstType) return failure(); @@ -674,7 +676,7 @@ if (!srcType.isIntOrIndexOrFloat()) return failure(); - Type dstType = typeConverter.convertType(srcType); + Type dstType = getTypeConverter()->convertType(srcType); if (!dstType) return failure(); @@ -800,7 +802,7 @@ #define DISPATCH(cmpPredicate, spirvOp) \ case cmpPredicate: \ if (isUnsignedOp() && \ - operandType != this->typeConverter.convertType(operandType)) { \ + operandType != this->getTypeConverter()->convertType(operandType)) { \ return cmpIOp.emitError( \ "bitwidth emulation is not implemented yet on unsigned op"); \ } \ @@ -837,12 +839,13 @@ auto memrefType = loadOp.memref().getType().cast(); if (!memrefType.getElementType().isSignlessInteger()) return failure(); - spirv::AccessChainOp accessChainOp = - spirv::getElementPtr(typeConverter, memrefType, loadOperands.memref(), - loadOperands.indices(), loc, rewriter); + spirv::AccessChainOp accessChainOp = spirv::getElementPtr( + *getTypeConverter(), memrefType, + loadOperands.memref(), loadOperands.indices(), loc, rewriter); int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); - auto dstType = typeConverter.convertType(memrefType) + auto dstType = getTypeConverter() + ->convertType(memrefType) .cast() .getPointeeType() .cast() @@ -864,8 +867,8 @@ // still returns a linearized accessing. If the accessing is not linearized, // there will be offset issues. assert(accessChainOp.indices().size() == 2); - Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp, - srcBits, dstBits, rewriter); + Value adjustedPtr = adjustAccessChainForBitwidth( + *getTypeConverter(), accessChainOp, srcBits, dstBits, rewriter); Value spvLoadOp = rewriter.create( loc, dstType, adjustedPtr, loadOp->getAttrOfType( @@ -910,9 +913,9 @@ auto memrefType = loadOp.memref().getType().cast(); if (memrefType.getElementType().isSignlessInteger()) return failure(); - auto loadPtr = - spirv::getElementPtr(typeConverter, memrefType, loadOperands.memref(), - loadOperands.indices(), loadOp.getLoc(), rewriter); + auto loadPtr = spirv::getElementPtr( + *getTypeConverter(), memrefType, + loadOperands.memref(), loadOperands.indices(), loadOp.getLoc(), rewriter); rewriter.replaceOpWithNewOp(loadOp, loadPtr); return success(); } @@ -962,11 +965,12 @@ return failure(); auto loc = storeOp.getLoc(); + auto *typeConverter = getTypeConverter(); spirv::AccessChainOp accessChainOp = - spirv::getElementPtr(typeConverter, memrefType, storeOperands.memref(), + spirv::getElementPtr(*typeConverter, memrefType, storeOperands.memref(), storeOperands.indices(), loc, rewriter); int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); - auto dstType = typeConverter.convertType(memrefType) + auto dstType = typeConverter->convertType(memrefType) .cast() .getPointeeType() .cast() @@ -1007,8 +1011,8 @@ Value storeVal = shiftValue(loc, storeOperands.value(), offset, mask, dstBits, rewriter); - Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp, - srcBits, dstBits, rewriter); + Value adjustedPtr = adjustAccessChainForBitwidth( + *getTypeConverter(), accessChainOp, srcBits, dstBits, rewriter); Optional scope = getAtomicOpScope(memrefType); if (!scope) return failure(); @@ -1039,8 +1043,9 @@ if (memrefType.getElementType().isSignlessInteger()) return failure(); auto storePtr = - spirv::getElementPtr(typeConverter, memrefType, storeOperands.memref(), - storeOperands.indices(), storeOp.getLoc(), rewriter); + spirv::getElementPtr(*getTypeConverter(), memrefType, + storeOperands.memref(), storeOperands.indices(), + storeOp.getLoc(), rewriter); rewriter.replaceOpWithNewOp(storeOp, storePtr, storeOperands.value()); return success(); @@ -1058,7 +1063,7 @@ if (isBoolScalarOrVector(operands.front().getType())) return failure(); - auto dstType = typeConverter.convertType(xorOp.getType()); + auto dstType = getTypeConverter()->convertType(xorOp.getType()); if (!dstType) return failure(); rewriter.replaceOpWithNewOp(xorOp, dstType, operands); @@ -1125,7 +1130,7 @@ TypeCastingOpPattern, TypeCastingOpPattern, TypeCastingOpPattern, - TypeCastingOpPattern>(context, - typeConverter); + TypeCastingOpPattern>(typeConverter, + context); } } // namespace mlir diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -24,8 +24,9 @@ namespace { struct VectorBroadcastConvert final - : public SPIRVOpLowering { - using SPIRVOpLowering::SPIRVOpLowering; + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { @@ -43,8 +44,9 @@ }; struct VectorExtractOpConvert final - : public SPIRVOpLowering { - using SPIRVOpLowering::SPIRVOpLowering; + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(vector::ExtractOp extractOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { @@ -60,8 +62,10 @@ } }; -struct VectorInsertOpConvert final : public SPIRVOpLowering { - using SPIRVOpLowering::SPIRVOpLowering; +struct VectorInsertOpConvert final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(vector::InsertOp insertOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { @@ -78,8 +82,9 @@ }; struct VectorExtractElementOpConvert final - : public SPIRVOpLowering { - using SPIRVOpLowering::SPIRVOpLowering; + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(vector::ExtractElementOp extractElementOp, ArrayRef operands, @@ -96,8 +101,9 @@ }; struct VectorInsertElementOpConvert final - : public SPIRVOpLowering { - using SPIRVOpLowering::SPIRVOpLowering; + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(vector::InsertElementOp insertElementOp, ArrayRef operands, @@ -120,5 +126,5 @@ OwningRewritePatternList &patterns) { patterns.insert(context, typeConverter); + VectorInsertElementOpConvert>(typeConverter, context); } diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp @@ -151,9 +151,10 @@ /// variable ABI attributes attached to function arguments and converts all /// function argument uses to those global variables. This is necessary because /// Vulkan requires all shader entry points to be of void(void) type. -class ProcessInterfaceVarABI final : public SPIRVOpLowering { +class ProcessInterfaceVarABI final : public OpConversionPattern { public: - using SPIRVOpLowering::SPIRVOpLowering; + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(spirv::FuncOp funcOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; @@ -214,7 +215,7 @@ } signatureConverter.remapInput(argType.index(), replacement); } - if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), typeConverter, + if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), *getTypeConverter(), &signatureConverter))) return failure(); @@ -246,7 +247,7 @@ }); OwningRewritePatternList patterns; - patterns.insert(context, typeConverter); + patterns.insert(typeConverter, context); ConversionTarget target(*context); // "Legal" function ops should have no interface variable ABI attributes. diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" +#include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/Debug.h" @@ -459,9 +460,9 @@ namespace { /// A pattern for rewriting function signature to convert arguments of functions /// to be of valid SPIR-V types. -class FuncOpConversion final : public SPIRVOpLowering { +class FuncOpConversion final : public OpConversionPattern { public: - using SPIRVOpLowering::SPIRVOpLowering; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(FuncOp funcOp, ArrayRef operands, @@ -478,7 +479,7 @@ TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs()); for (auto argType : enumerate(fnType.getInputs())) { - auto convertedType = typeConverter.convertType(argType.value()); + auto convertedType = getTypeConverter()->convertType(argType.value()); if (!convertedType) return failure(); signatureConverter.addInputs(argType.index(), convertedType); @@ -486,7 +487,7 @@ Type resultType; if (fnType.getNumResults() == 1) - resultType = typeConverter.convertType(fnType.getResult(0)); + resultType = getTypeConverter()->convertType(fnType.getResult(0)); // Create the converted spv.func op. auto newFuncOp = rewriter.create( @@ -504,8 +505,8 @@ rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), newFuncOp.end()); - if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter, - &signatureConverter))) + if (failed(rewriter.convertRegionTypes( + &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter))) return failure(); rewriter.eraseOp(funcOp); return success(); @@ -514,7 +515,7 @@ void mlir::populateBuiltinFuncToSPIRVPatterns( MLIRContext *context, SPIRVTypeConverter &typeConverter, OwningRewritePatternList &patterns) { - patterns.insert(context, typeConverter); + patterns.insert(typeConverter, context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -convert-std-to-spirv -verify-diagnostics %s -o - | FileCheck %s +// RUN: mlir-opt -split-input-file -convert-std-to-spirv -verify-diagnostics %s -o - | FileCheck %s //===----------------------------------------------------------------------===// // std arithmetic ops @@ -628,49 +628,59 @@ // ----- -// Checks that cast types will be adjusted when no special capabilities for -// non-32-bit scalar types. +// Checks that cast types will be adjusted when missing special capabilities for +// certain non-32-bit scalar types. module attributes { - spv.target_env = #spv.target_env<#spv.vce, {}> + spv.target_env = #spv.target_env<#spv.vce, {}> } { // CHECK-LABEL: @fpext1 // CHECK-SAME: %[[ARG:.*]]: f32 -func @fpext1(%arg0: f16) { - // CHECK-NEXT: "use"(%[[ARG]]) +func @fpext1(%arg0: f16) -> f64 { + // CHECK-NEXT: spv.FConvert %[[ARG]] : f32 to f64 %0 = std.fpext %arg0 : f16 to f64 - "use"(%0) : (f64) -> () + return %0: f64 } // CHECK-LABEL: @fpext2 // CHECK-SAME: %[[ARG:.*]]: f32 -func @fpext2(%arg0 : f32) { - // CHECK-NEXT: "use"(%[[ARG]]) +func @fpext2(%arg0 : f32) -> f64 { + // CHECK-NEXT: spv.FConvert %[[ARG]] : f32 to f64 %0 = std.fpext %arg0 : f32 to f64 - "use"(%0) : (f64) -> () + return %0: f64 } +} // end module + +// ----- + +// Checks that cast types will be adjusted when missing special capabilities for +// certain non-32-bit scalar types. +module attributes { + spv.target_env = #spv.target_env<#spv.vce, {}> +} { + // CHECK-LABEL: @fptrunc1 // CHECK-SAME: %[[ARG:.*]]: f32 -func @fptrunc1(%arg0 : f64) { - // CHECK-NEXT: "use"(%[[ARG]]) +func @fptrunc1(%arg0 : f64) -> f16 { + // CHECK-NEXT: spv.FConvert %[[ARG]] : f32 to f16 %0 = std.fptrunc %arg0 : f64 to f16 - "use"(%0) : (f16) -> () + return %0: f16 } // CHECK-LABEL: @fptrunc2 // CHECK-SAME: %[[ARG:.*]]: f32 -func @fptrunc2(%arg0: f32) { - // CHECK-NEXT: "use"(%[[ARG]]) +func @fptrunc2(%arg0: f32) -> f16 { + // CHECK-NEXT: spv.FConvert %[[ARG]] : f32 to f16 %0 = std.fptrunc %arg0 : f32 to f16 - "use"(%0) : (f16) -> () + return %0: f16 } // CHECK-LABEL: @sitofp -func @sitofp(%arg0 : i64) { +func @sitofp(%arg0 : i64) -> f64 { // CHECK: spv.ConvertSToF %{{.*}} : i32 to f32 %0 = std.sitofp %arg0 : i64 to f64 - "use"(%0) : (f64) -> () + return %0: f64 } } // end module