Index: mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h =================================================================== --- mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h +++ mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h @@ -14,13 +14,16 @@ namespace mlir { class LLVMTypeConverter; -class RewritePatternSet; +class ModuleOp; +template +class OperationPass; class Pass; +class RewritePatternSet; void populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns); -std::unique_ptr createConvertMathToLLVMPass(); +std::unique_ptr> createConvertMathToLLVMPass(); } // namespace mlir #endif // MLIR_CONVERSION_MATHTOLLVM_MATHTOLLVM_H Index: mlir/include/mlir/Conversion/Passes.td =================================================================== --- mlir/include/mlir/Conversion/Passes.td +++ mlir/include/mlir/Conversion/Passes.td @@ -488,7 +488,7 @@ // MathToLLVM //===----------------------------------------------------------------------===// -def ConvertMathToLLVM : Pass<"convert-math-to-llvm"> { +def ConvertMathToLLVM : Pass<"convert-math-to-llvm", "ModuleOp"> { let summary = "Convert Math dialect to LLVM dialect"; let description = [{ This pass converts supported Math ops to LLVM dialect intrinsics. Index: mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp =================================================================== --- mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp +++ mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp @@ -248,6 +248,275 @@ } }; +struct IPowIOpLowering : public ConvertOpToLLVMPattern { +private: + /// Create linkonce_odr function to implement the power function with + /// the given \p elementType scalar type. + /// + /// template + /// T __mlir_math_ipowi_*(T b, T p) { + /// if (p == T(0)) + /// return T(1); + /// if (p < T(0)) { + /// if (b == T(0)) + /// return T(1) / T(0); // trigger div-by-zero + /// if (b == T(1)) + /// return T(1); + /// if (b == T(-1)) { + /// if (p & T(1)) + /// return T(-1); + /// return T(1); + /// } + /// return T(0); + /// } + /// T result = T(1); + /// while (true) { + /// if (p & T(1)) + /// result *= b; + /// p >>= T(1); + /// if (p == T(0)) + /// return result; + /// b *= b; + /// } + /// } + LLVM::LLVMFuncOp getElementFunc(math::IPowIOp op, IntegerType elementType, + PatternRewriter &rewriter) const { + std::string funcName("__mlir_math_ipowi_"); + llvm::raw_string_ostream nameOS(funcName); + elementType.print(nameOS); + + auto module = SymbolTable::getNearestSymbolTable(op); + auto funcType = + LLVM::LLVMFunctionType::get(elementType, {elementType, elementType}); + if (auto funcOp = dyn_cast_or_null( + SymbolTable::lookupSymbolIn(module, funcName))) { + assert(funcOp.getFunctionTypeAttr().getValue() == funcType && + "ipowi function type mismatch"); + return funcOp; + } + + OpBuilder::InsertionGuard guard(rewriter); + auto loc = rewriter.getUnknownLoc(); + rewriter.setInsertionPointToEnd(&module->getRegion(0).front()); + auto funcOp = rewriter.create(loc, funcName, funcType, + LLVM::Linkage::LinkonceODR); + funcOp.setPrivate(); + + auto *entryBlock = funcOp.addEntryBlock(); + auto *funcBody = entryBlock->getParent(); + + auto bArg = funcOp.getArgument(0); + auto pArg = funcOp.getArgument(1); + rewriter.setInsertionPointToEnd(entryBlock); + Value zeroValue = rewriter.create( + loc, elementType, rewriter.getIntegerAttr(elementType, 0)); + Value oneValue = rewriter.create( + loc, elementType, rewriter.getIntegerAttr(elementType, 1)); + Value minusOneValue = rewriter.create( + loc, elementType, + rewriter.getIntegerAttr(elementType, + APInt(elementType.getIntOrFloatBitWidth(), + -1ULL, /*isSigned=*/true))); + + // if (p == T(0)) + // return T(1); + auto pIsZero = rewriter.create(loc, LLVM::ICmpPredicate::eq, + pArg, zeroValue); + auto *thenBlock = rewriter.createBlock(funcBody); + rewriter.create(loc, oneValue); + auto *fallthroughBlock = rewriter.createBlock(funcBody); + // Set up conditional branch for (p == T(0)). + rewriter.setInsertionPointToEnd(pIsZero->getBlock()); + rewriter.create(loc, pIsZero, thenBlock, fallthroughBlock); + + // if (p < T(0)) { + rewriter.setInsertionPointToEnd(fallthroughBlock); + auto pIsNeg = rewriter.create(loc, LLVM::ICmpPredicate::sle, + pArg, zeroValue); + // if (b == T(0)) + rewriter.createBlock(funcBody); + auto bIsZero = rewriter.create(loc, LLVM::ICmpPredicate::eq, + bArg, zeroValue); + // return T(1) / T(0); + thenBlock = rewriter.createBlock(funcBody); + rewriter.create( + loc, + rewriter.create(loc, oneValue, zeroValue).getResult()); + fallthroughBlock = rewriter.createBlock(funcBody); + // Set up conditional branch for (b == T(0)). + rewriter.setInsertionPointToEnd(bIsZero->getBlock()); + rewriter.create(loc, bIsZero, thenBlock, fallthroughBlock); + + // if (b == T(1)) + rewriter.setInsertionPointToEnd(fallthroughBlock); + auto bIsOne = rewriter.create(loc, LLVM::ICmpPredicate::eq, + bArg, oneValue); + // return T(1); + thenBlock = rewriter.createBlock(funcBody); + rewriter.create(loc, oneValue); + fallthroughBlock = rewriter.createBlock(funcBody); + // Set up conditional branch for (b == T(1)). + rewriter.setInsertionPointToEnd(bIsOne->getBlock()); + rewriter.create(loc, bIsOne, thenBlock, fallthroughBlock); + + // if (b == T(-1)) { + rewriter.setInsertionPointToEnd(fallthroughBlock); + auto bIsMinusOne = rewriter.create( + loc, LLVM::ICmpPredicate::eq, bArg, minusOneValue); + // if (p & T(1)) + rewriter.createBlock(funcBody); + auto pIsOdd = rewriter.create( + loc, LLVM::ICmpPredicate::ne, + rewriter.create(loc, pArg, oneValue), zeroValue); + // return T(-1); + thenBlock = rewriter.createBlock(funcBody); + rewriter.create(loc, minusOneValue); + fallthroughBlock = rewriter.createBlock(funcBody); + // Set up conditional branch for (p & T(1)). + rewriter.setInsertionPointToEnd(pIsOdd->getBlock()); + rewriter.create(loc, pIsOdd, thenBlock, fallthroughBlock); + + // return T(1); + // } // b == T(-1) + rewriter.setInsertionPointToEnd(fallthroughBlock); + rewriter.create(loc, oneValue); + fallthroughBlock = rewriter.createBlock(funcBody); + // Set up conditional branch for (b == T(-1)). + rewriter.setInsertionPointToEnd(bIsMinusOne->getBlock()); + rewriter.create(loc, bIsMinusOne, pIsOdd->getBlock(), + fallthroughBlock); + + // return T(0); + // } // (p < T(0)) + rewriter.setInsertionPointToEnd(fallthroughBlock); + rewriter.create(loc, zeroValue); + auto *loopHeader = rewriter.createBlock( + funcBody, funcBody->end(), {elementType, elementType, elementType}, + {loc, loc, loc}); + // Set up conditional branch for (p < T(0)). + rewriter.setInsertionPointToEnd(pIsNeg->getBlock()); + // Set initial values of 'result', 'b' and 'p' for the loop. + rewriter.create(loc, pIsNeg, bIsZero->getBlock(), + loopHeader, + ValueRange{oneValue, bArg, pArg}); + + // T result = T(1); + // while (true) { + // if (p & T(1)) + // result *= b; + // p >>= T(1); + // if (p == T(0)) + // return result; + // b *= b; + // } + Value resultTmp = loopHeader->getArgument(0); + Value baseTmp = loopHeader->getArgument(1); + Value powerTmp = loopHeader->getArgument(2); + rewriter.setInsertionPointToEnd(loopHeader); + + // if (p & T(1)) + auto powerTmpIsOdd = rewriter.create( + loc, LLVM::ICmpPredicate::ne, + rewriter.create(loc, powerTmp, oneValue), zeroValue); + thenBlock = rewriter.createBlock(funcBody); + // result *= b; + Value newResultTmp = rewriter.create(loc, resultTmp, baseTmp); + fallthroughBlock = + rewriter.createBlock(funcBody, funcBody->end(), elementType, loc); + rewriter.setInsertionPointToEnd(thenBlock); + rewriter.create(loc, newResultTmp, fallthroughBlock); + // Set up conditional branch for (p & T(1)). + rewriter.setInsertionPointToEnd(powerTmpIsOdd->getBlock()); + rewriter.create(loc, powerTmpIsOdd, thenBlock, + fallthroughBlock, resultTmp); + // Merged 'result'. + newResultTmp = fallthroughBlock->getArgument(0); + + // p >>= T(1); + rewriter.setInsertionPointToEnd(fallthroughBlock); + Value newPowerTmp = rewriter.create(loc, powerTmp, oneValue); + + // if (p == T(0)) + auto newPowerIsZero = rewriter.create( + loc, LLVM::ICmpPredicate::eq, newPowerTmp, zeroValue); + // return result; + thenBlock = rewriter.createBlock(funcBody); + rewriter.create(loc, newResultTmp); + fallthroughBlock = rewriter.createBlock(funcBody); + // Set up conditional branch for (p == T(0)). + rewriter.setInsertionPointToEnd(newPowerIsZero->getBlock()); + rewriter.create(loc, newPowerIsZero, thenBlock, + fallthroughBlock); + + // b *= b; + // } + rewriter.setInsertionPointToEnd(fallthroughBlock); + Value newBaseTmp = rewriter.create(loc, baseTmp, baseTmp); + // Pass new values for 'result', 'b' and 'p' to the loop header. + rewriter.create( + loc, ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader); + return funcOp; + } + +public: + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + /// Convert IPowI into a call to a local function implementing + /// the power operation. The local function computes a scalar result, + /// so vector forms of IPowI are linearized. + LogicalResult + matchAndRewrite(math::IPowIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto operandType = adaptor.getOperands()[0].getType(); + + if (!operandType || !LLVM::isCompatibleType(operandType)) + return failure(); + + auto loc = op.getLoc(); + if (auto elementType = operandType.dyn_cast()) { + auto elementFunc = getElementFunc(op, elementType, rewriter); + rewriter.replaceOpWithNewOp(op, elementFunc, + adaptor.getOperands()); + return success(); + } + + auto resultType = op.getResult().getType(); + auto vectorType = resultType.dyn_cast(); + if (!vectorType) + return failure(); + + auto elementType = vectorType.getElementType().dyn_cast(); + if (!elementType) + return failure(); + + auto elementFunc = getElementFunc(op, elementType, rewriter); + return LLVM::detail::handleMultidimensionalVectors( + op.getOperation(), adaptor.getOperands(), *getTypeConverter(), + [&](Type llvm1DVectorTy, ValueRange operands) { + Value result = rewriter.create(loc, llvm1DVectorTy); + auto vecType = llvm1DVectorTy.cast(); + for (unsigned idx = 0, numElements = vecType.getNumElements(); + idx < numElements; ++idx) { + Value pos = rewriter.create( + loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(idx)); + Value baseOp = + rewriter.create(loc, operands[0], pos); + Value powerOp = + rewriter.create(loc, operands[1], pos); + Value eltResult = + rewriter + .create(loc, elementFunc, + ValueRange{baseOp, powerOp}) + .getResult(0); + result = rewriter.create( + loc, llvm1DVectorTy, result, eltResult, pos); + } + return result; + }, + rewriter); + } +}; + struct ConvertMathToLLVMPass : public ConvertMathToLLVMBase { ConvertMathToLLVMPass() = default; @@ -257,6 +526,7 @@ LLVMTypeConverter converter(&getContext()); populateMathToLLVMConversionPatterns(converter, patterns); LLVMConversionTarget target(getContext()); + target.addIllegalDialect(); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); @@ -285,6 +555,7 @@ Log2OpLowering, LogOpLowering, PowFOpLowering, + IPowIOpLowering, RoundOpLowering, RsqrtOpLowering, SinOpLowering, @@ -293,6 +564,6 @@ // clang-format on } -std::unique_ptr mlir::createConvertMathToLLVMPass() { +std::unique_ptr> mlir::createConvertMathToLLVMPass() { return std::make_unique(); } Index: mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp =================================================================== --- mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp +++ mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp @@ -78,7 +78,7 @@ pm.addPass(createConvertVectorToLLVMPass(options.lowerVectorToLLVMOptions())); pm.addPass(createMemRefToLLVMPass()); pm.addNestedPass(createConvertComplexToStandardPass()); - pm.addNestedPass(createConvertMathToLLVMPass()); + pm.addPass(createConvertMathToLLVMPass()); pm.addPass(createConvertMathToLibmPass()); pm.addPass(createConvertComplexToLibmPass()); pm.addPass(createConvertComplexToLLVMPass()); Index: mlir/test/Conversion/ComplexToStandard/full-conversion.mlir =================================================================== --- mlir/test/Conversion/ComplexToStandard/full-conversion.mlir +++ mlir/test/Conversion/ComplexToStandard/full-conversion.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -pass-pipeline="func.func(convert-complex-to-standard),convert-complex-to-llvm,func.func(convert-math-to-llvm,convert-arith-to-llvm),convert-func-to-llvm,reconcile-unrealized-casts" | FileCheck %s +// RUN: mlir-opt %s -pass-pipeline="func.func(convert-complex-to-standard),convert-complex-to-llvm,convert-math-to-llvm,func.func(convert-arith-to-llvm),convert-func-to-llvm,reconcile-unrealized-casts" | FileCheck %s // CHECK-LABEL: llvm.func @complex_abs // CHECK-SAME: %[[ARG:.*]]: ![[C_TY:.*]]) Index: mlir/test/Conversion/FuncToLLVM/func-to-llvm.mlir =================================================================== --- mlir/test/Conversion/FuncToLLVM/func-to-llvm.mlir +++ mlir/test/Conversion/FuncToLLVM/func-to-llvm.mlir @@ -1,5 +1,5 @@ -// RUN: mlir-opt -pass-pipeline="func.func(convert-math-to-llvm,convert-arith-to-llvm),convert-func-to-llvm,reconcile-unrealized-casts" %s -split-input-file | FileCheck %s -// RUN: mlir-opt -pass-pipeline="func.func(convert-math-to-llvm,convert-arith-to-llvm{index-bitwidth=32}),convert-func-to-llvm{index-bitwidth=32},reconcile-unrealized-casts" %s -split-input-file | FileCheck --check-prefix=CHECK32 %s +// RUN: mlir-opt -pass-pipeline="convert-math-to-llvm,func.func(convert-arith-to-llvm),convert-func-to-llvm,reconcile-unrealized-casts" %s -split-input-file | FileCheck %s +// RUN: mlir-opt -pass-pipeline="convert-math-to-llvm,func.func(convert-arith-to-llvm{index-bitwidth=32}),convert-func-to-llvm{index-bitwidth=32},reconcile-unrealized-casts" %s -split-input-file | FileCheck --check-prefix=CHECK32 %s // CHECK-LABEL: func @empty() { // CHECK-NEXT: llvm.return Index: mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir =================================================================== --- mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir +++ mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -split-input-file -pass-pipeline="func.func(convert-math-to-llvm)" | FileCheck %s +// RUN: mlir-opt %s -split-input-file -pass-pipeline="convert-math-to-llvm" | FileCheck %s // CHECK-LABEL: @ops func.func @ops(%arg0: f32, %arg1: f32, %arg2: i32, %arg3: i32, %arg4: f64) { @@ -181,3 +181,169 @@ %0 = math.round %arg0 : f32 func.return } + +// ----- + +// CHECK-LABEL: func @ipowi( +// CHECK-SAME: %[[ARG0:.+]]: i64, +// CHECK-SAME: %[[ARG1:.+]]: i64) +func.func @ipowi(%arg0: i64, %arg1: i64) { + // CHECK: llvm.call @__mlir_math_ipowi_i64(%[[ARG0]], %[[ARG1]]) : (i64, i64) -> i64 + %0 = math.ipowi %arg0, %arg1 : i64 + func.return +} + +// CHECK-LABEL: llvm.func linkonce_odr @__mlir_math_ipowi_i64( +// CHECK-SAME: %[[VAL_0:.*]]: i64, +// CHECK-SAME: %[[VAL_1:.*]]: i64) -> i64 attributes {sym_visibility = "private"} { +// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK: %[[VAL_3:.*]] = llvm.mlir.constant(1 : i64) : i64 +// CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(-1 : i64) : i64 +// CHECK: %[[VAL_5:.*]] = llvm.icmp "eq" %[[VAL_1]], %[[VAL_2]] : i64 +// CHECK: llvm.cond_br %[[VAL_5]], ^bb1, ^bb2 +// CHECK: ^bb1: +// CHECK: llvm.return %[[VAL_3]] : i64 +// CHECK: ^bb2: +// CHECK: %[[VAL_6:.*]] = llvm.icmp "sle" %[[VAL_1]], %[[VAL_2]] : i64 +// CHECK: llvm.cond_br %[[VAL_6]], ^bb3, ^bb12(%[[VAL_3]], %[[VAL_0]], %[[VAL_1]] : i64, i64, i64) +// CHECK: ^bb3: +// CHECK: %[[VAL_7:.*]] = llvm.icmp "eq" %[[VAL_0]], %[[VAL_2]] : i64 +// CHECK: llvm.cond_br %[[VAL_7]], ^bb4, ^bb5 +// CHECK: ^bb4: +// CHECK: %[[VAL_8:.*]] = llvm.sdiv %[[VAL_3]], %[[VAL_2]] : i64 +// CHECK: llvm.return %[[VAL_8]] : i64 +// CHECK: ^bb5: +// CHECK: %[[VAL_9:.*]] = llvm.icmp "eq" %[[VAL_0]], %[[VAL_3]] : i64 +// CHECK: llvm.cond_br %[[VAL_9]], ^bb6, ^bb7 +// CHECK: ^bb6: +// CHECK: llvm.return %[[VAL_3]] : i64 +// CHECK: ^bb7: +// CHECK: %[[VAL_10:.*]] = llvm.icmp "eq" %[[VAL_0]], %[[VAL_4]] : i64 +// CHECK: llvm.cond_br %[[VAL_10]], ^bb8, ^bb11 +// CHECK: ^bb8: +// CHECK: %[[VAL_11:.*]] = llvm.and %[[VAL_1]], %[[VAL_3]] : i64 +// CHECK: %[[VAL_12:.*]] = llvm.icmp "ne" %[[VAL_11]], %[[VAL_2]] : i64 +// CHECK: llvm.cond_br %[[VAL_12]], ^bb9, ^bb10 +// CHECK: ^bb9: +// CHECK: llvm.return %[[VAL_4]] : i64 +// CHECK: ^bb10: +// CHECK: llvm.return %[[VAL_3]] : i64 +// CHECK: ^bb11: +// CHECK: llvm.return %[[VAL_2]] : i64 +// CHECK: ^bb12(%[[VAL_13:.*]]: i64, %[[VAL_14:.*]]: i64, %[[VAL_15:.*]]: i64): +// CHECK: %[[VAL_16:.*]] = llvm.and %[[VAL_15]], %[[VAL_3]] : i64 +// CHECK: %[[VAL_17:.*]] = llvm.icmp "ne" %[[VAL_16]], %[[VAL_2]] : i64 +// CHECK: llvm.cond_br %[[VAL_17]], ^bb13, ^bb14(%[[VAL_13]] : i64) +// CHECK: ^bb13: +// CHECK: %[[VAL_18:.*]] = llvm.mul %[[VAL_13]], %[[VAL_14]] : i64 +// CHECK: llvm.br ^bb14(%[[VAL_18]] : i64) +// CHECK: ^bb14(%[[VAL_19:.*]]: i64): +// CHECK: %[[VAL_20:.*]] = llvm.ashr %[[VAL_15]], %[[VAL_3]] : i64 +// CHECK: %[[VAL_21:.*]] = llvm.icmp "eq" %[[VAL_20]], %[[VAL_2]] : i64 +// CHECK: llvm.cond_br %[[VAL_21]], ^bb15, ^bb16 +// CHECK: ^bb15: +// CHECK: llvm.return %[[VAL_19]] : i64 +// CHECK: ^bb16: +// CHECK: %[[VAL_22:.*]] = llvm.mul %[[VAL_14]], %[[VAL_14]] : i64 +// CHECK: llvm.br ^bb12(%[[VAL_19]], %[[VAL_22]], %[[VAL_20]] : i64, i64, i64) +// CHECK: } + +// ----- + +// CHECK-LABEL: func @ipowi( +// CHECK-SAME: %[[ARG0:.+]]: i8, +// CHECK-SAME: %[[ARG1:.+]]: i8) + // CHECK: llvm.call @__mlir_math_ipowi_i8(%[[ARG0]], %[[ARG1]]) : (i8, i8) -> i8 +func.func @ipowi(%arg0: i8, %arg1: i8) { + %0 = math.ipowi %arg0, %arg1 : i8 + func.return +} + +// CHECK-LABEL: llvm.func linkonce_odr @__mlir_math_ipowi_i8( +// CHECK-SAME: %[[VAL_0:.*]]: i8, +// CHECK-SAME: %[[VAL_1:.*]]: i8) -> i8 attributes {sym_visibility = "private"} { +// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i8) : i8 +// CHECK: %[[VAL_3:.*]] = llvm.mlir.constant(1 : i8) : i8 +// CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(-1 : i8) : i8 +// CHECK: %[[VAL_5:.*]] = llvm.icmp "eq" %[[VAL_1]], %[[VAL_2]] : i8 +// CHECK: llvm.cond_br %[[VAL_5]], ^bb1, ^bb2 +// CHECK: ^bb1: +// CHECK: llvm.return %[[VAL_3]] : i8 +// CHECK: ^bb2: +// CHECK: %[[VAL_6:.*]] = llvm.icmp "sle" %[[VAL_1]], %[[VAL_2]] : i8 +// CHECK: llvm.cond_br %[[VAL_6]], ^bb3, ^bb12(%[[VAL_3]], %[[VAL_0]], %[[VAL_1]] : i8, i8, i8) +// CHECK: ^bb3: +// CHECK: %[[VAL_7:.*]] = llvm.icmp "eq" %[[VAL_0]], %[[VAL_2]] : i8 +// CHECK: llvm.cond_br %[[VAL_7]], ^bb4, ^bb5 +// CHECK: ^bb4: +// CHECK: %[[VAL_8:.*]] = llvm.sdiv %[[VAL_3]], %[[VAL_2]] : i8 +// CHECK: llvm.return %[[VAL_8]] : i8 +// CHECK: ^bb5: +// CHECK: %[[VAL_9:.*]] = llvm.icmp "eq" %[[VAL_0]], %[[VAL_3]] : i8 +// CHECK: llvm.cond_br %[[VAL_9]], ^bb6, ^bb7 +// CHECK: ^bb6: +// CHECK: llvm.return %[[VAL_3]] : i8 +// CHECK: ^bb7: +// CHECK: %[[VAL_10:.*]] = llvm.icmp "eq" %[[VAL_0]], %[[VAL_4]] : i8 +// CHECK: llvm.cond_br %[[VAL_10]], ^bb8, ^bb11 +// CHECK: ^bb8: +// CHECK: %[[VAL_11:.*]] = llvm.and %[[VAL_1]], %[[VAL_3]] : i8 +// CHECK: %[[VAL_12:.*]] = llvm.icmp "ne" %[[VAL_11]], %[[VAL_2]] : i8 +// CHECK: llvm.cond_br %[[VAL_12]], ^bb9, ^bb10 +// CHECK: ^bb9: +// CHECK: llvm.return %[[VAL_4]] : i8 +// CHECK: ^bb10: +// CHECK: llvm.return %[[VAL_3]] : i8 +// CHECK: ^bb11: +// CHECK: llvm.return %[[VAL_2]] : i8 +// CHECK: ^bb12(%[[VAL_13:.*]]: i8, %[[VAL_14:.*]]: i8, %[[VAL_15:.*]]: i8): +// CHECK: %[[VAL_16:.*]] = llvm.and %[[VAL_15]], %[[VAL_3]] : i8 +// CHECK: %[[VAL_17:.*]] = llvm.icmp "ne" %[[VAL_16]], %[[VAL_2]] : i8 +// CHECK: llvm.cond_br %[[VAL_17]], ^bb13, ^bb14(%[[VAL_13]] : i8) +// CHECK: ^bb13: +// CHECK: %[[VAL_18:.*]] = llvm.mul %[[VAL_13]], %[[VAL_14]] : i8 +// CHECK: llvm.br ^bb14(%[[VAL_18]] : i8) +// CHECK: ^bb14(%[[VAL_19:.*]]: i8): +// CHECK: %[[VAL_20:.*]] = llvm.ashr %[[VAL_15]], %[[VAL_3]] : i8 +// CHECK: %[[VAL_21:.*]] = llvm.icmp "eq" %[[VAL_20]], %[[VAL_2]] : i8 +// CHECK: llvm.cond_br %[[VAL_21]], ^bb15, ^bb16 +// CHECK: ^bb15: +// CHECK: llvm.return %[[VAL_19]] : i8 +// CHECK: ^bb16: +// CHECK: %[[VAL_22:.*]] = llvm.mul %[[VAL_14]], %[[VAL_14]] : i8 +// CHECK: llvm.br ^bb12(%[[VAL_19]], %[[VAL_22]], %[[VAL_20]] : i8, i8, i8) +// CHECK: } + +// ----- + +// CHECK-LABEL: func.func @ipowi_vec( +// CHECK-SAME: %[[VAL_0:.*]]: vector<2x3xi64>, +// CHECK-SAME: %[[VAL_1:.*]]: vector<2x3xi64>) { +func.func @ipowi_vec(%arg0: vector<2x3xi64>, %arg1: vector<2x3xi64>) { +// CHECK: %[[VAL_4:.*]] = llvm.mlir.undef : !llvm.array<2 x vector<3xi64>> +// CHECK: %[[VAL_5:.*]] = llvm.extractvalue %[[VAL_2:.*]][0] : !llvm.array<2 x vector<3xi64>> +// CHECK: %[[VAL_6:.*]] = llvm.extractvalue %[[VAL_3:.*]][0] : !llvm.array<2 x vector<3xi64>> +// CHECK: %[[VAL_7:.*]] = llvm.mlir.undef : vector<3xi64> +// CHECK: %[[VAL_8:.*]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK: %[[VAL_9:.*]] = llvm.extractelement %[[VAL_5]]{{\[}}%[[VAL_8]] : i64] : vector<3xi64> +// CHECK: %[[VAL_10:.*]] = llvm.extractelement %[[VAL_6]]{{\[}}%[[VAL_8]] : i64] : vector<3xi64> +// CHECK: %[[VAL_11:.*]] = llvm.call @__mlir_math_ipowi_i64(%[[VAL_9]], %[[VAL_10]]) : (i64, i64) -> i64 +// CHECK: %[[VAL_12:.*]] = llvm.insertelement %[[VAL_11]], %[[VAL_7]]{{\[}}%[[VAL_8]] : i64] : vector<3xi64> +// CHECK: %[[VAL_13:.*]] = llvm.mlir.constant(1 : i64) : i64 +// CHECK: %[[VAL_14:.*]] = llvm.extractelement %[[VAL_5]]{{\[}}%[[VAL_13]] : i64] : vector<3xi64> +// CHECK: %[[VAL_15:.*]] = llvm.extractelement %[[VAL_6]]{{\[}}%[[VAL_13]] : i64] : vector<3xi64> +// CHECK: %[[VAL_16:.*]] = llvm.call @__mlir_math_ipowi_i64(%[[VAL_14]], %[[VAL_15]]) : (i64, i64) -> i64 +// CHECK: %[[VAL_17:.*]] = llvm.insertelement %[[VAL_16]], %[[VAL_12]]{{\[}}%[[VAL_13]] : i64] : vector<3xi64> +// CHECK: %[[VAL_18:.*]] = llvm.mlir.constant(2 : i64) : i64 +// CHECK: %[[VAL_19:.*]] = llvm.extractelement %[[VAL_5]]{{\[}}%[[VAL_18]] : i64] : vector<3xi64> +// CHECK: %[[VAL_20:.*]] = llvm.extractelement %[[VAL_6]]{{\[}}%[[VAL_18]] : i64] : vector<3xi64> +// CHECK: %[[VAL_21:.*]] = llvm.call @__mlir_math_ipowi_i64(%[[VAL_19]], %[[VAL_20]]) : (i64, i64) -> i64 +// CHECK: %[[VAL_22:.*]] = llvm.insertelement %[[VAL_21]], %[[VAL_17]]{{\[}}%[[VAL_18]] : i64] : vector<3xi64> +// CHECK: %[[VAL_23:.*]] = llvm.insertvalue %[[VAL_22]], %[[VAL_4]][0] : !llvm.array<2 x vector<3xi64>> + +// CHECK: llvm.call @__mlir_math_ipowi_i64(%{{.*}}, %{{.*}}) : (i64, i64) -> i64 +// CHECK: llvm.call @__mlir_math_ipowi_i64(%{{.*}}, %{{.*}}) : (i64, i64) -> i64 +// CHECK: llvm.call @__mlir_math_ipowi_i64(%{{.*}}, %{{.*}}) : (i64, i64) -> i64 + %0 = math.ipowi %arg0, %arg1 : vector<2x3xi64> + func.return +} Index: mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir =================================================================== --- mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir +++ mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -pass-pipeline="func.func(test-math-polynomial-approximation,convert-arith-to-llvm),convert-vector-to-llvm,func.func(convert-math-to-llvm),convert-func-to-llvm,reconcile-unrealized-casts" \ +// RUN: mlir-opt %s -pass-pipeline="func.func(test-math-polynomial-approximation,convert-arith-to-llvm),convert-vector-to-llvm,convert-math-to-llvm,convert-func-to-llvm,reconcile-unrealized-casts" \ // RUN: | mlir-cpu-runner \ // RUN: -e main -entry-point-result=void -O0 \ // RUN: -shared-libs=%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext \