diff --git a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h --- a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h +++ b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h @@ -16,4 +16,17 @@ #ifndef MLIR_DIALECT_MEMREF_UTILS_MEMREFUTILS_H #define MLIR_DIALECT_MEMREF_UTILS_MEMREFUTILS_H +namespace mlir { + +class MemRefType; + +namespace memref { + +/// Returns true, if the memref type has static shapes and represents a +/// contiguous chunk of memory. +bool isStaticShapeAndContiguousRowMajor(MemRefType type); + +} // namespace memref +} // namespace mlir + #endif // MLIR_DIALECT_MEMREF_UTILS_MEMREFUTILS_H diff --git a/mlir/lib/Conversion/MemRefToLLVM/CMakeLists.txt b/mlir/lib/Conversion/MemRefToLLVM/CMakeLists.txt --- a/mlir/lib/Conversion/MemRefToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/MemRefToLLVM/CMakeLists.txt @@ -16,6 +16,7 @@ MLIRDataLayoutInterfaces MLIRLLVMCommonConversion MLIRMemRefDialect + MLIRMemRefUtils MLIRLLVMDialect MLIRTransforms ) diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -19,6 +19,7 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/IRMapping.h" #include "mlir/Pass/Pass.h" @@ -1055,34 +1056,6 @@ auto srcType = cast(op.getSource().getType()); auto targetType = cast(op.getTarget().getType()); - auto isStaticShapeAndContiguousRowMajor = [](MemRefType type) { - if (!type.hasStaticShape()) - return false; - - SmallVector strides; - int64_t offset; - if (failed(getStridesAndOffset(type, strides, offset))) - return false; - - // MemRef is contiguous if outer dimensions are size-1 and inner - // dimensions have unit strides. - int64_t runningStride = 1; - int64_t curDim = strides.size() - 1; - // Finds all inner dimensions with unit strides. - while (curDim >= 0 && strides[curDim] == runningStride) { - runningStride *= type.getDimSize(curDim); - --curDim; - } - - // Check if other dimensions are size-1. - while (curDim >= 0 && type.getDimSize(curDim) == 1) { - --curDim; - } - - // All dims are unit-strided or size-1. - return curDim < 0; - }; - auto isContiguousMemrefType = [&](BaseMemRefType type) { auto memrefType = dyn_cast(type); // We can use memcpy for memrefs if they have an identity layout or are @@ -1091,7 +1064,7 @@ return memrefType && (memrefType.getLayout().isIdentity() || (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 && - isStaticShapeAndContiguousRowMajor(memrefType))); + memref::isStaticShapeAndContiguousRowMajor(memrefType))); }; if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType)) diff --git a/mlir/lib/Dialect/MemRef/CMakeLists.txt b/mlir/lib/Dialect/MemRef/CMakeLists.txt --- a/mlir/lib/Dialect/MemRef/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/CMakeLists.txt @@ -1,3 +1,4 @@ add_subdirectory(IR) add_subdirectory(TransformOps) add_subdirectory(Transforms) +add_subdirectory(Utils) diff --git a/mlir/lib/Dialect/MemRef/Utils/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Utils/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/MemRef/Utils/CMakeLists.txt @@ -0,0 +1,9 @@ +add_mlir_dialect_library(MLIRMemRefUtils + MemRefUtils.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/inlude/mlir/Dialect/MemRef/Utils + + LINK_LIBS PUBLIC + MLIRIR +) diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp @@ -0,0 +1,48 @@ +//===- MemRefUtils.cpp - Utilities to support the MemRef dialect ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements utilities for the MemRef dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" + +namespace mlir { +namespace memref { + +bool isStaticShapeAndContiguousRowMajor(MemRefType type) { + if (!type.hasStaticShape()) + return false; + + SmallVector strides; + int64_t offset; + if (failed(getStridesAndOffset(type, strides, offset))) + return false; + + // MemRef is contiguous if outer dimensions are size-1 and inner + // dimensions have unit strides. + int64_t runningStride = 1; + int64_t curDim = strides.size() - 1; + // Finds all inner dimensions with unit strides. + while (curDim >= 0 && strides[curDim] == runningStride) { + runningStride *= type.getDimSize(curDim); + --curDim; + } + + // Check if other dimensions are size-1. + while (curDim >= 0 && type.getDimSize(curDim) == 1) { + --curDim; + } + + // All dims are unit-strided or size-1. + return curDim < 0; +}; + +} // namespace memref +} // namespace mlir