Index: mlir/include/mlir/Conversion/MathToFuncs/MathToFuncs.h =================================================================== --- /dev/null +++ mlir/include/mlir/Conversion/MathToFuncs/MathToFuncs.h @@ -0,0 +1,23 @@ +//===- MathToFuncs.h - Math to outlined impl conversion ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_MATHTOFUNCS_MATHTOFUNCS_H +#define MLIR_CONVERSION_MATHTOFUNCS_MATHTOFUNCS_H + +#include + +namespace mlir { + +class Pass; + +// Pass to convert some Math operations into calls of functions +// containing software implementation of these operations. +std::unique_ptr createConvertMathToFuncsPass(); +} // namespace mlir + +#endif // MLIR_CONVERSION_MATHTOFUNCS_MATHTOFUNCS_H Index: mlir/include/mlir/Conversion/Passes.h =================================================================== --- mlir/include/mlir/Conversion/Passes.h +++ mlir/include/mlir/Conversion/Passes.h @@ -32,6 +32,7 @@ #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" #include "mlir/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.h" #include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h" +#include "mlir/Conversion/MathToFuncs/MathToFuncs.h" #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" #include "mlir/Conversion/MathToLibm/MathToLibm.h" #include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h" Index: mlir/include/mlir/Conversion/Passes.td =================================================================== --- mlir/include/mlir/Conversion/Passes.td +++ mlir/include/mlir/Conversion/Passes.td @@ -511,6 +511,27 @@ let dependentDialects = ["spirv::SPIRVDialect"]; } +//===----------------------------------------------------------------------===// +// MathToFuncs +//===----------------------------------------------------------------------===// + +def ConvertMathToFuncs : Pass<"convert-math-to-funcs", "ModuleOp"> { + let summary = "Convert Math operations to calls of outlined implementations"; + let description = [{ + This pass converts supported Math ops to calls of compiler generated + functions implementing these operations in software. + LLVMDialect is used for LinkonceODR linkage of the generated functions. + }]; + let constructor = "mlir::createConvertMathToFuncsPass()"; + let dependentDialects = [ + "arith::ArithmeticDialect", + "cf::ControlFlowDialect", + "func::FuncDialect", + "vector::VectorDialect", + "LLVM::LLVMDialect", + ]; +} + //===----------------------------------------------------------------------===// // MemRefToLLVM //===----------------------------------------------------------------------===// Index: mlir/lib/Conversion/CMakeLists.txt =================================================================== --- mlir/lib/Conversion/CMakeLists.txt +++ mlir/lib/Conversion/CMakeLists.txt @@ -21,6 +21,7 @@ add_subdirectory(LinalgToSPIRV) add_subdirectory(LinalgToStandard) add_subdirectory(LLVMCommon) +add_subdirectory(MathToFuncs) add_subdirectory(MathToLibm) add_subdirectory(MathToLLVM) add_subdirectory(MathToSPIRV) Index: mlir/lib/Conversion/MathToFuncs/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/lib/Conversion/MathToFuncs/CMakeLists.txt @@ -0,0 +1,23 @@ +add_mlir_conversion_library(MLIRMathToFuncs + MathToFuncs.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToFuncs + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRArithmeticDialect + MLIRControlFlowDialect + MLIRFuncDialect + MLIRLLVMDialect + MLIRMathDialect + MLIRPass + MLIRTransforms + MLIRVectorDialect + MLIRVectorUtils + ) Index: mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp =================================================================== --- /dev/null +++ mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp @@ -0,0 +1,431 @@ +//===- MathToFuncs.cpp - Math to outlined implementation conversion -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/MathToFuncs/MathToFuncs.h" +#include "../PassDetail.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Utils/VectorUtils.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" + +using namespace mlir; + +namespace { +// Utility class for handling FunctionType's. +struct FunctionTypesUtils { + /// Represent given FunctionType \p type as a string. + static std::string stringizeType(const FunctionType &type); + + /// std::less-like operator for comparing FunctionType's. + bool operator()(const FunctionType &lhs, const FunctionType &rhs) const; +}; + +// Pattern to convert vector operations to scalar operations. +template +struct VecOpToScalarOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; +}; + +// Callback type for getting pre-generated FuncOp implementing +// a power operation of the given type. +using GetPowerFuncCallbackTy = std::function; + +// Pattern to convert scalar IPowIOp into a call of outlined +// software implementation. +struct IPowIOpLowering : public OpRewritePattern { + +private: + GetPowerFuncCallbackTy getFuncOpCallback; + +public: + IPowIOpLowering(MLIRContext *context, GetPowerFuncCallbackTy cb) + : OpRewritePattern(context), getFuncOpCallback(cb) {} + + /// Convert IPowI into a call to a local function implementing + /// the power operation. The local function computes a scalar result, + /// so vector forms of IPowI are linearized. + LogicalResult matchAndRewrite(math::IPowIOp op, + PatternRewriter &rewriter) const final; + + static func::FuncOp createElementFunc(ModuleOp *module, FunctionType funcType); +}; +} // namespace + +std::string FunctionTypesUtils::stringizeType(const FunctionType &type) { + std::string result; + llvm::raw_string_ostream typeOS(result); + for (unsigned i = 0, e = type.getNumResults(); i != e; ++i) + typeOS << '_' << type.getResult(i); + for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i) + typeOS << '_' << type.getInput(i); + + assert(!result.empty() && "invalid FunctionType"); + return result; +} + +bool FunctionTypesUtils::operator()(const FunctionType &lhs, const FunctionType &rhs) const { + return stringizeType(lhs) < stringizeType(rhs); +} + +template +LogicalResult +VecOpToScalarOp::matchAndRewrite(Op op, PatternRewriter &rewriter) const { + Type opType = op.getType(); + Location loc = op.getLoc(); + auto vecType = opType.template dyn_cast(); + + if (!vecType) + return rewriter.notifyMatchFailure(op, "not a vector operation"); + if (!vecType.hasRank()) + return rewriter.notifyMatchFailure(op, "unknown vector rank"); + ArrayRef shape = vecType.getShape(); + int64_t numElements = vecType.getNumElements(); + + Value result = rewriter.create( + loc, DenseElementsAttr::get( + vecType, IntegerAttr::get(vecType.getElementType(), 0))); + SmallVector ones(shape.size(), 1); + SmallVector strides = computeStrides(shape, ones); + for (int64_t linearIndex = 0; linearIndex < numElements; ++linearIndex) { + SmallVector positions = delinearize(strides, linearIndex); + SmallVector operands; + for (Value input : op->getOperands()) + operands.push_back( + rewriter.create(loc, input, positions)); + Value scalarOp = + rewriter.create(loc, vecType.getElementType(), operands); + result = + rewriter.create(loc, scalarOp, result, positions); + } + rewriter.replaceOp(op, result); + return success(); +} + +/// Create linkonce_odr function to implement the power function with +/// the given \p funcType type inside \p module. \p funcType must be +/// 'IntegerType (*)(IntegerType, IntegerType)' function type. +/// +/// template +/// T __mlir_math_ipowi_*(T b, T p) { +/// if (p == T(0)) +/// return T(1); +/// if (p < T(0)) { +/// if (b == T(0)) +/// return T(1) / T(0); // trigger div-by-zero +/// if (b == T(1)) +/// return T(1); +/// if (b == T(-1)) { +/// if (p & T(1)) +/// return T(-1); +/// return T(1); +/// } +/// return T(0); +/// } +/// T result = T(1); +/// while (true) { +/// if (p & T(1)) +/// result *= b; +/// p >>= T(1); +/// if (p == T(0)) +/// return result; +/// b *= b; +/// } +/// } +func::FuncOp IPowIOpLowering::createElementFunc(ModuleOp *module, + FunctionType funcType) { + assert(funcType.getNumResults() == 1 && funcType.getNumInputs() == 2 && + funcType.getResult(0).isa() && + funcType.getInput(0).isa() && + funcType.getInput(1).isa() && + "invalid function type deduced from IPowIOp"); + + IntegerType elementType = funcType.getInput(0).cast(); + ImplicitLocOpBuilder builder = + ImplicitLocOpBuilder::atBlockEnd(module->getLoc(), module->getBody()); + + std::string funcName("__mlir_math_ipowi"); + funcName += FunctionTypesUtils::stringizeType(funcType); + + auto funcOp = builder.create(funcName, funcType); + LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR; + Attribute linkage = + LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage); + funcOp->setAttr("llvm.linkage", linkage); + funcOp.setPrivate(); + + Block *entryBlock = funcOp.addEntryBlock(); + Region *funcBody = entryBlock->getParent(); + + Value bArg = funcOp.getArgument(0); + Value pArg = funcOp.getArgument(1); + builder.setInsertionPointToEnd(entryBlock); + Value zeroValue = builder.create( + elementType, builder.getIntegerAttr(elementType, 0)); + Value oneValue = builder.create( + elementType, builder.getIntegerAttr(elementType, 1)); + Value minusOneValue = builder.create( + elementType, + builder.getIntegerAttr(elementType, + APInt(elementType.getIntOrFloatBitWidth(), -1ULL, + /*isSigned=*/true))); + + // if (p == T(0)) + // return T(1); + auto pIsZero = + builder.create(arith::CmpIPredicate::eq, pArg, zeroValue); + Block *thenBlock = builder.createBlock(funcBody); + builder.create(oneValue); + Block *fallthroughBlock = builder.createBlock(funcBody); + // Set up conditional branch for (p == T(0)). + builder.setInsertionPointToEnd(pIsZero->getBlock()); + builder.create(pIsZero, thenBlock, fallthroughBlock); + + // if (p < T(0)) { + builder.setInsertionPointToEnd(fallthroughBlock); + auto pIsNeg = + builder.create(arith::CmpIPredicate::sle, pArg, zeroValue); + // if (b == T(0)) + builder.createBlock(funcBody); + auto bIsZero = + builder.create(arith::CmpIPredicate::eq, bArg, zeroValue); + // return T(1) / T(0); + thenBlock = builder.createBlock(funcBody); + builder.create( + builder.create(oneValue, zeroValue).getResult()); + fallthroughBlock = builder.createBlock(funcBody); + // Set up conditional branch for (b == T(0)). + builder.setInsertionPointToEnd(bIsZero->getBlock()); + builder.create(bIsZero, thenBlock, fallthroughBlock); + + // if (b == T(1)) + builder.setInsertionPointToEnd(fallthroughBlock); + auto bIsOne = + builder.create(arith::CmpIPredicate::eq, bArg, oneValue); + // return T(1); + thenBlock = builder.createBlock(funcBody); + builder.create(oneValue); + fallthroughBlock = builder.createBlock(funcBody); + // Set up conditional branch for (b == T(1)). + builder.setInsertionPointToEnd(bIsOne->getBlock()); + builder.create(bIsOne, thenBlock, fallthroughBlock); + + // if (b == T(-1)) { + builder.setInsertionPointToEnd(fallthroughBlock); + auto bIsMinusOne = builder.create(arith::CmpIPredicate::eq, + bArg, minusOneValue); + // if (p & T(1)) + builder.createBlock(funcBody); + auto pIsOdd = builder.create( + arith::CmpIPredicate::ne, builder.create(pArg, oneValue), + zeroValue); + // return T(-1); + thenBlock = builder.createBlock(funcBody); + builder.create(minusOneValue); + fallthroughBlock = builder.createBlock(funcBody); + // Set up conditional branch for (p & T(1)). + builder.setInsertionPointToEnd(pIsOdd->getBlock()); + builder.create(pIsOdd, thenBlock, fallthroughBlock); + + // return T(1); + // } // b == T(-1) + builder.setInsertionPointToEnd(fallthroughBlock); + builder.create(oneValue); + fallthroughBlock = builder.createBlock(funcBody); + // Set up conditional branch for (b == T(-1)). + builder.setInsertionPointToEnd(bIsMinusOne->getBlock()); + builder.create(bIsMinusOne, pIsOdd->getBlock(), + fallthroughBlock); + + // return T(0); + // } // (p < T(0)) + builder.setInsertionPointToEnd(fallthroughBlock); + builder.create(zeroValue); + Block *loopHeader = builder.createBlock( + funcBody, funcBody->end(), {elementType, elementType, elementType}, + {builder.getLoc(), builder.getLoc(), builder.getLoc()}); + // Set up conditional branch for (p < T(0)). + builder.setInsertionPointToEnd(pIsNeg->getBlock()); + // Set initial values of 'result', 'b' and 'p' for the loop. + builder.create(pIsNeg, bIsZero->getBlock(), loopHeader, + ValueRange{oneValue, bArg, pArg}); + + // T result = T(1); + // while (true) { + // if (p & T(1)) + // result *= b; + // p >>= T(1); + // if (p == T(0)) + // return result; + // b *= b; + // } + Value resultTmp = loopHeader->getArgument(0); + Value baseTmp = loopHeader->getArgument(1); + Value powerTmp = loopHeader->getArgument(2); + builder.setInsertionPointToEnd(loopHeader); + + // if (p & T(1)) + auto powerTmpIsOdd = builder.create( + arith::CmpIPredicate::ne, + builder.create(powerTmp, oneValue), zeroValue); + thenBlock = builder.createBlock(funcBody); + // result *= b; + Value newResultTmp = builder.create(resultTmp, baseTmp); + fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), elementType, + builder.getLoc()); + builder.setInsertionPointToEnd(thenBlock); + builder.create(newResultTmp, fallthroughBlock); + // Set up conditional branch for (p & T(1)). + builder.setInsertionPointToEnd(powerTmpIsOdd->getBlock()); + builder.create(powerTmpIsOdd, thenBlock, fallthroughBlock, + resultTmp); + // Merged 'result'. + newResultTmp = fallthroughBlock->getArgument(0); + + // p >>= T(1); + builder.setInsertionPointToEnd(fallthroughBlock); + Value newPowerTmp = builder.create(powerTmp, oneValue); + + // if (p == T(0)) + auto newPowerIsZero = builder.create(arith::CmpIPredicate::eq, + newPowerTmp, zeroValue); + // return result; + thenBlock = builder.createBlock(funcBody); + builder.create(newResultTmp); + fallthroughBlock = builder.createBlock(funcBody); + // Set up conditional branch for (p == T(0)). + builder.setInsertionPointToEnd(newPowerIsZero->getBlock()); + builder.create(newPowerIsZero, thenBlock, fallthroughBlock); + + // b *= b; + // } + builder.setInsertionPointToEnd(fallthroughBlock); + Value newBaseTmp = builder.create(baseTmp, baseTmp); + // Pass new values for 'result', 'b' and 'p' to the loop header. + builder.create( + ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader); + return funcOp; +} + +/// Convert IPowI into a call to a local function implementing +/// the power operation. The local function computes a scalar result, +/// so vector forms of IPowI are linearized. +LogicalResult +IPowIOpLowering::matchAndRewrite(math::IPowIOp op, + PatternRewriter &rewriter) const { + auto resultType = op.getResult().getType().dyn_cast(); + auto baseType = op.getOperands()[0].getType().dyn_cast(); + auto exponentType = op.getOperands()[1].getType().dyn_cast(); + + if (!baseType) + return rewriter.notifyMatchFailure(op, "non-integer base operand"); + if (baseType != exponentType) + return rewriter.notifyMatchFailure(op, "operands' types mismatch"); + if (baseType != resultType) + return rewriter.notifyMatchFailure(op, "result/operands types mismatch"); + + FunctionType funcType = FunctionType::get( + rewriter.getContext(), {baseType, exponentType}, resultType); + + // The outlined software implementation must have been already + // generated. + func::FuncOp elementFunc = getFuncOpCallback(funcType); + if (!elementFunc) + return rewriter.notifyMatchFailure(op, "missing software implementation"); + + rewriter.replaceOpWithNewOp(op, elementFunc, op.getOperands()); + return success(); +} + +namespace { +using PowerFuncsMap = + std::map; + +struct ConvertMathToFuncsPass + : public ConvertMathToFuncsBase { + ConvertMathToFuncsPass() = default; + + void runOnOperation() override; + +private: + // Generate outlined implementations for power operations + // and store them in powerFuncs map. + void preprocessPowOperations(); + + // A map between function types deduced from power operations + // and the corresponding outlined software implementations + // of these operations. + PowerFuncsMap powerFuncs; +}; +} // namespace + +void ConvertMathToFuncsPass::preprocessPowOperations() { + ModuleOp module = getOperation(); + + module.walk([&](Operation *op) { + TypeSwitch(op).Case([&](math::IPowIOp op) { + Type resultType = getElementTypeOrSelf(op.getResult().getType()); + Type baseType = getElementTypeOrSelf(op.getOperands()[0].getType()); + Type exponentType = getElementTypeOrSelf(op.getOperands()[1].getType()); + FunctionType funcType = FunctionType::get( + &getContext(), {baseType, exponentType}, resultType); + + // Generate the software implementation of this operation, + // if it has not been generated yet. + auto entry = powerFuncs.try_emplace(funcType, func::FuncOp{}); + if (entry.second) + entry.first->second = + IPowIOpLowering::createElementFunc(&module, funcType); + }); + }); +} + +void ConvertMathToFuncsPass::runOnOperation() { + ModuleOp module = getOperation(); + + // Create outlined implementations for power operations. + preprocessPowOperations(); + + RewritePatternSet patterns(&getContext()); + patterns.add>(patterns.getContext()); + + // For the given FunctionType Returns FuncOp stored in powerFuncs map. + auto getPowerFuncOpByType = [&](FunctionType type) -> func::FuncOp { + auto it = powerFuncs.find(type); + if (it == powerFuncs.end()) + return {}; + + return it->second; + }; + patterns.add(patterns.getContext(), getPowerFuncOpByType); + + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addIllegalOp(); + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + signalPassFailure(); +} + +std::unique_ptr mlir::createConvertMathToFuncsPass() { + return std::make_unique(); +} Index: mlir/test/Conversion/MathToFuncs/math-to-funcs.mlir =================================================================== --- /dev/null +++ mlir/test/Conversion/MathToFuncs/math-to-funcs.mlir @@ -0,0 +1,172 @@ +// RUN: mlir-opt %s -split-input-file -pass-pipeline="convert-math-to-funcs" | FileCheck %s + +// ----- + +// CHECK-LABEL: func @ipowi( +// CHECK-SAME: %[[ARG0:.+]]: i64, +// CHECK-SAME: %[[ARG1:.+]]: i64) +func.func @ipowi(%arg0: i64, %arg1: i64) { + // CHECK: call @__mlir_math_ipowi_i64_i64_i64(%[[ARG0]], %[[ARG1]]) : (i64, i64) -> i64 + %0 = math.ipowi %arg0, %arg1 : i64 + func.return +} + +// CHECK-LABEL: func.func private @__mlir_math_ipowi_i64_i64_i64( +// CHECK-SAME: %[[VAL_0:.*]]: i64, +// CHECK-SAME: %[[VAL_1:.*]]: i64) -> i64 +// CHECK-SAME: attributes {llvm.linkage = #llvm.linkage} { +// CHECK: %[[VAL_2:.*]] = arith.constant 0 : i64 +// CHECK: %[[VAL_3:.*]] = arith.constant 1 : i64 +// CHECK: %[[VAL_4:.*]] = arith.constant -1 : i64 +// CHECK: %[[VAL_5:.*]] = arith.cmpi eq, %[[VAL_1]], %[[VAL_2]] : i64 +// CHECK: cf.cond_br %[[VAL_5]], ^bb1, ^bb2 +// CHECK: ^bb1: +// CHECK: return %[[VAL_3]] : i64 +// CHECK: ^bb2: +// CHECK: %[[VAL_6:.*]] = arith.cmpi sle, %[[VAL_1]], %[[VAL_2]] : i64 +// CHECK: cf.cond_br %[[VAL_6]], ^bb3, ^bb12(%[[VAL_3]], %[[VAL_0]], %[[VAL_1]] : i64, i64, i64) +// CHECK: ^bb3: +// CHECK: %[[VAL_7:.*]] = arith.cmpi eq, %[[VAL_0]], %[[VAL_2]] : i64 +// CHECK: cf.cond_br %[[VAL_7]], ^bb4, ^bb5 +// CHECK: ^bb4: +// CHECK: %[[VAL_8:.*]] = arith.divsi %[[VAL_3]], %[[VAL_2]] : i64 +// CHECK: return %[[VAL_8]] : i64 +// CHECK: ^bb5: +// CHECK: %[[VAL_9:.*]] = arith.cmpi eq, %[[VAL_0]], %[[VAL_3]] : i64 +// CHECK: cf.cond_br %[[VAL_9]], ^bb6, ^bb7 +// CHECK: ^bb6: +// CHECK: return %[[VAL_3]] : i64 +// CHECK: ^bb7: +// CHECK: %[[VAL_10:.*]] = arith.cmpi eq, %[[VAL_0]], %[[VAL_4]] : i64 +// CHECK: cf.cond_br %[[VAL_10]], ^bb8, ^bb11 +// CHECK: ^bb8: +// CHECK: %[[VAL_11:.*]] = arith.andi %[[VAL_1]], %[[VAL_3]] : i64 +// CHECK: %[[VAL_12:.*]] = arith.cmpi ne, %[[VAL_11]], %[[VAL_2]] : i64 +// CHECK: cf.cond_br %[[VAL_12]], ^bb9, ^bb10 +// CHECK: ^bb9: +// CHECK: return %[[VAL_4]] : i64 +// CHECK: ^bb10: +// CHECK: return %[[VAL_3]] : i64 +// CHECK: ^bb11: +// CHECK: return %[[VAL_2]] : i64 +// CHECK: ^bb12(%[[VAL_13:.*]]: i64, %[[VAL_14:.*]]: i64, %[[VAL_15:.*]]: i64): +// CHECK: %[[VAL_16:.*]] = arith.andi %[[VAL_15]], %[[VAL_3]] : i64 +// CHECK: %[[VAL_17:.*]] = arith.cmpi ne, %[[VAL_16]], %[[VAL_2]] : i64 +// CHECK: cf.cond_br %[[VAL_17]], ^bb13, ^bb14(%[[VAL_13]] : i64) +// CHECK: ^bb13: +// CHECK: %[[VAL_18:.*]] = arith.muli %[[VAL_13]], %[[VAL_14]] : i64 +// CHECK: cf.br ^bb14(%[[VAL_18]] : i64) +// CHECK: ^bb14(%[[VAL_19:.*]]: i64): +// CHECK: %[[VAL_20:.*]] = arith.shrui %[[VAL_15]], %[[VAL_3]] : i64 +// CHECK: %[[VAL_21:.*]] = arith.cmpi eq, %[[VAL_20]], %[[VAL_2]] : i64 +// CHECK: cf.cond_br %[[VAL_21]], ^bb15, ^bb16 +// CHECK: ^bb15: +// CHECK: return %[[VAL_19]] : i64 +// CHECK: ^bb16: +// CHECK: %[[VAL_22:.*]] = arith.muli %[[VAL_14]], %[[VAL_14]] : i64 +// CHECK: cf.br ^bb12(%[[VAL_19]], %[[VAL_22]], %[[VAL_20]] : i64, i64, i64) +// CHECK: } + +// ----- + +// CHECK-LABEL: func @ipowi( +// CHECK-SAME: %[[ARG0:.+]]: i8, +// CHECK-SAME: %[[ARG1:.+]]: i8) + // CHECK: call @__mlir_math_ipowi_i8_i8_i8(%[[ARG0]], %[[ARG1]]) : (i8, i8) -> i8 +func.func @ipowi(%arg0: i8, %arg1: i8) { + %0 = math.ipowi %arg0, %arg1 : i8 + func.return +} + +// CHECK-LABEL: func.func private @__mlir_math_ipowi_i8_i8_i8( +// CHECK-SAME: %[[VAL_0:.*]]: i8, +// CHECK-SAME: %[[VAL_1:.*]]: i8) -> i8 +// CHECK-SAME: attributes {llvm.linkage = #llvm.linkage} { +// CHECK: %[[VAL_2:.*]] = arith.constant 0 : i8 +// CHECK: %[[VAL_3:.*]] = arith.constant 1 : i8 +// CHECK: %[[VAL_4:.*]] = arith.constant -1 : i8 +// CHECK: %[[VAL_5:.*]] = arith.cmpi eq, %[[VAL_1]], %[[VAL_2]] : i8 +// CHECK: cf.cond_br %[[VAL_5]], ^bb1, ^bb2 +// CHECK: ^bb1: +// CHECK: return %[[VAL_3]] : i8 +// CHECK: ^bb2: +// CHECK: %[[VAL_6:.*]] = arith.cmpi sle, %[[VAL_1]], %[[VAL_2]] : i8 +// CHECK: cf.cond_br %[[VAL_6]], ^bb3, ^bb12(%[[VAL_3]], %[[VAL_0]], %[[VAL_1]] : i8, i8, i8) +// CHECK: ^bb3: +// CHECK: %[[VAL_7:.*]] = arith.cmpi eq, %[[VAL_0]], %[[VAL_2]] : i8 +// CHECK: cf.cond_br %[[VAL_7]], ^bb4, ^bb5 +// CHECK: ^bb4: +// CHECK: %[[VAL_8:.*]] = arith.divsi %[[VAL_3]], %[[VAL_2]] : i8 +// CHECK: return %[[VAL_8]] : i8 +// CHECK: ^bb5: +// CHECK: %[[VAL_9:.*]] = arith.cmpi eq, %[[VAL_0]], %[[VAL_3]] : i8 +// CHECK: cf.cond_br %[[VAL_9]], ^bb6, ^bb7 +// CHECK: ^bb6: +// CHECK: return %[[VAL_3]] : i8 +// CHECK: ^bb7: +// CHECK: %[[VAL_10:.*]] = arith.cmpi eq, %[[VAL_0]], %[[VAL_4]] : i8 +// CHECK: cf.cond_br %[[VAL_10]], ^bb8, ^bb11 +// CHECK: ^bb8: +// CHECK: %[[VAL_11:.*]] = arith.andi %[[VAL_1]], %[[VAL_3]] : i8 +// CHECK: %[[VAL_12:.*]] = arith.cmpi ne, %[[VAL_11]], %[[VAL_2]] : i8 +// CHECK: cf.cond_br %[[VAL_12]], ^bb9, ^bb10 +// CHECK: ^bb9: +// CHECK: return %[[VAL_4]] : i8 +// CHECK: ^bb10: +// CHECK: return %[[VAL_3]] : i8 +// CHECK: ^bb11: +// CHECK: return %[[VAL_2]] : i8 +// CHECK: ^bb12(%[[VAL_13:.*]]: i8, %[[VAL_14:.*]]: i8, %[[VAL_15:.*]]: i8): +// CHECK: %[[VAL_16:.*]] = arith.andi %[[VAL_15]], %[[VAL_3]] : i8 +// CHECK: %[[VAL_17:.*]] = arith.cmpi ne, %[[VAL_16]], %[[VAL_2]] : i8 +// CHECK: cf.cond_br %[[VAL_17]], ^bb13, ^bb14(%[[VAL_13]] : i8) +// CHECK: ^bb13: +// CHECK: %[[VAL_18:.*]] = arith.muli %[[VAL_13]], %[[VAL_14]] : i8 +// CHECK: cf.br ^bb14(%[[VAL_18]] : i8) +// CHECK: ^bb14(%[[VAL_19:.*]]: i8): +// CHECK: %[[VAL_20:.*]] = arith.shrui %[[VAL_15]], %[[VAL_3]] : i8 +// CHECK: %[[VAL_21:.*]] = arith.cmpi eq, %[[VAL_20]], %[[VAL_2]] : i8 +// CHECK: cf.cond_br %[[VAL_21]], ^bb15, ^bb16 +// CHECK: ^bb15: +// CHECK: return %[[VAL_19]] : i8 +// CHECK: ^bb16: +// CHECK: %[[VAL_22:.*]] = arith.muli %[[VAL_14]], %[[VAL_14]] : i8 +// CHECK: cf.br ^bb12(%[[VAL_19]], %[[VAL_22]], %[[VAL_20]] : i8, i8, i8) +// CHECK: } + +// ----- + +// CHECK-LABEL: func.func @ipowi_vec( +// CHECK-SAME: %[[VAL_0:.*]]: vector<2x3xi64>, +// CHECK-SAME: %[[VAL_1:.*]]: vector<2x3xi64>) { +func.func @ipowi_vec(%arg0: vector<2x3xi64>, %arg1: vector<2x3xi64>) { +// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<2x3xi64> +// CHECK: %[[B00:.*]] = vector.extract %[[VAL_0]][0, 0] : vector<2x3xi64> +// CHECK: %[[E00:.*]] = vector.extract %[[VAL_1]][0, 0] : vector<2x3xi64> +// CHECK: %[[R00:.*]] = call @__mlir_math_ipowi_i64_i64_i64(%[[B00]], %[[E00]]) : (i64, i64) -> i64 +// CHECK: %[[TMP00:.*]] = vector.insert %[[R00]], %[[CST]] [0, 0] : i64 into vector<2x3xi64> +// CHECK: %[[B01:.*]] = vector.extract %[[VAL_0]][0, 1] : vector<2x3xi64> +// CHECK: %[[E01:.*]] = vector.extract %[[VAL_1]][0, 1] : vector<2x3xi64> +// CHECK: %[[R01:.*]] = call @__mlir_math_ipowi_i64_i64_i64(%[[B01]], %[[E01]]) : (i64, i64) -> i64 +// CHECK: %[[TMP01:.*]] = vector.insert %[[R01]], %[[TMP00]] [0, 1] : i64 into vector<2x3xi64> +// CHECK: %[[B02:.*]] = vector.extract %[[VAL_0]][0, 2] : vector<2x3xi64> +// CHECK: %[[E02:.*]] = vector.extract %[[VAL_1]][0, 2] : vector<2x3xi64> +// CHECK: %[[R02:.*]] = call @__mlir_math_ipowi_i64_i64_i64(%[[B02]], %[[E02]]) : (i64, i64) -> i64 +// CHECK: %[[TMP02:.*]] = vector.insert %[[R02]], %[[TMP01]] [0, 2] : i64 into vector<2x3xi64> +// CHECK: %[[B10:.*]] = vector.extract %[[VAL_0]][1, 0] : vector<2x3xi64> +// CHECK: %[[E10:.*]] = vector.extract %[[VAL_1]][1, 0] : vector<2x3xi64> +// CHECK: %[[R10:.*]] = call @__mlir_math_ipowi_i64_i64_i64(%[[B10]], %[[E10]]) : (i64, i64) -> i64 +// CHECK: %[[TMP10:.*]] = vector.insert %[[R10]], %[[TMP02]] [1, 0] : i64 into vector<2x3xi64> +// CHECK: %[[B11:.*]] = vector.extract %[[VAL_0]][1, 1] : vector<2x3xi64> +// CHECK: %[[E11:.*]] = vector.extract %[[VAL_1]][1, 1] : vector<2x3xi64> +// CHECK: %[[R11:.*]] = call @__mlir_math_ipowi_i64_i64_i64(%[[B11]], %[[E11]]) : (i64, i64) -> i64 +// CHECK: %[[TMP11:.*]] = vector.insert %[[R11]], %[[TMP10]] [1, 1] : i64 into vector<2x3xi64> +// CHECK: %[[B12:.*]] = vector.extract %[[VAL_0]][1, 2] : vector<2x3xi64> +// CHECK: %[[E12:.*]] = vector.extract %[[VAL_1]][1, 2] : vector<2x3xi64> +// CHECK: %[[R12:.*]] = call @__mlir_math_ipowi_i64_i64_i64(%[[B12]], %[[E12]]) : (i64, i64) -> i64 +// CHECK: %[[TMP12:.*]] = vector.insert %[[R12]], %[[TMP11]] [1, 2] : i64 into vector<2x3xi64> +// CHECK: return +// CHECK: } + %0 = math.ipowi %arg0, %arg1 : vector<2x3xi64> + func.return +}