diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -77,10 +77,17 @@ OpBuilder &builder); protected: + /// Convert a function argument type to an LLVM type using 'convertType'. + /// MemRef arguments are promoted to a pointer to the converted type. + virtual LLVM::LLVMType convertArgType(Type type); + /// LLVM IR module used to parse/create types. llvm::Module *module; LLVM::LLVMDialect *llvmDialect; + // Extract an LLVM IR dialect type. + LLVM::LLVMType unwrap(Type type); + private: Type convertStandardType(Type type); @@ -120,9 +127,23 @@ // Get the LLVM representation of the index type based on the bitwidth of the // pointer as defined by the data layout of the module. LLVM::LLVMType getIndexType(); +}; - // Extract an LLVM IR dialect type. - LLVM::LLVMType unwrap(Type type); +/// Custom LLVMTypeConverter that overrides `convertFunctionSignature` to +/// replace the type of MemRef function arguments with a bare pointer to the +/// MemRef element type. +class BarePtrTypeConverter : public mlir::LLVMTypeConverter { +public: + using LLVMTypeConverter::LLVMTypeConverter; + +private: + /// Convert a function argument type to an LLVM type using 'convertType' + /// except for MemRef arguments. MemRef types are converted to LLVM bare + /// pointers to the MemRef element type. + LLVM::LLVMType convertArgType(Type type) override; + + /// Converts MemRef type to an LLVM bare pointer to the MemRef element type. + mlir::Type convertMemRefTypeToBarePtr(mlir::MemRefType type); }; /// Helper class to produce LLVM dialect operations extracting or inserting diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h @@ -44,8 +44,8 @@ std::function(MLIRContext *)>; /// Collect a set of patterns to convert memory-related operations from the -/// Standard dialect to the LLVM dialect, excluding the memory-related -/// operations. +/// Standard dialect to the LLVM dialect, excluding non-memory-related +/// operations and FuncOp. void populateStdToLLVMMemoryConversionPatters( LLVMTypeConverter &converter, OwningRewritePatternList &patterns); @@ -54,10 +54,26 @@ void populateStdToLLVMNonMemoryConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns); -/// Collect a set of patterns to convert from the Standard dialect to LLVM. +/// Collect the default pattern to convert a FuncOp to the LLVM dialect. +void populateStdToLLVMDefaultFuncOpConversionPattern( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns); + +/// Collect a set of default patterns to convert from the Standard dialect to +/// LLVM. void populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter, OwningRewritePatternList &patterns); +/// Collect the pattern to convert a FuncOp to the LLVM dialect using the bare +/// pointer calling convertion for MemRef function arguments. +void populateStdToLLVMBarePtrFuncOpConversionPattern( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns); + +/// Collect a set of patterns to convert from the Standard dialect to +/// LLVM using the bare pointer calling convention for MemRef function +/// arguments. +void populateStdToLLVMBarePtrConversionPatterns( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns); + /// Creates a pass to convert the Standard dialect into the LLVMIR dialect. /// By default stdlib malloc/free are used for allocating MemRef payloads. /// Specifying `useAlloca-true` emits stack allocations instead. In the future diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -321,7 +321,7 @@ TypeConverter::SignatureConversion &conversion); /// Replace all the uses of the block argument `from` with value `to`. - void replaceUsesOfBlockArgument(BlockArgument from, Value to); + void replaceUsesOfWith(Value from, Value to); /// Return the converted value that replaces 'key'. Return 'key' if there is /// no such a converted value. diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -667,7 +667,7 @@ BlockArgument arg = block.getArgument(en.index()); Value loaded = rewriter.create(loc, arg); - rewriter.replaceUsesOfBlockArgument(arg, loaded); + rewriter.replaceUsesOfWith(arg, loaded); } } diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -44,6 +44,12 @@ llvm::cl::desc("Replace emission of malloc/free by alloca"), llvm::cl::init(false)); +static llvm::cl::opt clUseBarePtrCallConv( + PASS_NAME "-use-bare-ptr-memref-call-conv", + llvm::cl::desc("Replace FuncOp's MemRef arguments with " + "bare pointers to the MemRef element types"), + llvm::cl::init(false)); + LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx) : llvmDialect(ctx->getRegisteredDialect()) { assert(llvmDialect && "LLVM IR dialect is not registered"); @@ -107,6 +113,17 @@ return converted.getPointerTo(); } +// Convert a function argument type to an LLVM type using 'convertType'. MemRef +// arguments are promoted to a pointer to the converted type. +LLVM::LLVMType LLVMTypeConverter::convertArgType(Type type) { + auto converted = convertType(type).dyn_cast_or_null(); + if (!converted) + return {}; + if (type.isa() || type.isa()) + converted = converted.getPointerTo(); + return converted; +} + // Function types are converted to LLVM Function types by recursively converting // argument and result types. If MLIR Function has zero results, the LLVM // Function has one VoidType result. If MLIR Function has more than one result, @@ -117,11 +134,9 @@ // Convert argument types one by one and check for errors. for (auto &en : llvm::enumerate(type.getInputs())) { Type type = en.value(); - auto converted = convertType(type).dyn_cast_or_null(); + auto converted = convertArgType(type).dyn_cast_or_null(); if (!converted) return {}; - if (type.isa() || type.isa()) - converted = converted.getPointerTo(); result.addInputs(en.index(), converted); } @@ -239,6 +254,33 @@ .Default([](Type) { return Type(); }); } +// Convert a function argument type to an LLVM type using 'convertType' except +// for MemRef arguments. MemRef types are converted to LLVM bare pointers to the +// MemRef element type. +LLVM::LLVMType BarePtrTypeConverter::convertArgType(Type type) { + // TODO: Add support for unranked memref. + if (auto memrefTy = type.dyn_cast()) + return convertMemRefTypeToBarePtr(memrefTy) + .dyn_cast_or_null(); + return convertType(type).dyn_cast_or_null(); +} + +// Converts MemRef type to an LLVM bare pointer to the MemRef element type. +Type BarePtrTypeConverter::convertMemRefTypeToBarePtr(MemRefType type) { + int64_t offset; + SmallVector strides; + bool strideSuccess = succeeded(getStridesAndOffset(type, strides, offset)); + assert(strideSuccess && + "Non-strided layout maps must have been normalized away"); + (void)strideSuccess; + + LLVM::LLVMType elementType = unwrap(convertType(type.getElementType())); + if (!elementType) + return {}; + auto ptrTy = elementType.getPointerTo(type.getMemorySpace()); + return ptrTy; +} + LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context, LLVMTypeConverter &lowering_, PatternBenefit benefit) @@ -494,27 +536,29 @@ LLVM::LLVMDialect &dialect; }; -struct FuncOpConversion : public LLVMLegalizationPattern { - using LLVMLegalizationPattern::LLVMLegalizationPattern; - - PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - auto funcOp = cast(op); - FunctionType type = funcOp.getType(); - - // Store the positions of memref-typed arguments so that we can emit loads - // from them to follow the calling convention. - SmallVector promotedArgIndices; - promotedArgIndices.reserve(type.getNumInputs()); +struct FuncOpConversionBase : public LLVMLegalizationPattern { +protected: + using LLVMLegalizationPattern::LLVMLegalizationPattern; + using UnsignedTypePair = std::pair; + + // Gather the positions and types of memref-typed arguments in a given + // FunctionType. + void getMemRefArgIndicesAndTypes( + FunctionType type, SmallVectorImpl &argsInfo) const { + argsInfo.reserve(type.getNumInputs()); for (auto en : llvm::enumerate(type.getInputs())) { if (en.value().isa() || en.value().isa()) - promotedArgIndices.push_back(en.index()); + argsInfo.push_back({en.index(), en.value()}); } + } - // Convert the original function arguments. Struct arguments are promoted to - // pointer to struct arguments to allow calling external functions with - // various ABIs (e.g. compiled from C/C++ on platform X). + // Convert input FuncOp to a new FuncOp in LLVM dialect by using the + // LLVMTypeConverter provided to this legalization pattern. + LLVM::LLVMFuncOp + convertFuncOpToLLVMFuncOp(FuncOp funcOp, + ConversionPatternRewriter &rewriter) const { + // Convert the original function arguments. They are converted using the + // LLVMTypeConverter provided to this legalization pattern. auto varargsAttr = funcOp.getAttrOfType("std.varargs"); TypeConverter::SignatureConversion result(funcOp.getNumArguments()); auto llvmType = lowering.convertFunctionSignature( @@ -533,22 +577,92 @@ // Create an LLVM function, use external linkage by default until MLIR // functions have linkage. auto newFuncOp = rewriter.create( - op->getLoc(), funcOp.getName(), llvmType, LLVM::Linkage::External, + funcOp.getLoc(), funcOp.getName(), llvmType, LLVM::Linkage::External, attributes); rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), newFuncOp.end()); - // Tell the rewriter to convert the region signature. rewriter.applySignatureConversion(&newFuncOp.getBody(), result); + return newFuncOp; + } +}; + +// FuncOp legalization pattern that converts MemRef arguments to pointers to +// MemRef descriptors (LLVM struct data types) containing all the MemRef type +// information. +struct FuncOpConversion : public FuncOpConversionBase { + using FuncOpConversionBase::FuncOpConversionBase; + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto funcOp = cast(op); + + // Store the positions of memref-typed arguments so that we can emit loads + // from them to follow the calling convention. + SmallVector promotedArgsInfo; + getMemRefArgIndicesAndTypes(funcOp.getType(), promotedArgsInfo); + + auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter); + // Insert loads from memref descriptor pointers in function bodies. if (!newFuncOp.getBody().empty()) { Block *firstBlock = &newFuncOp.getBody().front(); rewriter.setInsertionPoint(firstBlock, firstBlock->begin()); - for (unsigned idx : promotedArgIndices) { - BlockArgument arg = firstBlock->getArgument(idx); + for (const auto &argInfo : promotedArgsInfo) { + BlockArgument arg = firstBlock->getArgument(argInfo.first); Value loaded = rewriter.create(funcOp.getLoc(), arg); - rewriter.replaceUsesOfBlockArgument(arg, loaded); + rewriter.replaceUsesOfWith(arg, loaded); + } + } + + rewriter.eraseOp(op); + return matchSuccess(); + } +}; + +// FuncOp legalization pattern that converts MemRef arguments to bare pointers +// to the MemRef element type. This will impact the calling convention and ABI. +struct BarePtrFuncOpConversion : public FuncOpConversionBase { + using FuncOpConversionBase::FuncOpConversionBase; + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto funcOp = cast(op); + + // Store the positions and type of memref-typed arguments so that we can + // promote them to MemRef descriptor structs at the beginning of the + // function. + SmallVector promotedArgsInfo; + getMemRefArgIndicesAndTypes(funcOp.getType(), promotedArgsInfo); + + auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter); + + // Promote bare pointers from MemRef arguments to a MemRef descriptor struct + // at the beginning of the function so that all the MemRefs in the function + // have a uniform representation. + if (!newFuncOp.getBody().empty()) { + Block *firstBlock = &newFuncOp.getBody().front(); + rewriter.setInsertionPoint(firstBlock, firstBlock->begin()); + auto funcLoc = funcOp.getLoc(); + for (const auto &argInfo : promotedArgsInfo) { + // TODO: Add support for unranked MemRefs. + if (auto memrefType = argInfo.second.dyn_cast()) { + // Replace argument with a placeholder (undef), promote argument to a + // MemRef descriptor and replace placeholder with the last instruction + // of the MemRef descriptor. The placeholder is needed to avoid + // replacing argument uses in the MemRef descriptor instructions. + BlockArgument arg = firstBlock->getArgument(argInfo.first); + Value placeHolder = + rewriter.create(funcLoc, arg.getType()); + rewriter.replaceUsesOfWith(arg, placeHolder); + auto desc = MemRefDescriptor::fromStaticShape( + rewriter, funcLoc, lowering, memrefType, arg); + rewriter.replaceUsesOfWith(placeHolder, desc); + placeHolder.getDefiningOp()->erase(); + } } } @@ -2126,7 +2240,6 @@ // clang-format off patterns.insert< DimOpLowering, - FuncOpConversion, LoadOpLowering, MemRefCastOpLowering, StoreOpLowering, @@ -2139,8 +2252,26 @@ // clang-format on } +void mlir::populateStdToLLVMDefaultFuncOpConversionPattern( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + patterns.insert(*converter.getDialect(), converter); +} + void mlir::populateStdToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + populateStdToLLVMDefaultFuncOpConversionPattern(converter, patterns); + populateStdToLLVMNonMemoryConversionPatterns(converter, patterns); + populateStdToLLVMMemoryConversionPatters(converter, patterns); +} + +void mlir::populateStdToLLVMBarePtrFuncOpConversionPattern( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + patterns.insert(*converter.getDialect(), converter); +} + +void mlir::populateStdToLLVMBarePtrConversionPatterns( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + populateStdToLLVMBarePtrFuncOpConversionPattern(converter, patterns); populateStdToLLVMNonMemoryConversionPatterns(converter, patterns); populateStdToLLVMMemoryConversionPatters(converter, patterns); } @@ -2210,6 +2341,12 @@ return std::make_unique(context); } +/// Create an instance of BarePtrTypeConverter in the given context. +static std::unique_ptr +makeStandardToLLVMBarePtrTypeConverter(MLIRContext *context) { + return std::make_unique(context); +} + namespace { /// A pass converting MLIR operations into the LLVM IR dialect. struct LLVMLoweringPass : public ModulePass { @@ -2274,6 +2411,9 @@ "Standard to the LLVM dialect", [] { return std::make_unique( - clUseAlloca.getValue(), populateStdToLLVMConversionPatterns, - makeStandardToLLVMTypeConverter); + clUseAlloca.getValue(), + clUseBarePtrCallConv ? populateStdToLLVMBarePtrConversionPatterns + : populateStdToLLVMConversionPatterns, + clUseBarePtrCallConv ? makeStandardToLLVMBarePtrTypeConverter + : makeStandardToLLVMTypeConverter); }); diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -861,8 +861,7 @@ return impl->applySignatureConversion(region, conversion); } -void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from, - Value to) { +void ConversionPatternRewriter::replaceUsesOfWith(Value from, Value to) { for (auto &u : from.getUses()) { if (u.getOwner() == to.getDefiningOp()) continue; diff --git a/mlir/test/Conversion/StandardToLLVM/convert-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir rename from mlir/test/Conversion/StandardToLLVM/convert-memref-ops.mlir rename to mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-memref-ops.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir @@ -1,10 +1,4 @@ // RUN: mlir-opt -convert-std-to-llvm %s | FileCheck %s -// RUN: mlir-opt -convert-std-to-llvm -convert-std-to-llvm-use-alloca=1 %s | FileCheck %s --check-prefix=ALLOCA - -// CHECK-LABEL: func @check_arguments(%arg0: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">, %arg1: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">, %arg2: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">) -func @check_arguments(%static: memref<10x20xf32>, %dynamic : memref, %mixed : memref<10x?xf32>) { - return -} // CHECK-LABEL: func @check_strided_memref_arguments( // CHECK-COUNT-3: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> @@ -14,74 +8,11 @@ return } -// CHECK-LABEL: func @check_static_return(%arg0: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">) -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> { -func @check_static_return(%static : memref<32x18xf32>) -> memref<32x18xf32> { -// CHECK: llvm.return %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> - return %static : memref<32x18xf32> -} - -// CHECK-LABEL: func @zero_d_alloc() -> !llvm<"{ float*, float*, i64 }"> { -// ALLOCA-LABEL: func @zero_d_alloc() -> !llvm<"{ float*, float*, i64 }"> { -func @zero_d_alloc() -> memref { -// CHECK-NEXT: llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*"> -// CHECK-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> -// CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64 -// CHECK-NEXT: llvm.mul %{{.*}}, %[[sizeof]] : !llvm.i64 -// CHECK-NEXT: llvm.call @malloc(%{{.*}}) : (!llvm.i64) -> !llvm<"i8*"> -// CHECK-NEXT: %[[ptr:.*]] = llvm.bitcast %{{.*}} : !llvm<"i8*"> to !llvm<"float*"> -// CHECK-NEXT: llvm.mlir.undef : !llvm<"{ float*, float*, i64 }"> -// CHECK-NEXT: llvm.insertvalue %[[ptr]], %{{.*}}[0] : !llvm<"{ float*, float*, i64 }"> -// CHECK-NEXT: llvm.insertvalue %[[ptr]], %{{.*}}[1] : !llvm<"{ float*, float*, i64 }"> -// CHECK-NEXT: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK-NEXT: llvm.insertvalue %[[c0]], %{{.*}}[2] : !llvm<"{ float*, float*, i64 }"> - -// ALLOCA-NOT: malloc -// ALLOCA: alloca -// ALLOCA-NOT: malloc - %0 = alloc() : memref - return %0 : memref -} - -// CHECK-LABEL: func @zero_d_dealloc(%{{.*}}: !llvm<"{ float*, float*, i64 }*">) { -func @zero_d_dealloc(%arg0: memref) { -// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64 }*"> -// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][0] : !llvm<"{ float*, float*, i64 }"> -// CHECK-NEXT: %[[bc:.*]] = llvm.bitcast %[[ptr]] : !llvm<"float*"> to !llvm<"i8*"> -// CHECK-NEXT: llvm.call @free(%[[bc]]) : (!llvm<"i8*">) -> () - dealloc %arg0 : memref +// CHECK-LABEL: func @check_arguments(%arg0: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">, %arg1: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">, %arg2: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">) +func @check_arguments(%static: memref<10x20xf32>, %dynamic : memref, %mixed : memref<10x?xf32>) { return } -// CHECK-LABEL: func @aligned_1d_alloc( -func @aligned_1d_alloc() -> memref<42xf32> { -// CHECK-NEXT: llvm.mlir.constant(42 : index) : !llvm.i64 -// CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*"> -// CHECK-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> -// CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64 -// CHECK-NEXT: llvm.mul %{{.*}}, %[[sizeof]] : !llvm.i64 -// CHECK-NEXT: %[[alignment:.*]] = llvm.mlir.constant(8 : index) : !llvm.i64 -// CHECK-NEXT: %[[alignmentMinus1:.*]] = llvm.add {{.*}}, %[[alignment]] : !llvm.i64 -// CHECK-NEXT: %[[allocsize:.*]] = llvm.sub %[[alignmentMinus1]], %[[one]] : !llvm.i64 -// CHECK-NEXT: %[[allocated:.*]] = llvm.call @malloc(%[[allocsize]]) : (!llvm.i64) -> !llvm<"i8*"> -// CHECK-NEXT: %[[ptr:.*]] = llvm.bitcast %{{.*}} : !llvm<"i8*"> to !llvm<"float*"> -// CHECK-NEXT: llvm.mlir.undef : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> -// CHECK-NEXT: llvm.insertvalue %[[ptr]], %{{.*}}[0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> -// CHECK-NEXT: %[[allocatedAsInt:.*]] = llvm.ptrtoint %[[allocated]] : !llvm<"i8*"> to !llvm.i64 -// CHECK-NEXT: %[[alignAdj1:.*]] = llvm.urem %[[allocatedAsInt]], %[[alignment]] : !llvm.i64 -// CHECK-NEXT: %[[alignAdj2:.*]] = llvm.sub %[[alignment]], %[[alignAdj1]] : !llvm.i64 -// CHECK-NEXT: %[[alignAdj3:.*]] = llvm.urem %[[alignAdj2]], %[[alignment]] : !llvm.i64 -// CHECK-NEXT: %[[aligned:.*]] = llvm.getelementptr %9[%[[alignAdj3]]] : (!llvm<"i8*">, !llvm.i64) -> !llvm<"i8*"> -// CHECK-NEXT: %[[alignedBitCast:.*]] = llvm.bitcast %[[aligned]] : !llvm<"i8*"> to !llvm<"float*"> -// CHECK-NEXT: llvm.insertvalue %[[alignedBitCast]], %{{.*}}[1] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> -// CHECK-NEXT: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK-NEXT: llvm.insertvalue %[[c0]], %{{.*}}[2] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> - %0 = alloc() {alignment = 8} : memref<42xf32> - return %0 : memref<42xf32> -} - // CHECK-LABEL: func @mixed_alloc( // CHECK: %[[M:.*]]: !llvm.i64, %[[N:.*]]: !llvm.i64) -> !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> { func @mixed_alloc(%arg0: index, %arg1: index) -> memref { @@ -162,61 +93,6 @@ return } -// CHECK-LABEL: func @static_alloc() -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> { -func @static_alloc() -> memref<32x18xf32> { -// CHECK-NEXT: %[[sz1:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64 -// CHECK-NEXT: %[[sz2:.*]] = llvm.mlir.constant(18 : index) : !llvm.i64 -// CHECK-NEXT: %[[num_elems:.*]] = llvm.mul %0, %1 : !llvm.i64 -// CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*"> -// CHECK-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> -// CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64 -// CHECK-NEXT: %[[bytes:.*]] = llvm.mul %[[num_elems]], %[[sizeof]] : !llvm.i64 -// CHECK-NEXT: %[[allocated:.*]] = llvm.call @malloc(%[[bytes]]) : (!llvm.i64) -> !llvm<"i8*"> -// CHECK-NEXT: llvm.bitcast %[[allocated]] : !llvm<"i8*"> to !llvm<"float*"> - %0 = alloc() : memref<32x18xf32> - return %0 : memref<32x18xf32> -} - -// CHECK-LABEL: func @static_dealloc(%{{.*}}: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">) { -func @static_dealloc(%static: memref<10x8xf32>) { -// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> -// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> -// CHECK-NEXT: %[[bc:.*]] = llvm.bitcast %[[ptr]] : !llvm<"float*"> to !llvm<"i8*"> -// CHECK-NEXT: llvm.call @free(%[[bc]]) : (!llvm<"i8*">) -> () - dealloc %static : memref<10x8xf32> - return -} - -// CHECK-LABEL: func @zero_d_load(%{{.*}}: !llvm<"{ float*, float*, i64 }*">) -> !llvm.float { -func @zero_d_load(%arg0: memref) -> f32 { -// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64 }*"> -// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64 }"> -// CHECK-NEXT: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[c0]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> -// CHECK-NEXT: %{{.*}} = llvm.load %[[addr]] : !llvm<"float*"> - %0 = load %arg0[] : memref - return %0 : f32 -} - -// CHECK-LABEL: func @static_load( -// CHECK: %[[A:.*]]: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">, %[[I:.*]]: !llvm.i64, %[[J:.*]]: !llvm.i64 -func @static_load(%static : memref<10x42xf32>, %i : index, %j : index) { -// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> -// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> -// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK-NEXT: %[[st0:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64 -// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64 -// CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64 -// CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64 -// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64 -// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> -// CHECK-NEXT: llvm.load %[[addr]] : !llvm<"float*"> - %0 = load %static[%i, %j] : memref<10x42xf32> - return -} - // CHECK-LABEL: func @mixed_load( // CHECK: %[[A:.*]]: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">, %[[I:.*]]: !llvm.i64, %[[J:.*]]: !llvm.i64 func @mixed_load(%mixed : memref<42x?xf32>, %i : index, %j : index) { @@ -283,34 +159,6 @@ return } -// CHECK-LABEL: func @zero_d_store(%arg0: !llvm<"{ float*, float*, i64 }*">, %arg1: !llvm.float) { -func @zero_d_store(%arg0: memref, %arg1: f32) { -// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64 }*"> -// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64 }"> -// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> -// CHECK-NEXT: llvm.store %arg1, %[[addr]] : !llvm<"float*"> - store %arg1, %arg0[] : memref - return -} - -// CHECK-LABEL: func @static_store -func @static_store(%static : memref<10x42xf32>, %i : index, %j : index, %val : f32) { -// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> -// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> -// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK-NEXT: %[[st0:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64 -// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64 -// CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64 -// CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64 -// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64 -// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> -// CHECK-NEXT: llvm.store %arg3, %[[addr]] : !llvm<"float*"> - store %val, %static[%i, %j] : memref<10x42xf32> - return -} - // CHECK-LABEL: func @dynamic_store func @dynamic_store(%dynamic : memref, %i : index, %j : index, %val : f32) { // CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> @@ -440,20 +288,3 @@ %4 = dim %mixed, 4 : memref<42x?x?x13x?xf32> return } - -// CHECK-LABEL: func @static_memref_dim(%arg0: !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }*">) { -func @static_memref_dim(%static : memref<42x32x15x13x27xf32>) { -// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }*"> -// CHECK-NEXT: llvm.mlir.constant(42 : index) : !llvm.i64 - %0 = dim %static, 0 : memref<42x32x15x13x27xf32> -// CHECK-NEXT: llvm.mlir.constant(32 : index) : !llvm.i64 - %1 = dim %static, 1 : memref<42x32x15x13x27xf32> -// CHECK-NEXT: llvm.mlir.constant(15 : index) : !llvm.i64 - %2 = dim %static, 2 : memref<42x32x15x13x27xf32> -// CHECK-NEXT: llvm.mlir.constant(13 : index) : !llvm.i64 - %3 = dim %static, 3 : memref<42x32x15x13x27xf32> -// CHECK-NEXT: llvm.mlir.constant(27 : index) : !llvm.i64 - %4 = dim %static, 4 : memref<42x32x15x13x27xf32> - return -} - diff --git a/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir @@ -0,0 +1,322 @@ +// RUN: mlir-opt -convert-std-to-llvm %s | FileCheck %s +// RUN: mlir-opt -convert-std-to-llvm -convert-std-to-llvm-use-alloca=1 %s | FileCheck %s --check-prefix=ALLOCA +// RUN: mlir-opt -convert-std-to-llvm -split-input-file -convert-std-to-llvm-use-bare-ptr-memref-call-conv=1 %s | FileCheck %s --check-prefix=BAREPTR + +// BAREPTR-LABEL: func @check_noalias +// BAREPTR-SAME: %{{.*}}: !llvm<"float*"> {llvm.noalias = true} +func @check_noalias(%static : memref<2xf32> {llvm.noalias = true}) { + return +} + +// ----- + +// CHECK-LABEL: func @check_static_return(%arg0: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">) -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> { +// BAREPTR-LABEL: func @check_static_return +// BAREPTR-SAME: (%[[arg:.*]]: !llvm<"float*">) -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> { +func @check_static_return(%static : memref<32x18xf32>) -> memref<32x18xf32> { +// CHECK: llvm.return %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + +// BAREPTR: %[[udf:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// BAREPTR-NEXT: %[[base:.*]] = llvm.insertvalue %[[arg]], %[[udf]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// BAREPTR-NEXT: %[[aligned:.*]] = llvm.insertvalue %[[arg]], %[[base]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// BAREPTR-NEXT: %[[val0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// BAREPTR-NEXT: %[[ins0:.*]] = llvm.insertvalue %[[val0]], %[[aligned]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// BAREPTR-NEXT: %[[val1:.*]] = llvm.mlir.constant(18 : index) : !llvm.i64 +// BAREPTR-NEXT: %[[ins1:.*]] = llvm.insertvalue %[[val1]], %[[ins0]][3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// BAREPTR-NEXT: %[[val2:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// BAREPTR-NEXT: %[[ins2:.*]] = llvm.insertvalue %[[val2]], %[[ins1]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// BAREPTR-NEXT: %[[val3:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64 +// BAREPTR-NEXT: %[[ins3:.*]] = llvm.insertvalue %[[val3]], %[[ins2]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// BAREPTR-NEXT: %[[val4:.*]] = llvm.mlir.constant(18 : index) : !llvm.i64 +// BAREPTR-NEXT: %[[ins4:.*]] = llvm.insertvalue %[[val4]], %[[ins3]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// BAREPTR-NEXT: llvm.return %[[ins4]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + return %static : memref<32x18xf32> +} + +// ----- + +// CHECK-LABEL: func @zero_d_alloc() -> !llvm<"{ float*, float*, i64 }"> { +// ALLOCA-LABEL: func @zero_d_alloc() -> !llvm<"{ float*, float*, i64 }"> { +// BAREPTR-LABEL: func @zero_d_alloc() -> !llvm<"{ float*, float*, i64 }"> { +func @zero_d_alloc() -> memref { +// CHECK-NEXT: llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*"> +// CHECK-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> +// CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64 +// CHECK-NEXT: llvm.mul %{{.*}}, %[[sizeof]] : !llvm.i64 +// CHECK-NEXT: llvm.call @malloc(%{{.*}}) : (!llvm.i64) -> !llvm<"i8*"> +// CHECK-NEXT: %[[ptr:.*]] = llvm.bitcast %{{.*}} : !llvm<"i8*"> to !llvm<"float*"> +// CHECK-NEXT: llvm.mlir.undef : !llvm<"{ float*, float*, i64 }"> +// CHECK-NEXT: llvm.insertvalue %[[ptr]], %{{.*}}[0] : !llvm<"{ float*, float*, i64 }"> +// CHECK-NEXT: llvm.insertvalue %[[ptr]], %{{.*}}[1] : !llvm<"{ float*, float*, i64 }"> +// CHECK-NEXT: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK-NEXT: llvm.insertvalue %[[c0]], %{{.*}}[2] : !llvm<"{ float*, float*, i64 }"> + +// ALLOCA-NOT: malloc +// ALLOCA: alloca +// ALLOCA-NOT: malloc + +// BAREPTR-NEXT: llvm.mlir.constant(1 : index) : !llvm.i64 +// BAREPTR-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*"> +// BAREPTR-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// BAREPTR-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> +// BAREPTR-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64 +// BAREPTR-NEXT: llvm.mul %{{.*}}, %[[sizeof]] : !llvm.i64 +// BAREPTR-NEXT: llvm.call @malloc(%{{.*}}) : (!llvm.i64) -> !llvm<"i8*"> +// BAREPTR-NEXT: %[[ptr:.*]] = llvm.bitcast %{{.*}} : !llvm<"i8*"> to !llvm<"float*"> +// BAREPTR-NEXT: llvm.mlir.undef : !llvm<"{ float*, float*, i64 }"> +// BAREPTR-NEXT: llvm.insertvalue %[[ptr]], %{{.*}}[0] : !llvm<"{ float*, float*, i64 }"> +// BAREPTR-NEXT: llvm.insertvalue %[[ptr]], %{{.*}}[1] : !llvm<"{ float*, float*, i64 }"> +// BAREPTR-NEXT: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// BAREPTR-NEXT: llvm.insertvalue %[[c0]], %{{.*}}[2] : !llvm<"{ float*, float*, i64 }"> + %0 = alloc() : memref + return %0 : memref +} + +// ----- + +// CHECK-LABEL: func @zero_d_dealloc(%{{.*}}: !llvm<"{ float*, float*, i64 }*">) { +// BAREPTR-LABEL: func @zero_d_dealloc(%{{.*}}: !llvm<"float*">) { +func @zero_d_dealloc(%arg0: memref) { +// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64 }*"> +// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][0] : !llvm<"{ float*, float*, i64 }"> +// CHECK-NEXT: %[[bc:.*]] = llvm.bitcast %[[ptr]] : !llvm<"float*"> to !llvm<"i8*"> +// CHECK-NEXT: llvm.call @free(%[[bc]]) : (!llvm<"i8*">) -> () + +// BAREPTR: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, float*, i64 }"> +// BAREPTR-NEXT: %[[bc:.*]] = llvm.bitcast %[[ptr]] : !llvm<"float*"> to !llvm<"i8*"> +// BAREPTR-NEXT: llvm.call @free(%[[bc]]) : (!llvm<"i8*">) -> () + dealloc %arg0 : memref + return +} + +// ----- + +// CHECK-LABEL: func @aligned_1d_alloc( +// BAREPTR-LABEL: func @aligned_1d_alloc( +func @aligned_1d_alloc() -> memref<42xf32> { +// CHECK-NEXT: llvm.mlir.constant(42 : index) : !llvm.i64 +// CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*"> +// CHECK-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> +// CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64 +// CHECK-NEXT: llvm.mul %{{.*}}, %[[sizeof]] : !llvm.i64 +// CHECK-NEXT: %[[alignment:.*]] = llvm.mlir.constant(8 : index) : !llvm.i64 +// CHECK-NEXT: %[[alignmentMinus1:.*]] = llvm.add {{.*}}, %[[alignment]] : !llvm.i64 +// CHECK-NEXT: %[[allocsize:.*]] = llvm.sub %[[alignmentMinus1]], %[[one]] : !llvm.i64 +// CHECK-NEXT: %[[allocated:.*]] = llvm.call @malloc(%[[allocsize]]) : (!llvm.i64) -> !llvm<"i8*"> +// CHECK-NEXT: %[[ptr:.*]] = llvm.bitcast %{{.*}} : !llvm<"i8*"> to !llvm<"float*"> +// CHECK-NEXT: llvm.mlir.undef : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> +// CHECK-NEXT: llvm.insertvalue %[[ptr]], %{{.*}}[0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> +// CHECK-NEXT: %[[allocatedAsInt:.*]] = llvm.ptrtoint %[[allocated]] : !llvm<"i8*"> to !llvm.i64 +// CHECK-NEXT: %[[alignAdj1:.*]] = llvm.urem %[[allocatedAsInt]], %[[alignment]] : !llvm.i64 +// CHECK-NEXT: %[[alignAdj2:.*]] = llvm.sub %[[alignment]], %[[alignAdj1]] : !llvm.i64 +// CHECK-NEXT: %[[alignAdj3:.*]] = llvm.urem %[[alignAdj2]], %[[alignment]] : !llvm.i64 +// CHECK-NEXT: %[[aligned:.*]] = llvm.getelementptr %9[%[[alignAdj3]]] : (!llvm<"i8*">, !llvm.i64) -> !llvm<"i8*"> +// CHECK-NEXT: %[[alignedBitCast:.*]] = llvm.bitcast %[[aligned]] : !llvm<"i8*"> to !llvm<"float*"> +// CHECK-NEXT: llvm.insertvalue %[[alignedBitCast]], %{{.*}}[1] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> +// CHECK-NEXT: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK-NEXT: llvm.insertvalue %[[c0]], %{{.*}}[2] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> + +// BAREPTR-NEXT: llvm.mlir.constant(42 : index) : !llvm.i64 +// BAREPTR-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*"> +// BAREPTR-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// BAREPTR-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> +// BAREPTR-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64 +// BAREPTR-NEXT: llvm.mul %{{.*}}, %[[sizeof]] : !llvm.i64 +// BAREPTR-NEXT: %[[alignment:.*]] = llvm.mlir.constant(8 : index) : !llvm.i64 +// BAREPTR-NEXT: %[[alignmentMinus1:.*]] = llvm.add {{.*}}, %[[alignment]] : !llvm.i64 +// BAREPTR-NEXT: %[[allocsize:.*]] = llvm.sub %[[alignmentMinus1]], %[[one]] : !llvm.i64 +// BAREPTR-NEXT: %[[allocated:.*]] = llvm.call @malloc(%[[allocsize]]) : (!llvm.i64) -> !llvm<"i8*"> +// BAREPTR-NEXT: %[[ptr:.*]] = llvm.bitcast %{{.*}} : !llvm<"i8*"> to !llvm<"float*"> +// BAREPTR-NEXT: llvm.mlir.undef : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> +// BAREPTR-NEXT: llvm.insertvalue %[[ptr]], %{{.*}}[0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> +// BAREPTR-NEXT: %[[allocatedAsInt:.*]] = llvm.ptrtoint %[[allocated]] : !llvm<"i8*"> to !llvm.i64 +// BAREPTR-NEXT: %[[alignAdj1:.*]] = llvm.urem %[[allocatedAsInt]], %[[alignment]] : !llvm.i64 +// BAREPTR-NEXT: %[[alignAdj2:.*]] = llvm.sub %[[alignment]], %[[alignAdj1]] : !llvm.i64 +// BAREPTR-NEXT: %[[alignAdj3:.*]] = llvm.urem %[[alignAdj2]], %[[alignment]] : !llvm.i64 +// BAREPTR-NEXT: %[[aligned:.*]] = llvm.getelementptr %9[%[[alignAdj3]]] : (!llvm<"i8*">, !llvm.i64) -> !llvm<"i8*"> +// BAREPTR-NEXT: %[[alignedBitCast:.*]] = llvm.bitcast %[[aligned]] : !llvm<"i8*"> to !llvm<"float*"> +// BAREPTR-NEXT: llvm.insertvalue %[[alignedBitCast]], %{{.*}}[1] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> +// BAREPTR-NEXT: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// BAREPTR-NEXT: llvm.insertvalue %[[c0]], %{{.*}}[2] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> + %0 = alloc() {alignment = 8} : memref<42xf32> + return %0 : memref<42xf32> +} + +// ----- + +// CHECK-LABEL: func @static_alloc() -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> { +// BAREPTR-LABEL: func @static_alloc() -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> { +func @static_alloc() -> memref<32x18xf32> { +// CHECK-NEXT: %[[sz1:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64 +// CHECK-NEXT: %[[sz2:.*]] = llvm.mlir.constant(18 : index) : !llvm.i64 +// CHECK-NEXT: %[[num_elems:.*]] = llvm.mul %0, %1 : !llvm.i64 +// CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*"> +// CHECK-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> +// CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64 +// CHECK-NEXT: %[[bytes:.*]] = llvm.mul %[[num_elems]], %[[sizeof]] : !llvm.i64 +// CHECK-NEXT: %[[allocated:.*]] = llvm.call @malloc(%[[bytes]]) : (!llvm.i64) -> !llvm<"i8*"> +// CHECK-NEXT: llvm.bitcast %[[allocated]] : !llvm<"i8*"> to !llvm<"float*"> + +// BAREPTR-NEXT: %[[sz1:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64 +// BAREPTR-NEXT: %[[sz2:.*]] = llvm.mlir.constant(18 : index) : !llvm.i64 +// BAREPTR-NEXT: %[[num_elems:.*]] = llvm.mul %[[sz1]], %[[sz2]] : !llvm.i64 +// BAREPTR-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*"> +// BAREPTR-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// BAREPTR-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> +// BAREPTR-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64 +// BAREPTR-NEXT: %[[bytes:.*]] = llvm.mul %[[num_elems]], %[[sizeof]] : !llvm.i64 +// BAREPTR-NEXT: %[[allocated:.*]] = llvm.call @malloc(%[[bytes]]) : (!llvm.i64) -> !llvm<"i8*"> +// BAREPTR-NEXT: llvm.bitcast %[[allocated]] : !llvm<"i8*"> to !llvm<"float*"> + %0 = alloc() : memref<32x18xf32> + return %0 : memref<32x18xf32> +} + +// ----- + +// CHECK-LABEL: func @static_dealloc(%{{.*}}: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">) { +// BAREPTR-LABEL: func @static_dealloc(%{{.*}}: !llvm<"float*">) { +func @static_dealloc(%static: memref<10x8xf32>) { +// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> +// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK-NEXT: %[[bc:.*]] = llvm.bitcast %[[ptr]] : !llvm<"float*"> to !llvm<"i8*"> +// CHECK-NEXT: llvm.call @free(%[[bc]]) : (!llvm<"i8*">) -> () + +// BAREPTR: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// BAREPTR-NEXT: %[[bc:.*]] = llvm.bitcast %[[ptr]] : !llvm<"float*"> to !llvm<"i8*"> +// BAREPTR-NEXT: llvm.call @free(%[[bc]]) : (!llvm<"i8*">) -> () + dealloc %static : memref<10x8xf32> + return +} + +// ----- + +// CHECK-LABEL: func @zero_d_load(%{{.*}}: !llvm<"{ float*, float*, i64 }*">) -> !llvm.float { +// BAREPTR-LABEL: func @zero_d_load(%{{.*}}: !llvm<"float*">) -> !llvm.float +func @zero_d_load(%arg0: memref) -> f32 { +// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64 }*"> +// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64 }"> +// CHECK-NEXT: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[c0]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> +// CHECK-NEXT: %{{.*}} = llvm.load %[[addr]] : !llvm<"float*"> + +// BAREPTR: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, float*, i64 }"> +// BAREPTR-NEXT: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// BAREPTR-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[c0]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> +// BAREPTR-NEXT: llvm.load %[[addr:.*]] : !llvm<"float*"> + %0 = load %arg0[] : memref + return %0 : f32 +} + +// ----- + +// CHECK-LABEL: func @static_load( +// CHECK-SAME: %[[A:.*]]: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">, %[[I:.*]]: !llvm.i64, %[[J:.*]]: !llvm.i64 +// BAREPTR-LABEL: func @static_load +// BAREPTR-SAME: (%[[A:.*]]: !llvm<"float*">, %[[I:.*]]: !llvm.i64, %[[J:.*]]: !llvm.i64) { +func @static_load(%static : memref<10x42xf32>, %i : index, %j : index) { +// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> +// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK-NEXT: %[[st0:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64 +// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64 +// CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64 +// CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64 +// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64 +// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> +// CHECK-NEXT: llvm.load %[[addr]] : !llvm<"float*"> + +// BAREPTR: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// BAREPTR-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// BAREPTR-NEXT: %[[st0:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64 +// BAREPTR-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64 +// BAREPTR-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64 +// BAREPTR-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// BAREPTR-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64 +// BAREPTR-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64 +// BAREPTR-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> +// BAREPTR-NEXT: llvm.load %[[addr]] : !llvm<"float*"> + %0 = load %static[%i, %j] : memref<10x42xf32> + return +} + +// ----- + +// CHECK-LABEL: func @zero_d_store(%arg0: !llvm<"{ float*, float*, i64 }*">, %arg1: !llvm.float) { +// BAREPTR-LABEL: func @zero_d_store +// BAREPTR-SAME: (%[[A:.*]]: !llvm<"float*">, %[[val:.*]]: !llvm.float) +func @zero_d_store(%arg0: memref, %arg1: f32) { +// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64 }*"> +// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64 }"> +// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> +// CHECK-NEXT: llvm.store %arg1, %[[addr]] : !llvm<"float*"> + +// BAREPTR: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, float*, i64 }"> +// BAREPTR-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// BAREPTR-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> +// BAREPTR-NEXT: llvm.store %[[val]], %[[addr]] : !llvm<"float*"> + store %arg1, %arg0[] : memref + return +} + +// ----- + +// CHECK-LABEL: func @static_store +// BAREPTR-LABEL: func @static_store +// BAREPTR-SAME: %[[A:.*]]: !llvm<"float*"> +func @static_store(%static : memref<10x42xf32>, %i : index, %j : index, %val : f32) { +// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> +// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK-NEXT: %[[st0:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64 +// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64 +// CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64 +// CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64 +// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64 +// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> +// CHECK-NEXT: llvm.store %arg3, %[[addr]] : !llvm<"float*"> + +// BAREPTR: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// BAREPTR-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// BAREPTR-NEXT: %[[st0:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64 +// BAREPTR-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64 +// BAREPTR-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64 +// BAREPTR-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// BAREPTR-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64 +// BAREPTR-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64 +// BAREPTR-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> +// BAREPTR-NEXT: llvm.store %{{.*}}, %[[addr]] : !llvm<"float*"> + store %val, %static[%i, %j] : memref<10x42xf32> + return +} + +// ----- + +// CHECK-LABEL: func @static_memref_dim(%arg0: !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }*">) { +// BAREPTR-LABEL: func @static_memref_dim(%{{.*}}: !llvm<"float*">) { +func @static_memref_dim(%static : memref<42x32x15x13x27xf32>) { +// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }*"> +// CHECK-NEXT: llvm.mlir.constant(42 : index) : !llvm.i64 +// BAREPTR: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> +// BAREPTR-NEXT: llvm.mlir.constant(42 : index) : !llvm.i64 + %0 = dim %static, 0 : memref<42x32x15x13x27xf32> +// CHECK-NEXT: llvm.mlir.constant(32 : index) : !llvm.i64 +// BAREPTR-NEXT: llvm.mlir.constant(32 : index) : !llvm.i64 + %1 = dim %static, 1 : memref<42x32x15x13x27xf32> +// CHECK-NEXT: llvm.mlir.constant(15 : index) : !llvm.i64 +// BAREPTR-NEXT: llvm.mlir.constant(15 : index) : !llvm.i64 + %2 = dim %static, 2 : memref<42x32x15x13x27xf32> +// CHECK-NEXT: llvm.mlir.constant(13 : index) : !llvm.i64 +// BAREPTR-NEXT: llvm.mlir.constant(13 : index) : !llvm.i64 + %3 = dim %static, 3 : memref<42x32x15x13x27xf32> +// CHECK-NEXT: llvm.mlir.constant(27 : index) : !llvm.i64 +// BAREPTR-NEXT: llvm.mlir.constant(27 : index) : !llvm.i64 + %4 = dim %static, 4 : memref<42x32x15x13x27xf32> + return +}