Changeset View
Standalone View
mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
- This file was added.
//===- MathToFuncs.cpp - Math to outlined implementation conversion -------===// | |||||||||||||||||||||||||||
// | |||||||||||||||||||||||||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | |||||||||||||||||||||||||||
// See https://llvm.org/LICENSE.txt for license information. | |||||||||||||||||||||||||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | |||||||||||||||||||||||||||
// | |||||||||||||||||||||||||||
//===----------------------------------------------------------------------===// | |||||||||||||||||||||||||||
#include "mlir/Conversion/MathToFuncs/MathToFuncs.h" | |||||||||||||||||||||||||||
#include "../PassDetail.h" | |||||||||||||||||||||||||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" | |||||||||||||||||||||||||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" | |||||||||||||||||||||||||||
#include "mlir/Dialect/Func/IR/FuncOps.h" | |||||||||||||||||||||||||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" | |||||||||||||||||||||||||||
#include "mlir/Dialect/Math/IR/Math.h" | |||||||||||||||||||||||||||
#include "mlir/Dialect/Utils/IndexingUtils.h" | |||||||||||||||||||||||||||
#include "mlir/Dialect/Vector/IR/VectorOps.h" | |||||||||||||||||||||||||||
#include "mlir/Dialect/Vector/Utils/VectorUtils.h" | |||||||||||||||||||||||||||
#include "mlir/IR/BuiltinDialect.h" | |||||||||||||||||||||||||||
#include "mlir/Transforms/DialectConversion.h" | |||||||||||||||||||||||||||
using namespace mlir; | |||||||||||||||||||||||||||
namespace { | |||||||||||||||||||||||||||
struct ConvertMathToFuncsPass | |||||||||||||||||||||||||||
: public ConvertMathToFuncsBase<ConvertMathToFuncsPass> { | |||||||||||||||||||||||||||
Mogball: Can you trim these includes? I don't think all of these are needed | |||||||||||||||||||||||||||
ConvertMathToFuncsPass() = default; | |||||||||||||||||||||||||||
void runOnOperation() override; | |||||||||||||||||||||||||||
}; | |||||||||||||||||||||||||||
// Pattern to convert vector operations to scalar operations. | |||||||||||||||||||||||||||
template <typename Op> | |||||||||||||||||||||||||||
struct VecOpToScalarOp : public OpRewritePattern<Op> { | |||||||||||||||||||||||||||
public: | |||||||||||||||||||||||||||
using OpRewritePattern<Op>::OpRewritePattern; | |||||||||||||||||||||||||||
LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; | |||||||||||||||||||||||||||
This isn't necessary. FunctionType can be hashed directly in a DenseMap. Mogball: This isn't necessary. `FunctionType` can be hashed directly in a `DenseMap`. | |||||||||||||||||||||||||||
}; | |||||||||||||||||||||||||||
We generally don't use std::function Mogball: We generally don't use `std::function` | |||||||||||||||||||||||||||
struct IPowIOpLowering : public OpRewritePattern<math::IPowIOp> { | |||||||||||||||||||||||||||
private: | |||||||||||||||||||||||||||
func::FuncOp getElementFunc(math::IPowIOp op, IntegerType elementType, | |||||||||||||||||||||||||||
PatternRewriter &rewriter) const; | |||||||||||||||||||||||||||
public: | |||||||||||||||||||||||||||
IPowIOpLowering(MLIRContext *context, PatternBenefit benefit) | |||||||||||||||||||||||||||
: OpRewritePattern<math::IPowIOp>(context, benefit) {} | |||||||||||||||||||||||||||
/// Convert IPowI into a call to a local function implementing | |||||||||||||||||||||||||||
/// the power operation. The local function computes a scalar result, | |||||||||||||||||||||||||||
/// so vector forms of IPowI are linearized. | |||||||||||||||||||||||||||
LogicalResult matchAndRewrite(math::IPowIOp op, | |||||||||||||||||||||||||||
PatternRewriter &rewriter) const final; | |||||||||||||||||||||||||||
}; | |||||||||||||||||||||||||||
} // namespace | |||||||||||||||||||||||||||
template <typename Op> | |||||||||||||||||||||||||||
LogicalResult | |||||||||||||||||||||||||||
VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const { | |||||||||||||||||||||||||||
auto opType = op.getType(); | |||||||||||||||||||||||||||
Not Done ReplyInline Actions
Mogball: | |||||||||||||||||||||||||||
auto loc = op.getLoc(); | |||||||||||||||||||||||||||
MogballUnsubmitted
The type should only be omitted if it's explicit on the right hand side on an assignment due to a cast, or if it's "too long" (e.g. vector<int32_t>::iterator) Mogball: The type should only be omitted if it's explicit on the right hand side on an assignment due to… | |||||||||||||||||||||||||||
vzakhariAuthorUnsubmitted I will fix auto declarations. I will keep auto for results of rewriter.create<T> - it looks like it is a common practice to write it with auto. vzakhari: I will fix `auto` declarations. I will keep `auto` for results of `rewriter.create<T>` - it… | |||||||||||||||||||||||||||
auto vecType = opType.template dyn_cast<VectorType>(); | |||||||||||||||||||||||||||
if (!vecType) | |||||||||||||||||||||||||||
I think the "function type stringify" is too generic. For FPowI, just the floating point and integer type will suffice. This doesn't need to be done generally. I think you can just remove this function. Mogball: I think the "function type stringify" is too generic. For `FPowI`, just the floating point and… | |||||||||||||||||||||||||||
Okay, I will remove it. vzakhari: Okay, I will remove it. | |||||||||||||||||||||||||||
return failure(); | |||||||||||||||||||||||||||
if (!vecType.hasRank()) | |||||||||||||||||||||||||||
return failure(); | |||||||||||||||||||||||||||
MogballUnsubmitted Can you add a failure reason with rewriter.notifyMatchFailure? It can notoriously difficult to debug pattern matching without those messages. Mogball: Can you add a failure reason with `rewriter.notifyMatchFailure`? It can notoriously difficult… | |||||||||||||||||||||||||||
vzakhariAuthorUnsubmitted Sure! Thanks for the suggestion! vzakhari: Sure! Thanks for the suggestion! | |||||||||||||||||||||||||||
auto shape = vecType.getShape(); | |||||||||||||||||||||||||||
MogballUnsubmitted Spell out the auto. Mogball: Spell out the `auto`. | |||||||||||||||||||||||||||
int64_t numElements = vecType.getNumElements(); | |||||||||||||||||||||||||||
MogballUnsubmitted Should this be unsigned? Mogball: Should this be `unsigned`? | |||||||||||||||||||||||||||
vzakhariAuthorUnsubmitted int64_t should be correct, since ShapedType::getNumElements is defined as returning int64_t in IR/BuiltinAttributes.h. vzakhari: `int64_t` should be correct, since `ShapedType::getNumElements` is defined as returning… | |||||||||||||||||||||||||||
Value result = rewriter.create<arith::ConstantOp>( | |||||||||||||||||||||||||||
loc, DenseElementsAttr::get( | |||||||||||||||||||||||||||
vecType, IntegerAttr::get(vecType.getElementType(), 0))); | |||||||||||||||||||||||||||
SmallVector<int64_t> ones(shape.size(), 1); | |||||||||||||||||||||||||||
SmallVector<int64_t> strides = computeStrides(shape, ones); | |||||||||||||||||||||||||||
for (auto linearIndex = 0; linearIndex < numElements; ++linearIndex) { | |||||||||||||||||||||||||||
MogballUnsubmitted Please don't use auto for a loop index. Mogball: Please don't use `auto` for a loop index. | |||||||||||||||||||||||||||
SmallVector<int64_t> positions = delinearize(strides, linearIndex); | |||||||||||||||||||||||||||
SmallVector<Value> operands; | |||||||||||||||||||||||||||
for (auto input : op->getOperands()) | |||||||||||||||||||||||||||
operands.push_back( | |||||||||||||||||||||||||||
rewriter.create<vector::ExtractOp>(loc, input, positions)); | |||||||||||||||||||||||||||
Value scalarOp = | |||||||||||||||||||||||||||
rewriter.create<Op>(loc, vecType.getElementType(), operands); | |||||||||||||||||||||||||||
result = | |||||||||||||||||||||||||||
rewriter.create<vector::InsertOp>(loc, scalarOp, result, positions); | |||||||||||||||||||||||||||
} | |||||||||||||||||||||||||||
rewriter.replaceOp(op, {result}); | |||||||||||||||||||||||||||
MogballUnsubmitted
Mogball: | |||||||||||||||||||||||||||
return success(); | |||||||||||||||||||||||||||
} | |||||||||||||||||||||||||||
/// Create linkonce_odr function to implement the power function with | |||||||||||||||||||||||||||
/// the given \p elementType scalar type. | |||||||||||||||||||||||||||
/// | |||||||||||||||||||||||||||
/// template <typename T> | |||||||||||||||||||||||||||
/// T __mlir_math_ipowi_*(T b, T p) { | |||||||||||||||||||||||||||
/// if (p == T(0)) | |||||||||||||||||||||||||||
/// return T(1); | |||||||||||||||||||||||||||
/// if (p < T(0)) { | |||||||||||||||||||||||||||
/// if (b == T(0)) | |||||||||||||||||||||||||||
/// return T(1) / T(0); // trigger div-by-zero | |||||||||||||||||||||||||||
/// if (b == T(1)) | |||||||||||||||||||||||||||
/// return T(1); | |||||||||||||||||||||||||||
/// if (b == T(-1)) { | |||||||||||||||||||||||||||
/// if (p & T(1)) | |||||||||||||||||||||||||||
/// return T(-1); | |||||||||||||||||||||||||||
/// return T(1); | |||||||||||||||||||||||||||
/// } | |||||||||||||||||||||||||||
/// return T(0); | |||||||||||||||||||||||||||
/// } | |||||||||||||||||||||||||||
/// T result = T(1); | |||||||||||||||||||||||||||
/// while (true) { | |||||||||||||||||||||||||||
/// if (p & T(1)) | |||||||||||||||||||||||||||
/// result *= b; | |||||||||||||||||||||||||||
/// p >>= T(1); | |||||||||||||||||||||||||||
/// if (p == T(0)) | |||||||||||||||||||||||||||
/// return result; | |||||||||||||||||||||||||||
/// b *= b; | |||||||||||||||||||||||||||
/// } | |||||||||||||||||||||||||||
/// } | |||||||||||||||||||||||||||
func::FuncOp IPowIOpLowering::getElementFunc(math::IPowIOp op, | |||||||||||||||||||||||||||
IntegerType elementType, | |||||||||||||||||||||||||||
PatternRewriter &rewriter) const { | |||||||||||||||||||||||||||
std::string funcName("__mlir_math_ipowi_"); | |||||||||||||||||||||||||||
It's generally not valid for patterns to access and modify parent operations, especially since these "patterns" are exposed publicly. The generation of these software implementations should not be done inside patterns. Patterns should be local rewrites of operations. Mogball: It's generally not valid for patterns to access and modify parent operations, especially since… | |||||||||||||||||||||||||||
I understand the concern. I followed the same approach that MathToLibm and ComplexToLibm use. I guess I can add a comment in MathToFuncs.h to warn about valid usage of populateMathToFuncsConversionPatterns - does it make sense? vzakhari: I understand the concern. I followed the same approach that `MathToLibm` and `ComplexToLibm`… | |||||||||||||||||||||||||||
No. Those other patterns are also invalid. Patterns should not be modifying parent operations. It's something that "just happens to work" because no one has needed to compose these patterns sets in a parallelized pass (which would result in race conditions). This is especially true in conversion patterns, because there is extra bookkeeping in the dialect conversion pass that is violated if this happens. Please don't propagate what has been done in the older passes and separate these. Mogball: No. Those other patterns are also invalid. Patterns should not be modifying parent operations. | |||||||||||||||||||||||||||
Thank you for the explanation! Then it looks like exposing the patterns is unsafe, in general, and I should rewrite the pass as you suggested without externalizing the match patterns (populateMathToFuncsConversionPatterns). The module pass will first visit all IPowI operations for collecting information about the kinds of implementation functions that need to be created in the module, then it will create the new functions, and then it will apply populateMathToFuncsConversionPatterns patterns to rewrite the operations into the calls. vzakhari: Thank you for the explanation!
Then it looks like exposing the patterns is unsafe, in general… | |||||||||||||||||||||||||||
llvm::raw_string_ostream nameOS(funcName); | |||||||||||||||||||||||||||
elementType.print(nameOS); | |||||||||||||||||||||||||||
auto module = SymbolTable::getNearestSymbolTable(op); | |||||||||||||||||||||||||||
MogballUnsubmitted Please spell out all the autos as appropriate. Mogball: Please spell out all the `auto`s as appropriate. | |||||||||||||||||||||||||||
auto funcType = FunctionType::get(rewriter.getContext(), | |||||||||||||||||||||||||||
{elementType, elementType}, {elementType}); | |||||||||||||||||||||||||||
if (auto funcOp = dyn_cast_or_null<func::FuncOp>( | |||||||||||||||||||||||||||
SymbolTable::lookupSymbolIn(module, funcName))) { | |||||||||||||||||||||||||||
assert(funcOp.getFunctionTypeAttr().getValue() == funcType && | |||||||||||||||||||||||||||
"ipowi function type mismatch"); | |||||||||||||||||||||||||||
return funcOp; | |||||||||||||||||||||||||||
} | |||||||||||||||||||||||||||
OpBuilder::InsertionGuard guard(rewriter); | |||||||||||||||||||||||||||
MogballUnsubmitted Would generating this function be made any easier by using scf instead of cf? Mogball: Would generating this function be made any easier by using `scf` instead of `cf`? | |||||||||||||||||||||||||||
vzakhariAuthorUnsubmitted It seems to me unstructured control flow fits better for the code with early returns from the function. vzakhari: It seems to me unstructured control flow fits better for the code with early returns from the… | |||||||||||||||||||||||||||
auto loc = rewriter.getUnknownLoc(); | |||||||||||||||||||||||||||
rewriter.setInsertionPointToEnd(&module->getRegion(0).front()); | |||||||||||||||||||||||||||
auto funcOp = rewriter.create<func::FuncOp>(loc, funcName, funcType); | |||||||||||||||||||||||||||
auto inlineLinkage = LLVM::linkage::Linkage::LinkonceODR; | |||||||||||||||||||||||||||
auto linkage = LLVM::LinkageAttr::get(rewriter.getContext(), inlineLinkage); | |||||||||||||||||||||||||||
funcOp->setAttr("llvm.linkage", linkage); | |||||||||||||||||||||||||||
funcOp.setPrivate(); | |||||||||||||||||||||||||||
auto *entryBlock = funcOp.addEntryBlock(); | |||||||||||||||||||||||||||
auto *funcBody = entryBlock->getParent(); | |||||||||||||||||||||||||||
auto bArg = funcOp.getArgument(0); | |||||||||||||||||||||||||||
auto pArg = funcOp.getArgument(1); | |||||||||||||||||||||||||||
rewriter.setInsertionPointToEnd(entryBlock); | |||||||||||||||||||||||||||
Mogball: | |||||||||||||||||||||||||||
Will do. flush should not be necessary according to this: https://llvm.org/doxygen/classllvm_1_1raw__string__ostream.html#details vzakhari: Will do. `flush` should not be necessary according to this: https://llvm. | |||||||||||||||||||||||||||
Value zeroValue = rewriter.create<arith::ConstantOp>( | |||||||||||||||||||||||||||
I would prefer this to be a standalone function, not a static function on the pattern. Mogball: I would prefer this to be a standalone function, not a static function on the pattern. | |||||||||||||||||||||||||||
loc, elementType, rewriter.getIntegerAttr(elementType, 0)); | |||||||||||||||||||||||||||
Value oneValue = rewriter.create<arith::ConstantOp>( | |||||||||||||||||||||||||||
loc, elementType, rewriter.getIntegerAttr(elementType, 1)); | |||||||||||||||||||||||||||
Value minusOneValue = rewriter.create<arith::ConstantOp>( | |||||||||||||||||||||||||||
loc, elementType, | |||||||||||||||||||||||||||
rewriter.getIntegerAttr(elementType, | |||||||||||||||||||||||||||
APInt(elementType.getIntOrFloatBitWidth(), -1ULL, | |||||||||||||||||||||||||||
/*isSigned=*/true))); | |||||||||||||||||||||||||||
// if (p == T(0)) | |||||||||||||||||||||||||||
// return T(1); | |||||||||||||||||||||||||||
auto pIsZero = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, | |||||||||||||||||||||||||||
pArg, zeroValue); | |||||||||||||||||||||||||||
auto *thenBlock = rewriter.createBlock(funcBody); | |||||||||||||||||||||||||||
rewriter.create<func::ReturnOp>(loc, oneValue); | |||||||||||||||||||||||||||
auto *fallthroughBlock = rewriter.createBlock(funcBody); | |||||||||||||||||||||||||||
// Set up conditional branch for (p == T(0)). | |||||||||||||||||||||||||||
rewriter.setInsertionPointToEnd(pIsZero->getBlock()); | |||||||||||||||||||||||||||
rewriter.create<cf::CondBranchOp>(loc, pIsZero, thenBlock, fallthroughBlock); | |||||||||||||||||||||||||||
// if (p < T(0)) { | |||||||||||||||||||||||||||
rewriter.setInsertionPointToEnd(fallthroughBlock); | |||||||||||||||||||||||||||
auto pIsNeg = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sle, | |||||||||||||||||||||||||||
pArg, zeroValue); | |||||||||||||||||||||||||||
// if (b == T(0)) | |||||||||||||||||||||||||||
rewriter.createBlock(funcBody); | |||||||||||||||||||||||||||
auto bIsZero = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, | |||||||||||||||||||||||||||
bArg, zeroValue); | |||||||||||||||||||||||||||
// return T(1) / T(0); | |||||||||||||||||||||||||||
thenBlock = rewriter.createBlock(funcBody); | |||||||||||||||||||||||||||
rewriter.create<func::ReturnOp>( | |||||||||||||||||||||||||||
loc, | |||||||||||||||||||||||||||
rewriter.create<arith::DivSIOp>(loc, oneValue, zeroValue).getResult()); | |||||||||||||||||||||||||||
fallthroughBlock = rewriter.createBlock(funcBody); | |||||||||||||||||||||||||||
// Set up conditional branch for (b == T(0)). | |||||||||||||||||||||||||||
rewriter.setInsertionPointToEnd(bIsZero->getBlock()); | |||||||||||||||||||||||||||
rewriter.create<cf::CondBranchOp>(loc, bIsZero, thenBlock, fallthroughBlock); | |||||||||||||||||||||||||||
// if (b == T(1)) | |||||||||||||||||||||||||||
rewriter.setInsertionPointToEnd(fallthroughBlock); | |||||||||||||||||||||||||||
auto bIsOne = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, | |||||||||||||||||||||||||||
bArg, oneValue); | |||||||||||||||||||||||||||
// return T(1); | |||||||||||||||||||||||||||
thenBlock = rewriter.createBlock(funcBody); | |||||||||||||||||||||||||||
rewriter.create<func::ReturnOp>(loc, oneValue); | |||||||||||||||||||||||||||
fallthroughBlock = rewriter.createBlock(funcBody); | |||||||||||||||||||||||||||
// Set up conditional branch for (b == T(1)). | |||||||||||||||||||||||||||
rewriter.setInsertionPointToEnd(bIsOne->getBlock()); | |||||||||||||||||||||||||||
rewriter.create<cf::CondBranchOp>(loc, bIsOne, thenBlock, fallthroughBlock); | |||||||||||||||||||||||||||
// if (b == T(-1)) { | |||||||||||||||||||||||||||
rewriter.setInsertionPointToEnd(fallthroughBlock); | |||||||||||||||||||||||||||
auto bIsMinusOne = rewriter.create<arith::CmpIOp>( | |||||||||||||||||||||||||||
loc, arith::CmpIPredicate::eq, bArg, minusOneValue); | |||||||||||||||||||||||||||
// if (p & T(1)) | |||||||||||||||||||||||||||
rewriter.createBlock(funcBody); | |||||||||||||||||||||||||||
auto pIsOdd = rewriter.create<arith::CmpIOp>( | |||||||||||||||||||||||||||
loc, arith::CmpIPredicate::ne, | |||||||||||||||||||||||||||
rewriter.create<arith::AndIOp>(loc, pArg, oneValue), zeroValue); | |||||||||||||||||||||||||||
// return T(-1); | |||||||||||||||||||||||||||
thenBlock = rewriter.createBlock(funcBody); | |||||||||||||||||||||||||||
rewriter.create<func::ReturnOp>(loc, minusOneValue); | |||||||||||||||||||||||||||
fallthroughBlock = rewriter.createBlock(funcBody); | |||||||||||||||||||||||||||
// Set up conditional branch for (p & T(1)). | |||||||||||||||||||||||||||
rewriter.setInsertionPointToEnd(pIsOdd->getBlock()); | |||||||||||||||||||||||||||
rewriter.create<cf::CondBranchOp>(loc, pIsOdd, thenBlock, fallthroughBlock); | |||||||||||||||||||||||||||
// return T(1); | |||||||||||||||||||||||||||
// } // b == T(-1) | |||||||||||||||||||||||||||
rewriter.setInsertionPointToEnd(fallthroughBlock); | |||||||||||||||||||||||||||
rewriter.create<func::ReturnOp>(loc, oneValue); | |||||||||||||||||||||||||||
fallthroughBlock = rewriter.createBlock(funcBody); | |||||||||||||||||||||||||||
// Set up conditional branch for (b == T(-1)). | |||||||||||||||||||||||||||
rewriter.setInsertionPointToEnd(bIsMinusOne->getBlock()); | |||||||||||||||||||||||||||
rewriter.create<cf::CondBranchOp>(loc, bIsMinusOne, pIsOdd->getBlock(), | |||||||||||||||||||||||||||
fallthroughBlock); | |||||||||||||||||||||||||||
// return T(0); | |||||||||||||||||||||||||||
// } // (p < T(0)) | |||||||||||||||||||||||||||
rewriter.setInsertionPointToEnd(fallthroughBlock); | |||||||||||||||||||||||||||
rewriter.create<func::ReturnOp>(loc, zeroValue); | |||||||||||||||||||||||||||
auto *loopHeader = rewriter.createBlock( | |||||||||||||||||||||||||||
funcBody, funcBody->end(), {elementType, elementType, elementType}, | |||||||||||||||||||||||||||
{loc, loc, loc}); | |||||||||||||||||||||||||||
// Set up conditional branch for (p < T(0)). | |||||||||||||||||||||||||||
rewriter.setInsertionPointToEnd(pIsNeg->getBlock()); | |||||||||||||||||||||||||||
// Set initial values of 'result', 'b' and 'p' for the loop. | |||||||||||||||||||||||||||
rewriter.create<cf::CondBranchOp>(loc, pIsNeg, bIsZero->getBlock(), | |||||||||||||||||||||||||||
loopHeader, | |||||||||||||||||||||||||||
ValueRange{oneValue, bArg, pArg}); | |||||||||||||||||||||||||||
// T result = T(1); | |||||||||||||||||||||||||||
// while (true) { | |||||||||||||||||||||||||||
// if (p & T(1)) | |||||||||||||||||||||||||||
// result *= b; | |||||||||||||||||||||||||||
// p >>= T(1); | |||||||||||||||||||||||||||
// if (p == T(0)) | |||||||||||||||||||||||||||
// return result; | |||||||||||||||||||||||||||
// b *= b; | |||||||||||||||||||||||||||
// } | |||||||||||||||||||||||||||
Value resultTmp = loopHeader->getArgument(0); | |||||||||||||||||||||||||||
Value baseTmp = loopHeader->getArgument(1); | |||||||||||||||||||||||||||
Value powerTmp = loopHeader->getArgument(2); | |||||||||||||||||||||||||||
rewriter.setInsertionPointToEnd(loopHeader); | |||||||||||||||||||||||||||
// if (p & T(1)) | |||||||||||||||||||||||||||
auto powerTmpIsOdd = rewriter.create<arith::CmpIOp>( | |||||||||||||||||||||||||||
loc, arith::CmpIPredicate::ne, | |||||||||||||||||||||||||||
rewriter.create<arith::AndIOp>(loc, powerTmp, oneValue), zeroValue); | |||||||||||||||||||||||||||
thenBlock = rewriter.createBlock(funcBody); | |||||||||||||||||||||||||||
// result *= b; | |||||||||||||||||||||||||||
Value newResultTmp = rewriter.create<arith::MulIOp>(loc, resultTmp, baseTmp); | |||||||||||||||||||||||||||
fallthroughBlock = | |||||||||||||||||||||||||||
rewriter.createBlock(funcBody, funcBody->end(), elementType, loc); | |||||||||||||||||||||||||||
rewriter.setInsertionPointToEnd(thenBlock); | |||||||||||||||||||||||||||
rewriter.create<cf::BranchOp>(loc, newResultTmp, fallthroughBlock); | |||||||||||||||||||||||||||
// Set up conditional branch for (p & T(1)). | |||||||||||||||||||||||||||
rewriter.setInsertionPointToEnd(powerTmpIsOdd->getBlock()); | |||||||||||||||||||||||||||
rewriter.create<cf::CondBranchOp>(loc, powerTmpIsOdd, thenBlock, | |||||||||||||||||||||||||||
fallthroughBlock, resultTmp); | |||||||||||||||||||||||||||
// Merged 'result'. | |||||||||||||||||||||||||||
newResultTmp = fallthroughBlock->getArgument(0); | |||||||||||||||||||||||||||
// p >>= T(1); | |||||||||||||||||||||||||||
rewriter.setInsertionPointToEnd(fallthroughBlock); | |||||||||||||||||||||||||||
Value newPowerTmp = rewriter.create<arith::ShRUIOp>(loc, powerTmp, oneValue); | |||||||||||||||||||||||||||
// if (p == T(0)) | |||||||||||||||||||||||||||
auto newPowerIsZero = rewriter.create<arith::CmpIOp>( | |||||||||||||||||||||||||||
loc, arith::CmpIPredicate::eq, newPowerTmp, zeroValue); | |||||||||||||||||||||||||||
// return result; | |||||||||||||||||||||||||||
thenBlock = rewriter.createBlock(funcBody); | |||||||||||||||||||||||||||
rewriter.create<func::ReturnOp>(loc, newResultTmp); | |||||||||||||||||||||||||||
fallthroughBlock = rewriter.createBlock(funcBody); | |||||||||||||||||||||||||||
// Set up conditional branch for (p == T(0)). | |||||||||||||||||||||||||||
rewriter.setInsertionPointToEnd(newPowerIsZero->getBlock()); | |||||||||||||||||||||||||||
rewriter.create<cf::CondBranchOp>(loc, newPowerIsZero, thenBlock, | |||||||||||||||||||||||||||
fallthroughBlock); | |||||||||||||||||||||||||||
// b *= b; | |||||||||||||||||||||||||||
// } | |||||||||||||||||||||||||||
rewriter.setInsertionPointToEnd(fallthroughBlock); | |||||||||||||||||||||||||||
Value newBaseTmp = rewriter.create<arith::MulIOp>(loc, baseTmp, baseTmp); | |||||||||||||||||||||||||||
// Pass new values for 'result', 'b' and 'p' to the loop header. | |||||||||||||||||||||||||||
rewriter.create<cf::BranchOp>( | |||||||||||||||||||||||||||
loc, ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader); | |||||||||||||||||||||||||||
return funcOp; | |||||||||||||||||||||||||||
} | |||||||||||||||||||||||||||
/// Convert IPowI into a call to a local function implementing | |||||||||||||||||||||||||||
/// the power operation. The local function computes a scalar result, | |||||||||||||||||||||||||||
/// so vector forms of IPowI are linearized. | |||||||||||||||||||||||||||
LogicalResult | |||||||||||||||||||||||||||
IPowIOpLowering::matchAndRewrite(math::IPowIOp op, | |||||||||||||||||||||||||||
PatternRewriter &rewriter) const { | |||||||||||||||||||||||||||
auto baseType = op.getOperands()[0].getType().dyn_cast<IntegerType>(); | |||||||||||||||||||||||||||
auto exponentType = op.getOperands()[1].getType().dyn_cast<IntegerType>(); | |||||||||||||||||||||||||||
if (!baseType || baseType != exponentType) | |||||||||||||||||||||||||||
return failure(); | |||||||||||||||||||||||||||
auto elementFunc = getElementFunc(op, baseType, rewriter); | |||||||||||||||||||||||||||
rewriter.replaceOpWithNewOp<func::CallOp>(op, elementFunc, op.getOperands()); | |||||||||||||||||||||||||||
return success(); | |||||||||||||||||||||||||||
} | |||||||||||||||||||||||||||
void ConvertMathToFuncsPass::runOnOperation() { | |||||||||||||||||||||||||||
auto module = getOperation(); | |||||||||||||||||||||||||||
RewritePatternSet patterns(&getContext()); | |||||||||||||||||||||||||||
populateMathToFuncsConversionPatterns(patterns, /*benefit=*/1); | |||||||||||||||||||||||||||
ConversionTarget target(getContext()); | |||||||||||||||||||||||||||
target.addLegalDialect<arith::ArithmeticDialect, BuiltinDialect, | |||||||||||||||||||||||||||
cf::ControlFlowDialect, func::FuncDialect, | |||||||||||||||||||||||||||
vector::VectorDialect>(); | |||||||||||||||||||||||||||
MogballUnsubmitted Since this is a partial dialect conversion, can you not just add math::IPowI as an illegal op? Mogball: Since this is a partial dialect conversion, can you not just add `math::IPowI` as an illegal op? | |||||||||||||||||||||||||||
vzakhariAuthorUnsubmitted Will do. vzakhari: Will do. | |||||||||||||||||||||||||||
if (failed(applyPartialConversion(module, target, std::move(patterns)))) | |||||||||||||||||||||||||||
signalPassFailure(); | |||||||||||||||||||||||||||
} | |||||||||||||||||||||||||||
void mlir::populateMathToFuncsConversionPatterns(RewritePatternSet &patterns, | |||||||||||||||||||||||||||
PatternBenefit benefit) { | |||||||||||||||||||||||||||
patterns.add<VecOpToScalarOp<math::IPowIOp>>(patterns.getContext(), benefit); | |||||||||||||||||||||||||||
patterns.add<IPowIOpLowering>(patterns.getContext(), benefit); | |||||||||||||||||||||||||||
} | |||||||||||||||||||||||||||
MogballUnsubmitted Nit: typically this comes before the pass declaration/definition Mogball: Nit: typically this comes before the pass declaration/definition | |||||||||||||||||||||||||||
std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertMathToFuncsPass() { | |||||||||||||||||||||||||||
return std::make_unique<ConvertMathToFuncsPass>(); | |||||||||||||||||||||||||||
} | |||||||||||||||||||||||||||
Mogball: | |||||||||||||||||||||||||||
Since I am going to add FPowI support later, does it make sense to keep the switch now? vzakhari: Since I am going to add `FPowI` support later, does it make sense to keep the switch now? | |||||||||||||||||||||||||||
Yes it does. Thanks. Mogball: Yes it does. Thanks. | |||||||||||||||||||||||||||
Mogball: | |||||||||||||||||||||||||||
If all the types are the same, then you only need to hash based on the element type. Mogball: If all the types are the same, then you only need to hash based on the element type. | |||||||||||||||||||||||||||
This is true for IPowI but not for FPowI. I prefer to keep it as a function type so that the keys are consistent in the map, otherwise, for IPowI we will have IntegerType keys and for FPowI we will have FunctionType keys. vzakhari: This is true for `IPowI` but not for `FPowI`. I prefer to keep it as a function type so that… | |||||||||||||||||||||||||||
Then you can just use Type as the key. Mogball: Then you can just use `Type` as the key. | |||||||||||||||||||||||||||
In any case, you don't need std::map and a custom comparator. DenseMap<Type, T> works just fine. Mogball: In any case, you don't need `std::map` and a custom comparator. `DenseMap<Type, T>` works just… | |||||||||||||||||||||||||||
Should be done now. Thanks! vzakhari: Should be done now. Thanks! | |||||||||||||||||||||||||||
Not Done ReplyInline Actions
Can you remove the extra logic for now? I know you want to add support for other ops, but I think it'd be less automagical to just add another Case for FPowI Mogball: Can you remove the extra logic for now? I know you want to add support for other ops, but I… |
Can you trim these includes? I don't think all of these are needed