diff --git a/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h b/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h --- a/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h +++ b/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h @@ -37,6 +37,12 @@ LLVMTypeConverter &typeConverter, OwningRewritePatternList &patterns); +/// Populates the given list with patterns for function conversion from SPIR-V +/// to LLVM. +void populateSPIRVToLLVMFunctionConversionPatterns( + MLIRContext *context, LLVMTypeConverter &typeConverter, + OwningRewritePatternList &patterns); + } // namespace mlir #endif // MLIR_CONVERSION_SPIRVTOLLVM_CONVERTSPIRVTOLLVM_H diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp @@ -21,6 +21,9 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "spirv-to-llvm-pattern" using namespace mlir; @@ -150,6 +153,32 @@ } }; +class ReturnPattern : public SPIRVToLLVMConversion { +public: + using SPIRVToLLVMConversion::SPIRVToLLVMConversion; + + LogicalResult + matchAndRewrite(spirv::ReturnOp returnOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(returnOp, ArrayRef(), + ArrayRef()); + return success(); + } +}; + +class ReturnValuePattern : public SPIRVToLLVMConversion { +public: + using SPIRVToLLVMConversion::SPIRVToLLVMConversion; + + LogicalResult + matchAndRewrite(spirv::ReturnValueOp returnValueOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(returnValueOp, ArrayRef(), + operands); + return success(); + } +}; + /// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect /// puts a restriction on `Shift` and `Base` to have the same bit width, /// `Shift` is zero or sign extended to match this specification. Cases when @@ -191,6 +220,64 @@ return success(); } }; + +//===----------------------------------------------------------------------===// +// FuncOp conversion +//===----------------------------------------------------------------------===// + +class FuncConversionPattern : public SPIRVToLLVMConversion { +public: + using SPIRVToLLVMConversion::SPIRVToLLVMConversion; + + LogicalResult + matchAndRewrite(spirv::FuncOp funcOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + + // Convert function signature. At the moment LLVMType converter is enough + // for currently supported types. + auto funcType = funcOp.getType(); + TypeConverter::SignatureConversion signatureConverter( + funcType.getNumInputs()); + auto llvmType = this->typeConverter.convertFunctionSignature( + funcOp.getType(), /*isVariadic=*/false, signatureConverter); + + // Create a new `LLVMFuncOp` + Location loc = funcOp.getLoc(); + StringRef name = funcOp.getName(); + auto newFuncOp = rewriter.create(loc, name, llvmType); + + // Convert SPIR-V Function Control to equivalent LLVM function attribute + MLIRContext *context = funcOp.getContext(); + switch (funcOp.function_control()) { +#define DISPATCH(functionControl, llvmAttr) \ + case functionControl: \ + newFuncOp.setAttr("passthrough", ArrayAttr::get({llvmAttr}, context)); \ + break; + + DISPATCH(spirv::FunctionControl::Inline, + StringAttr::get("alwaysinline", context)); + DISPATCH(spirv::FunctionControl::DontInline, + StringAttr::get("noinline", context)); + DISPATCH(spirv::FunctionControl::Pure, + StringAttr::get("readonly", context)); + DISPATCH(spirv::FunctionControl::Const, + StringAttr::get("readnone", context)); + +#undef DISPATCH + + // Default: if `spirv::FunctionControl::None`, then no attributes are + // needed. + default: + break; + } + + rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), + newFuncOp.end()); + rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter); + rewriter.eraseOp(funcOp); + return success(); + } +}; } // namespace //===----------------------------------------------------------------------===// @@ -263,6 +350,14 @@ // Shift ops ShiftPattern, ShiftPattern, - ShiftPattern>(context, - typeConverter); + ShiftPattern, + + // Return ops + ReturnPattern, ReturnValuePattern>(context, typeConverter); +} + +void mlir::populateSPIRVToLLVMFunctionConversionPatterns( + MLIRContext *context, LLVMTypeConverter &typeConverter, + OwningRewritePatternList &patterns) { + patterns.insert(context, typeConverter); } diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp @@ -35,6 +35,7 @@ OwningRewritePatternList patterns; populateSPIRVToLLVMConversionPatterns(context, converter, patterns); + populateSPIRVToLLVMFunctionConversionPatterns(context, converter, patterns); // Currently pulls in Std to LLVM conversion patterns // that help with testing. This allows to convert diff --git a/mlir/test/Conversion/SPIRVToLLVM/func-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/func-to-llvm.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/SPIRVToLLVM/func-to-llvm.mlir @@ -0,0 +1,62 @@ +// RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s + +//===----------------------------------------------------------------------===// +// spv.Return +//===----------------------------------------------------------------------===// + +func @return() { + // CHECK: llvm.return + spv.Return +} + +//===----------------------------------------------------------------------===// +// spv.ReturnValue +//===----------------------------------------------------------------------===// + +func @return_value(%arg: i32) { + // CHECK: llvm.return %{{.*}} : !llvm.i32 + spv.ReturnValue %arg : i32 +} + +//===----------------------------------------------------------------------===// +// spv.func +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: llvm.func @none() +spv.func @none() -> () "None" { + spv.Return +} + +// CHECK-LABEL: llvm.func @inline() attributes {passthrough = ["alwaysinline"]} +spv.func @inline() -> () "Inline" { + spv.Return +} + +// CHECK-LABEL: llvm.func @dont_inline() attributes {passthrough = ["noinline"]} +spv.func @dont_inline() -> () "DontInline" { + spv.Return +} + +// CHECK-LABEL: llvm.func @pure() attributes {passthrough = ["readonly"]} +spv.func @pure() -> () "Pure" { + spv.Return +} + +// CHECK-LABEL: llvm.func @const() attributes {passthrough = ["readnone"]} +spv.func @const() -> () "Const" { + spv.Return +} + +// CHECK-LABEL: llvm.func @scalar_types(%arg0: !llvm.i32, %arg1: !llvm.i1, %arg2: !llvm.double, %arg3: !llvm.float) +spv.func @scalar_types(%arg0: i32, %arg1: i1, %arg2: f64, %arg3: f32) -> () "None" { + spv.Return +} + +// CHECK-LABEL: llvm.func @vector_types(%arg0: !llvm<"<2 x i64>">, %arg1: !llvm<"<2 x i64>">) -> !llvm<"<2 x i64>"> +spv.func @vector_types(%arg0: vector<2xi64>, %arg1: vector<2xi64>) -> vector<2xi64> "None" { + %0 = spv.IAdd %arg0, %arg1 : vector<2xi64> + spv.ReturnValue %0 : vector<2xi64> +} + + +