diff --git a/mlir/include/mlir/Conversion/MathToSPIRV/MathToSPIRV.h b/mlir/include/mlir/Conversion/MathToSPIRV/MathToSPIRV.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/MathToSPIRV/MathToSPIRV.h @@ -0,0 +1,28 @@ +//===- MathToSPIRV.h - Math to SPIR-V Patterns ------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Provides patterns to convert Math dialect to SPIR-V dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_MATHTOSPIRV_MATHTOSPIRV_H +#define MLIR_CONVERSION_MATHTOSPIRV_MATHTOSPIRV_H + +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +class SPIRVTypeConverter; + +/// Appends to a pattern list additional patterns for translating Math ops +/// to SPIR-V ops. +void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter, + RewritePatternSet &patterns); + +} // namespace mlir + +#endif // MLIR_CONVERSION_MATHTOSPIRV_MATHTOSPIRV_H diff --git a/mlir/include/mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h b/mlir/include/mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h @@ -0,0 +1,25 @@ +//===- MathToSPIRVPass.h - Math to SPIR-V Passes ----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Provides passes to convert Math dialect to SPIR-V dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_MATHTOSPIRV_MATHTOSPIRVPASS_H +#define MLIR_CONVERSION_MATHTOSPIRV_MATHTOSPIRVPASS_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { + +/// Creates a pass to convert Math ops to SPIR-V ops. +std::unique_ptr> createConvertMathToSPIRVPass(); + +} // namespace mlir + +#endif // MLIR_CONVERSION_MATHTOSPIRV_MATHTOSPIRVPASS_H diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -24,6 +24,7 @@ #include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h" #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" #include "mlir/Conversion/MathToLibm/MathToLibm.h" +#include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/OpenACCToLLVM/ConvertOpenACCToLLVM.h" #include "mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -268,6 +268,16 @@ let dependentDialects = ["LLVM::LLVMDialect"]; } +//===----------------------------------------------------------------------===// +// MathToSPIRV +//===----------------------------------------------------------------------===// + +def ConvertMathToSPIRV : Pass<"convert-math-to-spirv", "ModuleOp"> { + let summary = "Convert Math dialect to SPIR-V dialect"; + let constructor = "mlir::createConvertMathToSPIRVPass()"; + let dependentDialects = ["spirv::SPIRVDialect"]; +} + //===----------------------------------------------------------------------===// // MemRefToLLVM //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -14,6 +14,7 @@ add_subdirectory(LLVMCommon) add_subdirectory(MathToLibm) add_subdirectory(MathToLLVM) +add_subdirectory(MathToSPIRV) add_subdirectory(MemRefToLLVM) add_subdirectory(OpenACCToLLVM) add_subdirectory(OpenACCToSCF) diff --git a/mlir/lib/Conversion/MathToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/MathToSPIRV/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/MathToSPIRV/CMakeLists.txt @@ -0,0 +1,21 @@ +add_mlir_conversion_library(MLIRMathToSPIRV + MathToSPIRV.cpp + MathToSPIRVPass.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SPIRV + ${MLIR_MAIN_INCLUDE_DIR}/mlir/IR + + DEPENDS + MLIRConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRMath + MLIRPass + MLIRSPIRV + MLIRSPIRVConversion + MLIRSupport + MLIRTransformUtils + ) + diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp @@ -0,0 +1,99 @@ +//===- MathToSPIRV.cpp - Math to SPIR-V Patterns --------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements patterns to convert Math dialect to SPIR-V dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" +#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "math-to-spirv-pattern" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Operation conversion +//===----------------------------------------------------------------------===// + +// Note that DRR cannot be used for the patterns in this file: we may need to +// convert type along the way, which requires ConversionPattern. DRR generates +// normal RewritePattern. + +namespace { + +/// Converts unary and binary standard operations to SPIR-V operations. +template +class UnaryAndBinaryOpPattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(StdOp operation, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + assert(operands.size() <= 2); + auto dstType = this->getTypeConverter()->convertType(operation.getType()); + if (!dstType) + return failure(); + if (SPIRVOp::template hasTrait() && + dstType != operation.getType()) { + return operation.emitError( + "bitwidth emulation is not implemented yet on unsigned op"); + } + rewriter.template replaceOpWithNewOp(operation, dstType, operands); + return success(); + } +}; + +/// Converts math.log1p to SPIR-V ops. +/// +/// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to +/// these operations. +class Log1pOpPattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(math::Log1pOp operation, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + assert(operands.size() == 1); + Location loc = operation.getLoc(); + auto type = + this->getTypeConverter()->convertType(operation.operand().getType()); + auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter); + auto onePlus = rewriter.create(loc, one, operands[0]); + rewriter.replaceOpWithNewOp(operation, type, onePlus); + return success(); + } +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// Pattern population +//===----------------------------------------------------------------------===// + +namespace mlir { +void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter, + RewritePatternSet &patterns) { + patterns.add, + UnaryAndBinaryOpPattern, + UnaryAndBinaryOpPattern, + UnaryAndBinaryOpPattern, + UnaryAndBinaryOpPattern, + UnaryAndBinaryOpPattern, + UnaryAndBinaryOpPattern, + UnaryAndBinaryOpPattern>( + typeConverter, patterns.getContext()); +} + +} // namespace mlir diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRVPass.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRVPass.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRVPass.cpp @@ -0,0 +1,48 @@ +//===- MathToSPIRVPass.cpp - Math to SPIR-V Passes ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to convert standard dialect to SPIR-V dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h" +#include "../PassDetail.h" +#include "mlir/Conversion/MathToSPIRV/MathToSPIRV.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" + +using namespace mlir; + +namespace { +/// A pass converting MLIR Math operations into the SPIR-V dialect. +class ConvertMathToSPIRVPass + : public ConvertMathToSPIRVBase { + void runOnOperation() override; +}; +} // namespace + +void ConvertMathToSPIRVPass::runOnOperation() { + MLIRContext *context = &getContext(); + ModuleOp module = getOperation(); + + auto targetAttr = spirv::lookupTargetEnvOrDefault(module); + std::unique_ptr target = + SPIRVConversionTarget::get(targetAttr); + + SPIRVTypeConverter typeConverter(targetAttr); + + RewritePatternSet patterns(context); + populateMathToSPIRVPatterns(typeConverter, patterns); + + if (failed(applyPartialConversion(module, *target, std::move(patterns)))) + return signalPassFailure(); +} + +std::unique_ptr> mlir::createConvertMathToSPIRVPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp @@ -10,7 +10,6 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" @@ -317,28 +316,6 @@ } }; -/// Converts math.log1p to SPIR-V ops. -/// -/// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to -/// these operations. -class Log1pOpPattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(math::Log1pOp operation, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - assert(operands.size() == 1); - Location loc = operation.getLoc(); - auto type = - this->getTypeConverter()->convertType(operation.operand().getType()); - auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter); - auto onePlus = rewriter.create(loc, one, operands[0]); - rewriter.replaceOpWithNewOp(operation, type, onePlus); - return success(); - } -}; - /// Converts std.remi_signed to SPIR-V ops. /// /// This cannot be merged into the template unary/binary pattern due to @@ -1336,17 +1313,6 @@ MLIRContext *context = patterns.getContext(); patterns.add< - // Math dialect operations. - // TODO: Move to separate pass. - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - // Unary and binary patterns BitwiseOpPattern, BitwiseOpPattern, @@ -1369,7 +1335,7 @@ UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, - Log1pOpPattern, SignedRemIOpPattern, XOrOpPattern, BoolXOrOpPattern, + SignedRemIOpPattern, XOrOpPattern, BoolXOrOpPattern, // Comparison patterns BoolCmpIOpPattern, CmpFOpPattern, CmpFOpNanNonePattern, CmpIOpPattern, diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-spirv.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/MathToSPIRV/math-to-spirv.mlir @@ -0,0 +1,61 @@ +// RUN: mlir-opt -split-input-file -convert-math-to-spirv -verify-diagnostics %s -o - | FileCheck %s + +// CHECK-LABEL: @float32_unary_scalar +func @float32_unary_scalar(%arg0: f32) { + // CHECK: spv.GLSL.Cos %{{.*}}: f32 + %0 = math.cos %arg0 : f32 + // CHECK: spv.GLSL.Exp %{{.*}}: f32 + %1 = math.exp %arg0 : f32 + // CHECK: spv.GLSL.Log %{{.*}}: f32 + %2 = math.log %arg0 : f32 + // CHECK: %[[ONE:.+]] = spv.Constant 1.000000e+00 : f32 + // CHECK: %[[ADDONE:.+]] = spv.FAdd %[[ONE]], %{{.+}} + // CHECK: spv.GLSL.Log %[[ADDONE]] + %3 = math.log1p %arg0 : f32 + // CHECK: spv.GLSL.InverseSqrt %{{.*}}: f32 + %4 = math.rsqrt %arg0 : f32 + // CHECK: spv.GLSL.Sqrt %{{.*}}: f32 + %5 = math.sqrt %arg0 : f32 + // CHECK: spv.GLSL.Tanh %{{.*}}: f32 + %6 = math.tanh %arg0 : f32 + // CHECK: spv.GLSL.Sin %{{.*}}: f32 + %7 = math.sin %arg0 : f32 + return +} + +// CHECK-LABEL: @float32_unary_vector +func @float32_unary_vector(%arg0: vector<3xf32>) { + // CHECK: spv.GLSL.Cos %{{.*}}: vector<3xf32> + %0 = math.cos %arg0 : vector<3xf32> + // CHECK: spv.GLSL.Exp %{{.*}}: vector<3xf32> + %1 = math.exp %arg0 : vector<3xf32> + // CHECK: spv.GLSL.Log %{{.*}}: vector<3xf32> + %2 = math.log %arg0 : vector<3xf32> + // CHECK: %[[ONE:.+]] = spv.Constant dense<1.000000e+00> : vector<3xf32> + // CHECK: %[[ADDONE:.+]] = spv.FAdd %[[ONE]], %{{.+}} + // CHECK: spv.GLSL.Log %[[ADDONE]] + %3 = math.log1p %arg0 : vector<3xf32> + // CHECK: spv.GLSL.InverseSqrt %{{.*}}: vector<3xf32> + %4 = math.rsqrt %arg0 : vector<3xf32> + // CHECK: spv.GLSL.Sqrt %{{.*}}: vector<3xf32> + %5 = math.sqrt %arg0 : vector<3xf32> + // CHECK: spv.GLSL.Tanh %{{.*}}: vector<3xf32> + %6 = math.tanh %arg0 : vector<3xf32> + // CHECK: spv.GLSL.Sin %{{.*}}: vector<3xf32> + %7 = math.sin %arg0 : vector<3xf32> + return +} + +// CHECK-LABEL: @float32_binary_scalar +func @float32_binary_scalar(%lhs: f32, %rhs: f32) { + // CHECK: spv.GLSL.Pow %{{.*}}: f32 + %0 = math.powf %lhs, %rhs : f32 + return +} + +// CHECK-LABEL: @float32_binary_vector +func @float32_binary_vector(%lhs: vector<4xf32>, %rhs: vector<4xf32>) { + // CHECK: spv.GLSL.Pow %{{.*}}: vector<4xf32> + %0 = math.powf %lhs, %rhs : vector<4xf32> + return +} diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir @@ -47,26 +47,8 @@ %0 = absf %arg0 : f32 // CHECK: spv.GLSL.Ceil %{{.*}}: f32 %1 = ceilf %arg0 : f32 - // CHECK: spv.GLSL.Cos %{{.*}}: f32 - %2 = math.cos %arg0 : f32 - // CHECK: spv.GLSL.Exp %{{.*}}: f32 - %3 = math.exp %arg0 : f32 - // CHECK: spv.GLSL.Log %{{.*}}: f32 - %4 = math.log %arg0 : f32 - // CHECK: %[[ONE:.+]] = spv.Constant 1.000000e+00 : f32 - // CHECK: %[[ADDONE:.+]] = spv.FAdd %[[ONE]], %{{.+}} - // CHECK: spv.GLSL.Log %[[ADDONE]] - %40 = math.log1p %arg0 : f32 // CHECK: spv.FNegate %{{.*}}: f32 %5 = negf %arg0 : f32 - // CHECK: spv.GLSL.InverseSqrt %{{.*}}: f32 - %6 = math.rsqrt %arg0 : f32 - // CHECK: spv.GLSL.Sqrt %{{.*}}: f32 - %7 = math.sqrt %arg0 : f32 - // CHECK: spv.GLSL.Tanh %{{.*}}: f32 - %8 = math.tanh %arg0 : f32 - // CHECK: spv.GLSL.Sin %{{.*}}: f32 - %9 = math.sin %arg0 : f32 // CHECK: spv.GLSL.Floor %{{.*}}: f32 %10 = floorf %arg0 : f32 return @@ -85,8 +67,6 @@ %3 = divf %lhs, %rhs: f32 // CHECK: spv.FRem %{{.*}}, %{{.*}}: f32 %4 = remf %lhs, %rhs: f32 - // CHECK: spv.GLSL.Pow %{{.*}}: f32 - %5 = math.powf %lhs, %rhs : f32 return } diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -1916,6 +1916,7 @@ ":LinalgToStandard", ":MathToLLVM", ":MathToLibm", + ":MathToSPIRV", ":MemRefToLLVM", ":OpenACCToLLVM", ":OpenACCToSCF", @@ -3643,6 +3644,32 @@ ], ) +cc_library( + name = "MathToSPIRV", + srcs = glob([ + "lib/Conversion/MathToSPIRV/*.cpp", + "lib/Conversion/MathToSPIRV/*.h", + ]) + ["lib/Conversion/PassDetail.h"], + hdrs = glob([ + "include/mlir/Conversion/MathToSPIRV/*.h", + ]), + includes = [ + "include", + "lib/Conversion/MathToSPIRV", + ], + deps = [ + ":ConversionPassIncGen", + ":IR", + ":MathDialect", + ":Pass", + ":SPIRVConversion", + ":SPIRVDialect", + ":Support", + ":Transforms", + "//llvm:Support", + ], +) + cc_library( name = "StandardToSPIRV", srcs = glob([ @@ -3659,7 +3686,6 @@ deps = [ ":ConversionPassIncGen", ":IR", - ":MathDialect", ":MemRefDialect", ":Pass", ":SPIRVConversion", @@ -4952,6 +4978,7 @@ ":MathDialect", ":MathToLLVM", ":MathToLibm", + ":MathToSPIRV", ":MathTransforms", ":MemRefDialect", ":MemRefToLLVM",