diff --git a/mlir/docs/ConversionToLLVMDialect.md b/mlir/docs/ConversionToLLVMDialect.md --- a/mlir/docs/ConversionToLLVMDialect.md +++ b/mlir/docs/ConversionToLLVMDialect.md @@ -248,58 +248,123 @@ ### Calling Convention for `memref` -For function _arguments_ of `memref` type, ranked or unranked, the type of the -argument is a _pointer_ to the memref descriptor type defined above. The caller -of such function is required to store the descriptor in memory and guarantee -that the storage remains live until the callee returns. The caller can than pass -the pointer to that memory as function argument. The callee loads from the -pointers it was passed as arguments in the entry block of the function, making -the descriptor passed in as argument available for use similarly to -ocally-defined descriptors. +Function _arguments_ of `memref` type, ranked or unranked, are _expanded_ into a +list of arguments of non-aggregate types that the memref descriptor defined +above comprises. That is, the outer struct type and the inner array types are +replaced with individual arguments. This convention is implemented in the conversion of `std.func` and `std.call` to -the LLVM dialect. Conversions from other dialects should take it into account. -The motivation for this convention is to simplify the ABI for interfacing with -other LLVM modules, in particular those generated from C sources, while avoiding -platform-specific aspects until MLIR has a proper ABI modeling. +the LLVM dialect, with the former unpacking the descriptor into a set of +individual values and the latter packing those values back into a descriptor so +as to make it transparently usable by other operations. Conversions from other +dialects should take this convention into account. -Example: +This specific convention is motivated by the necessity to specify alignment and +aliasing attributes on the raw pointers underpinning the memref. + +Examples: ```mlir +func @foo(%arg0: memref) -> () { + "use"(%arg0) : (memref) -> () + return +} -func @foo(memref) -> () { - %c0 = constant 0 : index - load %arg0[%c0] : memref +// Gets converted to the following. + +llvm.func @foo(%arg0: !llvm<"float*">, // Allocated pointer. + %arg1: !llvm<"float*">, // Aligned pointer. + %arg2: !llvm.i64, // Offset. + %arg3: !llvm.i64, // Size in dim 0. + %arg4: !llvm.i64) { // Stride in dim 0. + // Populate memref descriptor structure. + %0 = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> + %1 = llvm.insertvalue %arg0, %0[0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> + %2 = llvm.insertvalue %arg1, %1[1] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> + %3 = llvm.insertvalue %arg2, %2[2] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> + %4 = llvm.insertvalue %arg3, %3[3, 0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> + %5 = llvm.insertvalue %arg4, %4[4, 0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> + + // Descriptor is now usable as a single value. + "use"(%5) : (!llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">) -> () + llvm.return +} +``` + +```mlir +func @bar() { + %0 = "get"() : () -> (memref) + call @foo(%0) : (memref) -> () return } -func @bar(%arg0: index) { - %0 = alloc(%arg0) : memref - call @foo(%0) : (memref)-> () +// Gets converted to the following. + +llvm.func @bar() { + %0 = "get"() : () -> !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> + + // Unpack the memref descriptor. + %1 = llvm.extractvalue %0[0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> + %2 = llvm.extractvalue %0[1] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> + %3 = llvm.extractvalue %0[2] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> + %4 = llvm.extractvalue %0[3, 0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> + %5 = llvm.extractvalue %0[4, 0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> + + // Pass individual values to the callee. + llvm.call @foo(%1, %2, %3, %4, %5) : (!llvm<"float*">, !llvm<"float*">, !llvm.i64, !llvm.i64, !llvm.i64) -> () + llvm.return +} + +``` + +For **unranked** memrefs, the list of function arguments always contains two +elements, same as the unranked memref descriptor: an integer rank, and a +type-erased (`!llvm<"i8*">`) pointer to the ranked memref descriptor. Note that +while the _calling convention_ does not require stack allocation, _casting_ to +unranked memref does since one cannot take an address of an SSA value containing +the ranked memref. The caller is in charge of ensuring the thread safety and +eventually removing unnecessary stack allocations in cast operations. + +Example + +```mlir +llvm.func @foo(%arg0: memref<*xf32>) -> () { + "use"(%arg0) : (memref<*xf32>) -> () return } -// Gets converted to the following IR. -// Accepts a pointer to the memref descriptor. -llvm.func @foo(!llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">) { - // Loads the descriptor so that it can be used similarly to locally - // created descriptors. - %0 = llvm.load %arg0 : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*"> +// Gets converted to the following. + +llvm.func @foo(%arg0: !llvm.i64 // Rank. + %arg1: !llvm<"i8*">) { // Type-erased pointer to descriptor. + // Pack the unranked memref descriptor. + %0 = llvm.mlir.undef : !llvm<"{ i64, i8* }"> + %1 = llvm.insertvalue %arg0, %0[0] : !llvm<"{ i64, i8* }"> + %2 = llvm.insertvalue %arg1, %1[1] : !llvm<"{ i64, i8* }"> + + "use"(%2) : (!llvm<"{ i64, i8* }">) -> () + llvm.return +} +``` + +```mlir +llvm.func @bar() { + %0 = "get"() : () -> (memref<*xf32>) + call @foo(%0): (memref<*xf32>) -> () + return } -llvm.func @bar(%arg0: !llvm.i64) { - // ... Allocation ... - // Definition of the descriptor. - %7 = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> - // ... Filling in the descriptor ... - %14 = // The final value of the allocated descriptor. - // Allocate the memory for the descriptor and store it. - %15 = llvm.mlir.constant(1 : index) : !llvm.i64 - %16 = llvm.alloca %15 x !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> - : (!llvm.i64) -> !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*"> - llvm.store %14, %16 : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*"> - // Pass the pointer to the function. - llvm.call @foo(%16) : (!llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">) -> () +// Gets converted to the following. + +llvm.func @bar() { + %0 = "get"() : () -> (!llvm<"{ i64, i8* }">) + + // Unpack the memref descriptor. + %1 = llvm.extractvalue %0[0] : !llvm<"{ i64, i8* }"> + %2 = llvm.extractvalue %0[1] : !llvm<"{ i64, i8* }"> + + // Pass individual values to the callee. + llvm.call @foo(%1, %2) : (!llvm.i64, !llvm<"i8*">) llvm.return } ``` @@ -307,6 +372,141 @@ *This convention may or may not apply if the conversion of MemRef types is overridden by the user.* +### C-compatible wrapper emission + +In practical cases, it may be desirable to have externally-facing functions +with a single attribute corresponding to a MemRef argument. When interfacing +with LLVM IR produced from C, the code needs to respect the corresponding +calling convention. The conversion to the LLVM dialect provides an option to +generate wrapper functions that take memref descriptors as pointers-to-struct +compatible with data types produced by Clang when compiling C sources. + +More specifically, a memref argument is converted into a pointer-to-struct +argument of type `{T*, T*, i64, i64[N], i64[N]}*` in the wrapper function, where +`T` is the converted element type and `N` is the memref rank. This type is +compatible with that produced by Clang for the following C++ structure template +instantiations or their equivalents in C. + +```cpp +template +struct MemRefDescriptor { + T *allocated; + T *aligned; + intptr_t offset; + intptr_t sizes[N]; + intptr_t stides[N]; +}; +``` + +If enabled, the option will do the following. For _external_ functions declared +in the MLIR module. + +1. Declare a new function `_mlir_ciface_` where memref arguments + are converted to pointer-to-struct and the remaining arguments are converted + as usual. +1. Add a body to the original function (making it non-external) that + 1. allocates a memref descriptor, + 1. populates it, and + 1. passes the pointer to it into the newly declared interface function + 1. collects the result of the call and returns it to the caller. + +For (non-external) functions defined in the MLIR module. + +1. Define a new function `_mlir_ciface_` where memref arguments + are converted to pointer-to-struct and the remaining arguments are converted + as usual. +1. Populate the body of the newly defined function with IR that + 1. loads descriptors from pointers; + 1. unpacks descriptor into individual non-aggregate values; + 1. passes these values into the original function; + 1. collects the result of the call and returns it to the caller. + +Examples: + +```mlir + +func @qux(%arg0: memref) + +// Gets converted into the following. + +// Function with unpacked arguments. +llvm.func @qux(%arg0: !llvm<"float*">, %arg1: !llvm<"float*">, %arg2: !llvm.i64, + %arg3: !llvm.i64, %arg4: !llvm.i64, %arg5: !llvm.i64, + %arg6: !llvm.i64) { + // Populate memref descriptor (as per calling convention). + %0 = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + %1 = llvm.insertvalue %arg0, %0[0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + %2 = llvm.insertvalue %arg1, %1[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + %3 = llvm.insertvalue %arg2, %2[2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + %4 = llvm.insertvalue %arg3, %3[3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + %5 = llvm.insertvalue %arg5, %4[4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + %6 = llvm.insertvalue %arg4, %5[3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + %7 = llvm.insertvalue %arg6, %6[4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + + // Store the descriptor in a stack-allocated space. + %8 = llvm.mlir.constant(1 : index) : !llvm.i64 + %9 = llvm.alloca %8 x !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + : (!llvm.i64) -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> + llvm.store %7, %9 : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> + + // Call the interface function. + llvm.call @_mlir_ciface_qux(%9) : (!llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">) -> () + + // The stored descriptor will be freed on return. + llvm.return +} + +// Interface function. +llvm.func @_mlir_ciface_qux(!llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">) +``` + +```mlir +func @foo(%arg0: memref) { + return +} + +// Gets converted into the following. + +// Function with unpacked arguments. +llvm.func @foo(%arg0: !llvm<"float*">, %arg1: !llvm<"float*">, %arg2: !llvm.i64, + %arg3: !llvm.i64, %arg4: !llvm.i64, %arg5: !llvm.i64, + %arg6: !llvm.i64) { + llvm.return +} + +// Interface function callable from C. +llvm.func @_mlir_ciface_foo(%arg0: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">) { + // Load the descriptor. + %0 = llvm.load %arg0 : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> + + // Unpack the descriptor as per calling convention. + %1 = llvm.extractvalue %0[0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + %2 = llvm.extractvalue %0[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + %3 = llvm.extractvalue %0[2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + %4 = llvm.extractvalue %0[3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + %5 = llvm.extractvalue %0[3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + %6 = llvm.extractvalue %0[4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + %7 = llvm.extractvalue %0[4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + llvm.call @foo(%1, %2, %3, %4, %5, %6, %7) + : (!llvm<"float*">, !llvm<"float*">, !llvm.i64, !llvm.i64, !llvm.i64, + !llvm.i64, !llvm.i64) -> () + llvm.return +} +``` + +Rationale: Introducing auxiliary functions for C-compatible interfaces is +preferred to modifying the calling convention since it will minimize the effect +of C compatibility on intra-module calls or calls between MLIR-generated +functions. In particular, when calling external functions from an MLIR module in +a (parallel) loop, the fact of storing a memref descriptor on stack can lead to +stack exhaustion and/or concurrent access to the same address. Auxiliary +interface function serves as an allocation scope in this case. Furthermore, when +targeting accelerators with separate memory spaces such as GPUs, stack-allocated +descriptors passed by pointer would have to be transferred to the device memory, +which introduces significant overhead. In such situations, auxiliary interface +functions are executed on host and only pass the values through device function +invocation mechanism. + ## Repeated Successor Removal Since the goal of the LLVM IR dialect is to reflect LLVM IR in MLIR, the dialect diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -36,8 +36,8 @@ /// Set of callbacks that allows the customization of LLVMTypeConverter. struct LLVMTypeConverterCustomization { - using CustomCallback = - std::function; + using CustomCallback = std::function &)>; /// Customize the type conversion of function arguments. CustomCallback funcArgConverter; @@ -47,19 +47,26 @@ }; /// Callback to convert function argument types. It converts a MemRef function -/// argument to a struct that contains the descriptor information. Converted -/// types are promoted to a pointer to the converted type. -LLVM::LLVMType structFuncArgTypeConverter(LLVMTypeConverter &converter, - Type type); +/// argument to a list of non-aggregate types containing descriptor +/// information, and an UnrankedmemRef function argument to a list containing +/// the rank and a pointer to a descriptor struct. +LogicalResult structFuncArgTypeConverter(LLVMTypeConverter &converter, + Type type, + SmallVectorImpl &result); /// Callback to convert function argument types. It converts MemRef function -/// arguments to bare pointers to the MemRef element type. Converted types are -/// not promoted to pointers. -LLVM::LLVMType barePtrFuncArgTypeConverter(LLVMTypeConverter &converter, - Type type); +/// arguments to bare pointers to the MemRef element type. +LogicalResult barePtrFuncArgTypeConverter(LLVMTypeConverter &converter, + Type type, + SmallVectorImpl &result); /// Conversion from types in the Standard dialect to the LLVM IR dialect. class LLVMTypeConverter : public TypeConverter { + /// Give structFuncArgTypeConverter access to memref-specific functions. + friend LogicalResult + structFuncArgTypeConverter(LLVMTypeConverter &converter, Type type, + SmallVectorImpl &result); + public: using TypeConverter::convertType; @@ -107,6 +114,15 @@ Value promoteOneMemRefDescriptor(Location loc, Value operand, OpBuilder &builder); + /// Converts the function type to a C-compatible format, in particular using + /// pointers to memref descriptors for arguments. + LLVM::LLVMType convertFunctionTypeCWrapper(FunctionType type); + + /// Creates descriptor structs from individual values constituting them. + Operation *materializeConversion(PatternRewriter &rewriter, Type type, + ArrayRef values, + Location loc) override; + protected: /// LLVM IR module used to parse/create types. llvm::Module *module; @@ -133,14 +149,34 @@ // by LLVM. Type convertFloatType(FloatType type); - // Convert a memref type into an LLVM type that captures the relevant data. - // For statically-shaped memrefs, the resulting type is a pointer to the - // (converted) memref element type. For dynamically-shaped memrefs, the - // resulting type is an LLVM structure type that contains: - // 1. a pointer to the (converted) memref element type - // 2. as many index types as memref has dynamic dimensions. + /// Convert a memref type into an LLVM type that captures the relevant data. Type convertMemRefType(MemRefType type); + /// Convert a memref type into a list of non-aggregate LLVM IR types that + /// contain all the relevant data. In particular, the list will contain: + /// - two pointers to the memref element type, followed by + /// - an integer offset, followed by + /// - one integer size per dimension of the memref, followed by + /// - one integer stride per dimension of the memref. + /// For example, memref is converted to the following list: + /// - `!llvm<"float*">` (allocated pointer), + /// - `!llvm<"float*">` (aligned pointer), + /// - `!llvm.i64` (offset), + /// - `!llvm.i64`, `!llvm.i64` (sizes), + /// - `!llvm.i64`, `!llvm.i64` (strides). + /// These types can be recomposed to a memref descriptor struct. + SmallVector convertMemRefSignature(MemRefType type); + + /// Convert an unranked memref type into a list of non-aggregate LLVM IR types + /// that contain all the relevant data. In particular, this list contains: + /// - an integer rank, followed by + /// - a pointer to the memref descriptor struct. + /// For example, memref<*xf32> is converted to the following list: + /// !llvm.i64 (rank) + /// !llvm<"i8*"> (type-erased pointer). + /// These types can be recomposed to a unranked memref descriptor struct. + SmallVector convertUnrankedMemRefSignature(); + // Convert an unranked memref type to an LLVM type that captures the // runtime rank and a pointer to the static ranked memref desc Type convertUnrankedMemRefType(UnrankedMemRefType type); @@ -180,6 +216,7 @@ /// Builds IR to set a value in the struct at position pos void setPtr(OpBuilder &builder, Location loc, unsigned pos, Value ptr); }; + /// Helper class to produce LLVM dialect operations extracting or inserting /// elements of a MemRef descriptor. Wraps a Value pointing to the descriptor. /// The Value may be null, in which case none of the operations are valid. @@ -234,11 +271,63 @@ /// Returns the (LLVM) type this descriptor points to. LLVM::LLVMType getElementType(); + /// Builds IR populating a MemRef descriptor structure from a list of + /// individual values composing that descriptor, in the following order: + /// - allocated pointer; + /// - aligned pointer; + /// - offset; + /// - sizes; + /// - shapes; + /// where is the MemRef rank as provided in `type`. + static Value pack(OpBuilder &builder, Location loc, + LLVMTypeConverter &converter, MemRefType type, + ValueRange values); + + /// Builds IR extracting individual elements of a MemRef descriptor structure + /// and returning them as `results` list. + static void unpack(OpBuilder &builder, Location loc, Value packed, + MemRefType type, SmallVectorImpl &results); + + /// Returns the number of non-aggregate values that would be produced by + /// `unpack`. + static unsigned getNumUnpackedValues(MemRefType type); + private: // Cached index type. Type indexType; }; +/// Helper class allowing the user to access a range of Values that correspond +/// to an unpacked memref descriptor using named accessors. This does not own +/// the values. +class MemRefDescriptorView { +public: + /// Constructs the view from a range of values. Infers the rank from the size + /// of the range. + explicit MemRefDescriptorView(ValueRange range); + + /// Returns the allocated pointer Value. + Value allocatedPtr(); + + /// Returns the aligned pointer Value. + Value alignedPtr(); + + /// Returns the offset Value. + Value offset(); + + /// Returns the pos-th size Value. + Value size(unsigned pos); + + /// Returns the pos-th stride Value. + Value stride(unsigned pos); + +private: + /// Rank of the memref the descriptor is pointing to. + int rank; + /// Underlying range of Values. + ValueRange elements; +}; + class UnrankedMemRefDescriptor : public StructBuilder { public: /// Construct a helper for the given descriptor value. @@ -255,6 +344,23 @@ Value memRefDescPtr(OpBuilder &builder, Location loc); /// Builds IR setting ranked memref descriptor ptr void setMemRefDescPtr(OpBuilder &builder, Location loc, Value value); + + /// Builds IR populating an unranked MemRef descriptor structure from a list + /// of individual constituent values in the following order: + /// - rank of the memref; + /// - pointer to the memref descriptor. + static Value pack(OpBuilder &builder, Location loc, + LLVMTypeConverter &converter, UnrankedMemRefType type, + ValueRange values); + + /// Builds IR extracting individual elements that compose an unranked memref + /// descriptor and returns them as `results` list. + static void unpack(OpBuilder &builder, Location loc, Value packed, + SmallVectorImpl &results); + + /// Returns the number of non-aggregate values that would be produced by + /// `unpack`. + static unsigned getNumUnpackedValues() { return 2; } }; /// Base class for operation conversions targeting the LLVM IR dialect. Provides /// conversion patterns with an access to the containing LLVMLowering for the diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h @@ -29,16 +29,21 @@ void populateStdToLLVMNonMemoryConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns); -/// Collect the default pattern to convert a FuncOp to the LLVM dialect. +/// Collect the default pattern to convert a FuncOp to the LLVM dialect. If +/// `emitCWrappers` is set, the pattern will also produce functions +/// that pass memref descriptors by pointer-to-structure in addition to the +/// default unpacked form. void populateStdToLLVMDefaultFuncOpConversionPattern( - LLVMTypeConverter &converter, OwningRewritePatternList &patterns); + LLVMTypeConverter &converter, OwningRewritePatternList &patterns, + bool emitCWrappers = false); /// Collect a set of default patterns to convert from the Standard dialect to /// LLVM. If `useAlloca` is set, the patterns for AllocOp and DeallocOp will /// generate `llvm.alloca` instead of calls to "malloc". void populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter, OwningRewritePatternList &patterns, - bool useAlloca = false); + bool useAlloca = false, + bool emitCWrappers = false); /// Collect a set of patterns to convert from the Standard dialect to /// LLVM using the bare pointer calling convention for MemRef function @@ -53,7 +58,7 @@ /// Specifying `useAlloca-true` emits stack allocations instead. In the future /// this may become an enum when we have concrete uses for other options. std::unique_ptr> -createLowerToLLVMPass(bool useAlloca = false); +createLowerToLLVMPass(bool useAlloca = false, bool emitCWrappers = false); namespace LLVM { /// Make argument-taking successors of each block distinct. PHI nodes in LLVM diff --git a/mlir/include/mlir/IR/FunctionSupport.h b/mlir/include/mlir/IR/FunctionSupport.h --- a/mlir/include/mlir/IR/FunctionSupport.h +++ b/mlir/include/mlir/IR/FunctionSupport.h @@ -30,6 +30,13 @@ return ("arg" + Twine(arg)).toStringRef(out); } +/// Returns true if the given name is a valid argument attribute name. +inline bool isArgAttrName(StringRef name) { + APInt unused; + return name.startswith("arg") && + !name.drop_front(3).getAsInteger(/*Radix=*/10, unused); +} + /// Return the name of the attribute used for function results. inline StringRef getResultAttrName(unsigned arg, SmallVectorImpl &out) { out.clear(); diff --git a/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp b/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp --- a/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp +++ b/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp @@ -113,6 +113,8 @@ } void declareCudaFunctions(Location loc); + void addParamToList(OpBuilder &builder, Location loc, Value param, Value list, + unsigned pos, Value one); Value setupParamsArray(gpu::LaunchFuncOp launchOp, OpBuilder &builder); Value generateKernelNameConstant(StringRef name, Location loc, OpBuilder &builder); @@ -231,6 +233,35 @@ } } +/// Emits the IR with the following structure: +/// +/// %data = llvm.alloca 1 x type-of() +/// llvm.store , %data +/// %typeErased = llvm.bitcast %data to !llvm<"i8*"> +/// %addr = llvm.getelementptr [] +/// llvm.store %typeErased, %addr +/// +/// This is necessary to construct the list of arguments passed to the kernel +/// function as accepted by cuLaunchKernel, i.e. as a void** that points to list +/// of stack-allocated type-erased pointers to the actual arguments. +void GpuLaunchFuncToCudaCallsPass::addParamToList(OpBuilder &builder, + Location loc, Value param, + Value list, unsigned pos, + Value one) { + auto memLocation = builder.create( + loc, param.getType().cast().getPointerTo(), one, + /*alignment=*/1); + builder.create(loc, param, memLocation); + auto casted = + builder.create(loc, getPointerType(), memLocation); + + auto index = builder.create(loc, getInt32Type(), + builder.getI32IntegerAttr(pos)); + auto gep = builder.create(loc, getPointerPointerType(), list, + ArrayRef{index}); + builder.create(loc, casted, gep); +} + // Generates a parameters array to be used with a CUDA kernel launch call. The // arguments are extracted from the launchOp. // The generated code is essentially as follows: @@ -241,53 +272,66 @@ // return %array Value GpuLaunchFuncToCudaCallsPass::setupParamsArray(gpu::LaunchFuncOp launchOp, OpBuilder &builder) { + + // Get the launch target. + auto containingModule = launchOp.getParentOfType(); + if (!containingModule) + return {}; + auto gpuModule = containingModule.lookupSymbol( + launchOp.getKernelModuleName()); + if (!gpuModule) + return {}; + auto gpuFunc = gpuModule.lookupSymbol(launchOp.kernel()); + if (!gpuFunc) + return {}; + + unsigned numArgs = gpuFunc.getNumArguments(); + auto numKernelOperands = launchOp.getNumKernelOperands(); Location loc = launchOp.getLoc(); auto one = builder.create(loc, getInt32Type(), builder.getI32IntegerAttr(1)); - // Provision twice as much for the `array` to allow up to one level of - // indirection for each argument. auto arraySize = builder.create( - loc, getInt32Type(), builder.getI32IntegerAttr(numKernelOperands)); + loc, getInt32Type(), builder.getI32IntegerAttr(numArgs)); auto array = builder.create(loc, getPointerPointerType(), arraySize, /*alignment=*/0); + + unsigned pos = 0; for (unsigned idx = 0; idx < numKernelOperands; ++idx) { auto operand = launchOp.getKernelOperand(idx); auto llvmType = operand.getType().cast(); - Value memLocation = builder.create( - loc, llvmType.getPointerTo(), one, /*alignment=*/1); - builder.create(loc, operand, memLocation); - auto casted = - builder.create(loc, getPointerType(), memLocation); // Assume all struct arguments come from MemRef. If this assumption does not // hold anymore then we `launchOp` to lower from MemRefType and not after // LLVMConversion has taken place and the MemRef information is lost. - // Extra level of indirection in the `array`: - // the descriptor pointer is registered via @mcuMemHostRegisterPtr - if (llvmType.isStructTy()) { - auto registerFunc = - getModule().lookupSymbol(kMcuMemHostRegister); - auto nullPtr = builder.create(loc, llvmType.getPointerTo()); - auto gep = builder.create(loc, llvmType.getPointerTo(), - ArrayRef{nullPtr, one}); - auto size = builder.create(loc, getInt64Type(), gep); - builder.create(loc, ArrayRef{}, - builder.getSymbolRefAttr(registerFunc), - ArrayRef{casted, size}); - Value memLocation = builder.create( - loc, getPointerPointerType(), one, /*alignment=*/1); - builder.create(loc, casted, memLocation); - casted = - builder.create(loc, getPointerType(), memLocation); + if (!llvmType.isStructTy()) { + addParamToList(builder, loc, operand, array, pos++, one); + continue; } - auto index = builder.create( - loc, getInt32Type(), builder.getI32IntegerAttr(idx)); - auto gep = builder.create(loc, getPointerPointerType(), array, - ArrayRef{index}); - builder.create(loc, casted, gep); + // Put individual components of a memref descriptor into the flat argument + // list. We cannot use unpackMemref from LLVM lowering here because we have + // no access to MemRefType that had been lowered away. + for (int32_t j = 0, ej = llvmType.getStructNumElements(); j < ej; ++j) { + auto elemType = llvmType.getStructElementType(j); + if (elemType.isArrayTy()) { + for (int32_t k = 0, ek = elemType.getArrayNumElements(); k < ek; ++k) { + Value elem = builder.create( + loc, elemType.getArrayElementType(), operand, + builder.getI32ArrayAttr({j, k})); + addParamToList(builder, loc, elem, array, pos++, one); + } + } else { + assert((elemType.isIntegerTy() || elemType.isFloatTy() || + elemType.isDoubleTy() || elemType.isPointerTy()) && + "expected scalar type"); + Value strct = builder.create( + loc, elemType, operand, builder.getI32ArrayAttr(j)); + addParamToList(builder, loc, strct, array, pos++, one); + } + } } + return array; } @@ -392,6 +436,10 @@ auto cuFunctionRef = builder.create(loc, getPointerType(), cuFunction); auto paramsArray = setupParamsArray(launchOp, builder); + if (!paramsArray) { + launchOp.emitOpError() << "cannot pass given parameters to the kernel"; + return signalPassFailure(); + } auto nullpointer = builder.create(loc, getPointerPointerType(), zero); builder.create( diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -564,8 +564,8 @@ // Remap proper input types. TypeConverter::SignatureConversion signatureConversion( gpuFuncOp.front().getNumArguments()); - for (unsigned i = 0, e = funcType.getFunctionNumParams(); i < e; ++i) - signatureConversion.addInputs(i, funcType.getFunctionParamType(i)); + lowering.convertFunctionSignature(gpuFuncOp.getType(), /*isVariadic=*/false, + signatureConversion); // Create the new function operation. Only copy those attributes that are // not specific to function modeling. @@ -651,25 +651,6 @@ rewriter.applySignatureConversion(&llvmFuncOp.getBody(), signatureConversion); - { - // For memref-typed arguments, insert the relevant loads in the beginning - // of the block to comply with the LLVM dialect calling convention. This - // needs to be done after signature conversion to get the right types. - OpBuilder::InsertionGuard guard(rewriter); - Block &block = llvmFuncOp.front(); - rewriter.setInsertionPointToStart(&block); - - for (auto en : llvm::enumerate(gpuFuncOp.getType().getInputs())) { - if (!en.value().isa() && - !en.value().isa()) - continue; - - BlockArgument arg = block.getArgument(en.index()); - Value loaded = rewriter.create(loc, arg); - rewriter.replaceUsesOfBlockArgument(arg, loaded); - } - } - rewriter.eraseOp(gpuFuncOp); return matchSuccess(); } diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -577,7 +577,8 @@ LinalgTypeConverter converter(&getContext()); populateAffineToStdConversionPatterns(patterns, &getContext()); populateLoopToStdConversionPatterns(patterns, &getContext()); - populateStdToLLVMConversionPatterns(converter, patterns); + populateStdToLLVMConversionPatterns(converter, patterns, /*useAlloca=*/false, + /*emitCWrappers=*/true); populateVectorToLLVMConversionPatterns(converter, patterns); populateLinalgToStandardConversionPatterns(patterns, &getContext()); populateLinalgToLLVMConversionPatterns(converter, patterns, &getContext()); diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -30,6 +30,7 @@ #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Type.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/FormatVariadic.h" using namespace mlir; @@ -43,6 +44,11 @@ llvm::cl::desc("Replace emission of malloc/free by alloca"), llvm::cl::init(false)); +static llvm::cl::opt + clEmitCWrappers(PASS_NAME "-emit-c-wrappers", + llvm::cl::desc("Emit C-compatible wrapper functions"), + llvm::cl::init(false)); + static llvm::cl::opt clUseBarePtrCallConv( PASS_NAME "-use-bare-ptr-memref-call-conv", llvm::cl::desc("Replace FuncOp's MemRef arguments with " @@ -66,18 +72,32 @@ funcArgConverter = structFuncArgTypeConverter; } -// Callback to convert function argument types. It converts a MemRef function -// arguments to a struct that contains the descriptor information. Converted -// types are promoted to a pointer to the converted type. -LLVM::LLVMType mlir::structFuncArgTypeConverter(LLVMTypeConverter &converter, - Type type) { - auto converted = - converter.convertType(type).dyn_cast_or_null(); +/// Callback to convert function argument types. It converts a MemRef function +/// argument to a list of non-aggregate types containing descriptor +/// information, and an UnrankedmemRef function argument to a list containing +/// the rank and a pointer to a descriptor struct. +LogicalResult mlir::structFuncArgTypeConverter(LLVMTypeConverter &converter, + Type type, + SmallVectorImpl &result) { + if (auto memref = type.dyn_cast()) { + auto converted = converter.convertMemRefSignature(memref); + if (converted.empty()) + return failure(); + result.append(converted.begin(), converted.end()); + return success(); + } + if (type.isa()) { + auto converted = converter.convertUnrankedMemRefSignature(); + if (converted.empty()) + return failure(); + result.append(converted.begin(), converted.end()); + return success(); + } + auto converted = converter.convertType(type); if (!converted) - return {}; - if (type.isa() || type.isa()) - converted = converted.getPointerTo(); - return converted; + return failure(); + result.push_back(converted); + return success(); } /// Convert a MemRef type to a bare pointer to the MemRef element type. @@ -96,15 +116,26 @@ } /// Callback to convert function argument types. It converts MemRef function -/// arguments to bare pointers to the MemRef element type. Converted types are -/// not promoted to pointers. -LLVM::LLVMType mlir::barePtrFuncArgTypeConverter(LLVMTypeConverter &converter, - Type type) { +/// arguments to bare pointers to the MemRef element type. +LogicalResult mlir::barePtrFuncArgTypeConverter(LLVMTypeConverter &converter, + Type type, + SmallVectorImpl &result) { // TODO: Add support for unranked memref. - if (auto memrefTy = type.dyn_cast()) - return convertMemRefTypeToBarePtr(converter, memrefTy) - .dyn_cast_or_null(); - return converter.convertType(type).dyn_cast_or_null(); + if (auto memrefTy = type.dyn_cast()) { + auto llvmTy = convertMemRefTypeToBarePtr(converter, memrefTy); + if (!llvmTy) + return failure(); + + result.push_back(llvmTy); + return success(); + } + + auto llvmTy = converter.convertType(type); + if (!llvmTy) + return failure(); + + result.push_back(llvmTy); + return success(); } /// Create an LLVMTypeConverter using default LLVMTypeConverterCustomization. @@ -165,6 +196,33 @@ return converted.getPointerTo(); } +/// In signatures, MemRef descriptors are expanded into lists of non-aggregate +/// values. +SmallVector +LLVMTypeConverter::convertMemRefSignature(MemRefType type) { + SmallVector results; + assert(isStrided(type) && + "Non-strided layout maps must have been normalized away"); + + LLVM::LLVMType elementType = unwrap(convertType(type.getElementType())); + if (!elementType) + return {}; + auto indexTy = getIndexType(); + + results.insert(results.begin(), 2, + elementType.getPointerTo(type.getMemorySpace())); + results.push_back(indexTy); + auto rank = type.getRank(); + results.insert(results.end(), 2 * rank, indexTy); + return results; +} + +/// In signatures, unranked MemRef descriptors are expanded into a pair "rank, +/// pointer to descriptor". +SmallVector LLVMTypeConverter::convertUnrankedMemRefSignature() { + return {getIndexType(), LLVM::LLVMType::getInt8PtrTy(llvmDialect)}; +} + // Function types are converted to LLVM Function types by recursively converting // argument and result types. If MLIR Function has zero results, the LLVM // Function has one VoidType result. If MLIR Function has more than one result, @@ -175,9 +233,8 @@ // Convert argument types one by one and check for errors. for (auto &en : llvm::enumerate(type.getInputs())) { Type type = en.value(); - auto converted = customizations.funcArgConverter(*this, type) - .dyn_cast_or_null(); - if (!converted) + SmallVector converted; + if (failed(customizations.funcArgConverter(*this, type, converted))) return {}; result.addInputs(en.index(), converted); } @@ -199,6 +256,47 @@ return LLVM::LLVMType::getFunctionTy(resultType, argTypes, isVariadic); } +/// Converts the function type to a C-compatible format, in particular using +/// pointers to memref descriptors for arguments. +LLVM::LLVMType +LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) { + SmallVector inputs; + + for (Type t : type.getInputs()) { + auto converted = convertType(t).dyn_cast_or_null(); + if (!converted) + return {}; + if (t.isa() || t.isa()) + converted = converted.getPointerTo(); + inputs.push_back(converted); + } + + LLVM::LLVMType resultType = + type.getNumResults() == 0 + ? LLVM::LLVMType::getVoidTy(llvmDialect) + : unwrap(packFunctionResults(type.getResults())); + if (!resultType) + return {}; + + return LLVM::LLVMType::getFunctionTy(resultType, inputs, false); +} + +/// Creates descriptor structs from individual values constituting them. +Operation *LLVMTypeConverter::materializeConversion(PatternRewriter &rewriter, + Type type, + ArrayRef values, + Location loc) { + if (auto unrankedMemRefType = type.dyn_cast()) + return UnrankedMemRefDescriptor::pack(rewriter, loc, *this, + unrankedMemRefType, values) + .getDefiningOp(); + + auto memRefType = type.dyn_cast(); + assert(memRefType && "1->N conversion is only supported for memrefs"); + return MemRefDescriptor::pack(rewriter, loc, *this, memRefType, values) + .getDefiningOp(); +} + // Convert a MemRef to an LLVM type. The result is a MemRef descriptor which // contains: // 1. the pointer to the data buffer, followed by @@ -473,6 +571,85 @@ kAlignedPtrPosInMemRefDescriptor); } +/// Creates a MemRef descriptor structure from a list of individual values +/// composing that descriptor, in the following order: +/// - allocated pointer; +/// - aligned pointer; +/// - offset; +/// - sizes; +/// - shapes; +/// where is the MemRef rank as provided in `type`. +Value MemRefDescriptor::pack(OpBuilder &builder, Location loc, + LLVMTypeConverter &converter, MemRefType type, + ValueRange values) { + Type llvmType = converter.convertType(type); + auto d = MemRefDescriptor::undef(builder, loc, llvmType); + + d.setAllocatedPtr(builder, loc, values[kAllocatedPtrPosInMemRefDescriptor]); + d.setAlignedPtr(builder, loc, values[kAlignedPtrPosInMemRefDescriptor]); + d.setOffset(builder, loc, values[kOffsetPosInMemRefDescriptor]); + + int64_t rank = type.getRank(); + for (unsigned i = 0; i < rank; ++i) { + d.setSize(builder, loc, i, values[kSizePosInMemRefDescriptor + i]); + d.setStride(builder, loc, i, values[kSizePosInMemRefDescriptor + rank + i]); + } + + return d; +} + +/// Builds IR extracting individual elements of a MemRef descriptor structure +/// and returning them as `results` list. +void MemRefDescriptor::unpack(OpBuilder &builder, Location loc, Value packed, + MemRefType type, + SmallVectorImpl &results) { + int64_t rank = type.getRank(); + results.reserve(results.size() + getNumUnpackedValues(type)); + + MemRefDescriptor d(packed); + results.push_back(d.allocatedPtr(builder, loc)); + results.push_back(d.alignedPtr(builder, loc)); + results.push_back(d.offset(builder, loc)); + for (int64_t i = 0; i < rank; ++i) + results.push_back(d.size(builder, loc, i)); + for (int64_t i = 0; i < rank; ++i) + results.push_back(d.stride(builder, loc, i)); +} + +/// Returns the number of non-aggregate values that would be produced by +/// `unpack`. +unsigned MemRefDescriptor::getNumUnpackedValues(MemRefType type) { + // Two pointers, offset, sizes, shapes. + return 3 + 2 * type.getRank(); +} + +/*============================================================================*/ +/* MemRefDescriptorView implementation. */ +/*============================================================================*/ + +MemRefDescriptorView::MemRefDescriptorView(ValueRange range) + : rank((range.size() - kSizePosInMemRefDescriptor) / 2), elements(range) {} + +Value MemRefDescriptorView::allocatedPtr() { + return elements[kAllocatedPtrPosInMemRefDescriptor]; +} + +Value MemRefDescriptorView::alignedPtr() { + return elements[kAlignedPtrPosInMemRefDescriptor]; +} + +Value MemRefDescriptorView::offset() { + return elements[kOffsetPosInMemRefDescriptor]; +} + +Value MemRefDescriptorView::size(unsigned pos) { + return elements[kSizePosInMemRefDescriptor + pos]; +} + +Value MemRefDescriptorView::stride(unsigned pos) { + return elements[kSizePosInMemRefDescriptor + rank + pos]; +} + /*============================================================================*/ /* UnrankedMemRefDescriptor implementation */ /*============================================================================*/ @@ -504,6 +681,34 @@ Location loc, Value v) { setPtr(builder, loc, kPtrInUnrankedMemRefDescriptor, v); } + +/// Builds IR populating an unranked MemRef descriptor structure from a list +/// of individual constituent values in the following order: +/// - rank of the memref; +/// - pointer to the memref descriptor. +Value UnrankedMemRefDescriptor::pack(OpBuilder &builder, Location loc, + LLVMTypeConverter &converter, + UnrankedMemRefType type, + ValueRange values) { + Type llvmType = converter.convertType(type); + auto d = UnrankedMemRefDescriptor::undef(builder, loc, llvmType); + + d.setRank(builder, loc, values[kRankInUnrankedMemRefDescriptor]); + d.setMemRefDescPtr(builder, loc, values[kPtrInUnrankedMemRefDescriptor]); + return d; +} + +/// Builds IR extracting individual elements that compose an unranked memref +/// descriptor and returns them as `results` list. +void UnrankedMemRefDescriptor::unpack(OpBuilder &builder, Location loc, + Value packed, + SmallVectorImpl &results) { + UnrankedMemRefDescriptor d(packed); + results.reserve(results.size() + 2); + results.push_back(d.rank(builder, loc)); + results.push_back(d.memRefDescPtr(builder, loc)); +} + namespace { // Base class for Standard to LLVM IR op conversions. Matches the Op type // provided as template argument. Carries a reference to the LLVM dialect in @@ -551,9 +756,144 @@ LLVM::LLVMDialect &dialect; }; +/// Only retain those attributes that are not constructed by +/// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument +/// attributes. +static void filterFuncAttributes(ArrayRef attrs, + bool filterArgAttrs, + SmallVectorImpl &result) { + for (const auto &attr : attrs) { + if (attr.first.is(SymbolTable::getSymbolAttrName()) || + attr.first.is(impl::getTypeAttrName()) || + attr.first.is("std.varargs") || + (filterArgAttrs && impl::isArgAttrName(attr.first.strref()))) + continue; + result.push_back(attr); + } +} + +/// Creates an auxiliary function with pointer-to-memref-descriptor-struct +/// arguments instead of unpacked arguments. This function can be called from C +/// by passing a pointer to a C struct corresponding to a memref descriptor. +/// Internally, the auxiliary function unpacks the descriptor into individual +/// components and forwards them to `newFuncOp`. +static void wrapForExternalCallers(OpBuilder &rewriter, Location loc, + LLVMTypeConverter &typeConverter, + FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp) { + auto type = funcOp.getType(); + SmallVector attributes; + filterFuncAttributes(funcOp.getAttrs(), /*filterArgAttrs=*/false, attributes); + auto wrapperFuncOp = rewriter.create( + loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(), + typeConverter.convertFunctionTypeCWrapper(type), LLVM::Linkage::External, + attributes); + + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(wrapperFuncOp.addEntryBlock()); + + SmallVector args; + for (auto &en : llvm::enumerate(type.getInputs())) { + Value arg = wrapperFuncOp.getArgument(en.index()); + if (auto memrefType = en.value().dyn_cast()) { + Value loaded = rewriter.create(loc, arg); + MemRefDescriptor::unpack(rewriter, loc, loaded, memrefType, args); + continue; + } + if (en.value().isa()) { + Value loaded = rewriter.create(loc, arg); + UnrankedMemRefDescriptor::unpack(rewriter, loc, loaded, args); + continue; + } + + args.push_back(wrapperFuncOp.getArgument(en.index())); + } + auto call = rewriter.create(loc, newFuncOp, args); + rewriter.create(loc, call.getResults()); +} + +/// Creates an auxiliary function with pointer-to-memref-descriptor-struct +/// arguments instead of unpacked arguments. Creates a body for the (external) +/// `newFuncOp` that allocates a memref descriptor on stack, packs the +/// individual arguments into this descriptor and passes a pointer to it into +/// the auxiliary function. This auxiliary external function is now compatible +/// with functions defined in C using pointers to C structs corresponding to a +/// memref descriptor. +static void wrapExternalFunction(OpBuilder &builder, Location loc, + LLVMTypeConverter &typeConverter, + FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp) { + OpBuilder::InsertionGuard guard(builder); + + LLVM::LLVMType wrapperType = + typeConverter.convertFunctionTypeCWrapper(funcOp.getType()); + // This conversion can only fail if it could not convert one of the argument + // types. But since it has been applies to a non-wrapper function before, it + // should have failed earlier and not reach this point at all. + assert(wrapperType && "unexpected type conversion failure"); + + SmallVector attributes; + filterFuncAttributes(funcOp.getAttrs(), /*filterArgAttrs=*/false, attributes); + + // Create the auxiliary function. + auto wrapperFunc = builder.create( + loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(), + wrapperType, LLVM::Linkage::External, attributes); + + builder.setInsertionPointToStart(newFuncOp.addEntryBlock()); + + // Get a ValueRange containing argument types. Note that ValueRange is + // currently not constructible from a pair of iterators pointing to + // BlockArgument. + FunctionType type = funcOp.getType(); + SmallVector args; + args.reserve(type.getNumInputs()); + auto wrapperArgIters = newFuncOp.getArguments(); + SmallVector wrapperArgs(wrapperArgIters.begin(), + wrapperArgIters.end()); + ValueRange wrapperArgsRange(wrapperArgs); + + // Iterate over the inputs of the original function and pack values into + // memref descriptors if the original type is a memref. + for (auto &en : llvm::enumerate(type.getInputs())) { + Value arg; + int numToDrop = 1; + auto memRefType = en.value().dyn_cast(); + auto unrankedMemRefType = en.value().dyn_cast(); + if (memRefType || unrankedMemRefType) { + numToDrop = memRefType + ? MemRefDescriptor::getNumUnpackedValues(memRefType) + : UnrankedMemRefDescriptor::getNumUnpackedValues(); + Value packed = + memRefType + ? MemRefDescriptor::pack(builder, loc, typeConverter, memRefType, + wrapperArgsRange.take_front(numToDrop)) + : UnrankedMemRefDescriptor::pack( + builder, loc, typeConverter, unrankedMemRefType, + wrapperArgsRange.take_front(numToDrop)); + + auto ptrTy = packed.getType().cast().getPointerTo(); + Value one = builder.create( + loc, typeConverter.convertType(builder.getIndexType()), + builder.getIntegerAttr(builder.getIndexType(), 1)); + Value allocated = + builder.create(loc, ptrTy, one, /*alignment=*/0); + builder.create(loc, packed, allocated); + arg = allocated; + } else { + arg = wrapperArgsRange[0]; + } + + args.push_back(arg); + wrapperArgsRange = wrapperArgsRange.drop_front(numToDrop); + } + assert(wrapperArgsRange.empty() && "did not map some of the arguments"); + + auto call = builder.create(loc, wrapperFunc, args); + builder.create(loc, call.getResults()); +} + struct FuncOpConversionBase : public LLVMLegalizationPattern { protected: - using LLVMLegalizationPattern::LLVMLegalizationPattern; + using LLVMLegalizationPattern::LLVMLegalizationPattern; using UnsignedTypePair = std::pair; // Gather the positions and types of memref-typed arguments in a given @@ -579,14 +919,24 @@ auto llvmType = lowering.convertFunctionSignature( funcOp.getType(), varargsAttr && varargsAttr.getValue(), result); - // Only retain those attributes that are not constructed by build. + // Propagate argument attributes to all converted arguments obtained after + // converting a given original argument. SmallVector attributes; - for (const auto &attr : funcOp.getAttrs()) { - if (attr.first.is(SymbolTable::getSymbolAttrName()) || - attr.first.is(impl::getTypeAttrName()) || - attr.first.is("std.varargs")) + filterFuncAttributes(funcOp.getAttrs(), /*filterArgAttrs=*/true, + attributes); + for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) { + auto attr = impl::getArgAttrDict(funcOp, i); + if (!attr) continue; - attributes.push_back(attr); + + auto mapping = result.getInputMapping(i); + assert(mapping.hasValue() && "unexpected deletion of function argument"); + + SmallString<8> name; + for (size_t j = mapping->inputNo; j < mapping->size; ++j) { + impl::getArgAttrName(j, name); + attributes.push_back(rewriter.getNamedAttr(name, attr)); + } } // Create an LLVM function, use external linkage by default until MLIR @@ -607,34 +957,33 @@ /// MemRef descriptors (LLVM struct data types) containing all the MemRef type /// information. struct FuncOpConversion : public FuncOpConversionBase { - using FuncOpConversionBase::FuncOpConversionBase; + FuncOpConversion(LLVM::LLVMDialect &dialect, LLVMTypeConverter &converter, + bool emitCWrappers) + : FuncOpConversionBase(dialect, converter), emitWrappers(emitCWrappers) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto funcOp = cast(op); - // Store the positions of memref-typed arguments so that we can emit loads - // from them to follow the calling convention. - SmallVector promotedArgsInfo; - getMemRefArgIndicesAndTypes(funcOp.getType(), promotedArgsInfo); - auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter); - - // Insert loads from memref descriptor pointers in function bodies. - if (!newFuncOp.getBody().empty()) { - Block *firstBlock = &newFuncOp.getBody().front(); - rewriter.setInsertionPoint(firstBlock, firstBlock->begin()); - for (const auto &argInfo : promotedArgsInfo) { - BlockArgument arg = firstBlock->getArgument(argInfo.first); - Value loaded = rewriter.create(funcOp.getLoc(), arg); - rewriter.replaceUsesOfBlockArgument(arg, loaded); - } + if (emitWrappers) { + if (newFuncOp.isExternal()) + wrapExternalFunction(rewriter, op->getLoc(), lowering, funcOp, + newFuncOp); + else + wrapForExternalCallers(rewriter, op->getLoc(), lowering, funcOp, + newFuncOp); } rewriter.eraseOp(op); return matchSuccess(); } + +private: + /// If true, also create the adaptor functions having signatures compatible + /// with those produced by clang. + const bool emitWrappers; }; /// FuncOp legalization pattern that converts MemRef arguments to bare pointers @@ -2273,14 +2622,17 @@ } void mlir::populateStdToLLVMDefaultFuncOpConversionPattern( - LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { - patterns.insert(*converter.getDialect(), converter); + LLVMTypeConverter &converter, OwningRewritePatternList &patterns, + bool emitCWrappers) { + patterns.insert(*converter.getDialect(), converter, + emitCWrappers); } void mlir::populateStdToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns, - bool useAlloca) { - populateStdToLLVMDefaultFuncOpConversionPattern(converter, patterns); + bool useAlloca, bool emitCWrappers) { + populateStdToLLVMDefaultFuncOpConversionPattern(converter, patterns, + emitCWrappers); populateStdToLLVMNonMemoryConversionPatterns(converter, patterns); populateStdToLLVMMemoryConversionPatters(converter, patterns, useAlloca); } @@ -2346,13 +2698,20 @@ for (auto it : llvm::zip(opOperands, operands)) { auto operand = std::get<0>(it); auto llvmOperand = std::get<1>(it); - if (!operand.getType().isa() && - !operand.getType().isa()) { - promotedOperands.push_back(operand); + + if (operand.getType().isa()) { + UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand, + promotedOperands); continue; } - promotedOperands.push_back( - promoteOneMemRefDescriptor(loc, llvmOperand, builder)); + if (auto memrefType = operand.getType().dyn_cast()) { + MemRefDescriptor::unpack(builder, loc, llvmOperand, + operand.getType().cast(), + promotedOperands); + continue; + } + + promotedOperands.push_back(operand); } return promotedOperands; } @@ -2362,11 +2721,21 @@ struct LLVMLoweringPass : public ModulePass { /// Creates an LLVM lowering pass. explicit LLVMLoweringPass(bool useAlloca = false, - bool useBarePtrCallConv = false) - : useAlloca(useAlloca), useBarePtrCallConv(useBarePtrCallConv) {} + bool useBarePtrCallConv = false, + bool emitCWrappers = false) + : useAlloca(useAlloca), useBarePtrCallConv(useBarePtrCallConv), + emitCWrappers(emitCWrappers) {} /// Run the dialect converter on the module. void runOnModule() override { + if (useBarePtrCallConv && emitCWrappers) { + getModule().emitError() + << "incompatible conversion options: bare-pointer calling convention " + "and C wrapper emission"; + signalPassFailure(); + return; + } + ModuleOp m = getModule(); LLVM::ensureDistinctSuccessors(m); @@ -2380,7 +2749,8 @@ populateStdToLLVMBarePtrConversionPatterns(typeConverter, patterns, useAlloca); else - populateStdToLLVMConversionPatterns(typeConverter, patterns, useAlloca); + populateStdToLLVMConversionPatterns(typeConverter, patterns, useAlloca, + emitCWrappers); ConversionTarget target(getContext()); target.addLegalDialect(); @@ -2393,19 +2763,23 @@ /// Convert memrefs to bare pointers in function signatures. bool useBarePtrCallConv; + + /// Emit wrappers for C-compatible pointer-to-struct memref descriptors. + bool emitCWrappers; }; } // end namespace std::unique_ptr> -mlir::createLowerToLLVMPass(bool useAlloca) { - return std::make_unique(useAlloca); +mlir::createLowerToLLVMPass(bool useAlloca, bool emitCWrappers) { + return std::make_unique(useAlloca, emitCWrappers); } static PassRegistration - pass("convert-std-to-llvm", + pass(PASS_NAME, "Convert scalar and vector operations from the " "Standard to the LLVM dialect", [] { return std::make_unique( - clUseAlloca.getValue(), clUseBarePtrCallConv.getValue()); + clUseAlloca.getValue(), clUseBarePtrCallConv.getValue(), + clEmitCWrappers.getValue()); }); diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -90,26 +90,27 @@ return launchOp.emitOpError("kernel function is missing the '") << GPUDialect::getKernelFuncAttrName() << "' attribute"; + // TODO(ntv,zinenko,herhut): if the kernel function has been converted to + // the LLVM dialect but the caller hasn't (which happens during the + // separate compilation), do not check type correspondance as it would + // require the verifier to be aware of the LLVM type conversion. + if (kernelLLVMFunction) + return success(); + unsigned actualNumArguments = launchOp.getNumKernelOperands(); - unsigned expectedNumArguments = kernelLLVMFunction - ? kernelLLVMFunction.getNumArguments() - : kernelGPUFunction.getNumArguments(); + unsigned expectedNumArguments = kernelGPUFunction.getNumArguments(); if (expectedNumArguments != actualNumArguments) return launchOp.emitOpError("got ") << actualNumArguments << " kernel operands but expected " << expectedNumArguments; - // Due to the ordering of the current impl of lowering and LLVMLowering, - // type checks need to be temporarily disabled. - // TODO(ntv,zinenko,herhut): reactivate checks once "changing gpu.launchFunc - // to encode target module" has landed. - // auto functionType = kernelFunc.getType(); - // for (unsigned i = 0; i < numKernelFuncArgs; ++i) { - // if (getKernelOperand(i).getType() != functionType.getInput(i)) { - // return emitOpError("type of function argument ") - // << i << " does not match"; - // } - // } + auto functionType = kernelGPUFunction.getType(); + for (unsigned i = 0; i < expectedNumArguments; ++i) { + if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) { + return launchOp.emitOpError("type of function argument ") + << i << " does not match"; + } + } return success(); }); diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -401,8 +401,7 @@ auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size); Operation *cast = typeConverter->materializeConversion( rewriter, origArg.getType(), replArgs, loc); - assert(cast->getNumResults() == 1 && - cast->getNumOperands() == replArgs.size()); + assert(cast->getNumResults() == 1); mapping.map(origArg, cast->getResult(0)); info.argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size, cast->getResult(0)); diff --git a/mlir/test/Conversion/GPUToCUDA/lower-launch-func-to-cuda.mlir b/mlir/test/Conversion/GPUToCUDA/lower-launch-func-to-cuda.mlir --- a/mlir/test/Conversion/GPUToCUDA/lower-launch-func-to-cuda.mlir +++ b/mlir/test/Conversion/GPUToCUDA/lower-launch-func-to-cuda.mlir @@ -6,8 +6,8 @@ // CHECK: llvm.mlir.global internal constant @[[global:.*]]("CUBIN") gpu.module @kernel_module attributes {nvvm.cubin = "CUBIN"} { - gpu.func @kernel(%arg0: !llvm.float, %arg1: !llvm<"float*">) attributes {gpu.kernel} { - gpu.return + llvm.func @kernel(%arg0: !llvm.float, %arg1: !llvm<"float*">) attributes {gpu.kernel} { + llvm.return } } diff --git a/mlir/test/Conversion/StandardToLLVM/convert-argattrs.mlir b/mlir/test/Conversion/StandardToLLVM/convert-argattrs.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-argattrs.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-argattrs.mlir @@ -1,25 +1,11 @@ // RUN: mlir-opt -convert-std-to-llvm %s | FileCheck %s - -// CHECK-LABEL: func @check_attributes(%arg0: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> {dialect.a = true, dialect.b = 4 : i64}) { -// CHECK-NEXT: llvm.load %arg0 : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> +// CHECK-LABEL: func @check_attributes +// When expanding the memref to multiple arguments, argument attributes are replicated. +// CHECK-COUNT-7: {dialect.a = true, dialect.b = 4 : i64} func @check_attributes(%static: memref<10x20xf32> {dialect.a = true, dialect.b = 4 : i64 }) { %c0 = constant 0 : index %0 = load %static[%c0, %c0]: memref<10x20xf32> return } -// CHECK-LABEL: func @external_func(!llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">) -// CHECK: func @call_external(%[[arg:.*]]: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">) { -// CHECK: %[[ld:.*]] = llvm.load %[[arg]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> -// CHECK: %[[c1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK: %[[alloca:.*]] = llvm.alloca %[[c1]] x !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> : (!llvm.i64) -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> -// CHECK: llvm.store %[[ld]], %[[alloca]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> -// CHECK: call @external_func(%[[alloca]]) : (!llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">) -> () -func @external_func(memref<10x20xf32>) - -func @call_external(%static: memref<10x20xf32>) { - call @external_func(%static) : (memref<10x20xf32>) -> () - return -} - diff --git a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir @@ -1,14 +1,25 @@ // RUN: mlir-opt -convert-std-to-llvm %s | FileCheck %s -// CHECK-LABEL: func @check_strided_memref_arguments( -// CHECK-COUNT-3: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> +// CHECK-LABEL: func @check_strided_memref_arguments( +// CHECK-COUNT-2: !llvm<"float*"> +// CHECK-COUNT-5: !llvm.i64 +// CHECK-COUNT-2: !llvm<"float*"> +// CHECK-COUNT-5: !llvm.i64 +// CHECK-COUNT-2: !llvm<"float*"> +// CHECK-COUNT-5: !llvm.i64 func @check_strided_memref_arguments(%static: memref<10x20xf32, affine_map<(i,j)->(20 * i + j + 1)>>, %dynamic : memref(M * i + j + 1)>>, %mixed : memref<10x?xf32, affine_map<(i,j)[M]->(M * i + j + 1)>>) { return } -// CHECK-LABEL: func @check_arguments(%arg0: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">, %arg1: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">, %arg2: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">) +// CHECK-LABEL: func @check_arguments +// CHECK-COUNT-2: !llvm<"float*"> +// CHECK-COUNT-5: !llvm.i64 +// CHECK-COUNT-2: !llvm<"float*"> +// CHECK-COUNT-5: !llvm.i64 +// CHECK-COUNT-2: !llvm<"float*"> +// CHECK-COUNT-5: !llvm.i64 func @check_arguments(%static: memref<10x20xf32>, %dynamic : memref, %mixed : memref<10x?xf32>) { return } @@ -16,7 +27,7 @@ // CHECK-LABEL: func @mixed_alloc( // CHECK: %[[M:.*]]: !llvm.i64, %[[N:.*]]: !llvm.i64) -> !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> { func @mixed_alloc(%arg0: index, %arg1: index) -> memref { -// CHECK-NEXT: %[[c42:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64 +// CHECK: %[[c42:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64 // CHECK-NEXT: llvm.mul %[[M]], %[[c42]] : !llvm.i64 // CHECK-NEXT: %[[sz:.*]] = llvm.mul %{{.*}}, %[[N]] : !llvm.i64 // CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*"> @@ -45,10 +56,9 @@ return %0 : memref } -// CHECK-LABEL: func @mixed_dealloc(%arg0: !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*">) { +// CHECK-LABEL: func @mixed_dealloc func @mixed_dealloc(%arg0: memref) { -// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*"> -// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> // CHECK-NEXT: %[[ptri8:.*]] = llvm.bitcast %[[ptr]] : !llvm<"float*"> to !llvm<"i8*"> // CHECK-NEXT: llvm.call @free(%[[ptri8]]) : (!llvm<"i8*">) -> () dealloc %arg0 : memref @@ -59,7 +69,7 @@ // CHECK-LABEL: func @dynamic_alloc( // CHECK: %[[M:.*]]: !llvm.i64, %[[N:.*]]: !llvm.i64) -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> { func @dynamic_alloc(%arg0: index, %arg1: index) -> memref { -// CHECK-NEXT: %[[sz:.*]] = llvm.mul %[[M]], %[[N]] : !llvm.i64 +// CHECK: %[[sz:.*]] = llvm.mul %[[M]], %[[N]] : !llvm.i64 // CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*"> // CHECK-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> @@ -83,10 +93,9 @@ return %0 : memref } -// CHECK-LABEL: func @dynamic_dealloc(%arg0: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">) { +// CHECK-LABEL: func @dynamic_dealloc func @dynamic_dealloc(%arg0: memref) { -// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> -// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> // CHECK-NEXT: %[[ptri8:.*]] = llvm.bitcast %[[ptr]] : !llvm<"float*"> to !llvm<"i8*"> // CHECK-NEXT: llvm.call @free(%[[ptri8]]) : (!llvm<"i8*">) -> () dealloc %arg0 : memref @@ -94,10 +103,12 @@ } // CHECK-LABEL: func @mixed_load( -// CHECK: %[[A:.*]]: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">, %[[I:.*]]: !llvm.i64, %[[J:.*]]: !llvm.i64 +// CHECK-COUNT-2: !llvm<"float*">, +// CHECK-COUNT-5: {{%[a-zA-Z0-9]*}}: !llvm.i64 +// CHECK: %[[I:.*]]: !llvm.i64, +// CHECK: %[[J:.*]]: !llvm.i64) func @mixed_load(%mixed : memref<42x?xf32>, %i : index, %j : index) { -// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> -// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> // CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> // CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64 @@ -112,10 +123,8 @@ } // CHECK-LABEL: func @dynamic_load( -// CHECK: %[[A:.*]]: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">, %[[I:.*]]: !llvm.i64, %[[J:.*]]: !llvm.i64 func @dynamic_load(%dynamic : memref, %i : index, %j : index) { -// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> -// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> // CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> // CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64 @@ -131,8 +140,7 @@ // CHECK-LABEL: func @prefetch func @prefetch(%A : memref, %i : index, %j : index) { -// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> -// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> // CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> // CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64 @@ -161,8 +169,7 @@ // CHECK-LABEL: func @dynamic_store func @dynamic_store(%dynamic : memref, %i : index, %j : index, %val : f32) { -// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> -// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> // CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> // CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64 @@ -171,15 +178,14 @@ // CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64 // CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64 // CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> -// CHECK-NEXT: llvm.store %arg3, %[[addr]] : !llvm<"float*"> +// CHECK-NEXT: llvm.store %{{.*}}, %[[addr]] : !llvm<"float*"> store %val, %dynamic[%i, %j] : memref return } // CHECK-LABEL: func @mixed_store func @mixed_store(%mixed : memref<42x?xf32>, %i : index, %j : index, %val : f32) { -// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> -// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> // CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> // CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64 @@ -188,74 +194,66 @@ // CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64 // CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64 // CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> -// CHECK-NEXT: llvm.store %arg3, %[[addr]] : !llvm<"float*"> +// CHECK-NEXT: llvm.store %{{.*}}, %[[addr]] : !llvm<"float*"> store %val, %mixed[%i, %j] : memref<42x?xf32> return } // CHECK-LABEL: func @memref_cast_static_to_dynamic func @memref_cast_static_to_dynamic(%static : memref<10x42xf32>) { -// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> -// CHECK-NEXT: llvm.bitcast %[[ld]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK: llvm.bitcast %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> %0 = memref_cast %static : memref<10x42xf32> to memref return } // CHECK-LABEL: func @memref_cast_static_to_mixed func @memref_cast_static_to_mixed(%static : memref<10x42xf32>) { -// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> -// CHECK-NEXT: llvm.bitcast %[[ld]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK: llvm.bitcast %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> %0 = memref_cast %static : memref<10x42xf32> to memref return } // CHECK-LABEL: func @memref_cast_dynamic_to_static func @memref_cast_dynamic_to_static(%dynamic : memref) { -// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> -// CHECK-NEXT: llvm.bitcast %[[ld]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK: llvm.bitcast %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> %0 = memref_cast %dynamic : memref to memref<10x12xf32> return } // CHECK-LABEL: func @memref_cast_dynamic_to_mixed func @memref_cast_dynamic_to_mixed(%dynamic : memref) { -// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> -// CHECK-NEXT: llvm.bitcast %[[ld]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK: llvm.bitcast %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> %0 = memref_cast %dynamic : memref to memref return } // CHECK-LABEL: func @memref_cast_mixed_to_dynamic func @memref_cast_mixed_to_dynamic(%mixed : memref<42x?xf32>) { -// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> -// CHECK-NEXT: llvm.bitcast %[[ld]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK: llvm.bitcast %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> %0 = memref_cast %mixed : memref<42x?xf32> to memref return } // CHECK-LABEL: func @memref_cast_mixed_to_static func @memref_cast_mixed_to_static(%mixed : memref<42x?xf32>) { -// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> -// CHECK-NEXT: llvm.bitcast %[[ld]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK: llvm.bitcast %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> %0 = memref_cast %mixed : memref<42x?xf32> to memref<42x1xf32> return } // CHECK-LABEL: func @memref_cast_mixed_to_mixed func @memref_cast_mixed_to_mixed(%mixed : memref<42x?xf32>) { -// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> -// CHECK-NEXT: llvm.bitcast %[[ld]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK: llvm.bitcast %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> %0 = memref_cast %mixed : memref<42x?xf32> to memref return } // CHECK-LABEL: func @memref_cast_ranked_to_unranked func @memref_cast_ranked_to_unranked(%arg : memref<42x2x?xf32>) { -// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*"> // CHECK-DAG: %[[c:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK-DAG: %[[p:.*]] = llvm.alloca %[[c]] x !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> : (!llvm.i64) -> !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*"> -// CHECK-DAG: llvm.store %[[ld]], %[[p]] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*"> -// CHECK-DAG: %[[p2:.*]] = llvm.bitcast %2 : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*"> to !llvm<"i8*"> +// CHECK-DAG: llvm.store %{{.*}}, %[[p]] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*"> +// CHECK-DAG: %[[p2:.*]] = llvm.bitcast %[[p]] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*"> to !llvm<"i8*"> // CHECK-DAG: %[[r:.*]] = llvm.mlir.constant(3 : i64) : !llvm.i64 // CHECK : llvm.mlir.undef : !llvm<"{ i64, i8* }"> // CHECK-DAG: llvm.insertvalue %[[r]], %{{.*}}[0] : !llvm<"{ i64, i8* }"> @@ -266,19 +264,17 @@ // CHECK-LABEL: func @memref_cast_unranked_to_ranked func @memref_cast_unranked_to_ranked(%arg : memref<*xf32>) { -// CHECK: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ i64, i8* }*"> -// CHECK-NEXT: %[[p:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ i64, i8* }"> +// CHECK: %[[p:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm<"{ i64, i8* }"> // CHECK-NEXT: llvm.bitcast %[[p]] : !llvm<"i8*"> to !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }*"> %0 = memref_cast %arg : memref<*xf32> to memref return } -// CHECK-LABEL: func @mixed_memref_dim(%arg0: !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }*">) { +// CHECK-LABEL: func @mixed_memref_dim func @mixed_memref_dim(%mixed : memref<42x?x?x13x?xf32>) { -// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }*"> -// CHECK-NEXT: llvm.mlir.constant(42 : index) : !llvm.i64 +// CHECK: llvm.mlir.constant(42 : index) : !llvm.i64 %0 = dim %mixed, 0 : memref<42x?x?x13x?xf32> -// CHECK-NEXT: llvm.extractvalue %[[ld]][3, 1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> +// CHECK-NEXT: llvm.extractvalue %[[ld:.*]][3, 1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> %1 = dim %mixed, 1 : memref<42x?x?x13x?xf32> // CHECK-NEXT: llvm.extractvalue %[[ld]][3, 2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> %2 = dim %mixed, 2 : memref<42x?x?x13x?xf32> diff --git a/mlir/test/Conversion/StandardToLLVM/convert-funcs.mlir b/mlir/test/Conversion/StandardToLLVM/convert-funcs.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-funcs.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-funcs.mlir @@ -18,12 +18,12 @@ //CHECK: llvm.func @fifth_order_right(!llvm<"void ()* ()* ()* ()*">) func @fifth_order_right(%arg0: () -> (() -> (() -> (() -> ())))) -// Check that memrefs are converted to pointers-to-struct if appear as function arguments. -// CHECK: llvm.func @memref_call_conv(!llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">) +// Check that memrefs are converted to argument packs if appear as function arguments. +// CHECK: llvm.func @memref_call_conv(!llvm<"float*">, !llvm<"float*">, !llvm.i64, !llvm.i64, !llvm.i64) func @memref_call_conv(%arg0: memref) // Same in nested functions. -// CHECK: llvm.func @memref_call_conv_nested(!llvm<"void ({ float*, float*, i64, [1 x i64], [1 x i64] }*)*">) +// CHECK: llvm.func @memref_call_conv_nested(!llvm<"void (float*, float*, i64, i64, i64)*">) func @memref_call_conv_nested(%arg0: (memref) -> ()) //CHECK-LABEL: llvm.func @pass_through(%arg0: !llvm<"void ()*">) -> !llvm<"void ()*"> { diff --git a/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir @@ -10,7 +10,10 @@ // ----- -// CHECK-LABEL: func @check_static_return(%arg0: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">) -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> { +// CHECK-LABEL: func @check_static_return +// CHECK-COUNT-2: !llvm<"float*"> +// CHECK-COUNT-5: !llvm.i64 +// CHECK-SAME: -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> // BAREPTR-LABEL: func @check_static_return // BAREPTR-SAME: (%[[arg:.*]]: !llvm<"float*">) -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> { func @check_static_return(%static : memref<32x18xf32>) -> memref<32x18xf32> { @@ -76,11 +79,10 @@ // ----- -// CHECK-LABEL: func @zero_d_dealloc(%{{.*}}: !llvm<"{ float*, float*, i64 }*">) { +// CHECK-LABEL: func @zero_d_dealloc // BAREPTR-LABEL: func @zero_d_dealloc(%{{.*}}: !llvm<"float*">) { func @zero_d_dealloc(%arg0: memref) { -// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64 }*"> -// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][0] : !llvm<"{ float*, float*, i64 }"> +// CHECK: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, float*, i64 }"> // CHECK-NEXT: %[[bc:.*]] = llvm.bitcast %[[ptr]] : !llvm<"float*"> to !llvm<"i8*"> // CHECK-NEXT: llvm.call @free(%[[bc]]) : (!llvm<"i8*">) -> () @@ -96,7 +98,7 @@ // CHECK-LABEL: func @aligned_1d_alloc( // BAREPTR-LABEL: func @aligned_1d_alloc( func @aligned_1d_alloc() -> memref<42xf32> { -// CHECK-NEXT: llvm.mlir.constant(42 : index) : !llvm.i64 +// CHECK: llvm.mlir.constant(42 : index) : !llvm.i64 // CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*"> // CHECK-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> @@ -150,13 +152,13 @@ // CHECK-LABEL: func @static_alloc() -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> { // BAREPTR-LABEL: func @static_alloc() -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> { func @static_alloc() -> memref<32x18xf32> { -// CHECK-NEXT: %[[sz1:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64 +// CHECK: %[[sz1:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64 // CHECK-NEXT: %[[sz2:.*]] = llvm.mlir.constant(18 : index) : !llvm.i64 // CHECK-NEXT: %[[num_elems:.*]] = llvm.mul %0, %1 : !llvm.i64 // CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*"> // CHECK-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> -// CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64 +// CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64 // CHECK-NEXT: %[[bytes:.*]] = llvm.mul %[[num_elems]], %[[sizeof]] : !llvm.i64 // CHECK-NEXT: %[[allocated:.*]] = llvm.call @malloc(%[[bytes]]) : (!llvm.i64) -> !llvm<"i8*"> // CHECK-NEXT: llvm.bitcast %[[allocated]] : !llvm<"i8*"> to !llvm<"float*"> @@ -177,11 +179,10 @@ // ----- -// CHECK-LABEL: func @static_dealloc(%{{.*}}: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">) { +// CHECK-LABEL: func @static_dealloc // BAREPTR-LABEL: func @static_dealloc(%{{.*}}: !llvm<"float*">) { func @static_dealloc(%static: memref<10x8xf32>) { -// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> -// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> // CHECK-NEXT: %[[bc:.*]] = llvm.bitcast %[[ptr]] : !llvm<"float*"> to !llvm<"i8*"> // CHECK-NEXT: llvm.call @free(%[[bc]]) : (!llvm<"i8*">) -> () @@ -194,11 +195,10 @@ // ----- -// CHECK-LABEL: func @zero_d_load(%{{.*}}: !llvm<"{ float*, float*, i64 }*">) -> !llvm.float { +// CHECK-LABEL: func @zero_d_load // BAREPTR-LABEL: func @zero_d_load(%{{.*}}: !llvm<"float*">) -> !llvm.float func @zero_d_load(%arg0: memref) -> f32 { -// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64 }*"> -// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64 }"> +// CHECK: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, float*, i64 }"> // CHECK-NEXT: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[c0]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> // CHECK-NEXT: %{{.*}} = llvm.load %[[addr]] : !llvm<"float*"> @@ -214,20 +214,22 @@ // ----- // CHECK-LABEL: func @static_load( -// CHECK-SAME: %[[A:.*]]: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">, %[[I:.*]]: !llvm.i64, %[[J:.*]]: !llvm.i64 +// CHECK-COUNT-2: !llvm<"float*">, +// CHECK-COUNT-5: {{%[a-zA-Z0-9]*}}: !llvm.i64 +// CHECK: %[[I:.*]]: !llvm.i64, +// CHECK: %[[J:.*]]: !llvm.i64) // BAREPTR-LABEL: func @static_load // BAREPTR-SAME: (%[[A:.*]]: !llvm<"float*">, %[[I:.*]]: !llvm.i64, %[[J:.*]]: !llvm.i64) { func @static_load(%static : memref<10x42xf32>, %i : index, %j : index) { -// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> -// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> -// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK-NEXT: %[[st0:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64 -// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64 -// CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64 -// CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64 -// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64 -// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> +// CHECK: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK-NEXT: %[[st0:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64 +// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64 +// CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64 +// CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64 +// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64 +// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> // CHECK-NEXT: llvm.load %[[addr]] : !llvm<"float*"> // BAREPTR: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> @@ -246,15 +248,14 @@ // ----- -// CHECK-LABEL: func @zero_d_store(%arg0: !llvm<"{ float*, float*, i64 }*">, %arg1: !llvm.float) { +// CHECK-LABEL: func @zero_d_store // BAREPTR-LABEL: func @zero_d_store // BAREPTR-SAME: (%[[A:.*]]: !llvm<"float*">, %[[val:.*]]: !llvm.float) func @zero_d_store(%arg0: memref, %arg1: f32) { -// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64 }*"> -// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64 }"> +// CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm<"{ float*, float*, i64 }"> // CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> -// CHECK-NEXT: llvm.store %arg1, %[[addr]] : !llvm<"float*"> +// CHECK-NEXT: llvm.store %{{.*}}, %[[addr]] : !llvm<"float*"> // BAREPTR: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, float*, i64 }"> // BAREPTR-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 @@ -270,17 +271,16 @@ // BAREPTR-LABEL: func @static_store // BAREPTR-SAME: %[[A:.*]]: !llvm<"float*"> func @static_store(%static : memref<10x42xf32>, %i : index, %j : index, %val : f32) { -// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> -// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> -// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK-NEXT: %[[st0:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64 -// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64 -// CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64 -// CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64 -// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64 -// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> -// CHECK-NEXT: llvm.store %arg3, %[[addr]] : !llvm<"float*"> +// CHECK: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK-NEXT: %[[st0:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64 +// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64 +// CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64 +// CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64 +// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64 +// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> +// CHECK-NEXT: llvm.store %{{.*}}, %[[addr]] : !llvm<"float*"> // BAREPTR: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> // BAREPTR-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 @@ -298,11 +298,10 @@ // ----- -// CHECK-LABEL: func @static_memref_dim(%arg0: !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }*">) { +// CHECK-LABEL: func @static_memref_dim // BAREPTR-LABEL: func @static_memref_dim(%{{.*}}: !llvm<"float*">) { func @static_memref_dim(%static : memref<42x32x15x13x27xf32>) { -// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }*"> -// CHECK-NEXT: llvm.mlir.constant(42 : index) : !llvm.i64 +// CHECK: llvm.mlir.constant(42 : index) : !llvm.i64 // BAREPTR: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> // BAREPTR-NEXT: llvm.mlir.constant(42 : index) : !llvm.i64 %0 = dim %static, 0 : memref<42x32x15x13x27xf32> diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir @@ -728,9 +728,15 @@ } // CHECK-LABEL: func @subview( -// CHECK: %[[MEMREFPTR:.*]]: !llvm<{{.*}}>, %[[ARG0:.*]]: !llvm.i64, %[[ARG1:.*]]: !llvm.i64, %[[ARG2:.*]]: !llvm.i64 +// CHECK-COUNT-2: !llvm<"float*">, +// CHECK-COUNT-5: {{%[a-zA-Z0-9]*}}: !llvm.i64, +// CHECK: %[[ARG0:[a-zA-Z0-9]*]]: !llvm.i64, +// CHECK: %[[ARG1:[a-zA-Z0-9]*]]: !llvm.i64, +// CHECK: %[[ARG2:.*]]: !llvm.i64) func @subview(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>>, %arg0 : index, %arg1 : index, %arg2 : index) { - // CHECK: %[[MEMREF:.*]] = llvm.load %[[MEMREFPTR]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> + // The last "insertvalue" that populates the memref descriptor from the function arguments. + // CHECK: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] + // CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> // CHECK: %[[DESC0:.*]] = llvm.insertvalue %{{.*}}, %[[DESC]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> // CHECK: %[[DESC1:.*]] = llvm.insertvalue %{{.*}}, %[[DESC0]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> @@ -754,9 +760,10 @@ } // CHECK-LABEL: func @subview_const_size( -// CHECK: %[[MEMREFPTR:.*]]: !llvm<{{.*}}>, %[[ARG0:.*]]: !llvm.i64, %[[ARG1:.*]]: !llvm.i64, %[[ARG2:.*]]: !llvm.i64 func @subview_const_size(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>>, %arg0 : index, %arg1 : index, %arg2 : index) { - // CHECK: %[[MEMREF:.*]] = llvm.load %[[MEMREFPTR]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> + // The last "insertvalue" that populates the memref descriptor from the function arguments. + // CHECK: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] + // CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> // CHECK: %[[DESC0:.*]] = llvm.insertvalue %{{.*}}, %[[DESC]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> // CHECK: %[[DESC1:.*]] = llvm.insertvalue %{{.*}}, %[[DESC0]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> @@ -782,9 +789,10 @@ } // CHECK-LABEL: func @subview_const_stride( -// CHECK: %[[MEMREFPTR:.*]]: !llvm<{{.*}}>, %[[ARG0:.*]]: !llvm.i64, %[[ARG1:.*]]: !llvm.i64, %[[ARG2:.*]]: !llvm.i64 func @subview_const_stride(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>>, %arg0 : index, %arg1 : index, %arg2 : index) { - // CHECK: %[[MEMREF:.*]] = llvm.load %[[MEMREFPTR]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> + // The last "insertvalue" that populates the memref descriptor from the function arguments. + // CHECK: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] + // CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> // CHECK: %[[DESC0:.*]] = llvm.insertvalue %{{.*}}, %[[DESC]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> // CHECK: %[[DESC1:.*]] = llvm.insertvalue %{{.*}}, %[[DESC0]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> diff --git a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir --- a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir +++ b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir @@ -1,8 +1,7 @@ // RUN: mlir-opt %s -convert-std-to-llvm -split-input-file -verify-diagnostics | FileCheck %s // CHECK-LABEL: func @address_space( -// CHECK: %{{.*}}: !llvm<"{ float addrspace(7)*, float addrspace(7)*, i64, [1 x i64], [1 x i64] }*">) -// CHECK: llvm.load %{{.*}} : !llvm<"{ float addrspace(7)*, float addrspace(7)*, i64, [1 x i64], [1 x i64] }*"> +// CHECK-SAME: !llvm<"float addrspace(7)*"> func @address_space(%arg0 : memref<32xf32, affine_map<(d0) -> (d0)>, 7>) { %0 = alloc() : memref<32xf32, affine_map<(d0) -> (d0)>, 5> %1 = constant 7 : index diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir --- a/mlir/test/Dialect/GPU/invalid.mlir +++ b/mlir/test/Dialect/GPU/invalid.mlir @@ -175,23 +175,21 @@ // ----- -gpu.module @kernels { - gpu.func @kernel_1(%arg1 : !llvm<"float*">) attributes { gpu.kernel } { - gpu.return +module attributes {gpu.container_module} { + gpu.module @kernels { + gpu.func @kernel_1(%arg1 : f32) attributes { gpu.kernel } { + gpu.return + } } -} -// Due to the ordering of the current impl of lowering and LLVMLowering, type -// checks need to be temporarily disabled. -// TODO(ntv,zinenko,herhut): reactivate checks once "changing gpu.launchFunc -// to encode target module" has landed. -// func @launch_func_kernel_operand_types(%sz : index, %arg : f32) { -// // expected-err@+1 {{type of function argument 0 does not match}} -// "gpu.launch_func"(%sz, %sz, %sz, %sz, %sz, %sz, %arg) -// {kernel = "kernel_1"} -// : (index, index, index, index, index, index, f32) -> () -// return -// } + func @launch_func_kernel_operand_types(%sz : index, %arg : f32) { + // expected-err@+1 {{type of function argument 0 does not match}} + "gpu.launch_func"(%sz, %sz, %sz, %sz, %sz, %sz, %arg) + {kernel = "kernel_1", kernel_module = @kernels} + : (index, index, index, index, index, index, f32) -> () + return + } +} // ----- diff --git a/mlir/test/Dialect/Linalg/llvm.mlir b/mlir/test/Dialect/Linalg/llvm.mlir --- a/mlir/test/Dialect/Linalg/llvm.mlir +++ b/mlir/test/Dialect/Linalg/llvm.mlir @@ -52,9 +52,11 @@ linalg.dot(%arg0, %arg1, %arg2) : memref, memref, memref return } -// CHECK-LABEL: func @dot(%{{.*}}: !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">, %{{.*}}: !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">, %{{.*}}: !llvm<"{ float*, float*, i64 }*">) { -// CHECK-COUNT-3: llvm.mlir.constant(1 : index){{.*[[:space:]].*}}llvm.alloca{{.*[[:space:]].*}}llvm.store -// CHECK-NEXT: llvm.call @linalg_dot_viewsxf32_viewsxf32_viewf32(%{{.*}}, %{{.*}}, %{{.*}}) : (!llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">, !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">, !llvm<"{ float*, float*, i64 }*">) -> () +// CHECK-LABEL: func @dot +// CHECK: llvm.call @linalg_dot_viewsxf32_viewsxf32_viewf32(%{{.*}}) : +// CHECK-SAME: !llvm<"float*">, !llvm<"float*">, !llvm.i64, !llvm.i64, !llvm.i64 +// CHECK-SAME: !llvm<"float*">, !llvm<"float*">, !llvm.i64, !llvm.i64, !llvm.i64 +// CHECK-SAME: !llvm<"float*">, !llvm<"float*">, !llvm.i64 func @slice_with_range_and_index(%arg0: memref) { %c0 = constant 0 : index @@ -83,7 +85,9 @@ return } // CHECK-LABEL: func @copy -// CHECK: llvm.call @linalg_copy_viewsxsxsxf32_viewsxsxsxf32(%{{.*}}, %{{.*}}) : (!llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*">, !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*">) -> () +// CHECK: llvm.call @linalg_copy_viewsxsxsxf32_viewsxsxsxf32({{.*}}) : +// CHECK-SAME: !llvm<"float*">, !llvm<"float*">, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64 +// CHECK-SAME: !llvm<"float*">, !llvm<"float*">, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64 func @transpose(%arg0: memref) { %0 = linalg.transpose %arg0 (i, j, k) -> (k, i, j) : memref @@ -128,9 +132,8 @@ // CHECK: llvm.insertvalue {{.*}}[3, 1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> // CHECK: llvm.extractvalue {{.*}}[3, 2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> // CHECK: llvm.insertvalue {{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> -// Call external copy after promoting input and output structs to pointers -// CHECK-COUNT-2: llvm.mlir.constant(1 : index){{.*[[:space:]].*}}llvm.alloca{{.*[[:space:]].*}}llvm.store -// CHECK: llvm.call @linalg_copy_viewsxsxsxf32_viewsxsxsxf32(%{{.*}}, %{{.*}}) : (!llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*">, !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*">) -> () +// Call external copy. +// CHECK: llvm.call @linalg_copy_viewsxsxsxf32_viewsxsxsxf32 #matmul_accesses = [ affine_map<(m, n, k) -> (m, k)>, @@ -163,7 +166,10 @@ return } // CHECK-LABEL: func @matmul_vec_impl( -// CHECK: llvm.call @external_outerproduct_matmul(%{{.*}}) : (!llvm<"{ <4 x float>*, <4 x float>*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ <4 x float>*, <4 x float>*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ [4 x <4 x float>]*, [4 x <4 x float>]*, i64, [2 x i64], [2 x i64] }*">) -> () +// CHECK: llvm.call @external_outerproduct_matmul(%{{.*}}) : +// CHECK-SAME: !llvm<"<4 x float>*">, !llvm<"<4 x float>*">, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64 +// CHECK-SAME: !llvm<"<4 x float>*">, !llvm<"<4 x float>*">, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64 +// CHECK-SAME: !llvm<"[4 x <4 x float>]*">, !llvm<"[4 x <4 x float>]*">, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64 // LLVM-LOOPS-LABEL: func @matmul_vec_impl( // LLVM-LOOPS: llvm.shufflevector {{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>"> @@ -195,7 +201,10 @@ } // CHECK-LABEL: func @matmul_vec_indexed( // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK: llvm.call @external_indexed_outerproduct_matmul(%[[ZERO]], %[[ZERO]], %[[ZERO]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.i64, !llvm.i64, !llvm.i64, !llvm<"{ <4 x float>*, <4 x float>*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ <4 x float>*, <4 x float>*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ [4 x <4 x float>]*, [4 x <4 x float>]*, i64, [2 x i64], [2 x i64] }*">) -> () +// CHECK: llvm.call @external_indexed_outerproduct_matmul(%[[ZERO]], %[[ZERO]], %[[ZERO]], %{{.*}}) : +// CHECK-SAME: !llvm<"<4 x float>*">, !llvm<"<4 x float>*">, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64 +// CHECK-SAME: !llvm<"<4 x float>*">, !llvm<"<4 x float>*">, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64 +// CHECK-SAME: !llvm<"[4 x <4 x float>]*">, !llvm<"[4 x <4 x float>]*">, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64 func @reshape_static(%arg0: memref<3x4x5xf32>) { // Reshapes that expand and collapse back a contiguous tensor with some 1's. diff --git a/mlir/test/mlir-cpu-runner/cblas_interface.cpp b/mlir/test/mlir-cpu-runner/cblas_interface.cpp --- a/mlir/test/mlir-cpu-runner/cblas_interface.cpp +++ b/mlir/test/mlir-cpu-runner/cblas_interface.cpp @@ -15,32 +15,35 @@ #include #include -extern "C" void linalg_fill_viewf32_f32(StridedMemRefType *X, - float f) { +extern "C" void +_mlir_ciface_linalg_fill_viewf32_f32(StridedMemRefType *X, float f) { X->data[X->offset] = f; } -extern "C" void linalg_fill_viewsxf32_f32(StridedMemRefType *X, - float f) { +extern "C" void +_mlir_ciface_linalg_fill_viewsxf32_f32(StridedMemRefType *X, + float f) { for (unsigned i = 0; i < X->sizes[0]; ++i) *(X->data + X->offset + i * X->strides[0]) = f; } -extern "C" void linalg_fill_viewsxsxf32_f32(StridedMemRefType *X, - float f) { +extern "C" void +_mlir_ciface_linalg_fill_viewsxsxf32_f32(StridedMemRefType *X, + float f) { for (unsigned i = 0; i < X->sizes[0]; ++i) for (unsigned j = 0; j < X->sizes[1]; ++j) *(X->data + X->offset + i * X->strides[0] + j * X->strides[1]) = f; } -extern "C" void linalg_copy_viewf32_viewf32(StridedMemRefType *I, - StridedMemRefType *O) { +extern "C" void +_mlir_ciface_linalg_copy_viewf32_viewf32(StridedMemRefType *I, + StridedMemRefType *O) { O->data[O->offset] = I->data[I->offset]; } extern "C" void -linalg_copy_viewsxf32_viewsxf32(StridedMemRefType *I, - StridedMemRefType *O) { +_mlir_ciface_linalg_copy_viewsxf32_viewsxf32(StridedMemRefType *I, + StridedMemRefType *O) { if (I->sizes[0] != O->sizes[0]) { std::cerr << "Incompatible strided memrefs\n"; printMemRefMetaData(std::cerr, *I); @@ -52,9 +55,8 @@ I->data[I->offset + i * I->strides[0]]; } -extern "C" void -linalg_copy_viewsxsxf32_viewsxsxf32(StridedMemRefType *I, - StridedMemRefType *O) { +extern "C" void _mlir_ciface_linalg_copy_viewsxsxf32_viewsxsxf32( + StridedMemRefType *I, StridedMemRefType *O) { if (I->sizes[0] != O->sizes[0] || I->sizes[1] != O->sizes[1]) { std::cerr << "Incompatible strided memrefs\n"; printMemRefMetaData(std::cerr, *I); @@ -69,10 +71,9 @@ I->data[I->offset + i * si0 + j * si1]; } -extern "C" void -linalg_dot_viewsxf32_viewsxf32_viewf32(StridedMemRefType *X, - StridedMemRefType *Y, - StridedMemRefType *Z) { +extern "C" void _mlir_ciface_linalg_dot_viewsxf32_viewsxf32_viewf32( + StridedMemRefType *X, StridedMemRefType *Y, + StridedMemRefType *Z) { if (X->strides[0] != 1 || Y->strides[0] != 1 || X->sizes[0] != Y->sizes[0]) { std::cerr << "Incompatible strided memrefs\n"; printMemRefMetaData(std::cerr, *X); @@ -85,7 +86,7 @@ Y->data + Y->offset, Y->strides[0]); } -extern "C" void linalg_matmul_viewsxsxf32_viewsxsxf32_viewsxsxf32( +extern "C" void _mlir_ciface_linalg_matmul_viewsxsxf32_viewsxsxf32_viewsxsxf32( StridedMemRefType *A, StridedMemRefType *B, StridedMemRefType *C) { if (A->strides[1] != B->strides[1] || A->strides[1] != C->strides[1] || diff --git a/mlir/test/mlir-cpu-runner/include/cblas_interface.h b/mlir/test/mlir-cpu-runner/include/cblas_interface.h --- a/mlir/test/mlir-cpu-runner/include/cblas_interface.h +++ b/mlir/test/mlir-cpu-runner/include/cblas_interface.h @@ -25,33 +25,34 @@ #endif // _WIN32 extern "C" MLIR_CBLAS_INTERFACE_EXPORT void -linalg_fill_viewf32_f32(StridedMemRefType *X, float f); +_mlir_ciface_linalg_fill_viewf32_f32(StridedMemRefType *X, float f); extern "C" MLIR_CBLAS_INTERFACE_EXPORT void -linalg_fill_viewsxf32_f32(StridedMemRefType *X, float f); +_mlir_ciface_linalg_fill_viewsxf32_f32(StridedMemRefType *X, float f); extern "C" MLIR_CBLAS_INTERFACE_EXPORT void -linalg_fill_viewsxsxf32_f32(StridedMemRefType *X, float f); +_mlir_ciface_linalg_fill_viewsxsxf32_f32(StridedMemRefType *X, + float f); extern "C" MLIR_CBLAS_INTERFACE_EXPORT void -linalg_copy_viewf32_viewf32(StridedMemRefType *I, - StridedMemRefType *O); +_mlir_ciface_linalg_copy_viewf32_viewf32(StridedMemRefType *I, + StridedMemRefType *O); extern "C" MLIR_CBLAS_INTERFACE_EXPORT void -linalg_copy_viewsxf32_viewsxf32(StridedMemRefType *I, - StridedMemRefType *O); +_mlir_ciface_linalg_copy_viewsxf32_viewsxf32(StridedMemRefType *I, + StridedMemRefType *O); extern "C" MLIR_CBLAS_INTERFACE_EXPORT void -linalg_copy_viewsxsxf32_viewsxsxf32(StridedMemRefType *I, - StridedMemRefType *O); +_mlir_ciface_linalg_copy_viewsxsxf32_viewsxsxf32( + StridedMemRefType *I, StridedMemRefType *O); extern "C" MLIR_CBLAS_INTERFACE_EXPORT void -linalg_dot_viewsxf32_viewsxf32_viewf32(StridedMemRefType *X, - StridedMemRefType *Y, - StridedMemRefType *Z); +_mlir_ciface_linalg_dot_viewsxf32_viewsxf32_viewf32( + StridedMemRefType *X, StridedMemRefType *Y, + StridedMemRefType *Z); extern "C" MLIR_CBLAS_INTERFACE_EXPORT void -linalg_matmul_viewsxsxf32_viewsxsxf32_viewsxsxf32( +_mlir_ciface_linalg_matmul_viewsxsxf32_viewsxsxf32_viewsxsxf32( StridedMemRefType *A, StridedMemRefType *B, StridedMemRefType *C); diff --git a/mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h b/mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h --- a/mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h +++ b/mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h @@ -261,23 +261,27 @@ // Currently exposed C API. //////////////////////////////////////////////////////////////////////////////// extern "C" MLIR_RUNNER_UTILS_EXPORT void -print_memref_i8(UnrankedMemRefType *M); +_mlir_ciface_print_memref_i8(UnrankedMemRefType *M); extern "C" MLIR_RUNNER_UTILS_EXPORT void -print_memref_f32(UnrankedMemRefType *M); +_mlir_ciface_print_memref_f32(UnrankedMemRefType *M); + +extern "C" MLIR_RUNNER_UTILS_EXPORT void print_memref_f32(int64_t rank, + void *ptr); extern "C" MLIR_RUNNER_UTILS_EXPORT void -print_memref_0d_f32(StridedMemRefType *M); +_mlir_ciface_print_memref_0d_f32(StridedMemRefType *M); extern "C" MLIR_RUNNER_UTILS_EXPORT void -print_memref_1d_f32(StridedMemRefType *M); +_mlir_ciface_print_memref_1d_f32(StridedMemRefType *M); extern "C" MLIR_RUNNER_UTILS_EXPORT void -print_memref_2d_f32(StridedMemRefType *M); +_mlir_ciface_print_memref_2d_f32(StridedMemRefType *M); extern "C" MLIR_RUNNER_UTILS_EXPORT void -print_memref_3d_f32(StridedMemRefType *M); +_mlir_ciface_print_memref_3d_f32(StridedMemRefType *M); extern "C" MLIR_RUNNER_UTILS_EXPORT void -print_memref_4d_f32(StridedMemRefType *M); +_mlir_ciface_print_memref_4d_f32(StridedMemRefType *M); extern "C" MLIR_RUNNER_UTILS_EXPORT void -print_memref_vector_4x4xf32(StridedMemRefType, 2> *M); +_mlir_ciface_print_memref_vector_4x4xf32( + StridedMemRefType, 2> *M); // Small runtime support "lib" for vector.print lowering. extern "C" MLIR_RUNNER_UTILS_EXPORT void print_f32(float f); diff --git a/mlir/test/mlir-cpu-runner/mlir_runner_utils.cpp b/mlir/test/mlir-cpu-runner/mlir_runner_utils.cpp --- a/mlir/test/mlir-cpu-runner/mlir_runner_utils.cpp +++ b/mlir/test/mlir-cpu-runner/mlir_runner_utils.cpp @@ -16,8 +16,8 @@ #include #include -extern "C" void -print_memref_vector_4x4xf32(StridedMemRefType, 2> *M) { +extern "C" void _mlir_ciface_print_memref_vector_4x4xf32( + StridedMemRefType, 2> *M) { impl::printMemRef(*M); } @@ -26,7 +26,7 @@ impl::printMemRef(*(static_cast *>(ptr))); \ break -extern "C" void print_memref_i8(UnrankedMemRefType *M) { +extern "C" void _mlir_ciface_print_memref_i8(UnrankedMemRefType *M) { printUnrankedMemRefMetaData(std::cout, *M); int rank = M->rank; void *ptr = M->descriptor; @@ -42,7 +42,7 @@ } } -extern "C" void print_memref_f32(UnrankedMemRefType *M) { +extern "C" void _mlir_ciface_print_memref_f32(UnrankedMemRefType *M) { printUnrankedMemRefMetaData(std::cout, *M); int rank = M->rank; void *ptr = M->descriptor; @@ -58,19 +58,31 @@ } } -extern "C" void print_memref_0d_f32(StridedMemRefType *M) { +extern "C" void print_memref_f32(int64_t rank, void *ptr) { + UnrankedMemRefType descriptor; + descriptor.rank = rank; + descriptor.descriptor = ptr; + _mlir_ciface_print_memref_f32(&descriptor); +} + +extern "C" void +_mlir_ciface_print_memref_0d_f32(StridedMemRefType *M) { impl::printMemRef(*M); } -extern "C" void print_memref_1d_f32(StridedMemRefType *M) { +extern "C" void +_mlir_ciface_print_memref_1d_f32(StridedMemRefType *M) { impl::printMemRef(*M); } -extern "C" void print_memref_2d_f32(StridedMemRefType *M) { +extern "C" void +_mlir_ciface_print_memref_2d_f32(StridedMemRefType *M) { impl::printMemRef(*M); } -extern "C" void print_memref_3d_f32(StridedMemRefType *M) { +extern "C" void +_mlir_ciface_print_memref_3d_f32(StridedMemRefType *M) { impl::printMemRef(*M); } -extern "C" void print_memref_4d_f32(StridedMemRefType *M) { +extern "C" void +_mlir_ciface_print_memref_4d_f32(StridedMemRefType *M) { impl::printMemRef(*M); } diff --git a/mlir/test/mlir-cuda-runner/gpu-to-cubin.mlir b/mlir/test/mlir-cuda-runner/gpu-to-cubin.mlir --- a/mlir/test/mlir-cuda-runner/gpu-to-cubin.mlir +++ b/mlir/test/mlir-cuda-runner/gpu-to-cubin.mlir @@ -17,12 +17,13 @@ %21 = constant 5 : i32 %22 = memref_cast %arg0 : memref<5xf32> to memref call @mcuMemHostRegisterMemRef1dFloat(%22) : (memref) -> () - call @print_memref_1d_f32(%22) : (memref) -> () + %23 = memref_cast %22 : memref to memref<*xf32> + call @print_memref_f32(%23) : (memref<*xf32>) -> () %24 = constant 1.0 : f32 call @other_func(%24, %22) : (f32, memref) -> () - call @print_memref_1d_f32(%22) : (memref) -> () + call @print_memref_f32(%23) : (memref<*xf32>) -> () return } func @mcuMemHostRegisterMemRef1dFloat(%ptr : memref) -func @print_memref_1d_f32(memref) +func @print_memref_f32(%ptr : memref<*xf32>) diff --git a/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp b/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp --- a/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp +++ b/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp @@ -96,11 +96,34 @@ std::fill_n(arg->data, count, value); mcuMemHostRegister(arg->data, count * sizeof(T)); } -extern "C" void -mcuMemHostRegisterMemRef1dFloat(const MemRefType *arg) { - mcuMemHostRegisterMemRef(arg, 1.23f); + +extern "C" void mcuMemHostRegisterMemRef1dFloat(float *allocated, + float *aligned, int64_t offset, + int64_t size, int64_t stride) { + MemRefType descriptor; + descriptor.basePtr = allocated; + descriptor.data = aligned; + descriptor.offset = offset; + descriptor.sizes[0] = size; + descriptor.strides[0] = stride; + mcuMemHostRegisterMemRef(&descriptor, 1.23f); } -extern "C" void -mcuMemHostRegisterMemRef3dFloat(const MemRefType *arg) { - mcuMemHostRegisterMemRef(arg, 1.23f); + +extern "C" void mcuMemHostRegisterMemRef3dFloat(float *allocated, + float *aligned, int64_t offset, + int64_t size0, int64_t size1, + int64_t size2, int64_t stride0, + int64_t stride1, + int64_t stride2) { + MemRefType descriptor; + descriptor.basePtr = allocated; + descriptor.data = aligned; + descriptor.offset = offset; + descriptor.sizes[0] = size0; + descriptor.strides[0] = stride0; + descriptor.sizes[1] = size1; + descriptor.strides[1] = stride1; + descriptor.sizes[2] = size2; + descriptor.strides[2] = stride2; + mcuMemHostRegisterMemRef(&descriptor, 1.23f); }