diff --git a/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h b/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h @@ -0,0 +1,26 @@ +//===- MathToLLVM.h - Math to LLVM dialect conversion -----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_MATHTOLLVM_MATHTOLLVM_H +#define MLIR_CONVERSION_MATHTOLLVM_MATHTOLLVM_H + +#include + +namespace mlir { + +class LLVMTypeConverter; +class RewritePatternSet; +class Pass; + +void populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter, + RewritePatternSet &patterns); + +std::unique_ptr createConvertMathToLLVMPass(); +} // namespace mlir + +#endif // MLIR_CONVERSION_MATHTOLLVM_MATHTOLLVM_H diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -22,6 +22,7 @@ #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" #include "mlir/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.h" #include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h" +#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" #include "mlir/Conversion/MathToLibm/MathToLibm.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/OpenACCToLLVM/ConvertOpenACCToLLVM.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -255,6 +255,19 @@ let dependentDialects = ["StandardOpsDialect", "vector::VectorDialect"]; } +//===----------------------------------------------------------------------===// +// MathToLLVM +//===----------------------------------------------------------------------===// + +def ConvertMathToLLVM : FunctionPass<"convert-math-to-llvm"> { + let summary = "Convert Math dialect to LLVM dialect"; + let description = [{ + This pass converts supported Math ops to LLVM dialect intrinsics. + }]; + let constructor = "mlir::createConvertMathToLLVMPass()"; + let dependentDialects = ["LLVM::LLVMDialect"]; +} + //===----------------------------------------------------------------------===// // MemRefToLLVM //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -13,6 +13,7 @@ add_subdirectory(LinalgToStandard) add_subdirectory(LLVMCommon) add_subdirectory(MathToLibm) +add_subdirectory(MathToLLVM) add_subdirectory(MemRefToLLVM) add_subdirectory(OpenACCToLLVM) add_subdirectory(OpenACCToSCF) diff --git a/mlir/lib/Conversion/MathToLLVM/CMakeLists.txt b/mlir/lib/Conversion/MathToLLVM/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/MathToLLVM/CMakeLists.txt @@ -0,0 +1,19 @@ +add_mlir_conversion_library(MLIRMathToLLVM + MathToLLVM.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToLLVM + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRLLVMCommonConversion + MLIRLLVMIR + MLIRMath + MLIRPass + MLIRTransforms + ) diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp @@ -0,0 +1,234 @@ +//===- MathToLLVM.cpp - Math to LLVM dialect conversion -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" +#include "../PassDetail.h" +#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/VectorPattern.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/IR/TypeUtilities.h" + +using namespace mlir; + +namespace { +using CosOpLowering = VectorConvertToLLVMPattern; +using ExpOpLowering = VectorConvertToLLVMPattern; +using Exp2OpLowering = VectorConvertToLLVMPattern; +using Log10OpLowering = + VectorConvertToLLVMPattern; +using Log2OpLowering = VectorConvertToLLVMPattern; +using LogOpLowering = VectorConvertToLLVMPattern; +using PowFOpLowering = VectorConvertToLLVMPattern; +using SinOpLowering = VectorConvertToLLVMPattern; +using SqrtOpLowering = VectorConvertToLLVMPattern; + +// A `expm1` is converted into `exp - 1`. +struct ExpM1OpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(math::ExpM1Op op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + math::ExpM1Op::Adaptor transformed(operands); + auto operandType = transformed.operand().getType(); + + if (!operandType || !LLVM::isCompatibleType(operandType)) + return failure(); + + auto loc = op.getLoc(); + auto resultType = op.getResult().getType(); + auto floatType = getElementTypeOrSelf(resultType).cast(); + auto floatOne = rewriter.getFloatAttr(floatType, 1.0); + + if (!operandType.isa()) { + LLVM::ConstantOp one; + if (LLVM::isCompatibleVectorType(operandType)) { + one = rewriter.create( + loc, operandType, + SplatElementsAttr::get(resultType.cast(), floatOne)); + } else { + one = rewriter.create(loc, operandType, floatOne); + } + auto exp = rewriter.create(loc, transformed.operand()); + rewriter.replaceOpWithNewOp(op, operandType, exp, one); + return success(); + } + + auto vectorType = resultType.dyn_cast(); + if (!vectorType) + return rewriter.notifyMatchFailure(op, "expected vector result type"); + + return LLVM::detail::handleMultidimensionalVectors( + op.getOperation(), operands, *getTypeConverter(), + [&](Type llvm1DVectorTy, ValueRange operands) { + auto splatAttr = SplatElementsAttr::get( + mlir::VectorType::get( + {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()}, + floatType), + floatOne); + auto one = + rewriter.create(loc, llvm1DVectorTy, splatAttr); + auto exp = + rewriter.create(loc, llvm1DVectorTy, operands[0]); + return rewriter.create(loc, llvm1DVectorTy, exp, one); + }, + rewriter); + } +}; + +// A `log1p` is converted into `log(1 + ...)`. +struct Log1pOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(math::Log1pOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + math::Log1pOp::Adaptor transformed(operands); + auto operandType = transformed.operand().getType(); + + if (!operandType || !LLVM::isCompatibleType(operandType)) + return rewriter.notifyMatchFailure(op, "unsupported operand type"); + + auto loc = op.getLoc(); + auto resultType = op.getResult().getType(); + auto floatType = getElementTypeOrSelf(resultType).cast(); + auto floatOne = rewriter.getFloatAttr(floatType, 1.0); + + if (!operandType.isa()) { + LLVM::ConstantOp one = + LLVM::isCompatibleVectorType(operandType) + ? rewriter.create( + loc, operandType, + SplatElementsAttr::get(resultType.cast(), + floatOne)) + : rewriter.create(loc, operandType, floatOne); + + auto add = rewriter.create(loc, operandType, one, + transformed.operand()); + rewriter.replaceOpWithNewOp(op, operandType, add); + return success(); + } + + auto vectorType = resultType.dyn_cast(); + if (!vectorType) + return rewriter.notifyMatchFailure(op, "expected vector result type"); + + return LLVM::detail::handleMultidimensionalVectors( + op.getOperation(), operands, *getTypeConverter(), + [&](Type llvm1DVectorTy, ValueRange operands) { + auto splatAttr = SplatElementsAttr::get( + mlir::VectorType::get( + {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()}, + floatType), + floatOne); + auto one = + rewriter.create(loc, llvm1DVectorTy, splatAttr); + auto add = rewriter.create(loc, llvm1DVectorTy, one, + operands[0]); + return rewriter.create(loc, llvm1DVectorTy, add); + }, + rewriter); + } +}; + +// A `rsqrt` is converted into `1 / sqrt`. +struct RsqrtOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(math::RsqrtOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + math::RsqrtOp::Adaptor transformed(operands); + auto operandType = transformed.operand().getType(); + + if (!operandType || !LLVM::isCompatibleType(operandType)) + return failure(); + + auto loc = op.getLoc(); + auto resultType = op.getResult().getType(); + auto floatType = getElementTypeOrSelf(resultType).cast(); + auto floatOne = rewriter.getFloatAttr(floatType, 1.0); + + if (!operandType.isa()) { + LLVM::ConstantOp one; + if (LLVM::isCompatibleVectorType(operandType)) { + one = rewriter.create( + loc, operandType, + SplatElementsAttr::get(resultType.cast(), floatOne)); + } else { + one = rewriter.create(loc, operandType, floatOne); + } + auto sqrt = rewriter.create(loc, transformed.operand()); + rewriter.replaceOpWithNewOp(op, operandType, one, sqrt); + return success(); + } + + auto vectorType = resultType.dyn_cast(); + if (!vectorType) + return failure(); + + return LLVM::detail::handleMultidimensionalVectors( + op.getOperation(), operands, *getTypeConverter(), + [&](Type llvm1DVectorTy, ValueRange operands) { + auto splatAttr = SplatElementsAttr::get( + mlir::VectorType::get( + {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()}, + floatType), + floatOne); + auto one = + rewriter.create(loc, llvm1DVectorTy, splatAttr); + auto sqrt = + rewriter.create(loc, llvm1DVectorTy, operands[0]); + return rewriter.create(loc, llvm1DVectorTy, one, sqrt); + }, + rewriter); + } +}; + +struct ConvertMathToLLVMPass + : public ConvertMathToLLVMBase { + ConvertMathToLLVMPass() = default; + + void runOnFunction() override { + RewritePatternSet patterns(&getContext()); + LLVMTypeConverter converter(&getContext()); + populateMathToLLVMConversionPatterns(converter, patterns); + LLVMConversionTarget target(getContext()); + target.addLegalOp(); + if (failed( + applyPartialConversion(getFunction(), target, std::move(patterns)))) + signalPassFailure(); + } +}; +} // namespace + +void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter, + RewritePatternSet &patterns) { + // clang-format off + patterns.add< + CosOpLowering, + ExpOpLowering, + Exp2OpLowering, + ExpM1OpLowering, + Log10OpLowering, + Log1pOpLowering, + Log2OpLowering, + LogOpLowering, + PowFOpLowering, + RsqrtOpLowering, + SinOpLowering, + SqrtOpLowering + >(converter); + // clang-format on +} + +std::unique_ptr mlir::createConvertMathToLLVMPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -373,25 +373,17 @@ using CeilFOpLowering = VectorConvertToLLVMPattern; using CopySignOpLowering = VectorConvertToLLVMPattern; -using CosOpLowering = VectorConvertToLLVMPattern; using DivFOpLowering = VectorConvertToLLVMPattern; -using ExpOpLowering = VectorConvertToLLVMPattern; -using Exp2OpLowering = VectorConvertToLLVMPattern; using FPExtOpLowering = VectorConvertToLLVMPattern; using FPToSIOpLowering = VectorConvertToLLVMPattern; using FPToUIOpLowering = VectorConvertToLLVMPattern; using FPTruncOpLowering = VectorConvertToLLVMPattern; using FloorFOpLowering = VectorConvertToLLVMPattern; using FmaFOpLowering = VectorConvertToLLVMPattern; -using Log10OpLowering = - VectorConvertToLLVMPattern; -using Log2OpLowering = VectorConvertToLLVMPattern; -using LogOpLowering = VectorConvertToLLVMPattern; using MulFOpLowering = VectorConvertToLLVMPattern; using MulIOpLowering = VectorConvertToLLVMPattern; using NegFOpLowering = VectorConvertToLLVMPattern; using OrOpLowering = VectorConvertToLLVMPattern; -using PowFOpLowering = VectorConvertToLLVMPattern; using RemFOpLowering = VectorConvertToLLVMPattern; using SIToFPOpLowering = VectorConvertToLLVMPattern; using SelectOpLowering = VectorConvertToLLVMPattern; @@ -405,8 +397,6 @@ VectorConvertToLLVMPattern; using SignedShiftRightOpLowering = OneToOneConvertToLLVMPattern; -using SinOpLowering = VectorConvertToLLVMPattern; -using SqrtOpLowering = VectorConvertToLLVMPattern; using SubFOpLowering = VectorConvertToLLVMPattern; using SubIOpLowering = VectorConvertToLLVMPattern; using TruncateIOpLowering = VectorConvertToLLVMPattern; @@ -656,169 +646,6 @@ using Super::Super; }; -// A `expm1` is converted into `exp - 1`. -struct ExpM1OpLowering : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(math::ExpM1Op op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - math::ExpM1Op::Adaptor transformed(operands); - auto operandType = transformed.operand().getType(); - - if (!operandType || !LLVM::isCompatibleType(operandType)) - return failure(); - - auto loc = op.getLoc(); - auto resultType = op.getResult().getType(); - auto floatType = getElementTypeOrSelf(resultType).cast(); - auto floatOne = rewriter.getFloatAttr(floatType, 1.0); - - if (!operandType.isa()) { - LLVM::ConstantOp one; - if (LLVM::isCompatibleVectorType(operandType)) { - one = rewriter.create( - loc, operandType, - SplatElementsAttr::get(resultType.cast(), floatOne)); - } else { - one = rewriter.create(loc, operandType, floatOne); - } - auto exp = rewriter.create(loc, transformed.operand()); - rewriter.replaceOpWithNewOp(op, operandType, exp, one); - return success(); - } - - auto vectorType = resultType.dyn_cast(); - if (!vectorType) - return rewriter.notifyMatchFailure(op, "expected vector result type"); - - return LLVM::detail::handleMultidimensionalVectors( - op.getOperation(), operands, *getTypeConverter(), - [&](Type llvm1DVectorTy, ValueRange operands) { - auto splatAttr = SplatElementsAttr::get( - mlir::VectorType::get( - {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()}, - floatType), - floatOne); - auto one = - rewriter.create(loc, llvm1DVectorTy, splatAttr); - auto exp = - rewriter.create(loc, llvm1DVectorTy, operands[0]); - return rewriter.create(loc, llvm1DVectorTy, exp, one); - }, - rewriter); - } -}; - -// A `log1p` is converted into `log(1 + ...)`. -struct Log1pOpLowering : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(math::Log1pOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - math::Log1pOp::Adaptor transformed(operands); - auto operandType = transformed.operand().getType(); - - if (!operandType || !LLVM::isCompatibleType(operandType)) - return rewriter.notifyMatchFailure(op, "unsupported operand type"); - - auto loc = op.getLoc(); - auto resultType = op.getResult().getType(); - auto floatType = getElementTypeOrSelf(resultType).cast(); - auto floatOne = rewriter.getFloatAttr(floatType, 1.0); - - if (!operandType.isa()) { - LLVM::ConstantOp one = - LLVM::isCompatibleVectorType(operandType) - ? rewriter.create( - loc, operandType, - SplatElementsAttr::get(resultType.cast(), - floatOne)) - : rewriter.create(loc, operandType, floatOne); - - auto add = rewriter.create(loc, operandType, one, - transformed.operand()); - rewriter.replaceOpWithNewOp(op, operandType, add); - return success(); - } - - auto vectorType = resultType.dyn_cast(); - if (!vectorType) - return rewriter.notifyMatchFailure(op, "expected vector result type"); - - return LLVM::detail::handleMultidimensionalVectors( - op.getOperation(), operands, *getTypeConverter(), - [&](Type llvm1DVectorTy, ValueRange operands) { - auto splatAttr = SplatElementsAttr::get( - mlir::VectorType::get( - {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()}, - floatType), - floatOne); - auto one = - rewriter.create(loc, llvm1DVectorTy, splatAttr); - auto add = rewriter.create(loc, llvm1DVectorTy, one, - operands[0]); - return rewriter.create(loc, llvm1DVectorTy, add); - }, - rewriter); - } -}; - -// A `rsqrt` is converted into `1 / sqrt`. -struct RsqrtOpLowering : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(math::RsqrtOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - math::RsqrtOp::Adaptor transformed(operands); - auto operandType = transformed.operand().getType(); - - if (!operandType || !LLVM::isCompatibleType(operandType)) - return failure(); - - auto loc = op.getLoc(); - auto resultType = op.getResult().getType(); - auto floatType = getElementTypeOrSelf(resultType).cast(); - auto floatOne = rewriter.getFloatAttr(floatType, 1.0); - - if (!operandType.isa()) { - LLVM::ConstantOp one; - if (LLVM::isCompatibleVectorType(operandType)) { - one = rewriter.create( - loc, operandType, - SplatElementsAttr::get(resultType.cast(), floatOne)); - } else { - one = rewriter.create(loc, operandType, floatOne); - } - auto sqrt = rewriter.create(loc, transformed.operand()); - rewriter.replaceOpWithNewOp(op, operandType, one, sqrt); - return success(); - } - - auto vectorType = resultType.dyn_cast(); - if (!vectorType) - return failure(); - - return LLVM::detail::handleMultidimensionalVectors( - op.getOperation(), operands, *getTypeConverter(), - [&](Type llvm1DVectorTy, ValueRange operands) { - auto splatAttr = SplatElementsAttr::get( - mlir::VectorType::get( - {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()}, - floatType), - floatOne); - auto one = - rewriter.create(loc, llvm1DVectorTy, splatAttr); - auto sqrt = - rewriter.create(loc, llvm1DVectorTy, operands[0]); - return rewriter.create(loc, llvm1DVectorTy, one, sqrt); - }, - rewriter); - } -}; - struct DialectCastOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -1375,20 +1202,12 @@ CmpIOpLowering, CondBranchOpLowering, CopySignOpLowering, - CosOpLowering, ConstantOpLowering, DialectCastOpLowering, DivFOpLowering, - ExpOpLowering, - Exp2OpLowering, - ExpM1OpLowering, FloorFOpLowering, FmaFOpLowering, GenericAtomicRMWOpLowering, - LogOpLowering, - Log10OpLowering, - Log1pOpLowering, - Log2OpLowering, FPExtOpLowering, FPToSIOpLowering, FPToUIOpLowering, @@ -1398,11 +1217,9 @@ MulIOpLowering, NegFOpLowering, OrOpLowering, - PowFOpLowering, RemFOpLowering, RankOpLowering, ReturnOpLowering, - RsqrtOpLowering, SIToFPOpLowering, SelectOpLowering, ShiftLeftOpLowering, @@ -1410,10 +1227,8 @@ SignedDivIOpLowering, SignedRemIOpLowering, SignedShiftRightOpLowering, - SinOpLowering, SplatOpLowering, SplatNdOpLowering, - SqrtOpLowering, SubFOpLowering, SubIOpLowering, SwitchOpLowering, diff --git a/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir b/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir --- a/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir +++ b/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -convert-complex-to-standard -convert-complex-to-llvm -convert-std-to-llvm | FileCheck %s +// RUN: mlir-opt %s -convert-complex-to-standard -convert-complex-to-llvm -convert-math-to-llvm -convert-std-to-llvm | FileCheck %s // CHECK-LABEL: llvm.func @complex_abs // CHECK-SAME: %[[ARG:.*]]: ![[C_TY:.*]]) diff --git a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir @@ -0,0 +1,121 @@ +// RUN: mlir-opt %s -split-input-file -convert-math-to-llvm | FileCheck %s + +// CHECK-LABEL: @ops +func @ops(%arg0: f32, %arg1: f32, %arg2: i32, %arg3: i32, %arg4: f64) { +// CHECK: = "llvm.intr.exp"(%{{.*}}) : (f32) -> f32 + %13 = math.exp %arg0 : f32 +// CHECK: = "llvm.intr.exp2"(%{{.*}}) : (f32) -> f32 + %14 = math.exp2 %arg0 : f32 +// CHECK: = "llvm.intr.sqrt"(%{{.*}}) : (f32) -> f32 + %19 = math.sqrt %arg0 : f32 +// CHECK: = "llvm.intr.sqrt"(%{{.*}}) : (f64) -> f64 + %20 = math.sqrt %arg4 : f64 + std.return +} + +// ----- + +// CHECK-LABEL: func @log1p( +// CHECK-SAME: f32 +func @log1p(%arg0 : f32) { + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32 + // CHECK: %[[ADD:.*]] = llvm.fadd %[[ONE]], %arg0 : f32 + // CHECK: %[[LOG:.*]] = "llvm.intr.log"(%[[ADD]]) : (f32) -> f32 + %0 = math.log1p %arg0 : f32 + std.return +} + +// ----- + +// CHECK-LABEL: func @log1p_2dvector( +func @log1p_2dvector(%arg0 : vector<4x3xf32>) { + // CHECK: %[[EXTRACT:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<4 x vector<3xf32>> + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<3xf32>) : vector<3xf32> + // CHECK: %[[ADD:.*]] = llvm.fadd %[[ONE]], %[[EXTRACT]] : vector<3xf32> + // CHECK: %[[LOG:.*]] = "llvm.intr.log"(%[[ADD]]) : (vector<3xf32>) -> vector<3xf32> + // CHECK: %[[INSERT:.*]] = llvm.insertvalue %[[LOG]], %{{.*}}[0] : !llvm.array<4 x vector<3xf32>> + %0 = math.log1p %arg0 : vector<4x3xf32> + std.return +} + +// ----- + +// CHECK-LABEL: func @expm1( +// CHECK-SAME: f32 +func @expm1(%arg0 : f32) { + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32 + // CHECK: %[[EXP:.*]] = "llvm.intr.exp"(%arg0) : (f32) -> f32 + // CHECK: %[[SUB:.*]] = llvm.fsub %[[EXP]], %[[ONE]] : f32 + %0 = math.expm1 %arg0 : f32 + std.return +} + +// ----- + +// CHECK-LABEL: func @rsqrt( +// CHECK-SAME: f32 +func @rsqrt(%arg0 : f32) { + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32 + // CHECK: %[[SQRT:.*]] = "llvm.intr.sqrt"(%arg0) : (f32) -> f32 + // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] : f32 + %0 = math.rsqrt %arg0 : f32 + std.return +} + +// ----- + +// CHECK-LABEL: func @sine( +// CHECK-SAME: f32 +func @sine(%arg0 : f32) { + // CHECK: "llvm.intr.sin"(%arg0) : (f32) -> f32 + %0 = math.sin %arg0 : f32 + std.return +} + +// ----- + +// CHECK-LABEL: func @rsqrt_double( +// CHECK-SAME: f64 +func @rsqrt_double(%arg0 : f64) { + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f64) : f64 + // CHECK: %[[SQRT:.*]] = "llvm.intr.sqrt"(%arg0) : (f64) -> f64 + // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] : f64 + %0 = math.rsqrt %arg0 : f64 + std.return +} + +// ----- + +// CHECK-LABEL: func @rsqrt_vector( +// CHECK-SAME: vector<4xf32> +func @rsqrt_vector(%arg0 : vector<4xf32>) { + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<4xf32>) : vector<4xf32> + // CHECK: %[[SQRT:.*]] = "llvm.intr.sqrt"(%arg0) : (vector<4xf32>) -> vector<4xf32> + // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] : vector<4xf32> + %0 = math.rsqrt %arg0 : vector<4xf32> + std.return +} + +// ----- + +// CHECK-LABEL: func @rsqrt_multidim_vector( +func @rsqrt_multidim_vector(%arg0 : vector<4x3xf32>) { + // CHECK: %[[EXTRACT:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<4 x vector<3xf32>> + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<3xf32>) : vector<3xf32> + // CHECK: %[[SQRT:.*]] = "llvm.intr.sqrt"(%[[EXTRACT]]) : (vector<3xf32>) -> vector<3xf32> + // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] : vector<3xf32> + // CHECK: %[[INSERT:.*]] = llvm.insertvalue %[[DIV]], %{{.*}}[0] : !llvm.array<4 x vector<3xf32>> + %0 = math.rsqrt %arg0 : vector<4x3xf32> + std.return +} + +// ----- + +// CHECK-LABEL: func @powf( +// CHECK-SAME: f64 +func @powf(%arg0 : f64) { + // CHECK: %[[POWF:.*]] = "llvm.intr.pow"(%arg0, %arg0) : (f64, f64) -> f64 + %0 = math.powf %arg0, %arg0 : f64 + std.return +} + diff --git a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir --- a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir +++ b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir @@ -463,48 +463,40 @@ // CHECK-LABEL: @ops func @ops(f32, f32, i32, i32, f64) -> (f32, i32) { ^bb0(%arg0: f32, %arg1: f32, %arg2: i32, %arg3: i32, %arg4: f64): -// CHECK-NEXT: %0 = llvm.fsub %arg0, %arg1 : f32 +// CHECK: = llvm.fsub %arg0, %arg1 : f32 %0 = subf %arg0, %arg1: f32 -// CHECK-NEXT: %1 = llvm.sub %arg2, %arg3 : i32 +// CHECK: = llvm.sub %arg2, %arg3 : i32 %1 = subi %arg2, %arg3: i32 -// CHECK-NEXT: %2 = llvm.icmp "slt" %arg2, %1 : i32 +// CHECK: = llvm.icmp "slt" %arg2, %1 : i32 %2 = cmpi slt, %arg2, %1 : i32 -// CHECK-NEXT: %3 = llvm.sdiv %arg2, %arg3 : i32 +// CHECK: = llvm.sdiv %arg2, %arg3 : i32 %3 = divi_signed %arg2, %arg3 : i32 -// CHECK-NEXT: %4 = llvm.udiv %arg2, %arg3 : i32 +// CHECK: = llvm.udiv %arg2, %arg3 : i32 %4 = divi_unsigned %arg2, %arg3 : i32 -// CHECK-NEXT: %5 = llvm.srem %arg2, %arg3 : i32 +// CHECK: = llvm.srem %arg2, %arg3 : i32 %5 = remi_signed %arg2, %arg3 : i32 -// CHECK-NEXT: %6 = llvm.urem %arg2, %arg3 : i32 +// CHECK: = llvm.urem %arg2, %arg3 : i32 %6 = remi_unsigned %arg2, %arg3 : i32 -// CHECK-NEXT: %7 = llvm.select %2, %arg2, %arg3 : i1, i32 +// CHECK: = llvm.select %2, %arg2, %arg3 : i1, i32 %7 = select %2, %arg2, %arg3 : i32 -// CHECK-NEXT: %8 = llvm.fdiv %arg0, %arg1 : f32 +// CHECK: = llvm.fdiv %arg0, %arg1 : f32 %8 = divf %arg0, %arg1 : f32 -// CHECK-NEXT: %9 = llvm.frem %arg0, %arg1 : f32 +// CHECK: = llvm.frem %arg0, %arg1 : f32 %9 = remf %arg0, %arg1 : f32 -// CHECK-NEXT: %10 = llvm.and %arg2, %arg3 : i32 +// CHECK: = llvm.and %arg2, %arg3 : i32 %10 = and %arg2, %arg3 : i32 -// CHECK-NEXT: %11 = llvm.or %arg2, %arg3 : i32 +// CHECK: = llvm.or %arg2, %arg3 : i32 %11 = or %arg2, %arg3 : i32 -// CHECK-NEXT: %12 = llvm.xor %arg2, %arg3 : i32 +// CHECK: = llvm.xor %arg2, %arg3 : i32 %12 = xor %arg2, %arg3 : i32 -// CHECK-NEXT: %13 = "llvm.intr.exp"(%arg0) : (f32) -> f32 - %13 = math.exp %arg0 : f32 -// CHECK-NEXT: %14 = "llvm.intr.exp2"(%arg0) : (f32) -> f32 - %14 = math.exp2 %arg0 : f32 -// CHECK-NEXT: %15 = llvm.mlir.constant(7.900000e-01 : f64) : f64 +// CHECK: = llvm.mlir.constant(7.900000e-01 : f64) : f64 %15 = constant 7.9e-01 : f64 -// CHECK-NEXT: %16 = llvm.shl %arg2, %arg3 : i32 +// CHECK: = llvm.shl %arg2, %arg3 : i32 %16 = shift_left %arg2, %arg3 : i32 -// CHECK-NEXT: %17 = llvm.ashr %arg2, %arg3 : i32 +// CHECK: = llvm.ashr %arg2, %arg3 : i32 %17 = shift_right_signed %arg2, %arg3 : i32 -// CHECK-NEXT: %18 = llvm.lshr %arg2, %arg3 : i32 +// CHECK: = llvm.lshr %arg2, %arg3 : i32 %18 = shift_right_unsigned %arg2, %arg3 : i32 -// CHECK-NEXT: %{{[0-9]+}} = "llvm.intr.sqrt"(%arg0) : (f32) -> f32 - %19 = math.sqrt %arg0 : f32 -// CHECK-NEXT: %{{[0-9]+}} = "llvm.intr.sqrt"(%arg4) : (f64) -> f64 - %20 = math.sqrt %arg4 : f64 return %0, %4 : f32, i32 } @@ -859,66 +851,6 @@ // CHECK: llvm.mlir.constant(1 : index) : i64 // CHECK32: llvm.mlir.constant(1 : index) : i32 - -// ----- - -// CHECK-LABEL: func @log1p( -// CHECK-SAME: f32 -func @log1p(%arg0 : f32) { - // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32 - // CHECK: %[[ADD:.*]] = llvm.fadd %[[ONE]], %arg0 : f32 - // CHECK: %[[LOG:.*]] = "llvm.intr.log"(%[[ADD]]) : (f32) -> f32 - %0 = math.log1p %arg0 : f32 - std.return -} - -// ----- - -// CHECK-LABEL: func @log1p_2dvector( -func @log1p_2dvector(%arg0 : vector<4x3xf32>) { - // CHECK: %[[EXTRACT:.*]] = llvm.extractvalue %arg0[0] : !llvm.array<4 x vector<3xf32>> - // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<3xf32>) : vector<3xf32> - // CHECK: %[[ADD:.*]] = llvm.fadd %[[ONE]], %[[EXTRACT]] : vector<3xf32> - // CHECK: %[[LOG:.*]] = "llvm.intr.log"(%[[ADD]]) : (vector<3xf32>) -> vector<3xf32> - // CHECK: %[[INSERT:.*]] = llvm.insertvalue %[[LOG]], %0[0] : !llvm.array<4 x vector<3xf32>> - %0 = math.log1p %arg0 : vector<4x3xf32> - std.return -} - -// ----- - -// CHECK-LABEL: func @expm1( -// CHECK-SAME: f32 -func @expm1(%arg0 : f32) { - // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32 - // CHECK: %[[EXP:.*]] = "llvm.intr.exp"(%arg0) : (f32) -> f32 - // CHECK: %[[SUB:.*]] = llvm.fsub %[[EXP]], %[[ONE]] : f32 - %0 = math.expm1 %arg0 : f32 - std.return -} - -// ----- - -// CHECK-LABEL: func @rsqrt( -// CHECK-SAME: f32 -func @rsqrt(%arg0 : f32) { - // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32 - // CHECK: %[[SQRT:.*]] = "llvm.intr.sqrt"(%arg0) : (f32) -> f32 - // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] : f32 - %0 = math.rsqrt %arg0 : f32 - std.return -} - -// ----- - -// CHECK-LABEL: func @sine( -// CHECK-SAME: f32 -func @sine(%arg0 : f32) { - // CHECK: "llvm.intr.sin"(%arg0) : (f32) -> f32 - %0 = math.sin %arg0 : f32 - std.return -} - // ----- // CHECK-LABEL: func @ceilf( @@ -941,45 +873,6 @@ // ----- - -// CHECK-LABEL: func @rsqrt_double( -// CHECK-SAME: f64 -func @rsqrt_double(%arg0 : f64) { - // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f64) : f64 - // CHECK: %[[SQRT:.*]] = "llvm.intr.sqrt"(%arg0) : (f64) -> f64 - // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] : f64 - %0 = math.rsqrt %arg0 : f64 - std.return -} - -// ----- - -// CHECK-LABEL: func @rsqrt_vector( -// CHECK-SAME: vector<4xf32> -func @rsqrt_vector(%arg0 : vector<4xf32>) { - // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<4xf32>) : vector<4xf32> - // CHECK: %[[SQRT:.*]] = "llvm.intr.sqrt"(%arg0) : (vector<4xf32>) -> vector<4xf32> - // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] : vector<4xf32> - %0 = math.rsqrt %arg0 : vector<4xf32> - std.return -} - -// ----- - -// CHECK-LABEL: func @rsqrt_multidim_vector( -// CHECK-SAME: !llvm.array<4 x vector<3xf32>> -func @rsqrt_multidim_vector(%arg0 : vector<4x3xf32>) { - // CHECK: %[[EXTRACT:.*]] = llvm.extractvalue %arg0[0] : !llvm.array<4 x vector<3xf32>> - // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<3xf32>) : vector<3xf32> - // CHECK: %[[SQRT:.*]] = "llvm.intr.sqrt"(%[[EXTRACT]]) : (vector<3xf32>) -> vector<3xf32> - // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] : vector<3xf32> - // CHECK: %[[INSERT:.*]] = llvm.insertvalue %[[DIV]], %0[0] : !llvm.array<4 x vector<3xf32>> - %0 = math.rsqrt %arg0 : vector<4x3xf32> - std.return -} - -// ----- - // Lowers `assert` to a function call to `abort` if the assertion is violated. // CHECK: llvm.func @abort() // CHECK-LABEL: @assert_test_function @@ -1010,16 +903,6 @@ // ----- -// CHECK-LABEL: func @powf( -// CHECK-SAME: f64 -func @powf(%arg0 : f64) { - // CHECK: %[[POWF:.*]] = "llvm.intr.pow"(%arg0, %arg0) : (f64, f64) -> f64 - %0 = math.powf %arg0, %arg0 : f64 - std.return -} - -// ----- - // CHECK-LABEL: func @fmaf( // CHECK-SAME: %[[ARG0:.*]]: f32 // CHECK-SAME: %[[ARG1:.*]]: vector<4xf32>