diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -693,6 +693,71 @@ let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// ExtractMetadataOp +//===----------------------------------------------------------------------===// + +def MemRef_ExtractStridedMetadataOp : MemRef_Op<"extract_strided_metadata", + [SameVariadicResultSize]> { + let summary = "Extracts a buffer base with offset and strides"; + let description = [{ + Extracts a base buffer, offset and strides. This op allows additional layers + of transformations and foldings to be added as lowering progresses from + higher-level dialect to lower-level dialects such as the LLVM dialect. + + The op requires a strided memref source operand, if the source operand is not + a strided memref then verification fails. + + This operation can be viewed as the natural complement to the + `memref.reinterpret_cast` op. + + Intended Use Cases: + + The main use case is to lift the materialization of the internal logic of + ops that manipulate metadata outside of the LLVM dialect. This makes lowering + more progressive and brings the following benefits: + - foldings and canonicalizations can happen at a higher level in MLIR: + before this op existed, lowering to LLVM would create large amounts of + LLVMIR. Even when LLVM does a good job at folding the low-level IR from + a performance perspective, it is unnecessarily opaque and inefficient to + send unkempt IR to LLVM. This part could be solved by replicating LLVM + foldings and canonicalizations in the LLVM dialect but this would still + not solve the next problem. + - not all users of MLIR want to lower to LLVM and the information to e.g. + lower to library calls (e.g. libxsmm) or to SPIR-V was just not + available. + + Example: + + ``` + %base, %offset, %sizes, %strides = memref.extract_strided_metadata %memref + : memref<10x?xf32> + + // After folding, the type of %m2 can be memref<10x?xf32> and further + // folded to %memref. + %m2 = memref.reinterpret_cast %base to + offset: [%offset], + sizes: [%sizes#0, %sizes#1], + strides: [%strides#0, %strides#1] + : memref to memref + ``` + }]; + + let arguments = (ins + AnyStridedMemRef:$source + ); + let results = (outs + AnyStridedMemRefOfRank<0>:$base_buffer, + Index:$offset, + Variadic:$sizes, + Variadic:$strides + ); + + let assemblyFormat = [{ + $source `:` type($source) `->` type(results) attr-dict + }]; +} + //===----------------------------------------------------------------------===// // GenericAtomicRMWOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir --- a/mlir/test/Dialect/MemRef/ops.mlir +++ b/mlir/test/Dialect/MemRef/ops.mlir @@ -336,3 +336,20 @@ } { index_attr = 8 : index } return } + +// ----- + +func.func @extract_strided_metadata(%memref : memref<10x?xf32>) + -> memref { + + %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %memref + : memref<10x?xf32> -> memref, index, index, index, index, index + + %m2 = memref.reinterpret_cast %base to + offset: [%offset], + sizes: [%sizes#0, %sizes#1], + strides: [%strides#0, %strides#1] + : memref to memref + + return %m2: memref +}