diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h @@ -0,0 +1,63 @@ +//===- FunctionCallUtils.h - Utilities for C function calls -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares helper functions to call common simple C functions in +// LLVMIR (e.g. among others to support printing and debugging). +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LLVMIR_FUNCTIONCALLUTILS_H_ +#define MLIR_DIALECT_LLVMIR_FUNCTIONCALLUTILS_H_ + +#include "mlir/IR/Operation.h" +#include "mlir/Support/LLVM.h" + +namespace mlir { +class Location; +class ModuleOp; +class OpBuilder; +class Operation; +class Type; +class ValueRange; + +namespace LLVM { +class LLVMFuncOp; + +/// Helper functions to lookup or create the declaration for commonly used +/// external C function calls. Such ops can then be invoked by creating a CallOp +/// with the proper arguments via `createLLVMCall`. +/// The list of functions provided here must be implemented separately (e.g. as +/// part of a support runtime library or as part of the libc). +LLVM::LLVMFuncOp lookupOrCreatePrintI64Fn(ModuleOp moduleOp); +LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(ModuleOp moduleOp); +LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(ModuleOp moduleOp); +LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(ModuleOp moduleOp); +LLVM::LLVMFuncOp lookupOrCreatePrintOpenFn(ModuleOp moduleOp); +LLVM::LLVMFuncOp lookupOrCreatePrintCloseFn(ModuleOp moduleOp); +LLVM::LLVMFuncOp lookupOrCreatePrintCommaFn(ModuleOp moduleOp); +LLVM::LLVMFuncOp lookupOrCreatePrintNewlineFn(ModuleOp moduleOp); +LLVM::LLVMFuncOp lookupOrCreateMallocFn(ModuleOp moduleOp, Type indexType); +LLVM::LLVMFuncOp lookupOrCreateAlignedAllocFn(ModuleOp moduleOp, + Type indexType); +LLVM::LLVMFuncOp lookupOrCreateFreeFn(ModuleOp moduleOp); + +/// Create a FuncOp with signature `resultType`(`paramTypes`)` and name `name`. +LLVM::LLVMFuncOp lookupOrCreateFn(ModuleOp moduleOp, StringRef name, + ArrayRef paramTypes = {}, + Type resultType = {}); + +/// Helper wrapper to create a call to `fn` with `args` and `resultTypes`. +Operation::result_range createLLVMCall(OpBuilder &b, Location loc, + LLVM::LLVMFuncOp fn, + ValueRange args = {}, + ArrayRef resultTypes = {}); + +} // namespace LLVM +} // namespace mlir + +#endif // MLIR_DIALECT_LLVMIR_FUNCTIONCALLUTILS_H_ diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -14,6 +14,7 @@ #include "../PassDetail.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" +#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Attributes.h" @@ -1793,31 +1794,6 @@ return rewriter.create(loc, bumped, mod); } - // Creates a call to an allocation function with params and casts the - // resulting void pointer to ptrType. - Value createAllocCall(Location loc, StringRef name, Type ptrType, - ArrayRef params, ModuleOp module, - ConversionPatternRewriter &rewriter) const { - SmallVector paramTypes; - auto allocFuncOp = module.lookupSymbol(name); - if (!allocFuncOp) { - for (Value param : params) - paramTypes.push_back(param.getType()); - auto allocFuncType = - LLVM::LLVMFunctionType::get(getVoidPtrType(), paramTypes); - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(module.getBody()); - allocFuncOp = rewriter.create(rewriter.getUnknownLoc(), - name, allocFuncType); - } - auto allocFuncSymbol = rewriter.getSymbolRefAttr(allocFuncOp); - auto allocatedPtr = rewriter - .create(loc, getVoidPtrType(), - allocFuncSymbol, params) - .getResult(0); - return rewriter.create(loc, ptrType, allocatedPtr); - } - /// Allocates the underlying buffer. Returns the allocated pointer and the /// aligned pointer. virtual std::tuple @@ -1909,9 +1885,12 @@ // Allocate the underlying buffer and store a pointer to it in the MemRef // descriptor. Type elementPtrType = this->getElementPtrType(memRefType); + auto allocFuncOp = LLVM::lookupOrCreateMallocFn( + allocOp->getParentOfType(), getIndexType()); + auto results = createLLVMCall(rewriter, loc, allocFuncOp, {sizeBytes}, + getVoidPtrType()); Value allocatedPtr = - createAllocCall(loc, "malloc", elementPtrType, {sizeBytes}, - allocOp->getParentOfType(), rewriter); + rewriter.create(loc, elementPtrType, results[0]); Value alignedPtr = allocatedPtr; if (alignment) { @@ -1991,9 +1970,13 @@ sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment); Type elementPtrType = this->getElementPtrType(memRefType); - Value allocatedPtr = createAllocCall( - loc, "aligned_alloc", elementPtrType, {allocAlignment, sizeBytes}, - allocOp->getParentOfType(), rewriter); + auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn( + allocOp->getParentOfType(), getIndexType()); + auto results = + createLLVMCall(rewriter, loc, allocFuncOp, {allocAlignment, sizeBytes}, + getVoidPtrType()); + Value allocatedPtr = + rewriter.create(loc, elementPtrType, results[0]); return std::make_tuple(allocatedPtr, allocatedPtr); } @@ -2056,31 +2039,17 @@ // Get frequently used types. MLIRContext *context = builder.getContext(); - auto voidType = LLVM::LLVMVoidType::get(context); Type voidPtrType = LLVM::LLVMPointerType::get(IntegerType::get(context, 8)); auto i1Type = IntegerType::get(context, 1); Type indexType = typeConverter.getIndexType(); // Find the malloc and free, or declare them if necessary. auto module = builder.getInsertionPoint()->getParentOfType(); - auto mallocFunc = module.lookupSymbol("malloc"); - if (!mallocFunc && toDynamic) { - OpBuilder::InsertionGuard guard(builder); - builder.setInsertionPointToStart(module.getBody()); - mallocFunc = builder.create( - builder.getUnknownLoc(), "malloc", - LLVM::LLVMFunctionType::get(voidPtrType, llvm::makeArrayRef(indexType), - /*isVarArg=*/false)); - } - auto freeFunc = module.lookupSymbol("free"); - if (!freeFunc && !toDynamic) { - OpBuilder::InsertionGuard guard(builder); - builder.setInsertionPointToStart(module.getBody()); - freeFunc = builder.create( - builder.getUnknownLoc(), "free", - LLVM::LLVMFunctionType::get(voidType, llvm::makeArrayRef(voidPtrType), - /*isVarArg=*/false)); - } + LLVM::LLVMFuncOp freeFunc, mallocFunc; + if (toDynamic) + mallocFunc = LLVM::lookupOrCreateMallocFn(module, indexType); + if (!toDynamic) + freeFunc = LLVM::lookupOrCreateFreeFn(module); // Initialize shared constants. Value zero = @@ -2217,17 +2186,7 @@ DeallocOp::Adaptor transformed(operands); // Insert the `free` declaration if it is not already present. - auto freeFunc = - op->getParentOfType().lookupSymbol("free"); - if (!freeFunc) { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart( - op->getParentOfType().getBody()); - freeFunc = rewriter.create( - rewriter.getUnknownLoc(), "free", - LLVM::LLVMFunctionType::get(getVoidType(), getVoidPtrType())); - } - + auto freeFunc = LLVM::lookupOrCreateFreeFn(op->getParentOfType()); MemRefDescriptor memref(transformed.memref()); Value casted = rewriter.create( op.getLoc(), getVoidPtrType(), diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -10,6 +10,7 @@ #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" +#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Vector/VectorOps.h" @@ -1311,11 +1312,14 @@ Type eltType = vectorType ? vectorType.getElementType() : printType; Operation *printer; if (eltType.isF32()) { - printer = getPrintFloat(printOp); + printer = + LLVM::lookupOrCreatePrintF32Fn(printOp->getParentOfType()); } else if (eltType.isF64()) { - printer = getPrintDouble(printOp); + printer = + LLVM::lookupOrCreatePrintF64Fn(printOp->getParentOfType()); } else if (eltType.isIndex()) { - printer = getPrintU64(printOp); + printer = + LLVM::lookupOrCreatePrintU64Fn(printOp->getParentOfType()); } else if (auto intTy = eltType.dyn_cast()) { // Integers need a zero or sign extension on the operand // (depending on the source type) as well as a signed or @@ -1325,7 +1329,8 @@ if (width <= 64) { if (width < 64) conversion = PrintConversion::ZeroExt64; - printer = getPrintU64(printOp); + printer = LLVM::lookupOrCreatePrintU64Fn( + printOp->getParentOfType()); } else { return failure(); } @@ -1338,7 +1343,8 @@ conversion = PrintConversion::ZeroExt64; else if (width < 64) conversion = PrintConversion::SignExt64; - printer = getPrintI64(printOp); + printer = LLVM::lookupOrCreatePrintI64Fn( + printOp->getParentOfType()); } else { return failure(); } @@ -1351,7 +1357,9 @@ int64_t rank = vectorType ? vectorType.getRank() : 0; emitRanks(rewriter, printOp, adaptor.source(), vectorType, printer, rank, conversion); - emitCall(rewriter, printOp->getLoc(), getPrintNewline(printOp)); + emitCall(rewriter, printOp->getLoc(), + LLVM::lookupOrCreatePrintNewlineFn( + printOp->getParentOfType())); rewriter.eraseOp(printOp); return success(); } @@ -1386,8 +1394,10 @@ return; } - emitCall(rewriter, loc, getPrintOpen(op)); - Operation *printComma = getPrintComma(op); + emitCall(rewriter, loc, + LLVM::lookupOrCreatePrintOpenFn(op->getParentOfType())); + Operation *printComma = + LLVM::lookupOrCreatePrintCommaFn(op->getParentOfType()); int64_t dim = vectorType.getDimSize(0); for (int64_t d = 0; d < dim; ++d) { auto reducedType = @@ -1401,7 +1411,8 @@ if (d != dim - 1) emitCall(rewriter, loc, printComma); } - emitCall(rewriter, loc, getPrintClose(op)); + emitCall(rewriter, loc, + LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType())); } // Helper to emit a call. @@ -1410,46 +1421,6 @@ rewriter.create(loc, TypeRange(), rewriter.getSymbolRefAttr(ref), params); } - - // Helper for printer method declaration (first hit) and lookup. - static Operation *getPrint(Operation *op, StringRef name, - ArrayRef params) { - auto module = op->getParentOfType(); - auto func = module.lookupSymbol(name); - if (func) - return func; - OpBuilder moduleBuilder(module.getBodyRegion()); - return moduleBuilder.create( - op->getLoc(), name, - LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(op->getContext()), - params)); - } - - // Helpers for method names. - Operation *getPrintI64(Operation *op) const { - return getPrint(op, "printI64", IntegerType::get(op->getContext(), 64)); - } - Operation *getPrintU64(Operation *op) const { - return getPrint(op, "printU64", IntegerType::get(op->getContext(), 64)); - } - Operation *getPrintFloat(Operation *op) const { - return getPrint(op, "printF32", Float32Type::get(op->getContext())); - } - Operation *getPrintDouble(Operation *op) const { - return getPrint(op, "printF64", Float64Type::get(op->getContext())); - } - Operation *getPrintOpen(Operation *op) const { - return getPrint(op, "printOpen", {}); - } - Operation *getPrintClose(Operation *op) const { - return getPrint(op, "printClose", {}); - } - Operation *getPrintComma(Operation *op) const { - return getPrint(op, "printComma", {}); - } - Operation *getPrintNewline(Operation *op) const { - return getPrint(op, "printNewline", {}); - } }; /// Progressive lowering of ExtractStridedSliceOp to either: diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt --- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt @@ -1,6 +1,7 @@ add_subdirectory(Transforms) add_mlir_dialect_library(MLIRLLVMIR + IR/FunctionCallUtils.cpp IR/LLVMDialect.cpp IR/LLVMTypes.cpp IR/LLVMTypeSyntax.cpp diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp @@ -0,0 +1,125 @@ +//===- FunctionCallUtils.cpp - Utilities for C function calls -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements helper functions to call common simple C functions in +// LLVMIR (e.g. amon others to support printing and debugging). +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Support/LLVM.h" + +using namespace mlir; +using namespace mlir::LLVM; + +/// Helper functions to lookup or create the declaration for commonly used +/// external C function calls. The list of functions provided here must be +/// implemented separately (e.g. as part of a support runtime library or as +/// part of the libc). +static constexpr llvm::StringRef kPrintI64 = "printI64"; +static constexpr llvm::StringRef kPrintU64 = "printU64"; +static constexpr llvm::StringRef kPrintF32 = "printF32"; +static constexpr llvm::StringRef kPrintF64 = "printF64"; +static constexpr llvm::StringRef kPrintOpen = "printOpen"; +static constexpr llvm::StringRef kPrintClose = "printClose"; +static constexpr llvm::StringRef kPrintComma = "printComma"; +static constexpr llvm::StringRef kPrintNewline = "printNewline"; +static constexpr llvm::StringRef kMalloc = "malloc"; +static constexpr llvm::StringRef kAlignedAlloc = "aligned_alloc"; +static constexpr llvm::StringRef kFree = "free"; + +/// Generic print function lookupOrCreate helper. +LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(ModuleOp moduleOp, StringRef name, + ArrayRef paramTypes, + Type resultType) { + auto func = moduleOp.lookupSymbol(name); + if (func) + return func; + OpBuilder b(moduleOp.getBodyRegion()); + return b.create( + moduleOp->getLoc(), name, + LLVM::LLVMFunctionType::get(resultType, paramTypes)); +} + +LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintI64Fn(ModuleOp moduleOp) { + return lookupOrCreateFn(moduleOp, kPrintI64, + IntegerType::get(moduleOp->getContext(), 64), + LLVM::LLVMVoidType::get(moduleOp->getContext())); +} + +LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(ModuleOp moduleOp) { + return lookupOrCreateFn(moduleOp, kPrintU64, + IntegerType::get(moduleOp->getContext(), 64), + LLVM::LLVMVoidType::get(moduleOp->getContext())); +} + +LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(ModuleOp moduleOp) { + return lookupOrCreateFn(moduleOp, kPrintF32, + Float32Type::get(moduleOp->getContext()), + LLVM::LLVMVoidType::get(moduleOp->getContext())); +} + +LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF64Fn(ModuleOp moduleOp) { + return lookupOrCreateFn(moduleOp, kPrintF64, + Float64Type::get(moduleOp->getContext()), + LLVM::LLVMVoidType::get(moduleOp->getContext())); +} + +LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(ModuleOp moduleOp) { + return lookupOrCreateFn(moduleOp, kPrintOpen, {}, + LLVM::LLVMVoidType::get(moduleOp->getContext())); +} + +LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCloseFn(ModuleOp moduleOp) { + return lookupOrCreateFn(moduleOp, kPrintClose, {}, + LLVM::LLVMVoidType::get(moduleOp->getContext())); +} + +LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCommaFn(ModuleOp moduleOp) { + return lookupOrCreateFn(moduleOp, kPrintComma, {}, + LLVM::LLVMVoidType::get(moduleOp->getContext())); +} + +LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintNewlineFn(ModuleOp moduleOp) { + return lookupOrCreateFn(moduleOp, kPrintNewline, {}, + LLVM::LLVMVoidType::get(moduleOp->getContext())); +} + +LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateMallocFn(ModuleOp moduleOp, + Type indexType) { + return LLVM::lookupOrCreateFn( + moduleOp, kMalloc, indexType, + LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8))); +} + +LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateAlignedAllocFn(ModuleOp moduleOp, + Type indexType) { + return LLVM::lookupOrCreateFn( + moduleOp, kAlignedAlloc, {indexType, indexType}, + LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8))); +} + +LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(ModuleOp moduleOp) { + return LLVM::lookupOrCreateFn( + moduleOp, kFree, + LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8)), + LLVM::LLVMVoidType::get(moduleOp->getContext())); +} + +Operation::result_range mlir::LLVM::createLLVMCall(OpBuilder &b, Location loc, + LLVM::LLVMFuncOp fn, + ValueRange paramTypes, + ArrayRef resultTypes) { + return b + .create(loc, resultTypes, b.getSymbolRefAttr(fn), + paramTypes) + ->getResults(); +}