Index: mlir/include/mlir/Conversion/MathToFuncs/MathToFuncs.h =================================================================== --- /dev/null +++ mlir/include/mlir/Conversion/MathToFuncs/MathToFuncs.h @@ -0,0 +1,29 @@ +//===- MathToFuncs.h - Math to outlined impl conversion ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_MATHTOFUNCS_MATHTOFUNCS_H +#define MLIR_CONVERSION_MATHTOFUNCS_MATHTOFUNCS_H + +#include "mlir/IR/PatternMatch.h" +#include + +namespace mlir { + +class ModuleOp; +template +class OperationPass; +class PatternBenefit; +class RewritePatternSet; + +void populateMathToFuncsConversionPatterns(RewritePatternSet &patterns, + PatternBenefit benefit); + +std::unique_ptr> createConvertMathToFuncsPass(); +} // namespace mlir + +#endif // MLIR_CONVERSION_MATHTOFUNCS_MATHTOFUNCS_H Index: mlir/include/mlir/Conversion/Passes.h =================================================================== --- mlir/include/mlir/Conversion/Passes.h +++ mlir/include/mlir/Conversion/Passes.h @@ -32,6 +32,7 @@ #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" #include "mlir/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.h" #include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h" +#include "mlir/Conversion/MathToFuncs/MathToFuncs.h" #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" #include "mlir/Conversion/MathToLibm/MathToLibm.h" #include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h" Index: mlir/include/mlir/Conversion/Passes.td =================================================================== --- mlir/include/mlir/Conversion/Passes.td +++ mlir/include/mlir/Conversion/Passes.td @@ -511,6 +511,27 @@ let dependentDialects = ["spirv::SPIRVDialect"]; } +//===----------------------------------------------------------------------===// +// MathToFuncs +//===----------------------------------------------------------------------===// + +def ConvertMathToFuncs : Pass<"convert-math-to-funcs", "ModuleOp"> { + let summary = "Convert Math operations to calls of outlined implementations"; + let description = [{ + This pass converts supported Math ops to calls of compiler generated + functions implementing these operations in software. + LLVMDialect is used for LinkonceODR linkage of the generated functions. + }]; + let constructor = "mlir::createConvertMathToFuncsPass()"; + let dependentDialects = [ + "arith::ArithmeticDialect", + "cf::ControlFlowDialect", + "func::FuncDialect", + "vector::VectorDialect", + "LLVM::LLVMDialect", + ]; +} + //===----------------------------------------------------------------------===// // MemRefToLLVM //===----------------------------------------------------------------------===// Index: mlir/lib/Conversion/CMakeLists.txt =================================================================== --- mlir/lib/Conversion/CMakeLists.txt +++ mlir/lib/Conversion/CMakeLists.txt @@ -21,6 +21,7 @@ add_subdirectory(LinalgToSPIRV) add_subdirectory(LinalgToStandard) add_subdirectory(LLVMCommon) +add_subdirectory(MathToFuncs) add_subdirectory(MathToLibm) add_subdirectory(MathToLLVM) add_subdirectory(MathToSPIRV) Index: mlir/lib/Conversion/MathToFuncs/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/lib/Conversion/MathToFuncs/CMakeLists.txt @@ -0,0 +1,23 @@ +add_mlir_conversion_library(MLIRMathToFuncs + MathToFuncs.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToFuncs + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRArithmeticDialect + MLIRControlFlowDialect + MLIRFuncDialect + MLIRLLVMDialect + MLIRMathDialect + MLIRPass + MLIRTransforms + MLIRVectorDialect + MLIRVectorUtils + ) Index: mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp =================================================================== --- /dev/null +++ mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp @@ -0,0 +1,350 @@ +//===- MathToFuncs.cpp - Math to outlined implementation conversion -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/MathToFuncs/MathToFuncs.h" +#include "../PassDetail.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Utils/VectorUtils.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; + +namespace { +// Pattern to convert vector operations to scalar operations. +template +struct VecOpToScalarOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; +}; + +// Pattern to convert scalar IPowIOp into a call of outlined +// software implementation. +struct IPowIOpLowering : public OpRewritePattern { +private: + func::FuncOp getElementFunc(math::IPowIOp op, IntegerType elementType, + PatternRewriter &rewriter) const; + +public: + IPowIOpLowering(MLIRContext *context, PatternBenefit benefit) + : OpRewritePattern(context, benefit) {} + + /// Convert IPowI into a call to a local function implementing + /// the power operation. The local function computes a scalar result, + /// so vector forms of IPowI are linearized. + LogicalResult matchAndRewrite(math::IPowIOp op, + PatternRewriter &rewriter) const final; +}; + +} // namespace + +template +LogicalResult +VecOpToScalarOp::matchAndRewrite(Op op, PatternRewriter &rewriter) const { + Type opType = op.getType(); + Location loc = op.getLoc(); + auto vecType = opType.template dyn_cast(); + + if (!vecType) + return rewriter.notifyMatchFailure(op, "not a vector operation"); + if (!vecType.hasRank()) + return rewriter.notifyMatchFailure(op, "unknown vector rank"); + ArrayRef shape = vecType.getShape(); + int64_t numElements = vecType.getNumElements(); + + Value result = rewriter.create( + loc, DenseElementsAttr::get( + vecType, IntegerAttr::get(vecType.getElementType(), 0))); + SmallVector ones(shape.size(), 1); + SmallVector strides = computeStrides(shape, ones); + for (int64_t linearIndex = 0; linearIndex < numElements; ++linearIndex) { + SmallVector positions = delinearize(strides, linearIndex); + SmallVector operands; + for (auto input : op->getOperands()) + operands.push_back( + rewriter.create(loc, input, positions)); + Value scalarOp = + rewriter.create(loc, vecType.getElementType(), operands); + result = + rewriter.create(loc, scalarOp, result, positions); + } + rewriter.replaceOp(op, result); + return success(); +} + +/// Create linkonce_odr function to implement the power function with +/// the given \p elementType scalar type. +/// +/// template +/// T __mlir_math_ipowi_*(T b, T p) { +/// if (p == T(0)) +/// return T(1); +/// if (p < T(0)) { +/// if (b == T(0)) +/// return T(1) / T(0); // trigger div-by-zero +/// if (b == T(1)) +/// return T(1); +/// if (b == T(-1)) { +/// if (p & T(1)) +/// return T(-1); +/// return T(1); +/// } +/// return T(0); +/// } +/// T result = T(1); +/// while (true) { +/// if (p & T(1)) +/// result *= b; +/// p >>= T(1); +/// if (p == T(0)) +/// return result; +/// b *= b; +/// } +/// } +func::FuncOp IPowIOpLowering::getElementFunc(math::IPowIOp op, + IntegerType elementType, + PatternRewriter &rewriter) const { + std::string funcName("__mlir_math_ipowi_"); + llvm::raw_string_ostream nameOS(funcName); + elementType.print(nameOS); + + Operation *module = SymbolTable::getNearestSymbolTable(op); + FunctionType funcType = FunctionType::get( + rewriter.getContext(), {elementType, elementType}, {elementType}); + if (auto funcOp = dyn_cast_or_null( + SymbolTable::lookupSymbolIn(module, funcName))) { + assert(funcOp.getFunctionTypeAttr().getValue() == funcType && + "ipowi function type mismatch"); + return funcOp; + } + + OpBuilder::InsertionGuard guard(rewriter); + Location loc = rewriter.getUnknownLoc(); + rewriter.setInsertionPointToEnd(&module->getRegion(0).front()); + auto funcOp = rewriter.create(loc, funcName, funcType); + LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR; + Attribute linkage = + LLVM::LinkageAttr::get(rewriter.getContext(), inlineLinkage); + funcOp->setAttr("llvm.linkage", linkage); + funcOp.setPrivate(); + + Block *entryBlock = funcOp.addEntryBlock(); + Region *funcBody = entryBlock->getParent(); + + Value bArg = funcOp.getArgument(0); + Value pArg = funcOp.getArgument(1); + rewriter.setInsertionPointToEnd(entryBlock); + Value zeroValue = rewriter.create( + loc, elementType, rewriter.getIntegerAttr(elementType, 0)); + Value oneValue = rewriter.create( + loc, elementType, rewriter.getIntegerAttr(elementType, 1)); + Value minusOneValue = rewriter.create( + loc, elementType, + rewriter.getIntegerAttr(elementType, + APInt(elementType.getIntOrFloatBitWidth(), -1ULL, + /*isSigned=*/true))); + + // if (p == T(0)) + // return T(1); + auto pIsZero = rewriter.create(loc, arith::CmpIPredicate::eq, + pArg, zeroValue); + Block *thenBlock = rewriter.createBlock(funcBody); + rewriter.create(loc, oneValue); + Block *fallthroughBlock = rewriter.createBlock(funcBody); + // Set up conditional branch for (p == T(0)). + rewriter.setInsertionPointToEnd(pIsZero->getBlock()); + rewriter.create(loc, pIsZero, thenBlock, fallthroughBlock); + + // if (p < T(0)) { + rewriter.setInsertionPointToEnd(fallthroughBlock); + auto pIsNeg = rewriter.create(loc, arith::CmpIPredicate::sle, + pArg, zeroValue); + // if (b == T(0)) + rewriter.createBlock(funcBody); + auto bIsZero = rewriter.create(loc, arith::CmpIPredicate::eq, + bArg, zeroValue); + // return T(1) / T(0); + thenBlock = rewriter.createBlock(funcBody); + rewriter.create( + loc, + rewriter.create(loc, oneValue, zeroValue).getResult()); + fallthroughBlock = rewriter.createBlock(funcBody); + // Set up conditional branch for (b == T(0)). + rewriter.setInsertionPointToEnd(bIsZero->getBlock()); + rewriter.create(loc, bIsZero, thenBlock, fallthroughBlock); + + // if (b == T(1)) + rewriter.setInsertionPointToEnd(fallthroughBlock); + auto bIsOne = rewriter.create(loc, arith::CmpIPredicate::eq, + bArg, oneValue); + // return T(1); + thenBlock = rewriter.createBlock(funcBody); + rewriter.create(loc, oneValue); + fallthroughBlock = rewriter.createBlock(funcBody); + // Set up conditional branch for (b == T(1)). + rewriter.setInsertionPointToEnd(bIsOne->getBlock()); + rewriter.create(loc, bIsOne, thenBlock, fallthroughBlock); + + // if (b == T(-1)) { + rewriter.setInsertionPointToEnd(fallthroughBlock); + auto bIsMinusOne = rewriter.create( + loc, arith::CmpIPredicate::eq, bArg, minusOneValue); + // if (p & T(1)) + rewriter.createBlock(funcBody); + auto pIsOdd = rewriter.create( + loc, arith::CmpIPredicate::ne, + rewriter.create(loc, pArg, oneValue), zeroValue); + // return T(-1); + thenBlock = rewriter.createBlock(funcBody); + rewriter.create(loc, minusOneValue); + fallthroughBlock = rewriter.createBlock(funcBody); + // Set up conditional branch for (p & T(1)). + rewriter.setInsertionPointToEnd(pIsOdd->getBlock()); + rewriter.create(loc, pIsOdd, thenBlock, fallthroughBlock); + + // return T(1); + // } // b == T(-1) + rewriter.setInsertionPointToEnd(fallthroughBlock); + rewriter.create(loc, oneValue); + fallthroughBlock = rewriter.createBlock(funcBody); + // Set up conditional branch for (b == T(-1)). + rewriter.setInsertionPointToEnd(bIsMinusOne->getBlock()); + rewriter.create(loc, bIsMinusOne, pIsOdd->getBlock(), + fallthroughBlock); + + // return T(0); + // } // (p < T(0)) + rewriter.setInsertionPointToEnd(fallthroughBlock); + rewriter.create(loc, zeroValue); + Block *loopHeader = rewriter.createBlock( + funcBody, funcBody->end(), {elementType, elementType, elementType}, + {loc, loc, loc}); + // Set up conditional branch for (p < T(0)). + rewriter.setInsertionPointToEnd(pIsNeg->getBlock()); + // Set initial values of 'result', 'b' and 'p' for the loop. + rewriter.create(loc, pIsNeg, bIsZero->getBlock(), + loopHeader, + ValueRange{oneValue, bArg, pArg}); + + // T result = T(1); + // while (true) { + // if (p & T(1)) + // result *= b; + // p >>= T(1); + // if (p == T(0)) + // return result; + // b *= b; + // } + Value resultTmp = loopHeader->getArgument(0); + Value baseTmp = loopHeader->getArgument(1); + Value powerTmp = loopHeader->getArgument(2); + rewriter.setInsertionPointToEnd(loopHeader); + + // if (p & T(1)) + auto powerTmpIsOdd = rewriter.create( + loc, arith::CmpIPredicate::ne, + rewriter.create(loc, powerTmp, oneValue), zeroValue); + thenBlock = rewriter.createBlock(funcBody); + // result *= b; + Value newResultTmp = rewriter.create(loc, resultTmp, baseTmp); + fallthroughBlock = + rewriter.createBlock(funcBody, funcBody->end(), elementType, loc); + rewriter.setInsertionPointToEnd(thenBlock); + rewriter.create(loc, newResultTmp, fallthroughBlock); + // Set up conditional branch for (p & T(1)). + rewriter.setInsertionPointToEnd(powerTmpIsOdd->getBlock()); + rewriter.create(loc, powerTmpIsOdd, thenBlock, + fallthroughBlock, resultTmp); + // Merged 'result'. + newResultTmp = fallthroughBlock->getArgument(0); + + // p >>= T(1); + rewriter.setInsertionPointToEnd(fallthroughBlock); + Value newPowerTmp = rewriter.create(loc, powerTmp, oneValue); + + // if (p == T(0)) + auto newPowerIsZero = rewriter.create( + loc, arith::CmpIPredicate::eq, newPowerTmp, zeroValue); + // return result; + thenBlock = rewriter.createBlock(funcBody); + rewriter.create(loc, newResultTmp); + fallthroughBlock = rewriter.createBlock(funcBody); + // Set up conditional branch for (p == T(0)). + rewriter.setInsertionPointToEnd(newPowerIsZero->getBlock()); + rewriter.create(loc, newPowerIsZero, thenBlock, + fallthroughBlock); + + // b *= b; + // } + rewriter.setInsertionPointToEnd(fallthroughBlock); + Value newBaseTmp = rewriter.create(loc, baseTmp, baseTmp); + // Pass new values for 'result', 'b' and 'p' to the loop header. + rewriter.create( + loc, ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader); + return funcOp; +} + +/// Convert IPowI into a call to a local function implementing +/// the power operation. The local function computes a scalar result, +/// so vector forms of IPowI are linearized. +LogicalResult +IPowIOpLowering::matchAndRewrite(math::IPowIOp op, + PatternRewriter &rewriter) const { + auto baseType = op.getOperands()[0].getType().dyn_cast(); + auto exponentType = op.getOperands()[1].getType().dyn_cast(); + + if (!baseType) + return rewriter.notifyMatchFailure(op, "non-integer base operand"); + if (baseType != exponentType) + return rewriter.notifyMatchFailure(op, "operands' types mismatch"); + + func::FuncOp elementFunc = getElementFunc(op, baseType, rewriter); + rewriter.replaceOpWithNewOp(op, elementFunc, op.getOperands()); + return success(); +} + +void mlir::populateMathToFuncsConversionPatterns(RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add>(patterns.getContext(), benefit); + patterns.add(patterns.getContext(), benefit); +} + +namespace { +struct ConvertMathToFuncsPass + : public ConvertMathToFuncsBase { + ConvertMathToFuncsPass() = default; + + void runOnOperation() override; +}; +} // namespace + +void ConvertMathToFuncsPass::runOnOperation() { + ModuleOp module = getOperation(); + + RewritePatternSet patterns(&getContext()); + populateMathToFuncsConversionPatterns(patterns, /*benefit=*/1); + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addIllegalOp(); + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + signalPassFailure(); +} + +std::unique_ptr> mlir::createConvertMathToFuncsPass() { + return std::make_unique(); +} Index: mlir/test/Conversion/MathToFuncs/math-to-funcs.mlir =================================================================== --- /dev/null +++ mlir/test/Conversion/MathToFuncs/math-to-funcs.mlir @@ -0,0 +1,172 @@ +// RUN: mlir-opt %s -split-input-file -pass-pipeline="convert-math-to-funcs" | FileCheck %s + +// ----- + +// CHECK-LABEL: func @ipowi( +// CHECK-SAME: %[[ARG0:.+]]: i64, +// CHECK-SAME: %[[ARG1:.+]]: i64) +func.func @ipowi(%arg0: i64, %arg1: i64) { + // CHECK: call @__mlir_math_ipowi_i64(%[[ARG0]], %[[ARG1]]) : (i64, i64) -> i64 + %0 = math.ipowi %arg0, %arg1 : i64 + func.return +} + +// CHECK-LABEL: func.func private @__mlir_math_ipowi_i64( +// CHECK-SAME: %[[VAL_0:.*]]: i64, +// CHECK-SAME: %[[VAL_1:.*]]: i64) -> i64 +// CHECK-SAME: attributes {llvm.linkage = #llvm.linkage} { +// CHECK: %[[VAL_2:.*]] = arith.constant 0 : i64 +// CHECK: %[[VAL_3:.*]] = arith.constant 1 : i64 +// CHECK: %[[VAL_4:.*]] = arith.constant -1 : i64 +// CHECK: %[[VAL_5:.*]] = arith.cmpi eq, %[[VAL_1]], %[[VAL_2]] : i64 +// CHECK: cf.cond_br %[[VAL_5]], ^bb1, ^bb2 +// CHECK: ^bb1: +// CHECK: return %[[VAL_3]] : i64 +// CHECK: ^bb2: +// CHECK: %[[VAL_6:.*]] = arith.cmpi sle, %[[VAL_1]], %[[VAL_2]] : i64 +// CHECK: cf.cond_br %[[VAL_6]], ^bb3, ^bb12(%[[VAL_3]], %[[VAL_0]], %[[VAL_1]] : i64, i64, i64) +// CHECK: ^bb3: +// CHECK: %[[VAL_7:.*]] = arith.cmpi eq, %[[VAL_0]], %[[VAL_2]] : i64 +// CHECK: cf.cond_br %[[VAL_7]], ^bb4, ^bb5 +// CHECK: ^bb4: +// CHECK: %[[VAL_8:.*]] = arith.divsi %[[VAL_3]], %[[VAL_2]] : i64 +// CHECK: return %[[VAL_8]] : i64 +// CHECK: ^bb5: +// CHECK: %[[VAL_9:.*]] = arith.cmpi eq, %[[VAL_0]], %[[VAL_3]] : i64 +// CHECK: cf.cond_br %[[VAL_9]], ^bb6, ^bb7 +// CHECK: ^bb6: +// CHECK: return %[[VAL_3]] : i64 +// CHECK: ^bb7: +// CHECK: %[[VAL_10:.*]] = arith.cmpi eq, %[[VAL_0]], %[[VAL_4]] : i64 +// CHECK: cf.cond_br %[[VAL_10]], ^bb8, ^bb11 +// CHECK: ^bb8: +// CHECK: %[[VAL_11:.*]] = arith.andi %[[VAL_1]], %[[VAL_3]] : i64 +// CHECK: %[[VAL_12:.*]] = arith.cmpi ne, %[[VAL_11]], %[[VAL_2]] : i64 +// CHECK: cf.cond_br %[[VAL_12]], ^bb9, ^bb10 +// CHECK: ^bb9: +// CHECK: return %[[VAL_4]] : i64 +// CHECK: ^bb10: +// CHECK: return %[[VAL_3]] : i64 +// CHECK: ^bb11: +// CHECK: return %[[VAL_2]] : i64 +// CHECK: ^bb12(%[[VAL_13:.*]]: i64, %[[VAL_14:.*]]: i64, %[[VAL_15:.*]]: i64): +// CHECK: %[[VAL_16:.*]] = arith.andi %[[VAL_15]], %[[VAL_3]] : i64 +// CHECK: %[[VAL_17:.*]] = arith.cmpi ne, %[[VAL_16]], %[[VAL_2]] : i64 +// CHECK: cf.cond_br %[[VAL_17]], ^bb13, ^bb14(%[[VAL_13]] : i64) +// CHECK: ^bb13: +// CHECK: %[[VAL_18:.*]] = arith.muli %[[VAL_13]], %[[VAL_14]] : i64 +// CHECK: cf.br ^bb14(%[[VAL_18]] : i64) +// CHECK: ^bb14(%[[VAL_19:.*]]: i64): +// CHECK: %[[VAL_20:.*]] = arith.shrui %[[VAL_15]], %[[VAL_3]] : i64 +// CHECK: %[[VAL_21:.*]] = arith.cmpi eq, %[[VAL_20]], %[[VAL_2]] : i64 +// CHECK: cf.cond_br %[[VAL_21]], ^bb15, ^bb16 +// CHECK: ^bb15: +// CHECK: return %[[VAL_19]] : i64 +// CHECK: ^bb16: +// CHECK: %[[VAL_22:.*]] = arith.muli %[[VAL_14]], %[[VAL_14]] : i64 +// CHECK: cf.br ^bb12(%[[VAL_19]], %[[VAL_22]], %[[VAL_20]] : i64, i64, i64) +// CHECK: } + +// ----- + +// CHECK-LABEL: func @ipowi( +// CHECK-SAME: %[[ARG0:.+]]: i8, +// CHECK-SAME: %[[ARG1:.+]]: i8) + // CHECK: call @__mlir_math_ipowi_i8(%[[ARG0]], %[[ARG1]]) : (i8, i8) -> i8 +func.func @ipowi(%arg0: i8, %arg1: i8) { + %0 = math.ipowi %arg0, %arg1 : i8 + func.return +} + +// CHECK-LABEL: func.func private @__mlir_math_ipowi_i8( +// CHECK-SAME: %[[VAL_0:.*]]: i8, +// CHECK-SAME: %[[VAL_1:.*]]: i8) -> i8 +// CHECK-SAME: attributes {llvm.linkage = #llvm.linkage} { +// CHECK: %[[VAL_2:.*]] = arith.constant 0 : i8 +// CHECK: %[[VAL_3:.*]] = arith.constant 1 : i8 +// CHECK: %[[VAL_4:.*]] = arith.constant -1 : i8 +// CHECK: %[[VAL_5:.*]] = arith.cmpi eq, %[[VAL_1]], %[[VAL_2]] : i8 +// CHECK: cf.cond_br %[[VAL_5]], ^bb1, ^bb2 +// CHECK: ^bb1: +// CHECK: return %[[VAL_3]] : i8 +// CHECK: ^bb2: +// CHECK: %[[VAL_6:.*]] = arith.cmpi sle, %[[VAL_1]], %[[VAL_2]] : i8 +// CHECK: cf.cond_br %[[VAL_6]], ^bb3, ^bb12(%[[VAL_3]], %[[VAL_0]], %[[VAL_1]] : i8, i8, i8) +// CHECK: ^bb3: +// CHECK: %[[VAL_7:.*]] = arith.cmpi eq, %[[VAL_0]], %[[VAL_2]] : i8 +// CHECK: cf.cond_br %[[VAL_7]], ^bb4, ^bb5 +// CHECK: ^bb4: +// CHECK: %[[VAL_8:.*]] = arith.divsi %[[VAL_3]], %[[VAL_2]] : i8 +// CHECK: return %[[VAL_8]] : i8 +// CHECK: ^bb5: +// CHECK: %[[VAL_9:.*]] = arith.cmpi eq, %[[VAL_0]], %[[VAL_3]] : i8 +// CHECK: cf.cond_br %[[VAL_9]], ^bb6, ^bb7 +// CHECK: ^bb6: +// CHECK: return %[[VAL_3]] : i8 +// CHECK: ^bb7: +// CHECK: %[[VAL_10:.*]] = arith.cmpi eq, %[[VAL_0]], %[[VAL_4]] : i8 +// CHECK: cf.cond_br %[[VAL_10]], ^bb8, ^bb11 +// CHECK: ^bb8: +// CHECK: %[[VAL_11:.*]] = arith.andi %[[VAL_1]], %[[VAL_3]] : i8 +// CHECK: %[[VAL_12:.*]] = arith.cmpi ne, %[[VAL_11]], %[[VAL_2]] : i8 +// CHECK: cf.cond_br %[[VAL_12]], ^bb9, ^bb10 +// CHECK: ^bb9: +// CHECK: return %[[VAL_4]] : i8 +// CHECK: ^bb10: +// CHECK: return %[[VAL_3]] : i8 +// CHECK: ^bb11: +// CHECK: return %[[VAL_2]] : i8 +// CHECK: ^bb12(%[[VAL_13:.*]]: i8, %[[VAL_14:.*]]: i8, %[[VAL_15:.*]]: i8): +// CHECK: %[[VAL_16:.*]] = arith.andi %[[VAL_15]], %[[VAL_3]] : i8 +// CHECK: %[[VAL_17:.*]] = arith.cmpi ne, %[[VAL_16]], %[[VAL_2]] : i8 +// CHECK: cf.cond_br %[[VAL_17]], ^bb13, ^bb14(%[[VAL_13]] : i8) +// CHECK: ^bb13: +// CHECK: %[[VAL_18:.*]] = arith.muli %[[VAL_13]], %[[VAL_14]] : i8 +// CHECK: cf.br ^bb14(%[[VAL_18]] : i8) +// CHECK: ^bb14(%[[VAL_19:.*]]: i8): +// CHECK: %[[VAL_20:.*]] = arith.shrui %[[VAL_15]], %[[VAL_3]] : i8 +// CHECK: %[[VAL_21:.*]] = arith.cmpi eq, %[[VAL_20]], %[[VAL_2]] : i8 +// CHECK: cf.cond_br %[[VAL_21]], ^bb15, ^bb16 +// CHECK: ^bb15: +// CHECK: return %[[VAL_19]] : i8 +// CHECK: ^bb16: +// CHECK: %[[VAL_22:.*]] = arith.muli %[[VAL_14]], %[[VAL_14]] : i8 +// CHECK: cf.br ^bb12(%[[VAL_19]], %[[VAL_22]], %[[VAL_20]] : i8, i8, i8) +// CHECK: } + +// ----- + +// CHECK-LABEL: func.func @ipowi_vec( +// CHECK-SAME: %[[VAL_0:.*]]: vector<2x3xi64>, +// CHECK-SAME: %[[VAL_1:.*]]: vector<2x3xi64>) { +func.func @ipowi_vec(%arg0: vector<2x3xi64>, %arg1: vector<2x3xi64>) { +// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<2x3xi64> +// CHECK: %[[B00:.*]] = vector.extract %[[VAL_0]][0, 0] : vector<2x3xi64> +// CHECK: %[[E00:.*]] = vector.extract %[[VAL_1]][0, 0] : vector<2x3xi64> +// CHECK: %[[R00:.*]] = call @__mlir_math_ipowi_i64(%[[B00]], %[[E00]]) : (i64, i64) -> i64 +// CHECK: %[[TMP00:.*]] = vector.insert %[[R00]], %[[CST]] [0, 0] : i64 into vector<2x3xi64> +// CHECK: %[[B01:.*]] = vector.extract %[[VAL_0]][0, 1] : vector<2x3xi64> +// CHECK: %[[E01:.*]] = vector.extract %[[VAL_1]][0, 1] : vector<2x3xi64> +// CHECK: %[[R01:.*]] = call @__mlir_math_ipowi_i64(%[[B01]], %[[E01]]) : (i64, i64) -> i64 +// CHECK: %[[TMP01:.*]] = vector.insert %[[R01]], %[[TMP00]] [0, 1] : i64 into vector<2x3xi64> +// CHECK: %[[B02:.*]] = vector.extract %[[VAL_0]][0, 2] : vector<2x3xi64> +// CHECK: %[[E02:.*]] = vector.extract %[[VAL_1]][0, 2] : vector<2x3xi64> +// CHECK: %[[R02:.*]] = call @__mlir_math_ipowi_i64(%[[B02]], %[[E02]]) : (i64, i64) -> i64 +// CHECK: %[[TMP02:.*]] = vector.insert %[[R02]], %[[TMP01]] [0, 2] : i64 into vector<2x3xi64> +// CHECK: %[[B10:.*]] = vector.extract %[[VAL_0]][1, 0] : vector<2x3xi64> +// CHECK: %[[E10:.*]] = vector.extract %[[VAL_1]][1, 0] : vector<2x3xi64> +// CHECK: %[[R10:.*]] = call @__mlir_math_ipowi_i64(%[[B10]], %[[E10]]) : (i64, i64) -> i64 +// CHECK: %[[TMP10:.*]] = vector.insert %[[R10]], %[[TMP02]] [1, 0] : i64 into vector<2x3xi64> +// CHECK: %[[B11:.*]] = vector.extract %[[VAL_0]][1, 1] : vector<2x3xi64> +// CHECK: %[[E11:.*]] = vector.extract %[[VAL_1]][1, 1] : vector<2x3xi64> +// CHECK: %[[R11:.*]] = call @__mlir_math_ipowi_i64(%[[B11]], %[[E11]]) : (i64, i64) -> i64 +// CHECK: %[[TMP11:.*]] = vector.insert %[[R11]], %[[TMP10]] [1, 1] : i64 into vector<2x3xi64> +// CHECK: %[[B12:.*]] = vector.extract %[[VAL_0]][1, 2] : vector<2x3xi64> +// CHECK: %[[E12:.*]] = vector.extract %[[VAL_1]][1, 2] : vector<2x3xi64> +// CHECK: %[[R12:.*]] = call @__mlir_math_ipowi_i64(%[[B12]], %[[E12]]) : (i64, i64) -> i64 +// CHECK: %[[TMP12:.*]] = vector.insert %[[R12]], %[[TMP11]] [1, 2] : i64 into vector<2x3xi64> +// CHECK: return +// CHECK: } + %0 = math.ipowi %arg0, %arg1 : vector<2x3xi64> + func.return +}