diff --git a/mlir/include/mlir/Conversion/ComplexToSPIRV/ComplexToSPIRV.h b/mlir/include/mlir/Conversion/ComplexToSPIRV/ComplexToSPIRV.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/ComplexToSPIRV/ComplexToSPIRV.h @@ -0,0 +1,28 @@ +//===- ComplexToSPIRV.h - Complex to SPIR-V Patterns ------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Provides patterns to convert Complex dialect to SPIR-V dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_COMPLEXTOSPIRV_COMPLEXTOSPIRV_H +#define MLIR_CONVERSION_COMPLEXTOSPIRV_COMPLEXTOSPIRV_H + +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +class SPIRVTypeConverter; + +/// Appends to a pattern list additional patterns for translating Complex ops +/// to SPIR-V ops. +void populateComplexToSPIRVPatterns(SPIRVTypeConverter &typeConverter, + RewritePatternSet &patterns); + +} // namespace mlir + +#endif // MLIR_CONVERSION_COMPLEXTOSPIRV_COMPLEXTOSPIRV_H diff --git a/mlir/include/mlir/Conversion/ComplexToSPIRV/ComplexToSPIRVPass.h b/mlir/include/mlir/Conversion/ComplexToSPIRV/ComplexToSPIRVPass.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/ComplexToSPIRV/ComplexToSPIRVPass.h @@ -0,0 +1,26 @@ +//===- ComplexToSPIRVPass.h - Complex to SPIR-V Passes ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Provides passes to convert Complex dialect to SPIR-V dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_COMPLEXTOSPIRV_COMPLEXTOSPIRVPASS_H +#define MLIR_CONVERSION_COMPLEXTOSPIRV_COMPLEXTOSPIRVPASS_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { +class ModuleOp; + +#define GEN_PASS_DECL_CONVERTCOMPLEXTOSPIRVPASS +#include "mlir/Conversion/Passes.h.inc" + +} // namespace mlir + +#endif // MLIR_CONVERSION_COMPLEXTOSPIRV_COMPLEXTOSPIRVPASS_H diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -18,6 +18,7 @@ #include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h" #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" #include "mlir/Conversion/ComplexToLibm/ComplexToLibm.h" +#include "mlir/Conversion/ComplexToSPIRV/ComplexToSPIRVPass.h" #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" #include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -222,6 +222,15 @@ ]; } +//===----------------------------------------------------------------------===// +// ComplexToSPIRV +//===----------------------------------------------------------------------===// + +def ConvertComplexToSPIRVPass : Pass<"convert-complex-to-spirv"> { + let summary = "Convert Complex dialect to SPIRV dialect"; + let dependentDialects = ["spirv::SPIRVDialect"]; +} + //===----------------------------------------------------------------------===// // ComplexToStandard //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -8,6 +8,7 @@ add_subdirectory(BufferizationToMemRef) add_subdirectory(ComplexToLLVM) add_subdirectory(ComplexToLibm) +add_subdirectory(ComplexToSPIRV) add_subdirectory(ComplexToStandard) add_subdirectory(ControlFlowToLLVM) add_subdirectory(ControlFlowToSPIRV) diff --git a/mlir/lib/Conversion/ComplexToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/ComplexToSPIRV/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/ComplexToSPIRV/CMakeLists.txt @@ -0,0 +1,21 @@ +add_mlir_conversion_library(MLIRComplexToSPIRV + ComplexToSPIRV.cpp + ComplexToSPIRVPass.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SPIRV + ${MLIR_MAIN_INCLUDE_DIR}/mlir/IR + + DEPENDS + MLIRConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRComplexDialect + MLIRIR + MLIRPass + MLIRSPIRVDialect + MLIRSPIRVConversion + MLIRSupport + MLIRTransformUtils + ) + diff --git a/mlir/lib/Conversion/ComplexToSPIRV/ComplexToSPIRV.cpp b/mlir/lib/Conversion/ComplexToSPIRV/ComplexToSPIRV.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/ComplexToSPIRV/ComplexToSPIRV.cpp @@ -0,0 +1,93 @@ +//===- ComplexToSPIRV.cpp - Complex to SPIR-V Patterns --------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements patterns to convert Complex dialect to SPIR-V dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/ComplexToSPIRV/ComplexToSPIRV.h" +#include "mlir/Dialect/Complex/IR/Complex.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" +#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" +#include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "complex-to-spirv-pattern" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Operation conversion +//===----------------------------------------------------------------------===// + +namespace { + +struct CreateOpPattern final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(complex::CreateOp createOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type spirvType = getTypeConverter()->convertType(createOp.getType()); + if (!spirvType) + return rewriter.notifyMatchFailure(createOp, + "unable to convert result type"); + + rewriter.replaceOpWithNewOp( + createOp, spirvType, adaptor.getOperands()); + return success(); + } +}; + +struct ReOpPattern final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(complex::ReOp reOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type spirvType = getTypeConverter()->convertType(reOp.getType()); + if (!spirvType) + return rewriter.notifyMatchFailure(reOp, "unable to convert result type"); + + rewriter.replaceOpWithNewOp( + reOp, adaptor.getComplex(), llvm::ArrayRef(0)); + return success(); + } +}; + +struct ImOpPattern final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(complex::ImOp imOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type spirvType = getTypeConverter()->convertType(imOp.getType()); + if (!spirvType) + return rewriter.notifyMatchFailure(imOp, "unable to convert result type"); + + rewriter.replaceOpWithNewOp( + imOp, adaptor.getComplex(), llvm::ArrayRef(1)); + return success(); + } +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// Pattern population +//===----------------------------------------------------------------------===// + +void mlir::populateComplexToSPIRVPatterns(SPIRVTypeConverter &typeConverter, + RewritePatternSet &patterns) { + MLIRContext *context = patterns.getContext(); + + patterns.add(typeConverter, + context); +} diff --git a/mlir/lib/Conversion/ComplexToSPIRV/ComplexToSPIRVPass.cpp b/mlir/lib/Conversion/ComplexToSPIRV/ComplexToSPIRVPass.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/ComplexToSPIRV/ComplexToSPIRVPass.cpp @@ -0,0 +1,59 @@ +//===- ComplexToSPIRVPass.cpp - Complex to SPIR-V Passes ------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to convert Complex dialect to SPIR-V dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/ComplexToSPIRV/ComplexToSPIRVPass.h" + +#include "mlir/Conversion/ComplexToSPIRV/ComplexToSPIRV.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" + +namespace mlir { +#define GEN_PASS_DEF_CONVERTCOMPLEXTOSPIRVPASS +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +namespace { +/// A pass converting MLIR Complex operations into the SPIR-V dialect. +class ConvertComplexToSPIRVPass + : public impl::ConvertComplexToSPIRVPassBase { + void runOnOperation() override { + MLIRContext *context = &getContext(); + Operation *op = getOperation(); + + auto targetAttr = spirv::lookupTargetEnvOrDefault(op); + std::unique_ptr target = + SPIRVConversionTarget::get(targetAttr); + + SPIRVConversionOptions options; + SPIRVTypeConverter typeConverter(targetAttr, options); + + // Use UnrealizedConversionCast as the bridge so that we don't need to pull + // in patterns for other dialects. + auto addUnrealizedCast = [](OpBuilder &builder, Type type, + ValueRange inputs, Location loc) { + auto cast = builder.create(loc, type, inputs); + return std::optional(cast.getResult(0)); + }; + typeConverter.addSourceMaterialization(addUnrealizedCast); + typeConverter.addTargetMaterialization(addUnrealizedCast); + target->addLegalOp(); + + RewritePatternSet patterns(context); + populateComplexToSPIRVPatterns(typeConverter, patterns); + + if (failed(applyPartialConversion(op, *target, std::move(patterns)))) + return signalPassFailure(); + } +}; +} // namespace diff --git a/mlir/test/Conversion/ComplexToSPIRV/complex-to-spirv.mlir b/mlir/test/Conversion/ComplexToSPIRV/complex-to-spirv.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/ComplexToSPIRV/complex-to-spirv.mlir @@ -0,0 +1,40 @@ +// RUN: mlir-opt -split-input-file -convert-complex-to-spirv %s | FileCheck %s + +func.func @create_complex(%real: f32, %imag: f32) -> complex { + %0 = complex.create %real, %imag : complex + return %0 : complex +} + +// CHECK-LABEL: func.func @create_complex +// CHECK-SAME: (%[[RE:.+]]: f32, %[[IM:.+]]: f32) +// CHECK: %[[CC:.+]] = spirv.CompositeConstruct %[[RE]], %[[IM]] : (f32, f32) -> vector<2xf32> +// CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %[[CC]] : vector<2xf32> to complex +// CHECK: return %[[CAST]] : complex + + +// ----- + +func.func @real_number(%arg: complex) -> f32 { + %real = complex.re %arg : complex + return %real : f32 +} + +// CHECK-LABEL: func.func @real_number +// CHECK-SAME: %[[ARG:.+]]: complex +// CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG]] : complex to vector<2xf32> +// CHECK: %[[RE:.+]] = spirv.CompositeExtract %[[CAST]][0 : i32] : vector<2xf32> +// CHECK: return %[[RE]] : f32 + +// ----- + +func.func @imaginary_number(%arg: complex) -> f32 { + %imaginary = complex.im %arg : complex + return %imaginary: f32 +} + +// CHECK-LABEL: func.func @imaginary_number +// CHECK-SAME: %[[ARG:.+]]: complex +// CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG]] : complex to vector<2xf32> +// CHECK: %[[IM:.+]] = spirv.CompositeExtract %[[CAST]][1 : i32] : vector<2xf32> +// CHECK: return %[[IM]] : f32 + diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -2771,6 +2771,7 @@ ":BufferizationToMemRef", ":ComplexToLLVM", ":ComplexToLibm", + ":ComplexToSPIRV", ":ComplexToStandard", ":ControlFlowToLLVM", ":ControlFlowToSPIRV", @@ -7080,6 +7081,7 @@ ":ComplexDialect", ":ComplexToLLVM", ":ComplexToLibm", + ":ComplexToSPIRV", ":ControlFlowDialect", ":ConversionPasses", ":DLTIDialect", @@ -9593,6 +9595,31 @@ ], ) +cc_library( + name = "ComplexToSPIRV", + srcs = glob([ + "lib/Conversion/ComplexToSPIRV/*.cpp", + "lib/Conversion/ComplexToSPIRV/*.h", + ]), + hdrs = glob([ + "include/mlir/Conversion/ComplexToSPIRV/*.h", + ]), + includes = ["include"], + deps = [ + ":ComplexDialect", + ":ConversionPassIncGen", + ":IR", + ":Pass", + ":SPIRVCommonConversion", + ":SPIRVConversion", + ":SPIRVDialect", + ":Support", + ":Transforms", + "//llvm:Core", + "//llvm:Support", + ], +) + cc_library( name = "ComplexToStandard", srcs = glob([