diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -77,10 +77,17 @@ OpBuilder &builder); protected: + /// Convert a function argument type to an LLVM types using 'convertType'. + /// MemRef arguments are promoted to a pointer to the converted type. + virtual LLVM::LLVMType convertArgType(Type type); + /// LLVM IR module used to parse/create types. llvm::Module *module; LLVM::LLVMDialect *llvmDialect; + // Extract an LLVM IR dialect type. + LLVM::LLVMType unwrap(Type type); + private: Type convertStandardType(Type type); @@ -120,9 +127,22 @@ // Get the LLVM representation of the index type based on the bitwidth of the // pointer as defined by the data layout of the module. LLVM::LLVMType getIndexType(); +}; - // Extract an LLVM IR dialect type. - LLVM::LLVMType unwrap(Type type); +/// Custom LLVMTypeConverter that overrides `convertFunctionSignature` to +/// replace the type of MemRef function arguments with bare pointer to the +/// MemRef element type. +class BarePtrTypeConverter : public mlir::LLVMTypeConverter { +public: + using LLVMTypeConverter::LLVMTypeConverter; + +private: + /// Convert a function argument type to an LLVM type using 'convertType' + /// except for MemRef arguments. MemRef type is converted to a bare LLVM + /// pointer to the MemRef element type. + LLVM::LLVMType convertArgType(Type type) override; + + mlir::Type convertMemRefTypeToBarePtr(mlir::MemRefType type); }; /// Helper class to produce LLVM dialect operations extracting or inserting diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h @@ -44,8 +44,8 @@ std::function(MLIRContext *)>; /// Collect a set of patterns to convert memory-related operations from the -/// Standard dialect to the LLVM dialect, excluding the memory-related -/// operations. +/// Standard dialect to the LLVM dialect, excluding non-memory-related +/// operations and FuncOp. void populateStdToLLVMMemoryConversionPatters( LLVMTypeConverter &converter, OwningRewritePatternList &patterns); @@ -54,10 +54,26 @@ void populateStdToLLVMNonMemoryConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns); -/// Collect a set of patterns to convert from the Standard dialect to LLVM. +/// Collect the default pattern to convert a FuncOp to the LLVM dialect. +void populateStdToLLVMDefaultFuncOpConversionPattern( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns); + +/// Collect a set of default patterns to convert from the Standard dialect to +/// LLVM. void populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter, OwningRewritePatternList &patterns); +/// Collect the pattern to convert a FuncOp to the LLVM dialect using the bare +/// pointer calling convertion for MemRef function arguments. +void populateStdToLLVMBarePtrFuncOpConversionPattern( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns); + +/// Collect a set of patterns to convert from the Standard dialect to +/// LLVM using the bare pointer calling convention for MemRef function +/// arguments. +void populateStdToLLVMBarePtrConversionPatterns( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns); + /// Creates a pass to convert the Standard dialect into the LLVMIR dialect. /// By default stdlib malloc/free are used for allocating MemRef payloads. /// Specifying `useAlloca-true` emits stack allocations instead. In the future diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -321,7 +321,7 @@ TypeConverter::SignatureConversion &conversion); /// Replace all the uses of the block argument `from` with value `to`. - void replaceUsesOfBlockArgument(BlockArgument from, Value to); + void replaceUsesOfWith(Value from, Value to); /// Return the converted value that replaces 'key'. Return 'key' if there is /// no such a converted value. diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -667,7 +667,7 @@ BlockArgument arg = block.getArgument(en.index()); Value loaded = rewriter.create(loc, arg); - rewriter.replaceUsesOfBlockArgument(arg, loaded); + rewriter.replaceUsesOfWith(arg, loaded); } } diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -44,6 +44,12 @@ llvm::cl::desc("Replace emission of malloc/free by alloca"), llvm::cl::init(false)); +static llvm::cl::opt clUseBarePtrCallConv( + PASS_NAME "-use-bare-ptr-memref-call-conv", + llvm::cl::desc("Replace FuncOp's MemRef arguments with " + "bare pointers to the MemRef element types"), + llvm::cl::init(false)); + LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx) : llvmDialect(ctx->getRegisteredDialect()) { assert(llvmDialect && "LLVM IR dialect is not registered"); @@ -107,6 +113,17 @@ return converted.getPointerTo(); } +// Convert a function argument type to an LLVM types using 'convertType'. MemRef +// arguments are promoted to a pointer to the converted type. +LLVM::LLVMType LLVMTypeConverter::convertArgType(Type type) { + auto converted = convertType(type).dyn_cast_or_null(); + if (!converted) + return {}; + if (type.isa() || type.isa()) + converted = converted.getPointerTo(); + return converted; +} + // Function types are converted to LLVM Function types by recursively converting // argument and result types. If MLIR Function has zero results, the LLVM // Function has one VoidType result. If MLIR Function has more than one result, @@ -117,11 +134,9 @@ // Convert argument types one by one and check for errors. for (auto &en : llvm::enumerate(type.getInputs())) { Type type = en.value(); - auto converted = convertType(type).dyn_cast_or_null(); + auto converted = convertArgType(type).dyn_cast_or_null(); if (!converted) return {}; - if (type.isa() || type.isa()) - converted = converted.getPointerTo(); result.addInputs(en.index(), converted); } @@ -239,6 +254,33 @@ .Default([](Type) { return Type(); }); } +// Convert a function argument type to an LLVM type using 'convertType' except +// for MemRef arguments. MemRef type is converted to a bare LLVM pointer to the +// MemRef element type. +LLVM::LLVMType BarePtrTypeConverter::convertArgType(Type type) { + // TODO: Add support for unranked memref. + if (auto memrefTy = type.dyn_cast()) + return convertMemRefTypeToBarePtr(memrefTy) + .dyn_cast_or_null(); + return convertType(type).dyn_cast_or_null(); +} + +// Converts MemRefType to a bare LLVM pointer to the MemRef element type. +Type BarePtrTypeConverter::convertMemRefTypeToBarePtr(MemRefType type) { + int64_t offset; + SmallVector strides; + bool strideSuccess = succeeded(getStridesAndOffset(type, strides, offset)); + assert(strideSuccess && + "Non-strided layout maps must have been normalized away"); + (void)strideSuccess; + + LLVM::LLVMType elementType = unwrap(convertType(type.getElementType())); + if (!elementType) + return {}; + auto ptrTy = elementType.getPointerTo(type.getMemorySpace()); + return ptrTy; +} + LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context, LLVMTypeConverter &lowering_, PatternBenefit benefit) @@ -494,27 +536,29 @@ LLVM::LLVMDialect &dialect; }; -struct FuncOpConversion : public LLVMLegalizationPattern { - using LLVMLegalizationPattern::LLVMLegalizationPattern; - - PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - auto funcOp = cast(op); - FunctionType type = funcOp.getType(); - - // Store the positions of memref-typed arguments so that we can emit loads - // from them to follow the calling convention. - SmallVector promotedArgIndices; - promotedArgIndices.reserve(type.getNumInputs()); +struct FuncOpConversionBase : public LLVMLegalizationPattern { +protected: + using LLVMLegalizationPattern::LLVMLegalizationPattern; + using UnsignedTypePair = std::pair; + + // Gather the positions and types of memref-typed arguments in a given + // FunctionType. + void getMemRefArgIndicesAndTypes( + FunctionType type, SmallVectorImpl &argsInfo) const { + argsInfo.reserve(type.getNumInputs()); for (auto en : llvm::enumerate(type.getInputs())) { if (en.value().isa() || en.value().isa()) - promotedArgIndices.push_back(en.index()); + argsInfo.push_back({en.index(), en.value()}); } + } - // Convert the original function arguments. Struct arguments are promoted to - // pointer to struct arguments to allow calling external functions with - // various ABIs (e.g. compiled from C/C++ on platform X). + // Convert input FuncOp to a new FuncOp in LLVM dialect by using the + // LLVMTypeConverter provided to this legalization pattern. + LLVM::LLVMFuncOp + convertFuncOpToLLVMFuncOp(FuncOp funcOp, + ConversionPatternRewriter &rewriter) const { + // Convert the original function arguments. They are converted using the + // LLVMTypeConverter provided to this legalization pattern. auto varargsAttr = funcOp.getAttrOfType("std.varargs"); TypeConverter::SignatureConversion result(funcOp.getNumArguments()); auto llvmType = lowering.convertFunctionSignature( @@ -533,22 +577,92 @@ // Create an LLVM function, use external linkage by default until MLIR // functions have linkage. auto newFuncOp = rewriter.create( - op->getLoc(), funcOp.getName(), llvmType, LLVM::Linkage::External, + funcOp.getLoc(), funcOp.getName(), llvmType, LLVM::Linkage::External, attributes); rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), newFuncOp.end()); - // Tell the rewriter to convert the region signature. rewriter.applySignatureConversion(&newFuncOp.getBody(), result); + return newFuncOp; + } +}; + +// FuncOp legalization pattern that converts MemRef arguments to pointers to +// MemRef descriptors (LLVM struct data types) containing all the MemRef type +// information. +struct FuncOpConversion : public FuncOpConversionBase { + using FuncOpConversionBase::FuncOpConversionBase; + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto funcOp = cast(op); + + // Store the positions of memref-typed arguments so that we can emit loads + // from them to follow the calling convention. + SmallVector promotedArgsInfo; + getMemRefArgIndicesAndTypes(funcOp.getType(), promotedArgsInfo); + + auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter); + // Insert loads from memref descriptor pointers in function bodies. if (!newFuncOp.getBody().empty()) { Block *firstBlock = &newFuncOp.getBody().front(); rewriter.setInsertionPoint(firstBlock, firstBlock->begin()); - for (unsigned idx : promotedArgIndices) { - BlockArgument arg = firstBlock->getArgument(idx); + for (const auto &argInfo : promotedArgsInfo) { + BlockArgument arg = firstBlock->getArgument(argInfo.first); Value loaded = rewriter.create(funcOp.getLoc(), arg); - rewriter.replaceUsesOfBlockArgument(arg, loaded); + rewriter.replaceUsesOfWith(arg, loaded); + } + } + + rewriter.eraseOp(op); + return matchSuccess(); + } +}; + +// FuncOp legalization pattern that converts MemRef arguments to bare pointers +// to the MemRef element type. This will impact the calling convention and ABI. +struct BarePtrFuncOpConversion : public FuncOpConversionBase { + using FuncOpConversionBase::FuncOpConversionBase; + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto funcOp = cast(op); + + // Store the positions and type of memref-typed arguments so that we can + // promote them to MemRef descriptor structs at the beginning of the + // function. + SmallVector promotedArgsInfo; + getMemRefArgIndicesAndTypes(funcOp.getType(), promotedArgsInfo); + + auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter); + + // Promote bare pointers from MemRef arguments to a MemRef descriptor struct + // at the beginning of the function so that all the MemRefs in the function + // have a uniform representation. + if (!newFuncOp.getBody().empty()) { + Block *firstBlock = &newFuncOp.getBody().front(); + rewriter.setInsertionPoint(firstBlock, firstBlock->begin()); + auto funcLoc = funcOp.getLoc(); + for (const auto &argInfo : promotedArgsInfo) { + // TODO: Add support for unranked MemRefs. + if (auto memrefType = argInfo.second.dyn_cast()) { + // Replace argument with a placeholder (undef), promote argument to a + // MemRef descriptor and replace placeholder with the last instruction + // of the MemRef descriptor. The placeholder is needed to avoid + // replacing argument uses in the MemRef descriptor instructions. + BlockArgument arg = firstBlock->getArgument(argInfo.first); + Value placeHolder = + rewriter.create(funcLoc, arg.getType()); + rewriter.replaceUsesOfWith(arg, placeHolder); + auto desc = MemRefDescriptor::fromStaticShape( + rewriter, funcLoc, lowering, memrefType, arg); + rewriter.replaceUsesOfWith(placeHolder, desc); + placeHolder.getDefiningOp()->erase(); + } } } @@ -2126,7 +2240,6 @@ // clang-format off patterns.insert< DimOpLowering, - FuncOpConversion, LoadOpLowering, MemRefCastOpLowering, StoreOpLowering, @@ -2139,8 +2252,26 @@ // clang-format on } +void mlir::populateStdToLLVMDefaultFuncOpConversionPattern( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + patterns.insert(*converter.getDialect(), converter); +} + void mlir::populateStdToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + populateStdToLLVMDefaultFuncOpConversionPattern(converter, patterns); + populateStdToLLVMNonMemoryConversionPatterns(converter, patterns); + populateStdToLLVMMemoryConversionPatters(converter, patterns); +} + +void mlir::populateStdToLLVMBarePtrFuncOpConversionPattern( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + patterns.insert(*converter.getDialect(), converter); +} + +void mlir::populateStdToLLVMBarePtrConversionPatterns( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + populateStdToLLVMBarePtrFuncOpConversionPattern(converter, patterns); populateStdToLLVMNonMemoryConversionPatterns(converter, patterns); populateStdToLLVMMemoryConversionPatters(converter, patterns); } @@ -2210,6 +2341,12 @@ return std::make_unique(context); } +/// Create an instance of BarePtrTypeConverter in the given context. +static std::unique_ptr +makeStandardToLLVMBarePtrTypeConverter(MLIRContext *context) { + return std::make_unique(context); +} + namespace { /// A pass converting MLIR operations into the LLVM IR dialect. struct LLVMLoweringPass : public ModulePass { @@ -2274,6 +2411,9 @@ "Standard to the LLVM dialect", [] { return std::make_unique( - clUseAlloca.getValue(), populateStdToLLVMConversionPatterns, - makeStandardToLLVMTypeConverter); + clUseAlloca.getValue(), + clUseBarePtrCallConv ? populateStdToLLVMBarePtrConversionPatterns + : populateStdToLLVMConversionPatterns, + clUseBarePtrCallConv ? makeStandardToLLVMBarePtrTypeConverter + : makeStandardToLLVMTypeConverter); }); diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -861,8 +861,7 @@ return impl->applySignatureConversion(region, conversion); } -void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from, - Value to) { +void ConversionPatternRewriter::replaceUsesOfWith(Value from, Value to) { for (auto &u : from.getUses()) { if (u.getOwner() == to.getDefiningOp()) continue; diff --git a/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir @@ -0,0 +1,9 @@ +// RUN: mlir-opt -convert-std-to-llvm -split-input-file -convert-std-to-llvm-use-bare-ptr-memref-call-conv=1 %s | FileCheck %s --check-prefix=BAREPTR + +// BAREPTR-LABEL: func @check_noalias +// BAREPTR-SAME: [[ARG:%.*]]: !llvm<"float*"> {llvm.noalias = true} +func @check_noalias(%static : memref<2xf32> {llvm.noalias = true}) { + return +} + +// WIP: Move tests with static shapes from convert-memref-ops.mlir here.