diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h @@ -0,0 +1,72 @@ +//===- FunctionCallUtils.h - Utilities for C function calls -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares helper functions to call common simple C functions in +// LLVMIR (e.g. amon others to support printing and debugging). +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LLVMIR_FUNCTIONCALLUTILS_H_ +#define MLIR_DIALECT_LLVMIR_FUNCTIONCALLUTILS_H_ + +#include "mlir/IR/Operation.h" +#include "mlir/Support/LLVM.h" + +namespace mlir { +class Location; +class ModuleOp; +class OpBuilder; +class Operation; +class Type; +class ValueRange; + +namespace LLVM { +class LLVMFuncOp; + +struct CFunctionNames { + static constexpr llvm::StringRef kPrintI64 = "printI64"; + static constexpr llvm::StringRef kPrintU64 = "printU64"; + static constexpr llvm::StringRef kPrintF32 = "printF32"; + static constexpr llvm::StringRef kPrintF64 = "printF64"; + static constexpr llvm::StringRef kPrintOpen = "printOpen"; + static constexpr llvm::StringRef kPrintClose = "printClose"; + static constexpr llvm::StringRef kPrintComma = "printComma"; + static constexpr llvm::StringRef kPrintNewline = "printNewline"; + static constexpr llvm::StringRef kMalloc = "malloc"; + static constexpr llvm::StringRef kAlignedAlloc = "aligned_alloc"; +}; + +/// Helper functions to lookup or create the declaration for commonly used +/// external C function calls. Such ops can then be invoked by creating a CallOp +/// with the proper arguments via `createLLVMCall`. +/// The list of functions provided here must be implemented separately (e.g. as +/// part of a support runtime library or as part of the libc). +LLVM::LLVMFuncOp lookupOrCreatePrintI64Fn(ModuleOp moduleOp); +LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(ModuleOp moduleOp); +LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(ModuleOp moduleOp); +LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(ModuleOp moduleOp); +LLVM::LLVMFuncOp lookupOrCreatePrintOpenFn(ModuleOp moduleOp); +LLVM::LLVMFuncOp lookupOrCreatePrintCloseFn(ModuleOp moduleOp); +LLVM::LLVMFuncOp lookupOrCreatePrintCommaFn(ModuleOp moduleOp); +LLVM::LLVMFuncOp lookupOrCreatePrintNewlineFn(ModuleOp moduleOp); + +/// Create a FuncOp with signature `resultTyp`(`paramTypes`)` and name `name`. +LLVM::LLVMFuncOp lookupOrCreateFn(ModuleOp moduleOp, StringRef name, + ArrayRef paramTypes = {}, + Type resultType = {}); + +/// Helper wrapper to create a call to `fn` with `args` and `resultTypes`. +Operation::result_range createLLVMCall(OpBuilder &b, Location loc, + LLVM::LLVMFuncOp fn, + ValueRange args = {}, + ArrayRef resultTypes = {}); + +} // namespace LLVM +} // namespace mlir + +#endif // MLIR_DIALECT_LLVMIR_FUNCTIONCALLUTILS_H_ diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -14,6 +14,7 @@ #include "../PassDetail.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" +#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Attributes.h" @@ -1793,31 +1794,6 @@ return rewriter.create(loc, bumped, mod); } - // Creates a call to an allocation function with params and casts the - // resulting void pointer to ptrType. - Value createAllocCall(Location loc, StringRef name, Type ptrType, - ArrayRef params, ModuleOp module, - ConversionPatternRewriter &rewriter) const { - SmallVector paramTypes; - auto allocFuncOp = module.lookupSymbol(name); - if (!allocFuncOp) { - for (Value param : params) - paramTypes.push_back(param.getType()); - auto allocFuncType = - LLVM::LLVMFunctionType::get(getVoidPtrType(), paramTypes); - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(module.getBody()); - allocFuncOp = rewriter.create(rewriter.getUnknownLoc(), - name, allocFuncType); - } - auto allocFuncSymbol = rewriter.getSymbolRefAttr(allocFuncOp); - auto allocatedPtr = rewriter - .create(loc, getVoidPtrType(), - allocFuncSymbol, params) - .getResult(0); - return rewriter.create(loc, ptrType, allocatedPtr); - } - /// Allocates the underlying buffer. Returns the allocated pointer and the /// aligned pointer. virtual std::tuple @@ -1909,9 +1885,13 @@ // Allocate the underlying buffer and store a pointer to it in the MemRef // descriptor. Type elementPtrType = this->getElementPtrType(memRefType); + auto allocFuncOp = LLVM::lookupOrCreateFn( + allocOp->getParentOfType(), LLVM::CFunctionNames::kMalloc, + {getIndexType()}, getVoidPtrType()); + auto results = createLLVMCall(rewriter, loc, allocFuncOp, {sizeBytes}, + getVoidPtrType()); Value allocatedPtr = - createAllocCall(loc, "malloc", elementPtrType, {sizeBytes}, - allocOp->getParentOfType(), rewriter); + rewriter.create(loc, elementPtrType, results[0]); Value alignedPtr = allocatedPtr; if (alignment) { @@ -1991,9 +1971,15 @@ sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment); Type elementPtrType = this->getElementPtrType(memRefType); - Value allocatedPtr = createAllocCall( - loc, "aligned_alloc", elementPtrType, {allocAlignment, sizeBytes}, - allocOp->getParentOfType(), rewriter); + auto allocFuncOp = LLVM::lookupOrCreateFn( + allocOp->getParentOfType(), + LLVM::CFunctionNames::kAlignedAlloc, {getIndexType(), getIndexType()}, + getVoidPtrType()); + auto results = + createLLVMCall(rewriter, loc, allocFuncOp, {allocAlignment, sizeBytes}, + getVoidPtrType()); + Value allocatedPtr = + rewriter.create(loc, elementPtrType, results[0]); return std::make_tuple(allocatedPtr, allocatedPtr); } diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -10,6 +10,7 @@ #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" +#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Vector/VectorOps.h" @@ -1311,11 +1312,14 @@ Type eltType = vectorType ? vectorType.getElementType() : printType; Operation *printer; if (eltType.isF32()) { - printer = getPrintFloat(printOp); + printer = + LLVM::lookupOrCreatePrintF32Fn(printOp->getParentOfType()); } else if (eltType.isF64()) { - printer = getPrintDouble(printOp); + printer = + LLVM::lookupOrCreatePrintF64Fn(printOp->getParentOfType()); } else if (eltType.isIndex()) { - printer = getPrintU64(printOp); + printer = + LLVM::lookupOrCreatePrintU64Fn(printOp->getParentOfType()); } else if (auto intTy = eltType.dyn_cast()) { // Integers need a zero or sign extension on the operand // (depending on the source type) as well as a signed or @@ -1325,7 +1329,8 @@ if (width <= 64) { if (width < 64) conversion = PrintConversion::ZeroExt64; - printer = getPrintU64(printOp); + printer = LLVM::lookupOrCreatePrintU64Fn( + printOp->getParentOfType()); } else { return failure(); } @@ -1338,7 +1343,8 @@ conversion = PrintConversion::ZeroExt64; else if (width < 64) conversion = PrintConversion::SignExt64; - printer = getPrintI64(printOp); + printer = LLVM::lookupOrCreatePrintI64Fn( + printOp->getParentOfType()); } else { return failure(); } @@ -1351,7 +1357,9 @@ int64_t rank = vectorType ? vectorType.getRank() : 0; emitRanks(rewriter, printOp, adaptor.source(), vectorType, printer, rank, conversion); - emitCall(rewriter, printOp->getLoc(), getPrintNewline(printOp)); + emitCall(rewriter, printOp->getLoc(), + LLVM::lookupOrCreatePrintNewlineFn( + printOp->getParentOfType())); rewriter.eraseOp(printOp); return success(); } @@ -1386,8 +1394,10 @@ return; } - emitCall(rewriter, loc, getPrintOpen(op)); - Operation *printComma = getPrintComma(op); + emitCall(rewriter, loc, + LLVM::lookupOrCreatePrintOpenFn(op->getParentOfType())); + Operation *printComma = + LLVM::lookupOrCreatePrintCommaFn(op->getParentOfType()); int64_t dim = vectorType.getDimSize(0); for (int64_t d = 0; d < dim; ++d) { auto reducedType = @@ -1401,7 +1411,8 @@ if (d != dim - 1) emitCall(rewriter, loc, printComma); } - emitCall(rewriter, loc, getPrintClose(op)); + emitCall(rewriter, loc, + LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType())); } // Helper to emit a call. @@ -1410,46 +1421,6 @@ rewriter.create(loc, TypeRange(), rewriter.getSymbolRefAttr(ref), params); } - - // Helper for printer method declaration (first hit) and lookup. - static Operation *getPrint(Operation *op, StringRef name, - ArrayRef params) { - auto module = op->getParentOfType(); - auto func = module.lookupSymbol(name); - if (func) - return func; - OpBuilder moduleBuilder(module.getBodyRegion()); - return moduleBuilder.create( - op->getLoc(), name, - LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(op->getContext()), - params)); - } - - // Helpers for method names. - Operation *getPrintI64(Operation *op) const { - return getPrint(op, "printI64", IntegerType::get(op->getContext(), 64)); - } - Operation *getPrintU64(Operation *op) const { - return getPrint(op, "printU64", IntegerType::get(op->getContext(), 64)); - } - Operation *getPrintFloat(Operation *op) const { - return getPrint(op, "printF32", Float32Type::get(op->getContext())); - } - Operation *getPrintDouble(Operation *op) const { - return getPrint(op, "printF64", Float64Type::get(op->getContext())); - } - Operation *getPrintOpen(Operation *op) const { - return getPrint(op, "printOpen", {}); - } - Operation *getPrintClose(Operation *op) const { - return getPrint(op, "printClose", {}); - } - Operation *getPrintComma(Operation *op) const { - return getPrint(op, "printComma", {}); - } - Operation *getPrintNewline(Operation *op) const { - return getPrint(op, "printNewline", {}); - } }; /// Progressive lowering of ExtractStridedSliceOp to either: diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp @@ -0,0 +1,89 @@ +//===- FunctionCallUtils.cpp - Utilities for C function calls -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements helper functions to call common simple C functions in +// LLVMIR (e.g. amon others to support printing and debugging). +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Support/LLVM.h" + +using namespace mlir; +using namespace mlir::LLVM; + +/// Generic print function lookupOrCreate helper. +LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(ModuleOp moduleOp, StringRef name, + ArrayRef paramTypes, + Type resultType) { + auto func = moduleOp.lookupSymbol(name); + if (func) + return func; + OpBuilder b(moduleOp.getBodyRegion()); + return b.create( + moduleOp->getLoc(), name, + LLVM::LLVMFunctionType::get(resultType, paramTypes)); +} + +LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintI64Fn(ModuleOp moduleOp) { + return lookupOrCreateFn(moduleOp, CFunctionNames::kPrintI64, + IntegerType::get(moduleOp->getContext(), 64), + + LLVM::LLVMVoidType::get(moduleOp->getContext())); +} +LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(ModuleOp moduleOp) { + return lookupOrCreateFn(moduleOp, CFunctionNames::kPrintU64, + IntegerType::get(moduleOp->getContext(), 64), + + LLVM::LLVMVoidType::get(moduleOp->getContext())); +} +LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(ModuleOp moduleOp) { + return lookupOrCreateFn(moduleOp, CFunctionNames::kPrintF32, + Float32Type::get(moduleOp->getContext()), + + LLVM::LLVMVoidType::get(moduleOp->getContext())); +} +LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF64Fn(ModuleOp moduleOp) { + return lookupOrCreateFn(moduleOp, CFunctionNames::kPrintF64, + Float64Type::get(moduleOp->getContext()), + + LLVM::LLVMVoidType::get(moduleOp->getContext())); +} +LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(ModuleOp moduleOp) { + return lookupOrCreateFn(moduleOp, CFunctionNames::kPrintOpen, {}, + + LLVM::LLVMVoidType::get(moduleOp->getContext())); +} +LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCloseFn(ModuleOp moduleOp) { + return lookupOrCreateFn(moduleOp, CFunctionNames::kPrintClose, {}, + + LLVM::LLVMVoidType::get(moduleOp->getContext())); +} +LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCommaFn(ModuleOp moduleOp) { + return lookupOrCreateFn(moduleOp, CFunctionNames::kPrintComma, {}, + + LLVM::LLVMVoidType::get(moduleOp->getContext())); +} +LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintNewlineFn(ModuleOp moduleOp) { + return lookupOrCreateFn(moduleOp, CFunctionNames::kPrintNewline, {}, + + LLVM::LLVMVoidType::get(moduleOp->getContext())); +} + +Operation::result_range mlir::LLVM::createLLVMCall(OpBuilder &b, Location loc, + LLVM::LLVMFuncOp fn, + ValueRange paramTypes, + ArrayRef resultTypes) { + return b + .create(loc, resultTypes, b.getSymbolRefAttr(fn), + paramTypes) + ->getResults(); +}