diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h @@ -49,6 +49,14 @@ SPIRVTypeConverter &typeConverter; }; +/// Appends to a pattern list additional patterns for translating the builtin +/// `func` op to the SPIR-V dialect. These patterns do not handle shader +/// interface/ABI; they convert function parameters to be of SPIR-V allowed +/// types. +void populateBuiltinFuncToSPIRVPatterns(MLIRContext *context, + SPIRVTypeConverter &typeConverter, + OwningRewritePatternList &patterns); + namespace spirv { class SPIRVConversionTarget : public ConversionTarget { public: diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp @@ -20,18 +20,6 @@ using namespace mlir; namespace { - -/// A simple pattern for rewriting function signature to convert arguments of -/// functions to be of valid SPIR-V types. -class FuncOpConversion final : public SPIRVOpLowering { -public: - using SPIRVOpLowering::SPIRVOpLowering; - - PatternMatchResult - matchAndRewrite(FuncOp funcOp, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override; -}; - /// A pass converting MLIR Standard operations into the SPIR-V dialect. class ConvertStandardToSPIRVPass : public ModulePass { @@ -39,33 +27,6 @@ }; } // namespace -PatternMatchResult -FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef operands, - ConversionPatternRewriter &rewriter) const { - auto fnType = funcOp.getType(); - if (fnType.getNumResults()) { - return matchFailure(); - } - - TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs()); - { - for (auto argType : enumerate(funcOp.getType().getInputs())) { - auto convertedType = typeConverter.convertType(argType.value()); - if (!convertedType) { - return matchFailure(); - } - signatureConverter.addInputs(argType.index(), convertedType); - } - } - - rewriter.updateRootInPlace(funcOp, [&] { - funcOp.setType(rewriter.getFunctionType( - signatureConverter.getConvertedTypes(), llvm::None)); - rewriter.applySignatureConversion(&funcOp.getBody(), signatureConverter); - }); - return matchSuccess(); -} - void ConvertStandardToSPIRVPass::runOnModule() { MLIRContext *context = &getContext(); ModuleOp module = getModule(); @@ -73,7 +34,7 @@ SPIRVTypeConverter typeConverter; OwningRewritePatternList patterns; populateStandardToSPIRVPatterns(context, typeConverter, patterns); - patterns.insert(context, typeConverter); + populateBuiltinFuncToSPIRVPatterns(context, typeConverter, patterns); std::unique_ptr target = spirv::SPIRVConversionTarget::get( spirv::lookupTargetEnvOrDefault(module), context); diff --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp @@ -171,6 +171,53 @@ Type SPIRVTypeConverter::convertType(Type type) { return convertStdType(type); } //===----------------------------------------------------------------------===// +// FuncOp Conversion Patterns +//===----------------------------------------------------------------------===// + +namespace { +/// A pattern for rewriting function signature to convert arguments of functions +/// to be of valid SPIR-V types. +class FuncOpConversion final : public SPIRVOpLowering { +public: + using SPIRVOpLowering::SPIRVOpLowering; + + PatternMatchResult + matchAndRewrite(FuncOp funcOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; +} // namespace + +PatternMatchResult +FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + auto fnType = funcOp.getType(); + if (fnType.getNumResults()) + return matchFailure(); + + TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs()); + for (auto argType : enumerate(funcOp.getType().getInputs())) { + auto convertedType = typeConverter.convertType(argType.value()); + if (!convertedType) + return matchFailure(); + signatureConverter.addInputs(argType.index(), convertedType); + } + + rewriter.updateRootInPlace(funcOp, [&] { + funcOp.setType(rewriter.getFunctionType( + signatureConverter.getConvertedTypes(), llvm::None)); + rewriter.applySignatureConversion(&funcOp.getBody(), signatureConverter); + }); + + return matchSuccess(); +} + +void mlir::populateBuiltinFuncToSPIRVPatterns( + MLIRContext *context, SPIRVTypeConverter &typeConverter, + OwningRewritePatternList &patterns) { + patterns.insert(context, typeConverter); +} + +//===----------------------------------------------------------------------===// // Builtin Variables //===----------------------------------------------------------------------===//