diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h @@ -55,6 +55,10 @@ /// load/store ops into `patterns`. std::unique_ptr createFoldSubViewOpsPass(); +/// Creates an interprocedural pass to normalize memrefs to have a trivial +/// (identity) layout map. +std::unique_ptr> createNormalizeMemRefsPass(); + /// Creates an operation pass to resolve `memref.dim` operations with values /// that are defined by operations that implement the /// `ReifyRankedShapeTypeShapeOpInterface`, in terms of shapes of its input diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td @@ -23,6 +23,122 @@ ]; } +def NormalizeMemRefs : Pass<"normalize-memrefs", "ModuleOp"> { + let summary = "Normalize memrefs"; + let description = [{ + This pass transforms memref types with a non-trivial + [layout map](https://mlir.llvm.org/docs/LangRef/#layout-map) into + memref types with an identity layout map, e.g. (i, j) -> (i, j). This + pass is inter-procedural, in the sense that it can modify function + interfaces and call sites that pass memref types. In order to modify + memref types while preserving the original behavior, users of those + memref types are also modified to incorporate the resulting layout map. + For instance, an [AffineLoadOp] + (https://mlir.llvm.org/docs/Dialects/Affine/#affineload-affineloadop) + will be updated to compose the layout map with with the affine expression + contained in the op. Operations marked with the [MemRefsNormalizable] + (https://mlir.llvm.org/docs/Traits/#memrefsnormalizable) trait are + expected to be normalizable. Supported operations include affine + operations, memref.alloc, memref.dealloc, and std.return. + + Given an appropriate layout map specified in the code, this transformation + can express tiled or linearized access to multi-dimensional data + structures, but will not modify memref types without an explicit layout + map. + + Currently this pass is limited to only modify + functions where all memref types can be normalized. If a function + contains any operations that are not MemRefNormalizable, then the function + and any functions that call or call it will not be modified. + + Input + + ```mlir + #tile = affine_map<(i) -> (i floordiv 4, i mod 4)> + func @matmul(%A: memref<16xf64, #tile>, + %B: index, %C: memref<16xf64>) -> (memref<16xf64, #tile>) { + affine.for %arg3 = 0 to 16 { + %a = affine.load %A[%arg3] : memref<16xf64, #tile> + %p = arith.mulf %a, %a : f64 + affine.store %p, %A[%arg3] : memref<16xf64, #tile> + } + %c = memref.alloc() : memref<16xf64, #tile> + %d = affine.load %c[0] : memref<16xf64, #tile> + return %A: memref<16xf64, #tile> + } + ``` + + Output + + ```mlir + func @matmul(%arg0: memref<4x4xf64>, %arg1: index, %arg2: memref<16xf64>) + -> memref<4x4xf64> { + affine.for %arg3 = 0 to 16 { + %3 = affine.load %arg0[%arg3 floordiv 4, %arg3 mod 4]: memref<4x4xf64> + %4 = arith.mulf %3, %3 : f64 + affine.store %4, %arg0[%arg3 floordiv 4, %arg3 mod 4]: memref<4x4xf64> + } + %0 = memref.alloc() : memref<4x4xf64> + %1 = affine.apply #map1() + %2 = affine.load %0[0, 0] : memref<4x4xf64> + return %arg0 : memref<4x4xf64> + } + ``` + + Input + + ``` + #linear8 = affine_map<(i, j) -> (i * 8 + j)> + func @linearize(%arg0: memref<8x8xi32, #linear8>, + %arg1: memref<8x8xi32, #linear8>, + %arg2: memref<8x8xi32, #linear8>) { + %c8 = arith.constant 8 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + affine.for %arg3 = %c0 to %c8 { + affine.for %arg4 = %c0 to %c8 { + affine.for %arg5 = %c0 to %c8 { + %0 = affine.load %arg0[%arg3, %arg5] : memref<8x8xi32, #linear8> + %1 = affine.load %arg1[%arg5, %arg4] : memref<8x8xi32, #linear8> + %2 = affine.load %arg2[%arg3, %arg4] : memref<8x8xi32, #linear8> + %3 = arith.muli %0, %1 : i32 + %4 = arith.addi %2, %3 : i32 + affine.store %4, %arg2[%arg3, %arg4] : memref<8x8xi32, #linear8> + } + } + } + return + } + ``` + + Output + + ```mlir + func @linearize(%arg0: memref<64xi32>, + %arg1: memref<64xi32>, + %arg2: memref<64xi32>) { + %c8 = arith.constant 8 : index + %c0 = arith.constant 0 : index + affine.for %arg3 = %c0 to %c8 { + affine.for %arg4 = %c0 to %c8 { + affine.for %arg5 = %c0 to %c8 { + %0 = affine.load %arg0[%arg3 * 8 + %arg5] : memref<64xi32> + %1 = affine.load %arg1[%arg5 * 8 + %arg4] : memref<64xi32> + %2 = affine.load %arg2[%arg3 * 8 + %arg4] : memref<64xi32> + %3 = arith.muli %0, %1 : i32 + %4 = arith.addi %2, %3 : i32 + affine.store %4, %arg2[%arg3 * 8 + %arg4] : memref<64xi32> + } + } + } + return + } + ``` + }]; + let constructor = "mlir::memref::createNormalizeMemRefsPass()"; + let dependentDialects = ["AffineDialect"]; +} + def ResolveRankedShapeTypeResultDims : Pass<"resolve-ranked-shaped-type-result-dims"> { let summary = "Resolve memref.dim of result values of ranked shape type"; diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -110,10 +110,6 @@ /// pass may *only* be scheduled on an operation that defines a SymbolTable. std::unique_ptr createSymbolDCEPass(); -/// Creates an interprocedural pass to normalize memrefs to have a trivial -/// (identity) layout map. -std::unique_ptr> createNormalizeMemRefsPass(); - //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -327,122 +327,6 @@ let constructor = "mlir::createLoopInvariantCodeMotionPass()"; } -def NormalizeMemRefs : Pass<"normalize-memrefs", "ModuleOp"> { - let summary = "Normalize memrefs"; - let description = [{ - This pass transforms memref types with a non-trivial - [layout map](https://mlir.llvm.org/docs/LangRef/#layout-map) into - memref types with an identity layout map, e.g. (i, j) -> (i, j). This - pass is inter-procedural, in the sense that it can modify function - interfaces and call sites that pass memref types. In order to modify - memref types while preserving the original behavior, users of those - memref types are also modified to incorporate the resulting layout map. - For instance, an [AffineLoadOp] - (https://mlir.llvm.org/docs/Dialects/Affine/#affineload-affineloadop) - will be updated to compose the layout map with with the affine expression - contained in the op. Operations marked with the [MemRefsNormalizable] - (https://mlir.llvm.org/docs/Traits/#memrefsnormalizable) trait are - expected to be normalizable. Supported operations include affine - operations, memref.alloc, memref.dealloc, and std.return. - - Given an appropriate layout map specified in the code, this transformation - can express tiled or linearized access to multi-dimensional data - structures, but will not modify memref types without an explicit layout - map. - - Currently this pass is limited to only modify - functions where all memref types can be normalized. If a function - contains any operations that are not MemRefNormalizable, then the function - and any functions that call or call it will not be modified. - - Input - - ```mlir - #tile = affine_map<(i) -> (i floordiv 4, i mod 4)> - func @matmul(%A: memref<16xf64, #tile>, - %B: index, %C: memref<16xf64>) -> (memref<16xf64, #tile>) { - affine.for %arg3 = 0 to 16 { - %a = affine.load %A[%arg3] : memref<16xf64, #tile> - %p = arith.mulf %a, %a : f64 - affine.store %p, %A[%arg3] : memref<16xf64, #tile> - } - %c = memref.alloc() : memref<16xf64, #tile> - %d = affine.load %c[0] : memref<16xf64, #tile> - return %A: memref<16xf64, #tile> - } - ``` - - Output - - ```mlir - func @matmul(%arg0: memref<4x4xf64>, %arg1: index, %arg2: memref<16xf64>) - -> memref<4x4xf64> { - affine.for %arg3 = 0 to 16 { - %3 = affine.load %arg0[%arg3 floordiv 4, %arg3 mod 4]: memref<4x4xf64> - %4 = arith.mulf %3, %3 : f64 - affine.store %4, %arg0[%arg3 floordiv 4, %arg3 mod 4]: memref<4x4xf64> - } - %0 = memref.alloc() : memref<4x4xf64> - %1 = affine.apply #map1() - %2 = affine.load %0[0, 0] : memref<4x4xf64> - return %arg0 : memref<4x4xf64> - } - ``` - - Input - - ``` - #linear8 = affine_map<(i, j) -> (i * 8 + j)> - func @linearize(%arg0: memref<8x8xi32, #linear8>, - %arg1: memref<8x8xi32, #linear8>, - %arg2: memref<8x8xi32, #linear8>) { - %c8 = arith.constant 8 : index - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - affine.for %arg3 = %c0 to %c8 { - affine.for %arg4 = %c0 to %c8 { - affine.for %arg5 = %c0 to %c8 { - %0 = affine.load %arg0[%arg3, %arg5] : memref<8x8xi32, #linear8> - %1 = affine.load %arg1[%arg5, %arg4] : memref<8x8xi32, #linear8> - %2 = affine.load %arg2[%arg3, %arg4] : memref<8x8xi32, #linear8> - %3 = arith.muli %0, %1 : i32 - %4 = arith.addi %2, %3 : i32 - affine.store %4, %arg2[%arg3, %arg4] : memref<8x8xi32, #linear8> - } - } - } - return - } - ``` - - Output - - ```mlir - func @linearize(%arg0: memref<64xi32>, - %arg1: memref<64xi32>, - %arg2: memref<64xi32>) { - %c8 = arith.constant 8 : index - %c0 = arith.constant 0 : index - affine.for %arg3 = %c0 to %c8 { - affine.for %arg4 = %c0 to %c8 { - affine.for %arg5 = %c0 to %c8 { - %0 = affine.load %arg0[%arg3 * 8 + %arg5] : memref<64xi32> - %1 = affine.load %arg1[%arg5 * 8 + %arg4] : memref<64xi32> - %2 = affine.load %arg2[%arg3 * 8 + %arg4] : memref<64xi32> - %3 = arith.muli %0, %1 : i32 - %4 = arith.addi %2, %3 : i32 - affine.store %4, %arg2[%arg3 * 8 + %arg4] : memref<64xi32> - } - } - } - return - } - ``` - }]; - let constructor = "mlir::createNormalizeMemRefsPass()"; - let dependentDialects = ["AffineDialect"]; -} - def ParallelLoopCollapsing : Pass<"parallel-loop-collapsing"> { let summary = "Collapse parallel loops to use less induction variables"; let constructor = "mlir::createParallelLoopCollapsingPass()"; diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRMemRefTransforms FoldSubViewOps.cpp + NormalizeMemRefs.cpp ResolveShapedTypeResultDims.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/lib/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp rename from mlir/lib/Transforms/NormalizeMemRefs.cpp rename to mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp --- a/mlir/lib/Transforms/NormalizeMemRefs.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp @@ -14,7 +14,7 @@ #include "PassDetail.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Transforms/Passes.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Transforms/Utils.h" #include "llvm/ADT/SmallSet.h" #include "llvm/Support/Debug.h" @@ -43,7 +43,8 @@ } // namespace -std::unique_ptr> mlir::createNormalizeMemRefsPass() { +std::unique_ptr> +mlir::memref::createNormalizeMemRefsPass() { return std::make_unique(); } diff --git a/mlir/lib/Dialect/MemRef/Transforms/PassDetail.h b/mlir/lib/Dialect/MemRef/Transforms/PassDetail.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/MemRef/Transforms/PassDetail.h @@ -0,0 +1,43 @@ +//===- PassDetail.h - MemRef Pass class details -----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef DIALECT_MEMREF_TRANSFORMS_PASSDETAIL_H_ +#define DIALECT_MEMREF_TRANSFORMS_PASSDETAIL_H_ + +#include "mlir/Pass/Pass.h" + +namespace mlir { + +class AffineDialect; + +// Forward declaration from Dialect.h +template +void registerDialect(DialectRegistry ®istry); + +namespace arith { +class ArithmeticDialect; +} // namespace arith + +namespace memref { +class MemRefDialect; +} // namespace memref + +namespace tensor { +class TensorDialect; +} // namespace tensor + +namespace vector { +class VectorDialect; +} // namespace vector + +#define GEN_PASS_CLASSES +#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" + +} // namespace mlir + +#endif // DIALECT_MEMREF_TRANSFORMS_PASSDETAIL_H_ diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp @@ -11,6 +11,7 @@ // //===----------------------------------------------------------------------===// +#include "PassDetail.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" @@ -107,9 +108,6 @@ //===----------------------------------------------------------------------===// namespace { -#define GEN_PASS_CLASSES -#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" - struct ResolveRankedShapeTypeResultDimsPass final : public ResolveRankedShapeTypeResultDimsBase< ResolveRankedShapeTypeResultDimsPass> { diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt --- a/mlir/lib/Transforms/CMakeLists.txt +++ b/mlir/lib/Transforms/CMakeLists.txt @@ -8,7 +8,6 @@ LoopCoalescing.cpp LoopFusion.cpp LoopInvariantCodeMotion.cpp - NormalizeMemRefs.cpp OpStats.cpp ParallelLoopCollapsing.cpp PipelineDataTransfer.cpp