diff --git a/mlir/include/mlir/Analysis/DataLayoutAnalysis.h b/mlir/include/mlir/Analysis/DataLayoutAnalysis.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Analysis/DataLayoutAnalysis.h @@ -0,0 +1,48 @@ +//===- DataLayoutAnalysis.h - API for Querying Nested Data Layout -*- C++ -*-=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_DATALAYOUTANALYSIS_H +#define MLIR_ANALYSIS_DATALAYOUTANALYSIS_H + +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/DenseMap.h" + +#include + +namespace mlir { + +class Operation; +class DataLayout; + +/// Stores data layout objects for each operation that specifies the data layout +/// above and below the given operation. +class DataLayoutAnalysis { +public: + /// Constructs the data layouts. + explicit DataLayoutAnalysis(Operation *root); + + /// Returns the data layout active active at the given operation, that is the + /// data layout specified by the closest ancestor that can specify one, or the + /// default layout if there is no such ancestor. + const DataLayout &getAbove(Operation *operation) const; + + /// Returns the data layout specified by the given operation or its closest + /// ancestor that can specify one. + const DataLayout &getAtOrAbove(Operation *operation) const; + +private: + /// Storage for individual data layouts. + DenseMap> layouts; + + /// Default data layout in case no operations specify one. + std::unique_ptr defaultLayout; +}; + +} // namespace mlir + +#endif // MLIR_ANALYSIS_DATALAYOUTANALYSIS_H diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -29,6 +29,7 @@ class BaseMemRefType; class ComplexType; +class DataLayoutAnalysis; class LLVMTypeConverter; class UnrankedMemRefType; @@ -62,10 +63,14 @@ using TypeConverter::convertType; /// Create an LLVMTypeConverter using the default LowerToLLVMOptions. - LLVMTypeConverter(MLIRContext *ctx); + /// Optionally takes a data layout analysis to use in conversions. + LLVMTypeConverter(MLIRContext *ctx, + const DataLayoutAnalysis *analysis = nullptr); - /// Create an LLVMTypeConverter using custom LowerToLLVMOptions. - LLVMTypeConverter(MLIRContext *ctx, const LowerToLLVMOptions &options); + /// Create an LLVMTypeConverter using custom LowerToLLVMOptions. Optionally + /// takes a data layout analysis to use in conversions. + LLVMTypeConverter(MLIRContext *ctx, const LowerToLLVMOptions &options, + const DataLayoutAnalysis *analysis = nullptr); /// Convert a function type. The arguments and results are converted one by /// one and results are packed into a wrapped LLVM IR structure type. `result` @@ -124,6 +129,11 @@ /// Returns the data layout to use during and after conversion. const llvm::DataLayout &getDataLayout() { return options.dataLayout; } + /// Returns the data layout analysis to query during conversion. + const DataLayoutAnalysis *getDataLayoutAnalysis() const { + return dataLayoutAnalysis; + } + /// Gets the LLVM representation of the index type. The returned type is an /// integer type with the size configured for this type converter. Type getIndexType(); @@ -134,6 +144,13 @@ /// Gets the pointer bitwidth. unsigned getPointerBitwidth(unsigned addressSpace = 0); + /// Returns the size of the memref descriptor object in bytes. + unsigned getMemRefDescriptorSize(MemRefType type, const DataLayout &layout); + + /// Returns the size of the unranked memref descriptor object in bytes. + unsigned getUnrankedMemRefDescriptorSize(UnrankedMemRefType type, + const DataLayout &layout); + protected: /// Pointer to the LLVM dialect. LLVM::LLVMDialect *llvmDialect; @@ -207,11 +224,14 @@ /// Convert a memref type to a bare pointer to the memref element type. Type convertMemRefToBarePtr(BaseMemRefType type); - // Convert a 1D vector type into an LLVM vector type. + /// Convert a 1D vector type into an LLVM vector type. Type convertVectorType(VectorType type); /// Options for customizing the llvm lowering. LowerToLLVMOptions options; + + /// Data layout analysis mapping scopes to layouts active in them. + const DataLayoutAnalysis *dataLayoutAnalysis; }; /// Helper class to produce LLVM dialect operations extracting or inserting @@ -634,11 +654,6 @@ return op->getResult(0).getType().cast(); } - LogicalResult match(Operation *op) const override { - MemRefType memRefType = getMemRefResultType(op); - return success(isConvertibleAndHasIdentityMaps(memRefType)); - } - // An `alloc` is converted into a definition of a memref descriptor value and // a call to `malloc` to allocate the underlying data buffer. The memref // descriptor is of the LLVM structure type where: @@ -655,8 +670,9 @@ // An `alloca` is converted into a definition of a memref descriptor value and // an llvm.alloca to allocate the underlying data buffer. - void rewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override; + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; }; namespace LLVM { diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h @@ -32,7 +32,7 @@ class LowerToLLVMOptions { public: explicit LowerToLLVMOptions(MLIRContext *ctx); - explicit LowerToLLVMOptions(MLIRContext *ctx, DataLayout dl); + explicit LowerToLLVMOptions(MLIRContext *ctx, const DataLayout &dl); bool useBarePtrCallConv = false; bool emitCWrappers = false; diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -272,7 +272,8 @@ } inline bool BaseMemRefType::isValidElementType(Type type) { - return type.isIntOrIndexOrFloat() || type.isa() || + return type.isIntOrIndexOrFloat() || + type.isa() || type.isa(); } diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -313,6 +313,7 @@ - built-in index type; - built-in floating point types; - built-in vector types with elements of the above types; + - another memref type; - any other type implementing `MemRefElementTypeInterface`. ##### Codegen of Unranked Memref diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt --- a/mlir/lib/Analysis/CMakeLists.txt +++ b/mlir/lib/Analysis/CMakeLists.txt @@ -5,6 +5,7 @@ BufferViewFlowAnalysis.cpp CallGraph.cpp DataFlowAnalysis.cpp + DataLayoutAnalysis.cpp LinearTransform.cpp Liveness.cpp LoopAnalysis.cpp @@ -22,6 +23,7 @@ BufferViewFlowAnalysis.cpp CallGraph.cpp DataFlowAnalysis.cpp + DataLayoutAnalysis.cpp Liveness.cpp NumberOfExecutions.cpp SliceAnalysis.cpp diff --git a/mlir/lib/Analysis/DataLayoutAnalysis.cpp b/mlir/lib/Analysis/DataLayoutAnalysis.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Analysis/DataLayoutAnalysis.cpp @@ -0,0 +1,51 @@ +//===- DataLayoutAnalysis.cpp ---------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/DataLayoutAnalysis.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Operation.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" + +using namespace mlir; + +DataLayoutAnalysis::DataLayoutAnalysis(Operation *root) + : defaultLayout(std::make_unique(DataLayoutOpInterface())) { + // Construct a DataLayout if possible from the op. + auto computeLayout = [this](Operation *op) { + if (auto iface = dyn_cast(op)) + layouts[op] = std::make_unique(iface); + if (auto module = dyn_cast(op)) + layouts[op] = std::make_unique(module); + }; + + // Compute layouts for both ancestors and descendants. + root->walk(computeLayout); + for (Operation *ancestor = root->getParentOp(); ancestor != nullptr; + ancestor = ancestor->getParentOp()) { + computeLayout(ancestor); + } +} + +const DataLayout &DataLayoutAnalysis::getAbove(Operation *operation) const { + for (Operation *ancestor = operation->getParentOp(); ancestor != nullptr; + ancestor = ancestor->getParentOp()) { + auto it = layouts.find(ancestor); + if (it != layouts.end()) + return *it->getSecond(); + } + + // Fallback to the default layout. + return *defaultLayout; +} + +const DataLayout &DataLayoutAnalysis::getAtOrAbove(Operation *operation) const { + auto it = layouts.find(operation); + if (it != layouts.end()) + return *it->getSecond(); + return getAbove(operation); +} diff --git a/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt b/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt --- a/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt @@ -12,6 +12,7 @@ Core LINK_LIBS PUBLIC + MLIRAnalysis MLIRDataLayoutInterfaces MLIRLLVMIR MLIRMath diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "../PassDetail.h" +#include "mlir/Analysis/DataLayoutAnalysis.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" @@ -101,14 +102,16 @@ } /// Create an LLVMTypeConverter using default LowerToLLVMOptions. -LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx) - : LLVMTypeConverter(ctx, LowerToLLVMOptions(ctx)) {} +LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx, + const DataLayoutAnalysis *analysis) + : LLVMTypeConverter(ctx, LowerToLLVMOptions(ctx), analysis) {} /// Create an LLVMTypeConverter using custom LowerToLLVMOptions. LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx, - const LowerToLLVMOptions &options) - : llvmDialect(ctx->getOrLoadDialect()), - options(options) { + const LowerToLLVMOptions &options, + const DataLayoutAnalysis *analysis) + : llvmDialect(ctx->getOrLoadDialect()), options(options), + dataLayoutAnalysis(analysis) { assert(llvmDialect && "LLVM IR dialect is not registered"); // Register conversions for the builtin types. @@ -342,6 +345,14 @@ return results; } +unsigned LLVMTypeConverter::getMemRefDescriptorSize(MemRefType type, + const DataLayout &layout) { + // Compute the descriptor size given that of its components indicated above. + unsigned space = type.getMemorySpaceAsInt(); + return 2 * llvm::divideCeil(getPointerBitwidth(space), 8) + + (1 + 2 * type.getRank()) * layout.getTypeSize(getIndexType()); +} + /// Converts MemRefType to LLVMType. A MemRefType is converted to a struct that /// packs the descriptor fields as defined by `getMemRefDescriptorFields`. Type LLVMTypeConverter::convertMemRefType(MemRefType type) { @@ -369,6 +380,15 @@ LLVM::LLVMPointerType::get(IntegerType::get(&getContext(), 8))}; } +unsigned +LLVMTypeConverter::getUnrankedMemRefDescriptorSize(UnrankedMemRefType type, + const DataLayout &layout) { + // Compute the descriptor size given that of its components indicated above. + unsigned space = type.getMemorySpaceAsInt(); + return layout.getTypeSize(getIndexType()) + + llvm::divideCeil(getPointerBitwidth(space), 8); +} + Type LLVMTypeConverter::convertUnrankedMemRefType(UnrankedMemRefType type) { if (!convertType(type.getElementType())) return {}; @@ -1900,26 +1920,30 @@ : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), converter) {} - /// Returns the memref's element size in bytes. + /// Returns the memref's element size in bytes using the data layout active at + /// `op`. // TODO: there are other places where this is used. Expose publicly? - static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { - auto elementType = memRefType.getElementType(); - - unsigned sizeInBits; - if (elementType.isIntOrFloat()) { - sizeInBits = elementType.getIntOrFloatBitWidth(); - } else { - auto vectorType = elementType.cast(); - sizeInBits = - vectorType.getElementTypeBitWidth() * vectorType.getNumElements(); + unsigned getMemRefEltSizeInBytes(MemRefType memRefType, Operation *op) const { + const DataLayout *layout = &defaultLayout; + if (const DataLayoutAnalysis *analysis = + getTypeConverter()->getDataLayoutAnalysis()) { + layout = &analysis->getAbove(op); } - return llvm::divideCeil(sizeInBits, 8); + Type elementType = memRefType.getElementType(); + if (auto memRefElementType = elementType.dyn_cast()) + return getTypeConverter()->getMemRefDescriptorSize(memRefElementType, + *layout); + if (auto memRefElementType = elementType.dyn_cast()) + return getTypeConverter()->getUnrankedMemRefDescriptorSize( + memRefElementType, *layout); + return layout->getTypeSize(elementType); } /// Returns true if the memref size in bytes is known to be a multiple of - /// factor. - static bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor) { - uint64_t sizeDivisor = getMemRefEltSizeInBytes(type); + /// factor assuming the data layout active at `op`. + bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor, + Operation *op) const { + uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op); for (unsigned i = 0, e = type.getRank(); i < e; i++) { if (type.isDynamic(type.getDimSize(i))) continue; @@ -1938,7 +1962,7 @@ // Whenever we don't have alignment set, we will use an alignment // consistent with the element type; since the allocation size has to be a // power of two, we will bump to the next power of two if it already isn't. - auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType()); + auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType(), allocOp); return std::max(kMinAlignedAllocAlignment, llvm::PowerOf2Ceil(eltSizeBytes)); } @@ -1954,7 +1978,7 @@ // aligned_alloc requires size to be a multiple of alignment; we will pad // the size to the next multiple if necessary. - if (!isMemRefSizeMultipleOf(memRefType, alignment)) + if (!isMemRefSizeMultipleOf(memRefType, alignment, op)) sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment); Type elementPtrType = this->getElementPtrType(memRefType); @@ -1971,6 +1995,9 @@ /// The minimum alignment to use with aligned_alloc (has to be a power of 2). static constexpr uint64_t kMinAlignedAllocAlignment = 16UL; + + /// Default layout to use in absence of the corresponding analysis. + DataLayout defaultLayout; }; // Out of line definition, required till C++17. @@ -4068,8 +4095,10 @@ } ModuleOp m = getOperation(); + const auto &dataLayoutAnalysis = getAnalysis(); - LowerToLLVMOptions options(&getContext(), DataLayout(m)); + LowerToLLVMOptions options(&getContext(), + dataLayoutAnalysis.getAtOrAbove(m)); options.useBarePtrCallConv = useBarePtrCallConv; options.emitCWrappers = emitCWrappers; if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) @@ -4078,7 +4107,9 @@ (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc : LowerToLLVMOptions::AllocLowering::Malloc); options.dataLayout = llvm::DataLayout(this->dataLayout); - LLVMTypeConverter typeConverter(&getContext(), options); + + LLVMTypeConverter typeConverter(&getContext(), options, + &dataLayoutAnalysis); RewritePatternSet patterns(&getContext()); populateStdToLLVMConversionPatterns(typeConverter, patterns); @@ -4102,10 +4133,12 @@ return rewriter.create(loc, bumped, mod); } -void AllocLikeOpLLVMLowering::rewrite( +LogicalResult AllocLikeOpLLVMLowering::matchAndRewrite( Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { MemRefType memRefType = getMemRefResultType(op); + if (!isConvertibleAndHasIdentityMaps(memRefType)) + return rewriter.notifyMatchFailure(op, "incompatible memref type"); auto loc = op->getLoc(); // Get actual sizes of the memref as values: static sizes are constant @@ -4129,6 +4162,7 @@ // Return the final value of the descriptor. rewriter.replaceOp(op, {memRefDescriptor}); + return success(); } mlir::LLVMConversionTarget::LLVMConversionTarget(MLIRContext &ctx) @@ -4159,6 +4193,6 @@ : LowerToLLVMOptions(ctx, DataLayout()) {} mlir::LowerToLLVMOptions::LowerToLLVMOptions(MLIRContext *ctx, - mlir::DataLayout dl) { + const DataLayout &dl) { indexBitwidth = dl.getTypeSizeInBits(IndexType::get(ctx)); } diff --git a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir @@ -1,5 +1,5 @@ -// RUN: mlir-opt -convert-std-to-llvm %s | FileCheck %s -// RUN: mlir-opt -convert-std-to-llvm='use-aligned-alloc=1' %s | FileCheck %s --check-prefix=ALIGNED-ALLOC +// RUN: mlir-opt -split-input-file -convert-std-to-llvm %s | FileCheck %s +// RUN: mlir-opt -split-input-file -convert-std-to-llvm='use-aligned-alloc=1' %s | FileCheck %s --check-prefix=ALIGNED-ALLOC // CHECK-LABEL: func @check_strided_memref_arguments( // CHECK-COUNT-2: !llvm.ptr @@ -529,3 +529,98 @@ // CHECK: ^bb3: // CHECK: llvm.return + +// ----- + +// ALIGNED-ALLOC-LABEL: @memref_of_memref +func @memref_of_memref() { + // Sizeof computation is as usual. + // ALIGNED-ALLOC: %[[NULL:.*]] = llvm.mlir.null + // ALIGNED-ALLOC: %[[PTR:.*]] = llvm.getelementptr + // ALIGNED-ALLOC: %[[SIZEOF:.*]] = llvm.ptrtoint + + // Static alignment should be computed as ceilPowerOf2(2 * sizeof(pointer) + + // (1 + 2 * rank) * sizeof(index) = ceilPowerOf2(2 * 8 + 3 * 8) = 64. + // ALIGNED-ALLOC: llvm.mlir.constant(64 : index) + + // Check that the types are converted as expected. + // ALIGNED-ALLOC: llvm.call @aligned_alloc + // ALIGNED-ALLOC: llvm.bitcast %{{.*}} : !llvm.ptr to + // ALIGNED-ALLOC-SAME: !llvm. + // ALIGNED-ALLOC-SAME: [[INNER:ptr, ptr, i64, array<1 x i64>, array<1 x i64>\)>>]] + // ALIGNED-ALLOC: llvm.mlir.undef + // ALIGNED-ALLOC-SAME: !llvm.struct<([[INNER]], [[INNER]], i64, array<1 x i64>, array<1 x i64>)> + %0 = memref.alloc() : memref<1xmemref<1xf32>> + return +} + +// ----- + +module attributes { dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry> } { + // ALIGNED-ALLOC-LABEL: @memref_of_memref_32 + func @memref_of_memref_32() { + // Sizeof computation is as usual. + // ALIGNED-ALLOC: %[[NULL:.*]] = llvm.mlir.null + // ALIGNED-ALLOC: %[[PTR:.*]] = llvm.getelementptr + // ALIGNED-ALLOC: %[[SIZEOF:.*]] = llvm.ptrtoint + + // Static alignment should be computed as ceilPowerOf2(2 * sizeof(pointer) + + // (1 + 2 * rank) * sizeof(index) = ceilPowerOf2(2 * 8 + 3 * 4) = 32. + // ALIGNED-ALLOC: llvm.mlir.constant(32 : index) + + // Check that the types are converted as expected. + // ALIGNED-ALLOC: llvm.call @aligned_alloc + // ALIGNED-ALLOC: llvm.bitcast %{{.*}} : !llvm.ptr to + // ALIGNED-ALLOC-SAME: !llvm. + // ALIGNED-ALLOC-SAME: [[INNER:ptr, ptr, i32, array<1 x i32>, array<1 x i32>\)>>]] + // ALIGNED-ALLOC: llvm.mlir.undef + // ALIGNED-ALLOC-SAME: !llvm.struct<([[INNER]], [[INNER]], i32, array<1 x i32>, array<1 x i32>)> + %0 = memref.alloc() : memref<1xmemref<1xf32>> + return + } +} + + +// ----- + +// ALIGNED-ALLOC-LABEL: @memref_of_memref_of_memref +func @memref_of_memref_of_memref() { + // Sizeof computation is as usual, also check the type. + // ALIGNED-ALLOC: %[[NULL:.*]] = llvm.mlir.null : !llvm.ptr< + // ALIGNED-ALLOC-SAME: struct<( + // ALIGNED-ALLOC-SAME: [[INNER:ptr, ptr, i64, array<1 x i64>, array<1 x i64>\)>>]], + // ALIGNED-ALLOC-SAME: [[INNER]], + // ALIGNED-ALLOC-SAME: i64, array<1 x i64>, array<1 x i64> + // ALIGNED-ALLOC-SAME: )> + // ALIGNED-ALLOC-SAME: > + // ALIGNED-ALLOC: %[[PTR:.*]] = llvm.getelementptr + // ALIGNED-ALLOC: %[[SIZEOF:.*]] = llvm.ptrtoint + + // Static alignment should be computed as ceilPowerOf2(2 * sizeof(pointer) + + // (1 + 2 * rank) * sizeof(index) = ceilPowerOf2(2 * 8 + 3 * 8) = 64. + // ALIGNED-ALLOC: llvm.mlir.constant(64 : index) + // ALIGNED-ALLOC: llvm.call @aligned_alloc + %0 = memref.alloc() : memref<1 x memref<2 x memref<3 x f32>>> + return +} + +// ----- + +// ALIGNED-ALLOC-LABEL: @ranked_unranked +func @ranked_unranked() { + // ALIGNED-ALLOC: llvm.mlir.null + // ALIGNED-ALLOC-SAME: !llvm.[[INNER:ptr\)>>]] + // ALIGNED-ALLOC: llvm.getelementptr + // ALIGNED-ALLOC: llvm.ptrtoint + + // Static alignment should be computed as ceilPowerOf2(sizeof(index) + + // sizeof(pointer)) = 16. + // ALIGNED-ALLOC: llvm.mlir.constant(16 : index) + // ALIGNED-ALLOC: llvm.call @aligned_alloc + // ALIGNED-ALLOC: llvm.bitcast + // ALIGNED-ALLOC-SAME: !llvm.ptr to !llvm.[[INNER]] + %0 = memref.alloc() : memref<1 x memref<* x f32>> + memref.cast %0 : memref<1 x memref<* x f32>> to memref<* x memref<* x f32>> + return +} + diff --git a/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir @@ -447,3 +447,4 @@ // BAREPTR-SAME: memref< // BAREPTR-NOT: !llvm.ptr func private @unsupported_unranked_memref_element_type() -> memref<* x !test.memref_element> + diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -181,6 +181,18 @@ // CHECK: func private @memref_with_custom_elem(memref<1x?x!test.memref_element>) func private @memref_with_custom_elem(memref<1x?x!test.memref_element>) +// CHECK: func private @memref_of_memref(memref<1xmemref<1xf64>>) +func private @memref_of_memref(memref<1xmemref<1xf64>>) + +// CHECK: func private @memref_of_unranked_memref(memref<1xmemref<*xf32>>) +func private @memref_of_unranked_memref(memref<1xmemref<*xf32>>) + +// CHECK: func private @unranked_memref_of_memref(memref<*xmemref<1xf32>>) +func private @unranked_memref_of_memref(memref<*xmemref<1xf32>>) + +// CHECK: func private @unranked_memref_of_unranked_memref(memref<*xmemref<*xi32>>) +func private @unranked_memref_of_unranked_memref(memref<*xmemref<*xi32>>) + // CHECK: func private @unranked_memref_with_complex_elems(memref<*xcomplex>) func private @unranked_memref_with_complex_elems(memref<*xcomplex>) diff --git a/mlir/test/lib/Dialect/DLTI/TestDataLayoutQuery.cpp b/mlir/test/lib/Dialect/DLTI/TestDataLayoutQuery.cpp --- a/mlir/test/lib/Dialect/DLTI/TestDataLayoutQuery.cpp +++ b/mlir/test/lib/Dialect/DLTI/TestDataLayoutQuery.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "TestDialect.h" +#include "mlir/Analysis/DataLayoutAnalysis.h" #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/Pass/Pass.h" @@ -23,28 +24,14 @@ void runOnFunction() override { FuncOp func = getFunction(); Builder builder(func.getContext()); - DenseMap layouts; + const DataLayoutAnalysis &layouts = getAnalysis(); func.walk([&](test::DataLayoutQueryOp op) { // Skip the ops with already processed in a deeper call. if (op->getAttr("size")) return; - auto scope = op->getParentOfType(); - if (!layouts.count(scope)) { - layouts.try_emplace( - scope, scope ? cast(scope.getOperation()) - : nullptr); - } - auto module = op->getParentOfType(); - if (!layouts.count(module)) - layouts.try_emplace(module, module); - - Operation *closest = (scope && module && module->isProperAncestor(scope)) - ? scope.getOperation() - : module.getOperation(); - - const DataLayout &layout = layouts.find(closest)->getSecond(); + const DataLayout &layout = layouts.getAbove(op); unsigned size = layout.getTypeSize(op.getType()); unsigned bitsize = layout.getTypeSizeInBits(op.getType()); unsigned alignment = layout.getTypeABIAlignment(op.getType());