diff --git a/mlir/docs/Traits.md b/mlir/docs/Traits.md --- a/mlir/docs/Traits.md +++ b/mlir/docs/Traits.md @@ -251,13 +251,13 @@ * `OpTrait::MemRefsNormalizable` -- `MemRefsNormalizable` -This trait is used to flag operations that can accommodate `MemRefs` with -non-identity memory-layout specifications. This trait indicates that the -normalization of memory layout can be performed for such operations. -`MemRefs` normalization consists of replacing an original memory reference -with layout specifications to an equivalent memory reference where -the specified memory layout is applied by rewritting accesses and types -associated with that memory reference. +This trait is used to flag operations that can access `MemRefs` through an +arbitrary affine expression. In cases where an accessed MemRef has a +non-identity memory-layout specificiation, such operations can be +'normalized' so that the layout of memory is incorporated into the +index expression of the operation, resulting in a new MemRef type with +an identity memory-layout specification. See [the -normalize-memrefs pass]. +(https://mlir.llvm.org/docs/Passes/#-normalize-memrefs-normalize-memrefs) ### Single Block with Implicit Terminator diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -1212,13 +1212,12 @@ } }; -/// This trait is used to flag operations that can accommodate MemRefs with -/// non-identity memory-layout specifications. This trait indicates that the -/// normalization of memory layout can be performed for such operations. -/// MemRefs normalization consists of replacing an original memory reference -/// with layout specifications to an equivalent memory reference where the -/// specified memory layout is applied by rewritting accesses and types -/// associated with that memory reference. +/// This trait is used to flag operations that can access MemRefs through an +/// arbitrary affine expression. In cases where an accessed MemRef has a +/// non-identity memory-layout specificiation, such operations can be +/// 'normalized' so that the layout of memory is incorporated into the +/// index expression of the operation, resulting in a new MemRef type with +/// an identity memory-layout specification. // TODO: Right now, the operands of an operation are either all normalizable, // or not. In the future, we may want to allow some of the operands to be // normalizable. diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -313,6 +313,114 @@ def NormalizeMemRefs : Pass<"normalize-memrefs", "ModuleOp"> { let summary = "Normalize memrefs"; + let description = [{ + This pass transforms memref types with a non-trivial + [layout map](https://mlir.llvm.org/docs/LangRef/#layout-map) into + memref types with an identity layout map, e.g. (i, j) -> (i, j). This + pass is inter-procedural, in the sense that it can modify function + interfaces and call sites that pass memref types. In order to modify + memref types while preserving the original behavior, users of those + memref types are also modified to incorporate the resulting layout map. + For instance, an [AffineLoadOp] + (https://mlir.llvm.org/docs/Dialects/Affine/#affineload-affineloadop) + will be updated to compose the layout map with with the affine expression + contained in the op. Supported operations must be marked with the + [MemRefsNormalizable] + (https://mlir.llvm.org/docs/Traits/#memrefsnormalizable) trait. Currently + only affine operations, std.alloc, std.dealloc, and std.return are + normalizable. + Given an appropriate layout map specified in the code, this transformation + can express tiled or linearized access to multi-dimensional data + structures, but will not modify memref types without an explicit layout + map. + + Currently this pass is somewhat conservative: it will only modify + functions where all memref types can be normalized. If a function + contains any operations that are not MemRefNormalizable, then the function + and any functions that call or call it will not be modified. + + Input + + ```mlir + #tile = affine_map<(i) -> (i floordiv 4, i mod 4)> + func @matmul(%A: memref<16xf64, #tile>, + %B: index, %C: memref<16xf64>) -> (memref<16xf64, #tile>) { + affine.for %arg3 = 0 to 16 { + %a = affine.load %A[%arg3] : memref<16xf64, #tile> + %p = mulf %a, %a : f64 + affine.store %p, %A[%arg3] : memref<16xf64, #tile> + } + %c = alloc() : memref<16xf64, #tile> + %d = affine.load %c[0] : memref<16xf64, #tile> + return %A: memref<16xf64, #tile> + } + ``` + + Output + + ```mlir + func @matmul(%arg0: memref<4x4xf64>, %arg1: index, %arg2: memref<16xf64>) + -> memref<4x4xf64> { + affine.for %arg3 = 0 to 16 { + %3 = affine.load %arg0[%arg3 floordiv 4, %arg3 mod 4]: memref<4x4xf64> + %4 = mulf %3, %3 : f64 + affine.store %4, %arg0[%arg3 floordiv 4, %arg3 mod 4]: memref<4x4xf64> + } + %0 = alloc() : memref<4x4xf64> + %1 = affine.apply #map1() + %2 = affine.load %0[0, 0] : memref<4x4xf64> + return %arg0 : memref<4x4xf64> + } + ``` + Input + ``` + #linear8 = affine_map<(i, j) -> (i * 8 + j)> + func @linearize(%arg0: memref<8x8xi32, #linear8>, + %arg1: memref<8x8xi32, #linear8>, + %arg2: memref<8x8xi32, #linear8>) { + %c8 = constant 8 : index + %c0 = constant 0 : index + %c1 = constant 1 : index + affine.for %arg3 = %c0 to %c8 { + affine.for %arg4 = %c0 to %c8 { + affine.for %arg5 = %c0 to %c8 { + %0 = affine.load %arg0[%arg3, %arg5] : memref<8x8xi32, #linear8> + %1 = affine.load %arg1[%arg5, %arg4] : memref<8x8xi32, #linear8> + %2 = affine.load %arg2[%arg3, %arg4] : memref<8x8xi32, #linear8> + %3 = muli %0, %1 : i32 + %4 = addi %2, %3 : i32 + affine.store %4, %arg2[%arg3, %arg4] : memref<8x8xi32, #linear8> + } + } + } + return + } + ``` + + Output + ```mlir + func @linearize(%arg0: memref<64xi32>, + %arg1: memref<64xi32>, + %arg2: memref<64xi32>) { + %c8 = constant 8 : index + %c0 = constant 0 : index + affine.for %arg3 = %c0 to %c8 { + affine.for %arg4 = %c0 to %c8 { + affine.for %arg5 = %c0 to %c8 { + %0 = affine.load %arg0[%arg3 * 8 + %arg5] : memref<64xi32> + %1 = affine.load %arg1[%arg5 * 8 + %arg4] : memref<64xi32> + %2 = affine.load %arg2[%arg3 * 8 + %arg4] : memref<64xi32> + %3 = muli %0, %1 : i32 + %4 = addi %2, %3 : i32 + affine.store %4, %arg2[%arg3 * 8 + %arg4] : memref<64xi32> + } + } + } + return + } + ``` + + }]; let constructor = "mlir::createNormalizeMemRefsPass()"; } diff --git a/mlir/lib/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Transforms/NormalizeMemRefs.cpp --- a/mlir/lib/Transforms/NormalizeMemRefs.cpp +++ b/mlir/lib/Transforms/NormalizeMemRefs.cpp @@ -29,34 +29,6 @@ /// such functions as normalizable. Also, if a normalizable function is known /// to call a non-normalizable function, we treat that function as /// non-normalizable as well. We assume external functions to be normalizable. -/// -/// Input :- -/// #tile = affine_map<(i) -> (i floordiv 4, i mod 4)> -/// func @matmul(%A: memref<16xf64, #tile>, %B: index, %C: memref<16xf64>) -> -/// (memref<16xf64, #tile>) { -/// affine.for %arg3 = 0 to 16 { -/// %a = affine.load %A[%arg3] : memref<16xf64, #tile> -/// %p = mulf %a, %a : f64 -/// affine.store %p, %A[%arg3] : memref<16xf64, #tile> -/// } -/// %c = alloc() : memref<16xf64, #tile> -/// %d = affine.load %c[0] : memref<16xf64, #tile> -/// return %A: memref<16xf64, #tile> -/// } -/// -/// Output :- -/// func @matmul(%arg0: memref<4x4xf64>, %arg1: index, %arg2: memref<16xf64>) -/// -> memref<4x4xf64> { -/// affine.for %arg3 = 0 to 16 { -/// %2 = affine.load %arg0[%arg3 floordiv 4, %arg3 mod 4] : -/// memref<4x4xf64> %3 = mulf %2, %2 : f64 affine.store %3, %arg0[%arg3 -/// floordiv 4, %arg3 mod 4] : memref<4x4xf64> -/// } -/// %0 = alloc() : memref<16xf64, #map0> -/// %1 = affine.load %0[0] : memref<16xf64, #map0> -/// return %arg0 : memref<4x4xf64> -/// } -/// struct NormalizeMemRefs : public NormalizeMemRefsBase { void runOnOperation() override; void normalizeFuncOpMemRefs(FuncOp funcOp, ModuleOp moduleOp); @@ -73,6 +45,7 @@ } void NormalizeMemRefs::runOnOperation() { + LLVM_DEBUG(llvm::dbgs() << "Normalizing Memrefs...\n"); ModuleOp moduleOp = getOperation(); // We maintain all normalizable FuncOps in a DenseSet. It is initialized // with all the functions within a module and then functions which are not @@ -92,6 +65,7 @@ moduleOp.walk([&](FuncOp funcOp) { if (normalizableFuncs.contains(funcOp)) { if (!areMemRefsNormalizable(funcOp)) { + LLVM_DEBUG(llvm::dbgs() << "@" << funcOp.getName() << " contains ops that cannot normalize MemRefs\n"); // Since this function is not normalizable, we set all the caller // functions and the callees of this function as not normalizable. // TODO: Drop this conservative assumption in the future. @@ -101,6 +75,7 @@ } }); + LLVM_DEBUG(llvm::dbgs() << "Normalizing " << normalizableFuncs.size() << " functions\n"); // Those functions which can be normalized are subjected to normalization. for (FuncOp &funcOp : normalizableFuncs) normalizeFuncOpMemRefs(funcOp, moduleOp); @@ -127,6 +102,7 @@ if (!normalizableFuncs.contains(funcOp)) return; + LLVM_DEBUG(llvm::dbgs() << "@" << funcOp.getName() << " calls or is called by non-normalizable function\n"); normalizableFuncs.erase(funcOp); // Caller of the function. Optional symbolUses = funcOp.getSymbolUses(moduleOp);