diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp --- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp +++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp @@ -81,30 +81,30 @@ LogicalResult ScalarOpToLibmCall::matchAndRewrite(Op op, PatternRewriter &rewriter) const { - auto module = op->template getParentOfType(); + auto module = SymbolTable::getNearestSymbolTable(op); auto type = op.getType(); // TODO: Support Float16 by upcasting to Float32 if (!type.template isa()) return failure(); auto name = type.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc; - auto opFunc = module.template lookupSymbol(name); + auto opFunc = dyn_cast_or_null( + SymbolTable::lookupSymbolIn(module, name)); // Forward declare function if it hasn't already been if (!opFunc) { OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(module.getBody()); + rewriter.setInsertionPointToStart(&module->getRegion(0).front()); auto opFunctionTy = FunctionType::get( rewriter.getContext(), op->getOperandTypes(), op->getResultTypes()); opFunc = rewriter.create(rewriter.getUnknownLoc(), name, opFunctionTy); opFunc.setPrivate(); } - assert(opFunc.getType().template cast().getResults() == - op->getResultTypes()); - assert(opFunc.getType().template cast().getInputs() == - op->getOperandTypes()); + assert(SymbolTable::lookupSymbolIn(module, name) + ->template hasTrait()); - rewriter.replaceOpWithNewOp(op, opFunc, op->getOperands()); + rewriter.replaceOpWithNewOp(op, name, op.getType(), + op->getOperands()); return success(); }