diff --git a/mlir/docs/ConversionToLLVMDialect.md b/mlir/docs/ConversionToLLVMDialect.md --- a/mlir/docs/ConversionToLLVMDialect.md +++ b/mlir/docs/ConversionToLLVMDialect.md @@ -246,7 +246,7 @@ } ``` -### Calling Convention for `memref` +### Calling Convention for Ranked `memref` Function _arguments_ of `memref` type, ranked or unranked, are _expanded_ into a list of arguments of non-aggregate types that the memref descriptor defined @@ -317,7 +317,9 @@ ``` -For **unranked** memrefs, the list of function arguments always contains two +### Calling Convention for Unranked `memref` + +For unranked memrefs, the list of function arguments always contains two elements, same as the unranked memref descriptor: an integer rank, and a type-erased (`!llvm<"i8*">`) pointer to the ranked memref descriptor. Note that while the _calling convention_ does not require stack allocation, _casting_ to @@ -369,6 +371,20 @@ } ``` +**Lifetime.** The second element of the unranked memref descriptor points to +some memory in which the ranked memref descriptor is stored. By convention, this +memory is allocated on stack and has the lifetime of the function. (*Note:* due +to function-length lifetime, creation of multiple unranked memref descriptors, +e.g., in a loop, may lead to stack overflows.) If an unranked descriptor has to +be returned from a function, the ranked descriptor it points to is copied into +dynamically allocated memory, and the pointer in the unranked descriptor is +updated accodingly. The allocation happens immediately before returning. It is +the responsibility of the caller to free the dynamically allocated memory. The +default conversion of `std.call` and `std.call_indirect` copies the ranked +descriptor to newly allocated memory on the caller's stack. Thus, the convention +of the ranked memref descriptor pointed to by an unranked memref descriptor +being stored on stack is respected. + *This convention may or may not apply if the conversion of MemRef types is overridden by the user.* diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -129,6 +129,9 @@ /// Gets the bitwidth of the index type when converted to LLVM. unsigned getIndexTypeBitwidth() { return customizations.indexBitwidth; } + /// Gets the pointer bitwidth. + unsigned getPointerBitwidth(unsigned addressSpace = 0); + protected: /// LLVM IR module used to parse/create types. llvm::Module *module; @@ -386,6 +389,13 @@ /// Returns the number of non-aggregate values that would be produced by /// `unpack`. static unsigned getNumUnpackedValues() { return 2; } + + /// Builds IR computing the sizes in bytes (suitable for opaque allocation) + /// and appends the corresponding values into `sizes`. + static void computeSizes(OpBuilder &builder, Location loc, + LLVMTypeConverter &typeConverter, + ArrayRef values, + SmallVectorImpl &sizes); }; /// Base class for operation conversions targeting the LLVM IR dialect. Provides diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -794,6 +794,13 @@ def LLVM_BitReverseOp : LLVM_UnaryIntrinsicOp<"bitreverse">; def LLVM_CtPopOp : LLVM_UnaryIntrinsicOp<"ctpop">; +def LLVM_MemcpyOp : LLVM_ZeroResultIntrOp<"memcpy", [0, 1, 2]>, + Arguments<(ins LLVM_Type:$dst, LLVM_Type:$src, + LLVM_Type:$len, LLVM_Type:$isVolatile)>; +def LLVM_MemcpyInlineOp : LLVM_ZeroResultIntrOp<"memcpy.inline", [0, 1, 2]>, + Arguments<(ins LLVM_Type:$dst, LLVM_Type:$src, + LLVM_Type:$len, LLVM_Type:$isVolatile)>; + // // Vector Reductions. // diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -24,6 +24,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Support/LogicalResult.h" +#include "mlir/Support/MathExtras.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/Passes.h" #include "mlir/Transforms/Utils.h" @@ -184,6 +185,10 @@ return LLVM::LLVMType::getIntNTy(llvmDialect, getIndexTypeBitwidth()); } +unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace) { + return module->getDataLayout().getPointerSizeInBits(addressSpace); +} + Type LLVMTypeConverter::convertIndexType(IndexType type) { return getIndexType(); } @@ -769,6 +774,51 @@ results.push_back(d.memRefDescPtr(builder, loc)); } +void UnrankedMemRefDescriptor::computeSizes( + OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, + ArrayRef values, SmallVectorImpl &sizes) { + if (values.empty()) + return; + + // Cache the index type. + LLVM::LLVMType indexType = typeConverter.getIndexType(); + + // Initialize shared constants. + Value one = createIndexAttrConstant(builder, loc, indexType, 1); + Value two = createIndexAttrConstant(builder, loc, indexType, 2); + Value pointerSize = createIndexAttrConstant( + builder, loc, indexType, ceilDiv(typeConverter.getPointerBitwidth(), 8)); + Value indexSize = + createIndexAttrConstant(builder, loc, indexType, + ceilDiv(typeConverter.getIndexTypeBitwidth(), 8)); + + sizes.reserve(sizes.size() + values.size()); + for (UnrankedMemRefDescriptor desc : values) { + // Emit IR computing the memory necessary to store the descriptor. This + // assumes the descriptor to be + // { type*, type*, index, index[rank], index[rank] } + // and densely packed, so the total size is + // 2 * sizeof(pointer) + (1 + 2 * rank) * sizeof(index). + // TODO: consider including the actual size (including eventual padding due + // to data layout) into the unranked descriptor. + Value doublePointerSize = + builder.create(loc, indexType, two, pointerSize); + + // (1 + 2 * rank) * sizeof(index) + Value rank = desc.rank(builder, loc); + Value doubleRank = builder.create(loc, indexType, two, rank); + Value doubleRankIncremented = + builder.create(loc, indexType, doubleRank, one); + Value rankIndexSize = builder.create( + loc, indexType, doubleRankIncremented, indexSize); + + // Total allocation size. + Value allocationSize = builder.create( + loc, indexType, doublePointerSize, rankIndexSize); + sizes.push_back(allocationSize); + } +} + LLVM::LLVMDialect &ConvertToLLVMPattern::getDialect() const { return *typeConverter.getDialect(); } @@ -1863,6 +1913,104 @@ using AllocaOpLowering = AllocLikeOpLowering; +/// Copies the shaped descriptor part to (if `toDynamic` is set) or from +/// (otherwise) the dynamically allocated memory for any operands that were +/// unranked descriptors originally. +static LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc, + LLVMTypeConverter &typeConverter, + TypeRange origTypes, + SmallVectorImpl &operands, + bool toDynamic) { + assert(origTypes.size() == operands.size() && + "expected as may original types as operands"); + + // Find operands of unranked memref type and store them. + SmallVector unrankedMemrefs; + for (unsigned i = 0, e = operands.size(); i < e; ++i) { + if (!origTypes[i].isa()) + continue; + unrankedMemrefs.emplace_back(operands[i]); + } + + if (unrankedMemrefs.empty()) + return success(); + + // Compute allocation sizes. + SmallVector sizes; + UnrankedMemRefDescriptor::computeSizes(builder, loc, typeConverter, + unrankedMemrefs, sizes); + + // Get frequently used types. + auto voidType = LLVM::LLVMType::getVoidTy(typeConverter.getDialect()); + auto voidPtrType = LLVM::LLVMType::getInt8PtrTy(typeConverter.getDialect()); + auto i1Type = LLVM::LLVMType::getInt1Ty(typeConverter.getDialect()); + LLVM::LLVMType indexType = typeConverter.getIndexType(); + + // Find the malloc and free, or declare them if necessary. + auto module = builder.getInsertionPoint()->getParentOfType(); + auto mallocFunc = module.lookupSymbol("malloc"); + if (!mallocFunc && toDynamic) { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(module.getBody()); + mallocFunc = builder.create( + builder.getUnknownLoc(), "malloc", + LLVM::LLVMType::getFunctionTy( + voidPtrType, llvm::makeArrayRef(indexType), /*isVarArg=*/false)); + } + auto freeFunc = module.lookupSymbol("free"); + if (!freeFunc && !toDynamic) { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(module.getBody()); + freeFunc = builder.create( + builder.getUnknownLoc(), "free", + LLVM::LLVMType::getFunctionTy(voidType, llvm::makeArrayRef(voidPtrType), + /*isVarArg=*/false)); + } + + // Initialize shared constants. + Value zero = + builder.create(loc, i1Type, builder.getBoolAttr(false)); + + unsigned unrankedMemrefPos = 0; + for (unsigned i = 0, e = operands.size(); i < e; ++i) { + Type type = origTypes[i]; + if (!type.isa()) + continue; + Value allocationSize = sizes[unrankedMemrefPos++]; + UnrankedMemRefDescriptor desc(operands[i]); + + // Allocate memory, copy, and free the source if necessary. + Value memory = + toDynamic + ? builder.create(loc, mallocFunc, allocationSize) + .getResult(0) + : builder.create(loc, voidPtrType, allocationSize, + /*alignment=*/0); + + Value source = desc.memRefDescPtr(builder, loc); + builder.create(loc, memory, source, allocationSize, zero); + if (!toDynamic) + builder.create(loc, freeFunc, source); + + // Create a new descriptor. The same descriptor can be returned multiple + // times, attempting to modify its pointer can lead to memory leaks + // (allocated twice and overwritten) or double frees (the caller does not + // know if the descriptor points to the same memory). + Type descriptorType = typeConverter.convertType(type); + if (!descriptorType) + return failure(); + auto updatedDesc = + UnrankedMemRefDescriptor::undef(builder, loc, descriptorType); + Value rank = desc.rank(builder, loc); + updatedDesc.setRank(builder, loc, rank); + updatedDesc.setMemRefDescPtr(builder, loc, memory); + + operands[i] = updatedDesc; + } + + return success(); +} + // A CallOp automatically promotes MemRefType to a sequence of alloca/store and // passes the pointer to the MemRef across function boundaries. template @@ -1882,13 +2030,6 @@ unsigned numResults = callOp.getNumResults(); auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes()); - for (Type resType : resultTypes) { - assert(!resType.isa() && - "Returning unranked memref is not supported. Pass result as an" - "argument instead."); - (void)resType; - } - if (numResults != 0) { if (!(packedResult = this->typeConverter.packFunctionResults(resultTypes))) @@ -1900,25 +2041,25 @@ auto newOp = rewriter.create(op->getLoc(), packedResult, promoted, op->getAttrs()); - // If < 2 results, packing did not do anything and we can just return. - if (numResults < 2) { - rewriter.replaceOp(op, newOp.getResults()); - return success(); - } - - // Otherwise, it had been converted to an operation producing a structure. - // Extract individual results from the structure and return them as list. - // TODO(aminim, ntv, riverriddle, zinenko): this seems like patching around - // a particular interaction between MemRefType and CallOp lowering. Find a - // way to avoid special casing. SmallVector results; - results.reserve(numResults); - for (unsigned i = 0; i < numResults; ++i) { - auto type = this->typeConverter.convertType(op->getResult(i).getType()); - results.push_back(rewriter.create( - op->getLoc(), type, newOp.getOperation()->getResult(0), - rewriter.getI64ArrayAttr(i))); + if (numResults < 2) { + // If < 2 results, packing did not do anything and we can just return. + results.append(newOp.result_begin(), newOp.result_end()); + } else { + // Otherwise, it had been converted to an operation producing a structure. + // Extract individual results from the structure and return them as list. + results.reserve(numResults); + for (unsigned i = 0; i < numResults; ++i) { + auto type = this->typeConverter.convertType(op->getResult(i).getType()); + results.push_back(rewriter.create( + op->getLoc(), type, newOp.getOperation()->getResult(0), + rewriter.getI64ArrayAttr(i))); + } } + if (failed(copyUnrankedDescriptors( + rewriter, op->getLoc(), this->typeConverter, op->getResultTypes(), + results, /*toDynamic=*/false))) + return failure(); rewriter.replaceOp(op, results); return success(); @@ -2397,6 +2538,10 @@ matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { unsigned numArguments = op->getNumOperands(); + auto updatedOperands = llvm::to_vector<4>(operands); + copyUnrankedDescriptors(rewriter, op->getLoc(), typeConverter, + op->getOperands().getTypes(), updatedOperands, + /*toDynamic=*/true); // If ReturnOp has 0 or 1 operand, create it and return immediately. if (numArguments == 0) { @@ -2406,7 +2551,7 @@ } if (numArguments == 1) { rewriter.replaceOpWithNewOp( - op, ArrayRef(), operands.front(), op->getAttrs()); + op, ArrayRef(), updatedOperands, op->getAttrs()); return success(); } @@ -2418,7 +2563,7 @@ Value packed = rewriter.create(op->getLoc(), packedType); for (unsigned i = 0; i < numArguments; ++i) { packed = rewriter.create( - op->getLoc(), packedType, packed, operands[i], + op->getLoc(), packedType, packed, updatedOperands[i], rewriter.getI64ArrayAttr(i)); } rewriter.replaceOpWithNewOp(op, ArrayRef(), packed, diff --git a/mlir/test/Conversion/StandardToLLVM/calling-convention.mlir b/mlir/test/Conversion/StandardToLLVM/calling-convention.mlir --- a/mlir/test/Conversion/StandardToLLVM/calling-convention.mlir +++ b/mlir/test/Conversion/StandardToLLVM/calling-convention.mlir @@ -109,3 +109,134 @@ // EMIT_C_ATTRIBUTE: @_mlir_ciface_other_callee // EMIT_C_ATTRIBUTE: llvm.call @other_callee + +//===========================================================================// +// Calling convention on returning unranked memrefs. +//===========================================================================// + +// CHECK-LABEL: llvm.func @return_var_memref_caller +func @return_var_memref_caller(%arg0: memref<4x3xf32>) { + // CHECK: %[[CALL_RES:.*]] = llvm.call @return_var_memref + %0 = call @return_var_memref(%arg0) : (memref<4x3xf32>) -> memref<*xf32> + + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1 : index) + // CHECK: %[[TWO:.*]] = llvm.mlir.constant(2 : index) + // These sizes may depend on the data layout, not matching specific values. + // CHECK: %[[PTR_SIZE:.*]] = llvm.mlir.constant + // CHECK: %[[IDX_SIZE:.*]] = llvm.mlir.constant + + // CHECK: %[[DOUBLE_PTR_SIZE:.*]] = llvm.mul %[[TWO]], %[[PTR_SIZE]] + // CHECK: %[[RANK:.*]] = llvm.extractvalue %[[CALL_RES]][0] : !llvm<"{ i64, i8* }"> + // CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[TWO]], %[[RANK]] + // CHECK: %[[DOUBLE_RANK_INC:.*]] = llvm.add %[[DOUBLE_RANK]], %[[ONE]] + // CHECK: %[[TABLES_SIZE:.*]] = llvm.mul %[[DOUBLE_RANK_INC]], %[[IDX_SIZE]] + // CHECK: %[[ALLOC_SIZE:.*]] = llvm.add %[[DOUBLE_PTR_SIZE]], %[[TABLES_SIZE]] + // CHECK: %[[FALSE:.*]] = llvm.mlir.constant(false) + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOC_SIZE]] x !llvm.i8 + // CHECK: %[[SOURCE:.*]] = llvm.extractvalue %[[CALL_RES]][1] + // CHECK: "llvm.intr.memcpy"(%[[ALLOCA]], %[[SOURCE]], %[[ALLOC_SIZE]], %[[FALSE]]) + // CHECK: llvm.call @free(%[[SOURCE]]) + // CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm<"{ i64, i8* }"> + // CHECK: %[[RANK:.*]] = llvm.extractvalue %[[CALL_RES]][0] : !llvm<"{ i64, i8* }"> + // CHECK: %[[DESC_1:.*]] = llvm.insertvalue %[[RANK]], %[[DESC]][0] + // CHECK: llvm.insertvalue %[[ALLOCA]], %[[DESC_1]][1] + return +} + +// CHECK-LABEL: llvm.func @return_var_memref +func @return_var_memref(%arg0: memref<4x3xf32>) -> memref<*xf32> { + // Match the construction of the unranked descriptor. + // CHECK: %[[ALLOCA:.*]] = llvm.alloca + // CHECK: %[[MEMORY:.*]] = llvm.bitcast %[[ALLOCA]] + // CHECK: %[[DESC_0:.*]] = llvm.mlir.undef : !llvm<"{ i64, i8* }"> + // CHECK: %[[DESC_1:.*]] = llvm.insertvalue %{{.*}}, %[[DESC_0]][0] + // CHECK: %[[DESC_2:.*]] = llvm.insertvalue %[[MEMORY]], %[[DESC_1]][1] + %0 = memref_cast %arg0: memref<4x3xf32> to memref<*xf32> + + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1 : index) + // CHECK: %[[TWO:.*]] = llvm.mlir.constant(2 : index) + // These sizes may depend on the data layout, not matching specific values. + // CHECK: %[[PTR_SIZE:.*]] = llvm.mlir.constant + // CHECK: %[[IDX_SIZE:.*]] = llvm.mlir.constant + + // CHECK: %[[DOUBLE_PTR_SIZE:.*]] = llvm.mul %[[TWO]], %[[PTR_SIZE]] + // CHECK: %[[RANK:.*]] = llvm.extractvalue %[[DESC_2]][0] : !llvm<"{ i64, i8* }"> + // CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[TWO]], %[[RANK]] + // CHECK: %[[DOUBLE_RANK_INC:.*]] = llvm.add %[[DOUBLE_RANK]], %[[ONE]] + // CHECK: %[[TABLES_SIZE:.*]] = llvm.mul %[[DOUBLE_RANK_INC]], %[[IDX_SIZE]] + // CHECK: %[[ALLOC_SIZE:.*]] = llvm.add %[[DOUBLE_PTR_SIZE]], %[[TABLES_SIZE]] + // CHECK: %[[FALSE:.*]] = llvm.mlir.constant(false) + // CHECK: %[[ALLOCATED:.*]] = llvm.call @malloc(%[[ALLOC_SIZE]]) + // CHECK: %[[SOURCE:.*]] = llvm.extractvalue %[[DESC_2]][1] + // CHECK: "llvm.intr.memcpy"(%[[ALLOCATED]], %[[SOURCE]], %[[ALLOC_SIZE]], %[[FALSE]]) + // CHECK: %[[NEW_DESC:.*]] = llvm.mlir.undef : !llvm<"{ i64, i8* }"> + // CHECK: %[[RANK:.*]] = llvm.extractvalue %[[DESC_2]][0] : !llvm<"{ i64, i8* }"> + // CHECK: %[[NEW_DESC_1:.*]] = llvm.insertvalue %[[RANK]], %[[NEW_DESC]][0] + // CHECK: %[[NEW_DESC_2:.*]] = llvm.insertvalue %[[ALLOCATED]], %[[NEW_DESC_1]][1] + // CHECL: llvm.return %[[NEW_DESC_2]] + return %0 : memref<*xf32> +} + +// CHECK-LABEL: llvm.func @return_two_var_memref_caller +func @return_two_var_memref_caller(%arg0: memref<4x3xf32>) { + // Only check that we create two different descriptors using different + // memory, and deallocate both sources. The size computation is same as for + // the single result. + // CHECK: %[[CALL_RES:.*]] = llvm.call @return_two_var_memref + // CHECK: %[[RES_1:.*]] = llvm.extractvalue %[[CALL_RES]][0] + // CHECK: %[[RES_2:.*]] = llvm.extractvalue %[[CALL_RES]][1] + %0:2 = call @return_two_var_memref(%arg0) : (memref<4x3xf32>) -> (memref<*xf32>, memref<*xf32>) + + // CHECK: %[[ALLOCA_1:.*]] = llvm.alloca %{{.*}} x !llvm.i8 + // CHECK: %[[SOURCE_1:.*]] = llvm.extractvalue %[[RES_1:.*]][1] : ![[DESC_TYPE:.*]] + // CHECK: "llvm.intr.memcpy"(%[[ALLOCA_1]], %[[SOURCE_1]], %{{.*}}, %[[FALSE:.*]]) + // CHECK: llvm.call @free(%[[SOURCE_1]]) + // CHECK: %[[DESC_1:.*]] = llvm.mlir.undef : ![[DESC_TYPE]] + // CHECK: %[[DESC_11:.*]] = llvm.insertvalue %{{.*}}, %[[DESC_1]][0] + // CHECK: llvm.insertvalue %[[ALLOCA_1]], %[[DESC_11]][1] + + // CHECK: %[[ALLOCA_2:.*]] = llvm.alloca %{{.*}} x !llvm.i8 + // CHECK: %[[SOURCE_2:.*]] = llvm.extractvalue %[[RES_2:.*]][1] + // CHECK: "llvm.intr.memcpy"(%[[ALLOCA_2]], %[[SOURCE_2]], %{{.*}}, %[[FALSE]]) + // CHECK: llvm.call @free(%[[SOURCE_2]]) + // CHECK: %[[DESC_2:.*]] = llvm.mlir.undef : ![[DESC_TYPE]] + // CHECK: %[[DESC_21:.*]] = llvm.insertvalue %{{.*}}, %[[DESC_2]][0] + // CHECK: llvm.insertvalue %[[ALLOCA_2]], %[[DESC_21]][1] + return +} + +// CHECK-LABEL: llvm.func @return_two_var_memref +func @return_two_var_memref(%arg0: memref<4x3xf32>) -> (memref<*xf32>, memref<*xf32>) { + // Match the construction of the unranked descriptor. + // CHECK: %[[ALLOCA:.*]] = llvm.alloca + // CHECK: %[[MEMORY:.*]] = llvm.bitcast %[[ALLOCA]] + // CHECK: %[[DESC_0:.*]] = llvm.mlir.undef : !llvm<"{ i64, i8* }"> + // CHECK: %[[DESC_1:.*]] = llvm.insertvalue %{{.*}}, %[[DESC_0]][0] + // CHECK: %[[DESC_2:.*]] = llvm.insertvalue %[[MEMORY]], %[[DESC_1]][1] + %0 = memref_cast %arg0 : memref<4x3xf32> to memref<*xf32> + + // Only check that we allocate the memory for each operand of the "return" + // separately, even if both operands are the same value. The calling + // convention requires the caller to free them and the caller cannot know + // whether they are the same value or not. + // CHECK: %[[ALLOCATED_1:.*]] = llvm.call @malloc(%{{.*}}) + // CHECK: %[[SOURCE_1:.*]] = llvm.extractvalue %[[DESC_2]][1] + // CHECK: "llvm.intr.memcpy"(%[[ALLOCATED_1]], %[[SOURCE_1]], %{{.*}}, %[[FALSE:.*]]) + // CHECK: %[[RES_1:.*]] = llvm.mlir.undef + // CHECK: %[[RES_11:.*]] = llvm.insertvalue %{{.*}}, %[[RES_1]][0] + // CHECK: %[[RES_12:.*]] = llvm.insertvalue %[[ALLOCATED_1]], %[[RES_11]][1] + + // CHECK: %[[ALLOCATED_2:.*]] = llvm.call @malloc(%{{.*}}) + // CHECK: %[[SOURCE_2:.*]] = llvm.extractvalue %[[DESC_2]][1] + // CHECK: "llvm.intr.memcpy"(%[[ALLOCATED_2]], %[[SOURCE_2]], %{{.*}}, %[[FALSE]]) + // CHECK: %[[RES_2:.*]] = llvm.mlir.undef + // CHECK: %[[RES_21:.*]] = llvm.insertvalue %{{.*}}, %[[RES_2]][0] + // CHECK: %[[RES_22:.*]] = llvm.insertvalue %[[ALLOCATED_2]], %[[RES_21]][1] + + // CHECK: %[[RESULTS:.*]] = llvm.mlir.undef : !llvm<"{ { i64, i8* }, { i64, i8* } }"> + // CHECK: %[[RESULTS_1:.*]] = llvm.insertvalue %[[RES_12]], %[[RESULTS]] + // CHECK: %[[RESULTS_2:.*]] = llvm.insertvalue %[[RES_22]], %[[RESULTS_1]] + // CHECK: llvm.return %[[RESULTS_2]] + return %0, %0 : memref<*xf32>, memref<*xf32> +} + diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir --- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir +++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir @@ -1,7 +1,9 @@ // RUN: mlir-opt %s | mlir-opt | FileCheck %s -// CHECK-LABEL: func @ops(%arg0: !llvm.i32, %arg1: !llvm.float) -func @ops(%arg0 : !llvm.i32, %arg1 : !llvm.float) { +// CHECK-LABEL: func @ops +func @ops(%arg0: !llvm.i32, %arg1: !llvm.float, + %arg2: !llvm<"i8*">, %arg3: !llvm<"i8*">, + %arg4: !llvm.i1) { // Integer arithmetic binary operations. // // CHECK-NEXT: %0 = llvm.add %arg0, %arg0 : !llvm.i32 @@ -109,6 +111,17 @@ // CHECK: "llvm.intr.ctpop"(%{{.*}}) : (!llvm.i32) -> !llvm.i32 %33 = "llvm.intr.ctpop"(%arg0) : (!llvm.i32) -> !llvm.i32 +// CHECK: "llvm.intr.memcpy"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm<"i8*">, !llvm<"i8*">, !llvm.i32, !llvm.i1) -> () + "llvm.intr.memcpy"(%arg2, %arg3, %arg0, %arg4) : (!llvm<"i8*">, !llvm<"i8*">, !llvm.i32, !llvm.i1) -> () + +// CHECK: "llvm.intr.memcpy"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm<"i8*">, !llvm<"i8*">, !llvm.i32, !llvm.i1) -> () + "llvm.intr.memcpy"(%arg2, %arg3, %arg0, %arg4) : (!llvm<"i8*">, !llvm<"i8*">, !llvm.i32, !llvm.i1) -> () + +// CHECK: %[[SZ:.*]] = llvm.mlir.constant + %sz = llvm.mlir.constant(10: i64) : !llvm.i64 +// CHECK: "llvm.intr.memcpy.inline"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm<"i8*">, !llvm<"i8*">, !llvm.i64, !llvm.i1) -> () + "llvm.intr.memcpy.inline"(%arg2, %arg3, %sz, %arg4) : (!llvm<"i8*">, !llvm<"i8*">, !llvm.i64, !llvm.i1) -> () + // CHECK: llvm.return llvm.return } @@ -315,4 +328,4 @@ // CHECK: release llvm.fence release return -} \ No newline at end of file +} diff --git a/mlir/test/Target/llvmir-intrinsics.mlir b/mlir/test/Target/llvmir-intrinsics.mlir --- a/mlir/test/Target/llvmir-intrinsics.mlir +++ b/mlir/test/Target/llvmir-intrinsics.mlir @@ -202,6 +202,17 @@ llvm.return } +// CHECK-LABEL: @memcpy_test +llvm.func @memcpy_test(%arg0: !llvm.i32, %arg1: !llvm.i1, %arg2: !llvm<"i8*">, %arg3: !llvm<"i8*">) { + // CHECK: call void @llvm.memcpy.p0i8.p0i8.i32(i8* %{{.*}}, i8* %{{.*}}, i32 %{{.*}}, i1 %{{.*}}) + "llvm.intr.memcpy"(%arg2, %arg3, %arg0, %arg1) : (!llvm<"i8*">, !llvm<"i8*">, !llvm.i32, !llvm.i1) -> () + %sz = llvm.mlir.constant(10: i64) : !llvm.i64 + // CHECK: call void @llvm.memcpy.inline.p0i8.p0i8.i64(i8* %{{.*}}, i8* %{{.*}}, i64 10, i1 %{{.*}}) + "llvm.intr.memcpy.inline"(%arg2, %arg3, %sz, %arg1) : (!llvm<"i8*">, !llvm<"i8*">, !llvm.i64, !llvm.i1) -> () + llvm.return +} + + // Check that intrinsics are declared with appropriate types. // CHECK-DAG: declare float @llvm.fma.f32(float, float, float) // CHECK-DAG: declare <8 x float> @llvm.fma.v8f32(<8 x float>, <8 x float>, <8 x float>) #0 @@ -231,3 +242,5 @@ // CHECK-DAG: declare void @llvm.matrix.column.major.store.v48f32.p0f32(<48 x float>, float* nocapture writeonly, i64, i1 immarg, i32 immarg, i32 immarg) // CHECK-DAG: declare <7 x float> @llvm.masked.load.v7f32.p0v7f32(<7 x float>*, i32 immarg, <7 x i1>, <7 x float>) // CHECK-DAG: declare void @llvm.masked.store.v7f32.p0v7f32(<7 x float>, <7 x float>*, i32 immarg, <7 x i1>) +// CHECK-DAG: declare void @llvm.memcpy.p0i8.p0i8.i32(i8* noalias nocapture writeonly, i8* noalias nocapture readonly, i32, i1 immarg) +// CHECK-DAG: declare void @llvm.memcpy.inline.p0i8.p0i8.i64(i8* noalias nocapture writeonly, i8* noalias nocapture readonly, i64 immarg, i1 immarg) diff --git a/mlir/test/mlir-cpu-runner/unranked_memref.mlir b/mlir/test/mlir-cpu-runner/unranked_memref.mlir --- a/mlir/test/mlir-cpu-runner/unranked_memref.mlir +++ b/mlir/test/mlir-cpu-runner/unranked_memref.mlir @@ -18,6 +18,21 @@ // CHECK: rank = 0 // 122 is ASCII for 'z'. // CHECK: [z] +// +// CHECK: rank = 2 +// CHECK-SAME: sizes = [4, 3] +// CHECK-SAME: strides = [3, 1] +// CHECK-COUNT-4: [1, 1, 1] +// +// CHECK: rank = 2 +// CHECK-SAME: sizes = [4, 3] +// CHECK-SAME: strides = [3, 1] +// CHECK-COUNT-4: [1, 1, 1] +// +// CHECK: rank = 2 +// CHECK-SAME: sizes = [4, 3] +// CHECK-SAME: strides = [3, 1] +// CHECK-COUNT-4: [1, 1, 1] func @main() -> () { %A = alloc() : memref<10x3xf32, 0> %f2 = constant 2.00000e+00 : f32 @@ -48,8 +63,40 @@ call @print_memref_i8(%U4) : (memref<*xi8>) -> () dealloc %A : memref<10x3xf32, 0> + + call @return_var_memref_caller() : () -> () + call @return_two_var_memref_caller() : () -> () return } func @print_memref_i8(memref<*xi8>) attributes { llvm.emit_c_interface } func @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface } + +func @return_two_var_memref_caller() { + %0 = alloca() : memref<4x3xf32> + %c0f32 = constant 1.0 : f32 + linalg.fill(%0, %c0f32) : memref<4x3xf32>, f32 + %1:2 = call @return_two_var_memref(%0) : (memref<4x3xf32>) -> (memref<*xf32>, memref<*xf32>) + call @print_memref_f32(%1#0) : (memref<*xf32>) -> () + call @print_memref_f32(%1#1) : (memref<*xf32>) -> () + return + } + + func @return_two_var_memref(%arg0: memref<4x3xf32>) -> (memref<*xf32>, memref<*xf32>) { + %0 = memref_cast %arg0 : memref<4x3xf32> to memref<*xf32> + return %0, %0 : memref<*xf32>, memref<*xf32> +} + +func @return_var_memref_caller() { + %0 = alloca() : memref<4x3xf32> + %c0f32 = constant 1.0 : f32 + linalg.fill(%0, %c0f32) : memref<4x3xf32>, f32 + %1 = call @return_var_memref(%0) : (memref<4x3xf32>) -> memref<*xf32> + call @print_memref_f32(%1) : (memref<*xf32>) -> () + return +} + +func @return_var_memref(%arg0: memref<4x3xf32>) -> memref<*xf32> { + %0 = memref_cast %arg0: memref<4x3xf32> to memref<*xf32> + return %0 : memref<*xf32> +}