diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -693,6 +693,71 @@ let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// ExtractMetadataOp +//===----------------------------------------------------------------------===// + +def MemRef_ExtractStridedMetadataOp : MemRef_Op<"extract_strided_metadata", + [SameVariadicResultSize]> { + let summary = "Extracts a buffer base with offset and strides"; + let description = [{ + Extracts a base buffer, offset and strides. This op allows additional layers + of transformations and foldings to be added as lowering progresses from + higher-level dialect to lower-level dialects such as the LLVM dialect. + + The op requires a strided memref source operand. If the source operand is not + a strided memref, then verification fails. + + This operation is also useful for completeness to the existing memref.dim op. + While accessing strides, offsets and the base pointer independently is not + available, this is useful for composing with its natural complement op: + `memref.reinterpret_cast`. + + Intended Use Cases: + + The main use case is to expose the logic for manipulate memref metadata at a + higher level than the LLVM dialect. + This makes lowering more progressive and brings the following benefits: + - not all users of MLIR want to lower to LLVM and the information to e.g. + lower to library calls---like libxsmm---or to SPIR-V was not available. + - foldings and canonicalizations can happen at a higher level in MLIR: + before this op existed, lowering to LLVM would create large amounts of + LLVMIR. Even when LLVM does a good job at folding the low-level IR from + a performance perspective, it is unnecessarily opaque and inefficient to + send unkempt IR to LLVM. + + Example: + + ```mlir + %base, %offset, %sizes:2, %strides:2 = + memref.extract_strided_metadata %memref : + memref<10x?xf32>, index, index, index, index, index + + // After folding, the type of %m2 can be memref<10x?xf32> and further + // folded to %memref. + %m2 = memref.reinterpret_cast %base to + offset: [%offset], + sizes: [%sizes#0, %sizes#1], + strides: [%strides#0, %strides#1] + : memref to memref + ``` + }]; + + let arguments = (ins + AnyStridedMemRef:$source + ); + let results = (outs + AnyStridedMemRefOfRank<0>:$base_buffer, + Index:$offset, + Variadic:$sizes, + Variadic:$strides + ); + + let assemblyFormat = [{ + $source `:` type($source) `->` type(results) attr-dict + }]; +} + //===----------------------------------------------------------------------===// // GenericAtomicRMWOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir --- a/mlir/test/Dialect/MemRef/ops.mlir +++ b/mlir/test/Dialect/MemRef/ops.mlir @@ -336,3 +336,20 @@ } { index_attr = 8 : index } return } + +// ----- + +func.func @extract_strided_metadata(%memref : memref<10x?xf32>) + -> memref { + + %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %memref + : memref<10x?xf32> -> memref, index, index, index, index, index + + %m2 = memref.reinterpret_cast %base to + offset: [%offset], + sizes: [%sizes#0, %sizes#1], + strides: [%strides#0, %strides#1] + : memref to memref + + return %m2: memref +}