diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -47,8 +47,9 @@ /// Convert a function type. The arguments and results are converted one by /// one and results are packed into a wrapped LLVM IR structure type. `result` /// is populated with argument mapping. - LLVM::LLVMType convertFunctionSignature(FunctionType type, bool isVariadic, - SignatureConversion &result); + virtual LLVM::LLVMType convertFunctionSignature(FunctionType type, + bool isVariadic, + SignatureConversion &result); /// Convert a non-empty list of types to be returned from a function into a /// supported LLVM IR type. In particular, if more than one values is @@ -81,6 +82,9 @@ llvm::Module *module; LLVM::LLVMDialect *llvmDialect; + // Extract an LLVM IR dialect type. + LLVM::LLVMType unwrap(Type type); + private: Type convertStandardType(Type type); @@ -120,9 +124,24 @@ // Get the LLVM representation of the index type based on the bitwidth of the // pointer as defined by the data layout of the module. LLVM::LLVMType getIndexType(); +}; - // Extract an LLVM IR dialect type. - LLVM::LLVMType unwrap(Type type); +/// Custom LLVMTypeConverter that overrides `convertFunctionSignature` to +/// replace the type of MemRef function arguments with bare pointer to the +/// MemRef element type. +class BarePtrTypeConverter : public mlir::LLVMTypeConverter { +public: + using LLVMTypeConverter::LLVMTypeConverter; + + /// Converts function signature following LLVMTypeConverter approach but + /// replacing the type of MemRef arguments with a bare LLVM pointer to + /// the MemRef element type. + mlir::LLVM::LLVMType convertFunctionSignature( + mlir::FunctionType type, bool isVariadic, + mlir::LLVMTypeConverter::SignatureConversion &result) override; + +private: + mlir::Type convertMemRefTypeToBarePtr(mlir::MemRefType type); }; /// Helper class to produce LLVM dialect operations extracting or inserting diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h @@ -44,8 +44,8 @@ std::function(MLIRContext *)>; /// Collect a set of patterns to convert memory-related operations from the -/// Standard dialect to the LLVM dialect, excluding the memory-related -/// operations. +/// Standard dialect to the LLVM dialect, excluding non-memory-related +/// operations and FuncOp. void populateStdToLLVMMemoryConversionPatters( LLVMTypeConverter &converter, OwningRewritePatternList &patterns); @@ -54,10 +54,26 @@ void populateStdToLLVMNonMemoryConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns); -/// Collect a set of patterns to convert from the Standard dialect to LLVM. +/// Collect the default pattern to convert a FuncOp to the LLVM dialect. +void populateStdToLLVMDefaultFuncOpConversionPattern( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns); + +/// Collect a set of default patterns to convert from the Standard dialect to +/// LLVM. void populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter, OwningRewritePatternList &patterns); +/// Collect the pattern to convert a FuncOp to the LLVM dialect using the bare +/// pointer calling convertion for MemRef function arguments. +void populateStdToLLVMBarePtrFuncOpConversionPattern( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns); + +/// Collect a set of patterns to convert from the Standard dialect to +/// LLVM using the bare pointer calling convention for MemRef function +/// arguments. +void populateStdToLLVMBarePtrConversionPatterns( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns); + /// Creates a pass to convert the Standard dialect into the LLVMIR dialect. /// By default stdlib malloc/free are used for allocating MemRef payloads. /// Specifying `useAlloca-true` emits stack allocations instead. In the future diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -321,7 +321,7 @@ TypeConverter::SignatureConversion &conversion); /// Replace all the uses of the block argument `from` with value `to`. - void replaceUsesOfBlockArgument(BlockArgument from, Value to); + void replaceUsesOfWith(Value from, Value to); /// Return the converted value that replaces 'key'. Return 'key' if there is /// no such a converted value. diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -667,7 +667,7 @@ BlockArgument arg = block.getArgument(en.index()); Value loaded = rewriter.create(loc, arg); - rewriter.replaceUsesOfBlockArgument(arg, loaded); + rewriter.replaceUsesOfWith(arg, loaded); } } diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -44,6 +44,12 @@ llvm::cl::desc("Replace emission of malloc/free by alloca"), llvm::cl::init(false)); +static llvm::cl::opt clUseBarePtrCallConv( + PASS_NAME "-use-bare-ptr-memref-call-conv", + llvm::cl::desc("Replace FuncOp's MemRef arguments with " + "bare pointers to the MemRef element types"), + llvm::cl::init(false)); + LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx) : llvmDialect(ctx->getRegisteredDialect()) { assert(llvmDialect && "LLVM IR dialect is not registered"); @@ -239,6 +245,60 @@ .Default([](Type) { return Type(); }); } +// Converts function signature following LLVMTypeConverter approach but +// replacing the type of MemRef arguments with a bare LLVM pointer to +// the MemRef element type. +LLVM::LLVMType BarePtrTypeConverter::convertFunctionSignature( + FunctionType type, bool isVariadic, + LLVMTypeConverter::SignatureConversion &result) { + // Convert argument types one by one and check for errors. + for (auto &en : llvm::enumerate(type.getInputs())) { + Type type = en.value(); + Type converted; + if (auto memrefTy = type.dyn_cast()) + converted = convertMemRefTypeToBarePtr(memrefTy) + .dyn_cast_or_null(); + else + converted = convertType(type).dyn_cast_or_null(); + + if (!converted) + return {}; + result.addInputs(en.index(), converted); + } + + SmallVector argTypes; + argTypes.reserve(llvm::size(result.getConvertedTypes())); + for (Type type : result.getConvertedTypes()) + argTypes.push_back(unwrap(type)); + + // If function does not return anything, create the void result type, if it + // returns on element, convert it, otherwise pack the result types into a + // struct. + LLVM::LLVMType resultType = + type.getNumResults() == 0 + ? LLVM::LLVMType::getVoidTy(llvmDialect) + : unwrap(packFunctionResults(type.getResults())); + if (!resultType) + return {}; + return LLVM::LLVMType::getFunctionTy(resultType, argTypes, isVariadic); +} + +// Converts MemRefType to a bare LLVM pointer to the MemRef element type. +Type BarePtrTypeConverter::convertMemRefTypeToBarePtr(MemRefType type) { + int64_t offset; + SmallVector strides; + bool strideSuccess = succeeded(getStridesAndOffset(type, strides, offset)); + assert(strideSuccess && + "Non-strided layout maps must have been normalized away"); + (void)strideSuccess; + + LLVM::LLVMType elementType = unwrap(convertType(type.getElementType())); + if (!elementType) + return {}; + auto ptrTy = elementType.getPointerTo(type.getMemorySpace()); + return ptrTy; +} + LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context, LLVMTypeConverter &lowering_, PatternBenefit benefit) @@ -548,7 +608,84 @@ for (unsigned idx : promotedArgIndices) { BlockArgument arg = firstBlock->getArgument(idx); Value loaded = rewriter.create(funcOp.getLoc(), arg); - rewriter.replaceUsesOfBlockArgument(arg, loaded); + rewriter.replaceUsesOfWith(arg, loaded); + } + } + + rewriter.eraseOp(op); + return matchSuccess(); + } +}; + +// FuncOp conversion that converts MemRef arguments to bare pointers to the type +// of the MemRef. +struct BarePtrFuncOpConversion : public LLVMLegalizationPattern { + using LLVMLegalizationPattern::LLVMLegalizationPattern; + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto funcOp = cast(op); + FunctionType type = funcOp.getType(); + auto funcLoc = funcOp.getLoc(); + + // Store the positions of memref-typed arguments so that we can promote them + // to MemRef descriptor structs at the beginning of the function. + SmallVector, 4> promotedArgIndices; + promotedArgIndices.reserve(type.getNumInputs()); + for (auto en : llvm::enumerate(type.getInputs())) { + if (en.value().isa()) + promotedArgIndices.push_back({en.index(), en.value()}); + } + + // Convert the original function arguments. MemRef types are lowered to bare + // pointers to the MemRef element type. + auto varargsAttr = funcOp.getAttrOfType("std.varargs"); + TypeConverter::SignatureConversion result(funcOp.getNumArguments()); + auto llvmType = lowering.convertFunctionSignature( + funcOp.getType(), varargsAttr && varargsAttr.getValue(), result); + + // Only retain those attributes that are not constructed by build. + SmallVector attributes; + for (const auto &attr : funcOp.getAttrs()) { + if (attr.first.is(SymbolTable::getSymbolAttrName()) || + attr.first.is(impl::getTypeAttrName()) || + attr.first.is("std.varargs")) + continue; + attributes.push_back(attr); + } + + // Create an LLVM function, use external linkage by default until MLIR + // functions have linkage. + auto newFuncOp = + rewriter.create(funcLoc, funcOp.getName(), llvmType, + LLVM::Linkage::External, attributes); + rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), + newFuncOp.end()); + + // Tell the rewriter to convert the region signature. + rewriter.applySignatureConversion(&newFuncOp.getBody(), result); + + // Promote bare pointers from MemRef arguments to a MemRef descriptor struct + // at the beginning of the function so that all the MemRefs in the function + // have a uniform representation. + if (!newFuncOp.getBody().empty()) { + Block *firstBlock = &newFuncOp.getBody().front(); + rewriter.setInsertionPoint(firstBlock, firstBlock->begin()); + for (auto argIdxTypePair : promotedArgIndices) { + // Replace argument with a placeholder (undef), promote argument to a + // MemRef descriptor and replace placeholder with the last instruction + // of the MemRef descriptor. The placeholder is needed to avoid + // replacing argument uses in the MemRef descriptor instructions. + BlockArgument arg = firstBlock->getArgument(argIdxTypePair.first); + Value placeHolder = + rewriter.create(funcLoc, arg.getType()); + rewriter.replaceUsesOfWith(arg, placeHolder); + auto desc = MemRefDescriptor::fromStaticShape( + rewriter, funcLoc, lowering, + argIdxTypePair.second.cast(), arg); + rewriter.replaceUsesOfWith(placeHolder, desc); + placeHolder.getDefiningOp()->erase(); } } @@ -2126,7 +2263,6 @@ // clang-format off patterns.insert< DimOpLowering, - FuncOpConversion, LoadOpLowering, MemRefCastOpLowering, StoreOpLowering, @@ -2139,8 +2275,26 @@ // clang-format on } +void mlir::populateStdToLLVMDefaultFuncOpConversionPattern( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + patterns.insert(*converter.getDialect(), converter); +} + void mlir::populateStdToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + populateStdToLLVMDefaultFuncOpConversionPattern(converter, patterns); + populateStdToLLVMNonMemoryConversionPatterns(converter, patterns); + populateStdToLLVMMemoryConversionPatters(converter, patterns); +} + +void mlir::populateStdToLLVMBarePtrFuncOpConversionPattern( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + patterns.insert(*converter.getDialect(), converter); +} + +void mlir::populateStdToLLVMBarePtrConversionPatterns( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + populateStdToLLVMBarePtrFuncOpConversionPattern(converter, patterns); populateStdToLLVMNonMemoryConversionPatterns(converter, patterns); populateStdToLLVMMemoryConversionPatters(converter, patterns); } @@ -2210,6 +2364,12 @@ return std::make_unique(context); } +/// Create an instance of BarePtrTypeConverter in the given context. +static std::unique_ptr +makeStandardToLLVMBarePtrTypeConverter(MLIRContext *context) { + return std::make_unique(context); +} + namespace { /// A pass converting MLIR operations into the LLVM IR dialect. struct LLVMLoweringPass : public ModulePass { @@ -2274,6 +2434,9 @@ "Standard to the LLVM dialect", [] { return std::make_unique( - clUseAlloca.getValue(), populateStdToLLVMConversionPatterns, - makeStandardToLLVMTypeConverter); + clUseAlloca.getValue(), + clUseBarePtrCallConv ? populateStdToLLVMBarePtrConversionPatterns + : populateStdToLLVMConversionPatterns, + clUseBarePtrCallConv ? makeStandardToLLVMBarePtrTypeConverter + : makeStandardToLLVMTypeConverter); }); diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -861,8 +861,7 @@ return impl->applySignatureConversion(region, conversion); } -void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from, - Value to) { +void ConversionPatternRewriter::replaceUsesOfWith(Value from, Value to) { for (auto &u : from.getUses()) { if (u.getOwner() == to.getDefiningOp()) continue; diff --git a/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir @@ -0,0 +1,9 @@ +// RUN: mlir-opt -convert-std-to-llvm -split-input-file -convert-std-to-llvm-use-bare-ptr-memref-call-conv=1 %s | FileCheck %s --check-prefix=BAREPTR + +// BAREPTR-LABEL: func @check_noalias +// BAREPTR-SAME: [[ARG:%.*]]: !llvm<"float*"> {llvm.noalias = true} +func @check_noalias(%static : memref<2xf32> {llvm.noalias = true}) { + return +} + +// WIP: Move tests with static shapes from convert-memref-ops.mlir here.