diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -207,6 +207,15 @@ } //===----------------------------------------------------------------------===// +// SPIRVToLLVM +//===----------------------------------------------------------------------===// + +def ConvertSPIRVToLLVM : Pass<"convert-spirv-to-llvm", "ModuleOp"> { + let summary = "Convert SPIR-V dialect to LLVM dialect"; + let constructor = "mlir::createConvertSPIRVToLLVMPass()"; +} + +//===----------------------------------------------------------------------===// // StandardToLLVM //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h b/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h @@ -0,0 +1,30 @@ +//===- ConvertSPIRVToLLVM.h - Convert SPIR-V to LLVM dialect ----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Provides patterns to convert SPIR-V dialect to LLVM dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_SPIRVTOLLVM_CONVERTSPIRVTOLLVM_H +#define MLIR_CONVERSION_SPIRVTOLLVM_CONVERTSPIRVTOLLVM_H + +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +class LLVMTypeConverter; +class MLIRContext; +class ModuleOp; + +/// Populates the given list with patterns that convert from SPIR-V to LLVM. +void populateSPIRVToLLVMConversionPatterns(MLIRContext *context, + LLVMTypeConverter &typeConverter, + OwningRewritePatternList &patterns); + +} // namespace mlir + +#endif // MLIR_CONVERSION_SPIRVTOLLVM_CONVERTSPIRVTOLLVM_H diff --git a/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.h b/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.h @@ -0,0 +1,28 @@ +//===- ConvertSPIRVToLLVMPass.h - SPIR-V dialect to LLVM pass ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Provides a pass to lower from SPIR-V dialect to LLVM dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_SPIRVTOLLVM_CONVERTSPIRVTOLLVMPASS_H +#define MLIR_CONVERSION_SPIRVTOLLVM_CONVERTSPIRVTOLLVMPASS_H + +#include + +namespace mlir { +class ModuleOp; +template +class OperationPass; + +/// Creates a pass to convert SPIR-V operations to the LLVMIR dialect. +std::unique_ptr> createConvertSPIRVToLLVMPass(); + +} // namespace mlir + +#endif // MLIR_CONVERSION_SPIRVTOLLVM_CONVERTSPIRVTOLLVMPASS_H_ diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h --- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h +++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h @@ -13,7 +13,8 @@ namespace mlir { class LLVMTypeConverter; class ModuleOp; -template class OperationPass; +template +class OperationPass; /// Collect a set of patterns to convert from Vector contractions to LLVM Matrix /// Intrinsics. To lower to assembly, the LLVM flag -lower-matrix-intrinsics diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h --- a/mlir/include/mlir/InitAllPasses.h +++ b/mlir/include/mlir/InitAllPasses.h @@ -25,6 +25,7 @@ #include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h" #include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h" #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" +#include "mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.h" #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h" diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -11,6 +11,7 @@ add_subdirectory(SCFToGPU) add_subdirectory(SCFToStandard) add_subdirectory(ShapeToStandard) +add_subdirectory(SPIRVToLLVM) add_subdirectory(StandardToLLVM) add_subdirectory(StandardToSPIRV) add_subdirectory(VectorToLLVM) diff --git a/mlir/lib/Conversion/SPIRVToLLVM/CMakeLists.txt b/mlir/lib/Conversion/SPIRVToLLVM/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/SPIRVToLLVM/CMakeLists.txt @@ -0,0 +1,18 @@ +add_mlir_conversion_library(MLIRSPIRVToLLVM + ConvertSPIRVToLLVM.cpp + ConvertSPIRVToLLVMPass.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/SPIRVToLLVM + + DEPENDS + MLIRConversionPassIncGen + intrinsics_gen + + LINK_LIBS PUBLIC + MLIRSPIRV + MLIRLLVMIR + MLIRStandardToLLVM + MLIRIR + MLIRTransforms + ) diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp @@ -0,0 +1,53 @@ +//===- ConvertSPIRVToLLVM.cpp - SPIR-V dialect to LLVM dialect conversion -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements patterns to convert SPIR-V dialect to LLVM dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/SPIRV/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/SPIRVOps.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; + +namespace { + +class BitwiseAndOpConversion : public ConvertToLLVMPattern { +public: + explicit BitwiseAndOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : ConvertToLLVMPattern(spirv::BitwiseAndOp::getOperationName(), context, + typeConverter) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto bitwiseAndOp = cast(op); + auto dstType = typeConverter.convertType(bitwiseAndOp.getType()); + if (!dstType) + return failure(); + rewriter.replaceOpWithNewOp(bitwiseAndOp, dstType, operands); + return success(); + } +}; +} // namespace + +void mlir::populateSPIRVToLLVMConversionPatterns( + MLIRContext *context, LLVMTypeConverter &typeConverter, + OwningRewritePatternList &patterns) { + patterns.insert(context, typeConverter); +} diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp @@ -0,0 +1,54 @@ +//===- ConvertSPIRVToLLVMPass.cpp - Convert SPIR-V ops to LLVM ops --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to convert MLIR SPIR-V ops into LLVM ops +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.h" +#include "../PassDetail.h" +#include "mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/SPIRV/SPIRVDialect.h" + +using namespace mlir; + +namespace { +/// A pass converting MLIR SPIR-V operations into LLVM dialect. +class ConvertSPIRVToLLVMPass + : public ConvertSPIRVToLLVMBase { + void runOnOperation() override; +}; +} // namespace + +void ConvertSPIRVToLLVMPass::runOnOperation() { + MLIRContext *context = &getContext(); + ModuleOp module = getOperation(); + LLVMTypeConverter converter(&getContext()); + + OwningRewritePatternList patterns; + populateSPIRVToLLVMConversionPatterns(context, converter, patterns); + + // Currently pulls in Std to LLVM conversion patterns + // that help with testing. This allows to convert + // function arguments to LLVM. + populateStdToLLVMConversionPatterns(converter, patterns); + + ConversionTarget target(getContext()); + target.addIllegalDialect(); + target.addLegalDialect(); + + if (failed(applyPartialConversion(module, target, patterns, &converter))) + signalPassFailure(); +} + +std::unique_ptr> mlir::createConvertSPIRVToLLVMPass() { + return std::make_unique(); +} diff --git a/mlir/test/Conversion/SPIRVToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/convert-to-llvm.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/SPIRVToLLVM/convert-to-llvm.mlir @@ -0,0 +1,13 @@ +// RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s + +func @bitwise_and_scalar(%arg0: i32, %arg1: i32) { + // CHECK: %{{.*}} = llvm.and %{{.*}}, %{{.*}} : !llvm.i32 + %0 = spv.BitwiseAnd %arg0, %arg1 : i32 + return +} + +func @bitwise_and_vector(%arg0: vector<4xi64>, %arg1: vector<4xi64>) { + // CHECK: %{{.*}} = llvm.and %{{.*}}, %{{.*}} : !llvm<"<4 x i64>"> + %0 = spv.BitwiseAnd %arg0, %arg1 : vector<4xi64> + return +}