Changeset View
Standalone View
mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
- This file was added.
//===- MathToFuncs.cpp - Math to outlined implementation conversion -------===// | |||||||||||||||||||||||||||
// | |||||||||||||||||||||||||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | |||||||||||||||||||||||||||
// See https://llvm.org/LICENSE.txt for license information. | |||||||||||||||||||||||||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | |||||||||||||||||||||||||||
// | |||||||||||||||||||||||||||
//===----------------------------------------------------------------------===// | |||||||||||||||||||||||||||
#include "mlir/Conversion/MathToFuncs/MathToFuncs.h" | |||||||||||||||||||||||||||
#include "../PassDetail.h" | |||||||||||||||||||||||||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" | |||||||||||||||||||||||||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" | |||||||||||||||||||||||||||
#include "mlir/Dialect/Func/IR/FuncOps.h" | |||||||||||||||||||||||||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" | |||||||||||||||||||||||||||
#include "mlir/Dialect/Math/IR/Math.h" | |||||||||||||||||||||||||||
#include "mlir/Dialect/Utils/IndexingUtils.h" | |||||||||||||||||||||||||||
#include "mlir/Dialect/Vector/IR/VectorOps.h" | |||||||||||||||||||||||||||
#include "mlir/Dialect/Vector/Utils/VectorUtils.h" | |||||||||||||||||||||||||||
#include "mlir/IR/ImplicitLocOpBuilder.h" | |||||||||||||||||||||||||||
#include "mlir/IR/TypeUtilities.h" | |||||||||||||||||||||||||||
#include "mlir/Transforms/DialectConversion.h" | |||||||||||||||||||||||||||
#include "llvm/ADT/DenseMap.h" | |||||||||||||||||||||||||||
#include "llvm/ADT/TypeSwitch.h" | |||||||||||||||||||||||||||
using namespace mlir; | |||||||||||||||||||||||||||
Mogball: Can you trim these includes? I don't think all of these are needed | |||||||||||||||||||||||||||
namespace { | |||||||||||||||||||||||||||
// Pattern to convert vector operations to scalar operations. | |||||||||||||||||||||||||||
template <typename Op> | |||||||||||||||||||||||||||
struct VecOpToScalarOp : public OpRewritePattern<Op> { | |||||||||||||||||||||||||||
public: | |||||||||||||||||||||||||||
using OpRewritePattern<Op>::OpRewritePattern; | |||||||||||||||||||||||||||
LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; | |||||||||||||||||||||||||||
}; | |||||||||||||||||||||||||||
// Callback type for getting pre-generated FuncOp implementing | |||||||||||||||||||||||||||
// a power operation of the given type. | |||||||||||||||||||||||||||
This isn't necessary. FunctionType can be hashed directly in a DenseMap. Mogball: This isn't necessary. `FunctionType` can be hashed directly in a `DenseMap`. | |||||||||||||||||||||||||||
using GetPowerFuncCallbackTy = std::function<func::FuncOp(FunctionType)>; | |||||||||||||||||||||||||||
We generally don't use std::function Mogball: We generally don't use `std::function` | |||||||||||||||||||||||||||
// Pattern to convert scalar IPowIOp into a call of outlined | |||||||||||||||||||||||||||
// software implementation. | |||||||||||||||||||||||||||
struct IPowIOpLowering : public OpRewritePattern<math::IPowIOp> { | |||||||||||||||||||||||||||
private: | |||||||||||||||||||||||||||
GetPowerFuncCallbackTy getFuncOpCallback; | |||||||||||||||||||||||||||
public: | |||||||||||||||||||||||||||
IPowIOpLowering(MLIRContext *context, GetPowerFuncCallbackTy cb) | |||||||||||||||||||||||||||
: OpRewritePattern<math::IPowIOp>(context), getFuncOpCallback(cb) {} | |||||||||||||||||||||||||||
/// Convert IPowI into a call to a local function implementing | |||||||||||||||||||||||||||
/// the power operation. The local function computes a scalar result, | |||||||||||||||||||||||||||
/// so vector forms of IPowI are linearized. | |||||||||||||||||||||||||||
LogicalResult matchAndRewrite(math::IPowIOp op, | |||||||||||||||||||||||||||
PatternRewriter &rewriter) const final; | |||||||||||||||||||||||||||
}; | |||||||||||||||||||||||||||
} // namespace | |||||||||||||||||||||||||||
/// Represent given FunctionType \p type as a string. | |||||||||||||||||||||||||||
std::string stringizeType(const FunctionType &type) { | |||||||||||||||||||||||||||
std::string result; | |||||||||||||||||||||||||||
Not Done ReplyInline Actions
Mogball: | |||||||||||||||||||||||||||
llvm::raw_string_ostream typeOS(result); | |||||||||||||||||||||||||||
for (unsigned i = 0, e = type.getNumResults(); i != e; ++i) | |||||||||||||||||||||||||||
The type should only be omitted if it's explicit on the right hand side on an assignment due to a cast, or if it's "too long" (e.g. vector<int32_t>::iterator) Mogball: The type should only be omitted if it's explicit on the right hand side on an assignment due to… | |||||||||||||||||||||||||||
I will fix auto declarations. I will keep auto for results of rewriter.create<T> - it looks like it is a common practice to write it with auto. vzakhari: I will fix `auto` declarations. I will keep `auto` for results of `rewriter.create<T>` - it… | |||||||||||||||||||||||||||
typeOS << '_' << type.getResult(i); | |||||||||||||||||||||||||||
for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i) | |||||||||||||||||||||||||||
I think the "function type stringify" is too generic. For FPowI, just the floating point and integer type will suffice. This doesn't need to be done generally. I think you can just remove this function. Mogball: I think the "function type stringify" is too generic. For `FPowI`, just the floating point and… | |||||||||||||||||||||||||||
Okay, I will remove it. vzakhari: Okay, I will remove it. | |||||||||||||||||||||||||||
typeOS << '_' << type.getInput(i); | |||||||||||||||||||||||||||
assert(!result.empty() && "invalid FunctionType"); | |||||||||||||||||||||||||||
return result; | |||||||||||||||||||||||||||
Can you add a failure reason with rewriter.notifyMatchFailure? It can notoriously difficult to debug pattern matching without those messages. Mogball: Can you add a failure reason with `rewriter.notifyMatchFailure`? It can notoriously difficult… | |||||||||||||||||||||||||||
Sure! Thanks for the suggestion! vzakhari: Sure! Thanks for the suggestion! | |||||||||||||||||||||||||||
} | |||||||||||||||||||||||||||
Spell out the auto. Mogball: Spell out the `auto`. | |||||||||||||||||||||||||||
Should this be unsigned? Mogball: Should this be `unsigned`? | |||||||||||||||||||||||||||
int64_t should be correct, since ShapedType::getNumElements is defined as returning int64_t in IR/BuiltinAttributes.h. vzakhari: `int64_t` should be correct, since `ShapedType::getNumElements` is defined as returning… | |||||||||||||||||||||||||||
template <typename Op> | |||||||||||||||||||||||||||
LogicalResult | |||||||||||||||||||||||||||
VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const { | |||||||||||||||||||||||||||
Type opType = op.getType(); | |||||||||||||||||||||||||||
Location loc = op.getLoc(); | |||||||||||||||||||||||||||
auto vecType = opType.template dyn_cast<VectorType>(); | |||||||||||||||||||||||||||
Please don't use auto for a loop index. Mogball: Please don't use `auto` for a loop index. | |||||||||||||||||||||||||||
if (!vecType) | |||||||||||||||||||||||||||
return rewriter.notifyMatchFailure(op, "not a vector operation"); | |||||||||||||||||||||||||||
if (!vecType.hasRank()) | |||||||||||||||||||||||||||
return rewriter.notifyMatchFailure(op, "unknown vector rank"); | |||||||||||||||||||||||||||
ArrayRef<int64_t> shape = vecType.getShape(); | |||||||||||||||||||||||||||
int64_t numElements = vecType.getNumElements(); | |||||||||||||||||||||||||||
Value result = rewriter.create<arith::ConstantOp>( | |||||||||||||||||||||||||||
loc, DenseElementsAttr::get( | |||||||||||||||||||||||||||
vecType, IntegerAttr::get(vecType.getElementType(), 0))); | |||||||||||||||||||||||||||
SmallVector<int64_t> ones(shape.size(), 1); | |||||||||||||||||||||||||||
Mogball: | |||||||||||||||||||||||||||
SmallVector<int64_t> strides = computeStrides(shape, ones); | |||||||||||||||||||||||||||
for (int64_t linearIndex = 0; linearIndex < numElements; ++linearIndex) { | |||||||||||||||||||||||||||
SmallVector<int64_t> positions = delinearize(strides, linearIndex); | |||||||||||||||||||||||||||
SmallVector<Value> operands; | |||||||||||||||||||||||||||
for (Value input : op->getOperands()) | |||||||||||||||||||||||||||
operands.push_back( | |||||||||||||||||||||||||||
rewriter.create<vector::ExtractOp>(loc, input, positions)); | |||||||||||||||||||||||||||
Value scalarOp = | |||||||||||||||||||||||||||
rewriter.create<Op>(loc, vecType.getElementType(), operands); | |||||||||||||||||||||||||||
result = | |||||||||||||||||||||||||||
rewriter.create<vector::InsertOp>(loc, scalarOp, result, positions); | |||||||||||||||||||||||||||
} | |||||||||||||||||||||||||||
rewriter.replaceOp(op, result); | |||||||||||||||||||||||||||
return success(); | |||||||||||||||||||||||||||
} | |||||||||||||||||||||||||||
/// Create linkonce_odr function to implement the power function with | |||||||||||||||||||||||||||
/// the given \p funcType type inside \p module. \p funcType must be | |||||||||||||||||||||||||||
/// 'IntegerType (*)(IntegerType, IntegerType)' function type. | |||||||||||||||||||||||||||
/// | |||||||||||||||||||||||||||
/// template <typename T> | |||||||||||||||||||||||||||
/// T __mlir_math_ipowi_*(T b, T p) { | |||||||||||||||||||||||||||
/// if (p == T(0)) | |||||||||||||||||||||||||||
/// return T(1); | |||||||||||||||||||||||||||
/// if (p < T(0)) { | |||||||||||||||||||||||||||
/// if (b == T(0)) | |||||||||||||||||||||||||||
/// return T(1) / T(0); // trigger div-by-zero | |||||||||||||||||||||||||||
/// if (b == T(1)) | |||||||||||||||||||||||||||
/// return T(1); | |||||||||||||||||||||||||||
/// if (b == T(-1)) { | |||||||||||||||||||||||||||
/// if (p & T(1)) | |||||||||||||||||||||||||||
/// return T(-1); | |||||||||||||||||||||||||||
/// return T(1); | |||||||||||||||||||||||||||
/// } | |||||||||||||||||||||||||||
/// return T(0); | |||||||||||||||||||||||||||
It's generally not valid for patterns to access and modify parent operations, especially since these "patterns" are exposed publicly. The generation of these software implementations should not be done inside patterns. Patterns should be local rewrites of operations. Mogball: It's generally not valid for patterns to access and modify parent operations, especially since… | |||||||||||||||||||||||||||
I understand the concern. I followed the same approach that MathToLibm and ComplexToLibm use. I guess I can add a comment in MathToFuncs.h to warn about valid usage of populateMathToFuncsConversionPatterns - does it make sense? vzakhari: I understand the concern. I followed the same approach that `MathToLibm` and `ComplexToLibm`… | |||||||||||||||||||||||||||
No. Those other patterns are also invalid. Patterns should not be modifying parent operations. It's something that "just happens to work" because no one has needed to compose these patterns sets in a parallelized pass (which would result in race conditions). This is especially true in conversion patterns, because there is extra bookkeeping in the dialect conversion pass that is violated if this happens. Please don't propagate what has been done in the older passes and separate these. Mogball: No. Those other patterns are also invalid. Patterns should not be modifying parent operations. | |||||||||||||||||||||||||||
Thank you for the explanation! Then it looks like exposing the patterns is unsafe, in general, and I should rewrite the pass as you suggested without externalizing the match patterns (populateMathToFuncsConversionPatterns). The module pass will first visit all IPowI operations for collecting information about the kinds of implementation functions that need to be created in the module, then it will create the new functions, and then it will apply populateMathToFuncsConversionPatterns patterns to rewrite the operations into the calls. vzakhari: Thank you for the explanation!
Then it looks like exposing the patterns is unsafe, in general… | |||||||||||||||||||||||||||
/// } | |||||||||||||||||||||||||||
/// T result = T(1); | |||||||||||||||||||||||||||
/// while (true) { | |||||||||||||||||||||||||||
/// if (p & T(1)) | |||||||||||||||||||||||||||
/// result *= b; | |||||||||||||||||||||||||||
Please spell out all the autos as appropriate. Mogball: Please spell out all the `auto`s as appropriate. | |||||||||||||||||||||||||||
/// p >>= T(1); | |||||||||||||||||||||||||||
/// if (p == T(0)) | |||||||||||||||||||||||||||
/// return result; | |||||||||||||||||||||||||||
/// b *= b; | |||||||||||||||||||||||||||
/// } | |||||||||||||||||||||||||||
/// } | |||||||||||||||||||||||||||
static func::FuncOp createElementIPowIFunc(ModuleOp *module, | |||||||||||||||||||||||||||
FunctionType funcType) { | |||||||||||||||||||||||||||
assert(funcType.getNumResults() == 1 && funcType.getNumInputs() == 2 && | |||||||||||||||||||||||||||
funcType.getResult(0).isa<IntegerType>() && | |||||||||||||||||||||||||||
Would generating this function be made any easier by using scf instead of cf? Mogball: Would generating this function be made any easier by using `scf` instead of `cf`? | |||||||||||||||||||||||||||
It seems to me unstructured control flow fits better for the code with early returns from the function. vzakhari: It seems to me unstructured control flow fits better for the code with early returns from the… | |||||||||||||||||||||||||||
funcType.getInput(0).isa<IntegerType>() && | |||||||||||||||||||||||||||
funcType.getInput(1).isa<IntegerType>() && | |||||||||||||||||||||||||||
"invalid function type deduced from IPowIOp"); | |||||||||||||||||||||||||||
IntegerType elementType = funcType.getInput(0).cast<IntegerType>(); | |||||||||||||||||||||||||||
ImplicitLocOpBuilder builder = | |||||||||||||||||||||||||||
ImplicitLocOpBuilder::atBlockEnd(module->getLoc(), module->getBody()); | |||||||||||||||||||||||||||
std::string funcName("__mlir_math_ipowi"); | |||||||||||||||||||||||||||
funcName += stringizeType(funcType); | |||||||||||||||||||||||||||
auto funcOp = builder.create<func::FuncOp>(funcName, funcType); | |||||||||||||||||||||||||||
LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR; | |||||||||||||||||||||||||||
Mogball: | |||||||||||||||||||||||||||
Will do. flush should not be necessary according to this: https://llvm.org/doxygen/classllvm_1_1raw__string__ostream.html#details vzakhari: Will do. `flush` should not be necessary according to this: https://llvm. | |||||||||||||||||||||||||||
Attribute linkage = | |||||||||||||||||||||||||||
I would prefer this to be a standalone function, not a static function on the pattern. Mogball: I would prefer this to be a standalone function, not a static function on the pattern. | |||||||||||||||||||||||||||
LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage); | |||||||||||||||||||||||||||
funcOp->setAttr("llvm.linkage", linkage); | |||||||||||||||||||||||||||
funcOp.setPrivate(); | |||||||||||||||||||||||||||
Block *entryBlock = funcOp.addEntryBlock(); | |||||||||||||||||||||||||||
Region *funcBody = entryBlock->getParent(); | |||||||||||||||||||||||||||
Value bArg = funcOp.getArgument(0); | |||||||||||||||||||||||||||
Value pArg = funcOp.getArgument(1); | |||||||||||||||||||||||||||
builder.setInsertionPointToEnd(entryBlock); | |||||||||||||||||||||||||||
Value zeroValue = builder.create<arith::ConstantOp>( | |||||||||||||||||||||||||||
elementType, builder.getIntegerAttr(elementType, 0)); | |||||||||||||||||||||||||||
Value oneValue = builder.create<arith::ConstantOp>( | |||||||||||||||||||||||||||
elementType, builder.getIntegerAttr(elementType, 1)); | |||||||||||||||||||||||||||
Value minusOneValue = builder.create<arith::ConstantOp>( | |||||||||||||||||||||||||||
elementType, | |||||||||||||||||||||||||||
builder.getIntegerAttr(elementType, | |||||||||||||||||||||||||||
APInt(elementType.getIntOrFloatBitWidth(), -1ULL, | |||||||||||||||||||||||||||
/*isSigned=*/true))); | |||||||||||||||||||||||||||
// if (p == T(0)) | |||||||||||||||||||||||||||
// return T(1); | |||||||||||||||||||||||||||
auto pIsZero = | |||||||||||||||||||||||||||
builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, pArg, zeroValue); | |||||||||||||||||||||||||||
Block *thenBlock = builder.createBlock(funcBody); | |||||||||||||||||||||||||||
builder.create<func::ReturnOp>(oneValue); | |||||||||||||||||||||||||||
Block *fallthroughBlock = builder.createBlock(funcBody); | |||||||||||||||||||||||||||
// Set up conditional branch for (p == T(0)). | |||||||||||||||||||||||||||
builder.setInsertionPointToEnd(pIsZero->getBlock()); | |||||||||||||||||||||||||||
builder.create<cf::CondBranchOp>(pIsZero, thenBlock, fallthroughBlock); | |||||||||||||||||||||||||||
// if (p < T(0)) { | |||||||||||||||||||||||||||
builder.setInsertionPointToEnd(fallthroughBlock); | |||||||||||||||||||||||||||
auto pIsNeg = | |||||||||||||||||||||||||||
builder.create<arith::CmpIOp>(arith::CmpIPredicate::sle, pArg, zeroValue); | |||||||||||||||||||||||||||
// if (b == T(0)) | |||||||||||||||||||||||||||
builder.createBlock(funcBody); | |||||||||||||||||||||||||||
auto bIsZero = | |||||||||||||||||||||||||||
builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bArg, zeroValue); | |||||||||||||||||||||||||||
// return T(1) / T(0); | |||||||||||||||||||||||||||
thenBlock = builder.createBlock(funcBody); | |||||||||||||||||||||||||||
builder.create<func::ReturnOp>( | |||||||||||||||||||||||||||
builder.create<arith::DivSIOp>(oneValue, zeroValue).getResult()); | |||||||||||||||||||||||||||
fallthroughBlock = builder.createBlock(funcBody); | |||||||||||||||||||||||||||
// Set up conditional branch for (b == T(0)). | |||||||||||||||||||||||||||
builder.setInsertionPointToEnd(bIsZero->getBlock()); | |||||||||||||||||||||||||||
builder.create<cf::CondBranchOp>(bIsZero, thenBlock, fallthroughBlock); | |||||||||||||||||||||||||||
// if (b == T(1)) | |||||||||||||||||||||||||||
builder.setInsertionPointToEnd(fallthroughBlock); | |||||||||||||||||||||||||||
auto bIsOne = | |||||||||||||||||||||||||||
builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bArg, oneValue); | |||||||||||||||||||||||||||
// return T(1); | |||||||||||||||||||||||||||
thenBlock = builder.createBlock(funcBody); | |||||||||||||||||||||||||||
builder.create<func::ReturnOp>(oneValue); | |||||||||||||||||||||||||||
fallthroughBlock = builder.createBlock(funcBody); | |||||||||||||||||||||||||||
// Set up conditional branch for (b == T(1)). | |||||||||||||||||||||||||||
builder.setInsertionPointToEnd(bIsOne->getBlock()); | |||||||||||||||||||||||||||
builder.create<cf::CondBranchOp>(bIsOne, thenBlock, fallthroughBlock); | |||||||||||||||||||||||||||
// if (b == T(-1)) { | |||||||||||||||||||||||||||
builder.setInsertionPointToEnd(fallthroughBlock); | |||||||||||||||||||||||||||
auto bIsMinusOne = builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, | |||||||||||||||||||||||||||
bArg, minusOneValue); | |||||||||||||||||||||||||||
// if (p & T(1)) | |||||||||||||||||||||||||||
builder.createBlock(funcBody); | |||||||||||||||||||||||||||
auto pIsOdd = builder.create<arith::CmpIOp>( | |||||||||||||||||||||||||||
arith::CmpIPredicate::ne, builder.create<arith::AndIOp>(pArg, oneValue), | |||||||||||||||||||||||||||
zeroValue); | |||||||||||||||||||||||||||
// return T(-1); | |||||||||||||||||||||||||||
thenBlock = builder.createBlock(funcBody); | |||||||||||||||||||||||||||
builder.create<func::ReturnOp>(minusOneValue); | |||||||||||||||||||||||||||
fallthroughBlock = builder.createBlock(funcBody); | |||||||||||||||||||||||||||
// Set up conditional branch for (p & T(1)). | |||||||||||||||||||||||||||
builder.setInsertionPointToEnd(pIsOdd->getBlock()); | |||||||||||||||||||||||||||
builder.create<cf::CondBranchOp>(pIsOdd, thenBlock, fallthroughBlock); | |||||||||||||||||||||||||||
// return T(1); | |||||||||||||||||||||||||||
// } // b == T(-1) | |||||||||||||||||||||||||||
builder.setInsertionPointToEnd(fallthroughBlock); | |||||||||||||||||||||||||||
builder.create<func::ReturnOp>(oneValue); | |||||||||||||||||||||||||||
fallthroughBlock = builder.createBlock(funcBody); | |||||||||||||||||||||||||||
// Set up conditional branch for (b == T(-1)). | |||||||||||||||||||||||||||
builder.setInsertionPointToEnd(bIsMinusOne->getBlock()); | |||||||||||||||||||||||||||
builder.create<cf::CondBranchOp>(bIsMinusOne, pIsOdd->getBlock(), | |||||||||||||||||||||||||||
fallthroughBlock); | |||||||||||||||||||||||||||
// return T(0); | |||||||||||||||||||||||||||
// } // (p < T(0)) | |||||||||||||||||||||||||||
builder.setInsertionPointToEnd(fallthroughBlock); | |||||||||||||||||||||||||||
builder.create<func::ReturnOp>(zeroValue); | |||||||||||||||||||||||||||
Block *loopHeader = builder.createBlock( | |||||||||||||||||||||||||||
funcBody, funcBody->end(), {elementType, elementType, elementType}, | |||||||||||||||||||||||||||
{builder.getLoc(), builder.getLoc(), builder.getLoc()}); | |||||||||||||||||||||||||||
// Set up conditional branch for (p < T(0)). | |||||||||||||||||||||||||||
builder.setInsertionPointToEnd(pIsNeg->getBlock()); | |||||||||||||||||||||||||||
// Set initial values of 'result', 'b' and 'p' for the loop. | |||||||||||||||||||||||||||
builder.create<cf::CondBranchOp>(pIsNeg, bIsZero->getBlock(), loopHeader, | |||||||||||||||||||||||||||
ValueRange{oneValue, bArg, pArg}); | |||||||||||||||||||||||||||
// T result = T(1); | |||||||||||||||||||||||||||
// while (true) { | |||||||||||||||||||||||||||
// if (p & T(1)) | |||||||||||||||||||||||||||
// result *= b; | |||||||||||||||||||||||||||
// p >>= T(1); | |||||||||||||||||||||||||||
// if (p == T(0)) | |||||||||||||||||||||||||||
// return result; | |||||||||||||||||||||||||||
// b *= b; | |||||||||||||||||||||||||||
// } | |||||||||||||||||||||||||||
Value resultTmp = loopHeader->getArgument(0); | |||||||||||||||||||||||||||
Value baseTmp = loopHeader->getArgument(1); | |||||||||||||||||||||||||||
Value powerTmp = loopHeader->getArgument(2); | |||||||||||||||||||||||||||
builder.setInsertionPointToEnd(loopHeader); | |||||||||||||||||||||||||||
// if (p & T(1)) | |||||||||||||||||||||||||||
auto powerTmpIsOdd = builder.create<arith::CmpIOp>( | |||||||||||||||||||||||||||
arith::CmpIPredicate::ne, | |||||||||||||||||||||||||||
builder.create<arith::AndIOp>(powerTmp, oneValue), zeroValue); | |||||||||||||||||||||||||||
thenBlock = builder.createBlock(funcBody); | |||||||||||||||||||||||||||
// result *= b; | |||||||||||||||||||||||||||
Value newResultTmp = builder.create<arith::MulIOp>(resultTmp, baseTmp); | |||||||||||||||||||||||||||
fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), elementType, | |||||||||||||||||||||||||||
builder.getLoc()); | |||||||||||||||||||||||||||
builder.setInsertionPointToEnd(thenBlock); | |||||||||||||||||||||||||||
builder.create<cf::BranchOp>(newResultTmp, fallthroughBlock); | |||||||||||||||||||||||||||
// Set up conditional branch for (p & T(1)). | |||||||||||||||||||||||||||
builder.setInsertionPointToEnd(powerTmpIsOdd->getBlock()); | |||||||||||||||||||||||||||
builder.create<cf::CondBranchOp>(powerTmpIsOdd, thenBlock, fallthroughBlock, | |||||||||||||||||||||||||||
resultTmp); | |||||||||||||||||||||||||||
// Merged 'result'. | |||||||||||||||||||||||||||
newResultTmp = fallthroughBlock->getArgument(0); | |||||||||||||||||||||||||||
// p >>= T(1); | |||||||||||||||||||||||||||
builder.setInsertionPointToEnd(fallthroughBlock); | |||||||||||||||||||||||||||
Value newPowerTmp = builder.create<arith::ShRUIOp>(powerTmp, oneValue); | |||||||||||||||||||||||||||
// if (p == T(0)) | |||||||||||||||||||||||||||
auto newPowerIsZero = builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, | |||||||||||||||||||||||||||
newPowerTmp, zeroValue); | |||||||||||||||||||||||||||
// return result; | |||||||||||||||||||||||||||
thenBlock = builder.createBlock(funcBody); | |||||||||||||||||||||||||||
builder.create<func::ReturnOp>(newResultTmp); | |||||||||||||||||||||||||||
fallthroughBlock = builder.createBlock(funcBody); | |||||||||||||||||||||||||||
// Set up conditional branch for (p == T(0)). | |||||||||||||||||||||||||||
builder.setInsertionPointToEnd(newPowerIsZero->getBlock()); | |||||||||||||||||||||||||||
builder.create<cf::CondBranchOp>(newPowerIsZero, thenBlock, fallthroughBlock); | |||||||||||||||||||||||||||
// b *= b; | |||||||||||||||||||||||||||
// } | |||||||||||||||||||||||||||
builder.setInsertionPointToEnd(fallthroughBlock); | |||||||||||||||||||||||||||
Value newBaseTmp = builder.create<arith::MulIOp>(baseTmp, baseTmp); | |||||||||||||||||||||||||||
// Pass new values for 'result', 'b' and 'p' to the loop header. | |||||||||||||||||||||||||||
builder.create<cf::BranchOp>( | |||||||||||||||||||||||||||
ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader); | |||||||||||||||||||||||||||
return funcOp; | |||||||||||||||||||||||||||
} | |||||||||||||||||||||||||||
/// Convert IPowI into a call to a local function implementing | |||||||||||||||||||||||||||
/// the power operation. The local function computes a scalar result, | |||||||||||||||||||||||||||
/// so vector forms of IPowI are linearized. | |||||||||||||||||||||||||||
LogicalResult | |||||||||||||||||||||||||||
IPowIOpLowering::matchAndRewrite(math::IPowIOp op, | |||||||||||||||||||||||||||
PatternRewriter &rewriter) const { | |||||||||||||||||||||||||||
auto baseType = op.getOperands()[0].getType().dyn_cast<IntegerType>(); | |||||||||||||||||||||||||||
if (!baseType) | |||||||||||||||||||||||||||
return rewriter.notifyMatchFailure(op, "non-integer base operand"); | |||||||||||||||||||||||||||
FunctionType funcType = | |||||||||||||||||||||||||||
FunctionType::get(rewriter.getContext(), {baseType, baseType}, baseType); | |||||||||||||||||||||||||||
// The outlined software implementation must have been already | |||||||||||||||||||||||||||
// generated. | |||||||||||||||||||||||||||
func::FuncOp elementFunc = getFuncOpCallback(funcType); | |||||||||||||||||||||||||||
if (!elementFunc) | |||||||||||||||||||||||||||
return rewriter.notifyMatchFailure(op, "missing software implementation"); | |||||||||||||||||||||||||||
Since this is a partial dialect conversion, can you not just add math::IPowI as an illegal op? Mogball: Since this is a partial dialect conversion, can you not just add `math::IPowI` as an illegal op? | |||||||||||||||||||||||||||
Will do. vzakhari: Will do. | |||||||||||||||||||||||||||
rewriter.replaceOpWithNewOp<func::CallOp>(op, elementFunc, op.getOperands()); | |||||||||||||||||||||||||||
return success(); | |||||||||||||||||||||||||||
} | |||||||||||||||||||||||||||
namespace { | |||||||||||||||||||||||||||
struct ConvertMathToFuncsPass | |||||||||||||||||||||||||||
: public ConvertMathToFuncsBase<ConvertMathToFuncsPass> { | |||||||||||||||||||||||||||
ConvertMathToFuncsPass() = default; | |||||||||||||||||||||||||||
Nit: typically this comes before the pass declaration/definition Mogball: Nit: typically this comes before the pass declaration/definition | |||||||||||||||||||||||||||
void runOnOperation() override; | |||||||||||||||||||||||||||
private: | |||||||||||||||||||||||||||
// Generate outlined implementations for power operations | |||||||||||||||||||||||||||
These two checks aren't necessary since they are invariants enforced by the verifier. You just need to check that this is a scalar integer op (done by the first condition). Mogball: These two checks aren't necessary since they are invariants enforced by the verifier. You just… | |||||||||||||||||||||||||||
// and store them in powerFuncs map. | |||||||||||||||||||||||||||
void preprocessPowOperations(); | |||||||||||||||||||||||||||
// A map between function types deduced from power operations | |||||||||||||||||||||||||||
// and the corresponding outlined software implementations | |||||||||||||||||||||||||||
// of these operations. | |||||||||||||||||||||||||||
DenseMap<FunctionType, func::FuncOp> powerFuncs; | |||||||||||||||||||||||||||
}; | |||||||||||||||||||||||||||
} // namespace | |||||||||||||||||||||||||||
void ConvertMathToFuncsPass::preprocessPowOperations() { | |||||||||||||||||||||||||||
ModuleOp module = getOperation(); | |||||||||||||||||||||||||||
module.walk([&](Operation *op) { | |||||||||||||||||||||||||||
TypeSwitch<Operation *>(op).Case<math::IPowIOp>([&](math::IPowIOp op) { | |||||||||||||||||||||||||||
Type resultType = getElementTypeOrSelf(op.getResult().getType()); | |||||||||||||||||||||||||||
Type baseType = getElementTypeOrSelf(op.getOperands()[0].getType()); | |||||||||||||||||||||||||||
Type exponentType = getElementTypeOrSelf(op.getOperands()[1].getType()); | |||||||||||||||||||||||||||
FunctionType funcType = FunctionType::get( | |||||||||||||||||||||||||||
&getContext(), {baseType, exponentType}, resultType); | |||||||||||||||||||||||||||
// Generate the software implementation of this operation, | |||||||||||||||||||||||||||
// if it has not been generated yet. | |||||||||||||||||||||||||||
auto entry = powerFuncs.try_emplace(funcType, func::FuncOp{}); | |||||||||||||||||||||||||||
if (entry.second) | |||||||||||||||||||||||||||
entry.first->second = createElementIPowIFunc(&module, funcType); | |||||||||||||||||||||||||||
}); | |||||||||||||||||||||||||||
}); | |||||||||||||||||||||||||||
} | |||||||||||||||||||||||||||
void ConvertMathToFuncsPass::runOnOperation() { | |||||||||||||||||||||||||||
ModuleOp module = getOperation(); | |||||||||||||||||||||||||||
Not Done ReplyInline Actions
Can you remove the extra logic for now? I know you want to add support for other ops, but I think it'd be less automagical to just add another Case for FPowI Mogball: Can you remove the extra logic for now? I know you want to add support for other ops, but I… | |||||||||||||||||||||||||||
Mogball: | |||||||||||||||||||||||||||
// Create outlined implementations for power operations. | |||||||||||||||||||||||||||
preprocessPowOperations(); | |||||||||||||||||||||||||||
RewritePatternSet patterns(&getContext()); | |||||||||||||||||||||||||||
patterns.add<VecOpToScalarOp<math::IPowIOp>>(patterns.getContext()); | |||||||||||||||||||||||||||
// For the given FunctionType Returns FuncOp stored in powerFuncs map. | |||||||||||||||||||||||||||
auto getPowerFuncOpByType = [&](FunctionType type) -> func::FuncOp { | |||||||||||||||||||||||||||
Mogball: | |||||||||||||||||||||||||||
Since I am going to add FPowI support later, does it make sense to keep the switch now? vzakhari: Since I am going to add `FPowI` support later, does it make sense to keep the switch now? | |||||||||||||||||||||||||||
Yes it does. Thanks. Mogball: Yes it does. Thanks. | |||||||||||||||||||||||||||
auto it = powerFuncs.find(type); | |||||||||||||||||||||||||||
if (it == powerFuncs.end()) | |||||||||||||||||||||||||||
return {}; | |||||||||||||||||||||||||||
If all the types are the same, then you only need to hash based on the element type. Mogball: If all the types are the same, then you only need to hash based on the element type. | |||||||||||||||||||||||||||
This is true for IPowI but not for FPowI. I prefer to keep it as a function type so that the keys are consistent in the map, otherwise, for IPowI we will have IntegerType keys and for FPowI we will have FunctionType keys. vzakhari: This is true for `IPowI` but not for `FPowI`. I prefer to keep it as a function type so that… | |||||||||||||||||||||||||||
Then you can just use Type as the key. Mogball: Then you can just use `Type` as the key. | |||||||||||||||||||||||||||
In any case, you don't need std::map and a custom comparator. DenseMap<Type, T> works just fine. Mogball: In any case, you don't need `std::map` and a custom comparator. `DenseMap<Type, T>` works just… | |||||||||||||||||||||||||||
Should be done now. Thanks! vzakhari: Should be done now. Thanks! | |||||||||||||||||||||||||||
return it->second; | |||||||||||||||||||||||||||
}; | |||||||||||||||||||||||||||
patterns.add<IPowIOpLowering>(patterns.getContext(), getPowerFuncOpByType); | |||||||||||||||||||||||||||
ConversionTarget target(getContext()); | |||||||||||||||||||||||||||
target.addLegalDialect<arith::ArithmeticDialect, cf::ControlFlowDialect, | |||||||||||||||||||||||||||
func::FuncDialect, vector::VectorDialect>(); | |||||||||||||||||||||||||||
target.addIllegalOp<math::IPowIOp>(); | |||||||||||||||||||||||||||
if (failed(applyPartialConversion(module, target, std::move(patterns)))) | |||||||||||||||||||||||||||
signalPassFailure(); | |||||||||||||||||||||||||||
} | |||||||||||||||||||||||||||
std::unique_ptr<Pass> mlir::createConvertMathToFuncsPass() { | |||||||||||||||||||||||||||
return std::make_unique<ConvertMathToFuncsPass>(); | |||||||||||||||||||||||||||
} |
Can you trim these includes? I don't think all of these are needed