diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -24,7 +24,9 @@ class FuncOp; class ModuleOp; class Pass; -template class OperationPass; + +template +class OperationPass; /// Creates an instance of the BufferPlacement pass. std::unique_ptr createBufferPlacementPass(); @@ -89,6 +91,10 @@ /// Creates a pass which delete symbol operations that are unreachable. This /// pass may *only* be scheduled on an operation that defines a SymbolTable. std::unique_ptr createSymbolDCEPass(); + +/// Creates an interprocedural pass to normalize memrefs to have a trivial +/// (identity) layout map. +std::unique_ptr> createNormalizeMemRefsPass(); } // end namespace mlir #endif // MLIR_TRANSFORMS_PASSES_H diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -406,4 +406,8 @@ let constructor = "mlir::createSymbolDCEPass()"; } +def NormalizeMemRefs : Pass<"normalize-memrefs", "ModuleOp"> { + let summary = "Interprocedural normalize memrefs"; + let constructor = "mlir::createNormalizeMemRefsPass()"; +} #endif // MLIR_TRANSFORMS_PASSES diff --git a/mlir/include/mlir/Transforms/Utils.h b/mlir/include/mlir/Transforms/Utils.h --- a/mlir/include/mlir/Transforms/Utils.h +++ b/mlir/include/mlir/Transforms/Utils.h @@ -45,10 +45,15 @@ /// operations that are dominated by the former; similarly, `postDomInstFilter` /// restricts replacement to only those operations that are postdominated by it. /// +/// 'allowNonDereferencingOps', if set, allows replacement of non-dereferencing +/// uses of a memref without any requirement for access index rewrites. The +/// default value of this flag variable is false. +/// /// Returns true on success and false if the replacement is not possible, -/// whenever a memref is used as an operand in a non-dereferencing context, -/// except for dealloc's on the memref which are left untouched. See comments at -/// function definition for an example. +/// whenever a memref is used as an operand in a non-dereferencing context and +/// 'allowNonDereferencingOps' is false, except for dealloc's on the memref +/// which are left untouched. See comments at function definition for an +/// example. // // Ex: to replace load %A[%i, %j] with load %Abuf[%t mod 2, %ii - %i, %j]: // The SSA value corresponding to '%t mod 2' should be in 'extraIndices', and @@ -63,22 +68,34 @@ ArrayRef extraOperands = {}, ArrayRef symbolOperands = {}, Operation *domInstFilter = nullptr, - Operation *postDomInstFilter = nullptr); + Operation *postDomInstFilter = nullptr, + bool allowNonDereferencingOps = false); /// Performs the same replacement as the other version above but only for the -/// dereferencing uses of `oldMemRef` in `op`. +/// dereferencing uses of `oldMemRef` in `op`, except in cases where +/// 'allowNonDereferencingOps' is set to true and we provide replacement for the +/// non-dereferencing uses as well. LogicalResult replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef, Operation *op, ArrayRef extraIndices = {}, AffineMap indexRemap = AffineMap(), ArrayRef extraOperands = {}, - ArrayRef symbolOperands = {}); + ArrayRef symbolOperands = {}, + bool allowNonDereferencingOps = false); /// Rewrites the memref defined by this alloc op to have an identity layout map /// and updates all its indexing uses. Returns failure if any of its uses /// escape (while leaving the IR in a valid state). LogicalResult normalizeMemRef(AllocOp op); +/// Uses the old memref type map layout and computes the new memref type to have +/// a new shape and a layout map, where the old layout map has been normalized +/// to an identity layout map. It returns the old memref in case no +/// normalization was needed or a failure occurs while transforming the old map +/// layout to an identity layout map. +MemRefType normalizeMemRefType(MemRefType memrefType, OpBuilder builder, + unsigned numSymbolicOperands); + /// Creates and inserts into 'builder' a new AffineApplyOp, with the number of /// its results equal to the number of operands, as a composition /// of all other AffineApplyOps reachable from input parameter 'operands'. If diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp --- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp @@ -96,13 +96,4 @@ if (isa(op)) applyOpPatternsAndFold(op, patterns); }); - - // Turn memrefs' non-identity layouts maps into ones with identity. Collect - // alloc ops first and then process since normalizeMemRef replaces/erases ops - // during memref rewriting. - SmallVector allocOps; - func.walk([&](AllocOp op) { allocOps.push_back(op); }); - for (auto allocOp : allocOps) { - normalizeMemRef(allocOp); - } } diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt --- a/mlir/lib/Transforms/CMakeLists.txt +++ b/mlir/lib/Transforms/CMakeLists.txt @@ -12,6 +12,7 @@ LoopFusion.cpp LoopInvariantCodeMotion.cpp MemRefDataFlowOpt.cpp + NormalizeMemRefs.cpp OpStats.cpp ParallelLoopCollapsing.cpp PipelineDataTransfer.cpp diff --git a/mlir/lib/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Transforms/NormalizeMemRefs.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Transforms/NormalizeMemRefs.cpp @@ -0,0 +1,236 @@ +//===- NormalizeMemRefs.cpp ---------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements an interprocedural pass to normalize memrefs to have +// identity layout maps. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Analysis/Utils.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/FunctionSupport.h" +#include "mlir/IR/Operation.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/Utils.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/Format.h" +#include "llvm/Support/raw_ostream.h" + +#define DEBUG_TYPE "normalize-memrefs" + +using namespace mlir; + +namespace { + +/// All interprocedural memrefs with non-trivial layout maps are converted to +/// ones with trivial identity layout ones. + +// Input :- +// #tile = affine_map<(i) -> (i floordiv 4, i mod 4)> +// func @matmul(%A: memref<16xf64, #tile>, %B: index, %C: memref<16xf64>) -> +// (memref<16xf64, #tile>) { +// affine.for %arg3 = 0 to 16 { +// %a = affine.load %A[%arg3] : memref<16xf64, #tile> +// %p = mulf %a, %a : f64 +// affine.store %p, %A[%arg3] : memref<16xf64, #tile> +// } +// %c = alloc() : memref<16xf64, #tile> +// %d = affine.load %c[0] : memref<16xf64, #tile> +// return %A: memref<16xf64, #tile> +// } + +// Output :- +// module { +// func @matmul(%arg0: memref<4x4xf64>, %arg1: index, %arg2: memref<16xf64>) +// -> memref<4x4xf64> { +// affine.for %arg3 = 0 to 16 { +// %2 = affine.load %arg0[%arg3 floordiv 4, %arg3 mod 4] : memref<4x4xf64> +// %3 = mulf %2, %2 : f64 +// affine.store %3, %arg0[%arg3 floordiv 4, %arg3 mod 4] : memref<4x4xf64> +// } +// %0 = alloc() : memref<16xf64, #map0> +// %1 = affine.load %0[0] : memref<16xf64, #map0> +// return %arg0 : memref<4x4xf64> +// } +// } + +struct NormalizeMemRefs : public NormalizeMemRefsBase { + void runOnOperation() override; + void runOnFunction(FuncOp, ModuleOp); +}; + +} // end anonymous namespace + +std::unique_ptr> mlir::createNormalizeMemRefsPass() { + return std::make_unique(); +} + +void NormalizeMemRefs::runOnOperation() { + // Here we get hold of the module/operation that basically contains one + // region. + ModuleOp moduleOp = getOperation(); + + // We traverse each function within the module in order to normalize the + // memref type arguments. + // TODO(avarmapml): Handle external functions. + moduleOp.walk([&](FuncOp funcOp) { runOnFunction(funcOp, moduleOp); }); +} + +void NormalizeMemRefs::runOnFunction(FuncOp funcOp, ModuleOp moduleOp) { + // Turn memrefs' non-identity layouts maps into ones with identity. Collect + // alloc ops first and then process since normalizeMemRef replaces/erases ops + // during memref rewriting. + SmallVector allocOps; + funcOp.walk([&](AllocOp op) { allocOps.push_back(op); }); + for (AllocOp allocOp : allocOps) { + normalizeMemRef(allocOp); + } + // We use this OpBuilder to create new memref layout later. + OpBuilder b(funcOp); + + // Getting hold of the signature of the function before normalizing the + // inputs. + FunctionType ft = funcOp.getType(); + SmallVector argTypes; + SmallVector resultTypes; + + // Populating results with function's initial result (type) as + // this will be used to modify and set function's signature later. + for (unsigned retIndex : llvm::seq(0, funcOp.getNumResults())) { + resultTypes.push_back(ft.getResult(retIndex)); + } + + // Walk over each argument of a function to perform memref normalization (if + // any). + for (unsigned argIndex : llvm::seq(0, funcOp.getNumArguments())) { + Type argType = funcOp.getArgument(argIndex).getType(); + MemRefType memrefType = argType.dyn_cast(); + // Check whether argument is of MemRef type. + if (!memrefType) { + // Any other argument type can simply be part of the final function + // signature. + argTypes.push_back(argType); + continue; + } + + // Fetch a new memref type after normalizing the old memref to have an + // identity map layout. + MemRefType newMemRefType = normalizeMemRefType(memrefType, b, + /*numSymbolicOperands=*/0); + if (newMemRefType == memrefType) { + // Either memrefType already had an identity map or the map couldn't be + // transformed to an identity map. + argTypes.push_back(argType); + continue; + } + + // Insert a new temporary argument with the new memref type. + funcOp.front().insertArgument(argIndex, newMemRefType); + BlockArgument newMemRef = funcOp.getArgument(argIndex); + BlockArgument oldMemRef = funcOp.getArgument(argIndex + 1); + AffineMap layoutMap = memrefType.getAffineMaps().front(); + // Replace all uses of the old memref. + if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newMemRef, + /*extraIndices=*/{}, + /*indexRemap=*/layoutMap, + /*extraOperands=*/{}, + /*symbolOperands=*/{}, + /*domInstFilter=*/nullptr, + /*postDomInstFilter=*/nullptr, + /*allowNonDereferencingOps=*/true))) { + // If it failed (due to escapes for example), bail out. Removing the + // temporary argument inserted previously. + funcOp.front().eraseArgument(argIndex); + continue; + } + + // Since in this pass the objective is to normalize the layout maps of + // the memref arguments and replace the uses accordingly, we + // check if the function return type uses the same old memref type. + // TODO(avarmapml): Check - A function's return type might have a + // different memref layout and a map. + for (unsigned retIndex : llvm::seq(0, funcOp.getNumResults())) { + if (resultTypes[retIndex] == memrefType) + resultTypes[retIndex] = newMemRef.getType(); + } + + // All uses for the argument with old memref type were replaced + // successfully. So we remove the old argument now. + // TODO(avarmapml): replaceAllUsesWith. + funcOp.front().eraseArgument(argIndex + 1); + + // Add the new type to the function signature later. + argTypes.push_back(newMemRefType); + } + + // We create a new function type and modify the function signature with this + // new type. + FunctionType newFuncType = FunctionType::get(/*inputs=*/argTypes, + /*results=*/resultTypes, + /**context=*/&getContext()); + + // We get hold of all the symbolic uses of this function within the given + // region. + Optional symbolUses = SymbolTable::getSymbolUses( + funcOp.getAttrOfType(SymbolTable::getSymbolAttrName()) + .getValue(), + &moduleOp.getBodyRegion()); + // We iterate over all symbolic uses of the function and update the return + // type at the caller site. + for (SymbolTable::SymbolUse symbolUse : *symbolUses) { + Operation *callOp = symbolUse.getUser(); + // We build a new CallOp to reflect the updated return type of the function. + OpBuilder builder(callOp); + OperationState state(callOp->getLoc(), callOp->getName()); + state.operands.reserve(callOp->getNumOperands()); + state.operands.append(callOp->operand_begin(), callOp->operand_end()); + // We take care only of the return type from the function signature because + // the operand type will get changed later at the caller function. + state.addTypes(resultTypes); + StringRef callee = cast(callOp).getCallee(); + state.addAttribute("callee", builder.getSymbolRefAttr(callee)); + Operation *newCallOp = builder.createOperation(state); + bool replacingMemRefUsesFailed = false; + // A function might return more than one result. So here we loop over + // all the results returned and replace uses of any result of type Memref + // whose map layout has changed. + for (unsigned resIndex : llvm::seq(0, callOp->getNumResults())) { + OpResult oldMemRef = callOp->getResult(resIndex); + OpResult newMemRef = newCallOp->getResult(resIndex); + // If the result type is not a Memref or if there is no change in the + // Memref map layout, there is no need to call replaceAllMemRefUsesWith. + if (oldMemRef.getType() == newMemRef.getType()) + continue; + AffineMap layoutMap = + oldMemRef.getType().cast().getAffineMaps().front(); + if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newMemRef, + /*extraIndices=*/{}, + /*indexRemap=*/layoutMap, + /*extraOperands=*/{}, + /*symbolOperands=*/{}, + /*domInstFilter=*/nullptr, + /*postDomInstFilter=*/nullptr, + /*allowDereferencingOps=*/true))) { + // If it failed (due to escapes for example), bail out. + newCallOp->erase(); + replacingMemRefUsesFailed = true; + break; + } + } + if (replacingMemRefUsesFailed) + continue; + callOp->replaceAllUsesWith(newCallOp); + callOp->erase(); + } + funcOp.setType(newFuncType); +} diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -48,7 +48,8 @@ ArrayRef extraIndices, AffineMap indexRemap, ArrayRef extraOperands, - ArrayRef symbolOperands) { + ArrayRef symbolOperands, + bool allowNonDereferencingOps) { unsigned newMemRefRank = newMemRef.getType().cast().getRank(); (void)newMemRefRank; // unused in opt mode unsigned oldMemRefRank = oldMemRef.getType().cast().getRank(); @@ -67,11 +68,6 @@ assert(oldMemRef.getType().cast().getElementType() == newMemRef.getType().cast().getElementType()); - if (!isMemRefDereferencingOp(*op)) - // Failure: memref used in a non-dereferencing context (potentially - // escapes); no replacement in these cases. - return failure(); - SmallVector usePositions; for (const auto &opEntry : llvm::enumerate(op->getOperands())) { if (opEntry.value() == oldMemRef) @@ -91,6 +87,45 @@ unsigned memRefOperandPos = usePositions.front(); OpBuilder builder(op); + OperationState state(op->getLoc(), op->getName()); + // The following checks if op is dereferencing memref and performs the access + // index rewrites. + if (!isMemRefDereferencingOp(*op)) { + if (!allowNonDereferencingOps) + // Failure: memref used in a non-dereferencing context (potentially + // escapes); no replacement in these cases unless allowNonDereferencingOps + // is set. + return failure(); + // For non-dereferencing op we simply replace the memref type. + state.operands.reserve(op->getNumOperands() + extraIndices.size()); + // Insert the operands ahead of the oldMemRef type in argument list. + state.operands.append(op->operand_begin(), + op->operand_begin() + memRefOperandPos); + // Insert the new memref value. + state.operands.push_back(newMemRef); + // Insert other operands that follow the oldMemRef type in argumemt list. + state.operands.append(op->operand_begin() + memRefOperandPos + 1, + op->operand_end()); + state.types.reserve(op->getNumResults()); + for (auto result : op->getResults()) + state.types.push_back(result.getType()); + + // For CallOps we need to add an attribute 'callee' whose value will be the + // function name. + CallOp callOp = dyn_cast(op); + if (callOp) { + StringRef callee = callOp.getCallee(); + state.addAttribute("callee", builder.getSymbolRefAttr(callee)); + } + + // Create the new operation. + Operation *repOp = builder.createOperation(state); + op->replaceAllUsesWith(repOp); + op->erase(); + + return success(); + } + // Perform index rewrites for the dereferencing op and then replace the op NamedAttribute oldMapAttrPair = getAffineMapAttrForMemRef(op, oldMemRef); AffineMap oldMap = oldMapAttrPair.second.cast().getValue(); unsigned oldMapNumInputs = oldMap.getNumInputs(); @@ -112,7 +147,7 @@ affineApplyOps.push_back(afOp); } } else { - oldMemRefOperands.append(oldMapOperands.begin(), oldMapOperands.end()); + oldMemRefOperands.assign(oldMapOperands.begin(), oldMapOperands.end()); } // Construct new indices as a remap of the old ones if a remapping has been @@ -141,14 +176,14 @@ } } else { // No remapping specified. - remapOutputs.append(remapOperands.begin(), remapOperands.end()); + remapOutputs.assign(remapOperands.begin(), remapOperands.end()); } SmallVector newMapOperands; newMapOperands.reserve(newMemRefRank); // Prepend 'extraIndices' in 'newMapOperands'. - for (auto extraIndex : extraIndices) { + for (Value extraIndex : extraIndices) { assert(extraIndex.getDefiningOp()->getNumResults() == 1 && "single result op's expected to generate these indices"); assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) && @@ -167,12 +202,11 @@ newMap = simplifyAffineMap(newMap); canonicalizeMapAndOperands(&newMap, &newMapOperands); // Remove any affine.apply's that became dead as a result of composition. - for (auto value : affineApplyOps) + for (Value value : affineApplyOps) if (value.use_empty()) value.getDefiningOp()->erase(); // Construct the new operation using this memref. - OperationState state(op->getLoc(), op->getName()); state.operands.reserve(op->getNumOperands() + extraIndices.size()); // Insert the non-memref operands. state.operands.append(op->operand_begin(), @@ -196,11 +230,10 @@ // Add attribute for 'newMap', other Attributes do not change. auto newMapAttr = AffineMapAttr::get(newMap); for (auto namedAttr : op->getAttrs()) { - if (namedAttr.first == oldMapAttrPair.first) { + if (namedAttr.first == oldMapAttrPair.first) state.attributes.push_back({namedAttr.first, newMapAttr}); - } else { + else state.attributes.push_back(namedAttr); - } } // Create the new operation. @@ -211,13 +244,11 @@ return success(); } -LogicalResult mlir::replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef, - ArrayRef extraIndices, - AffineMap indexRemap, - ArrayRef extraOperands, - ArrayRef symbolOperands, - Operation *domInstFilter, - Operation *postDomInstFilter) { +LogicalResult mlir::replaceAllMemRefUsesWith( + Value oldMemRef, Value newMemRef, ArrayRef extraIndices, + AffineMap indexRemap, ArrayRef extraOperands, + ArrayRef symbolOperands, Operation *domInstFilter, + Operation *postDomInstFilter, bool allowNonDereferencingOps) { unsigned newMemRefRank = newMemRef.getType().cast().getRank(); (void)newMemRefRank; // unused in opt mode unsigned oldMemRefRank = oldMemRef.getType().cast().getRank(); @@ -267,10 +298,15 @@ // Check if the memref was used in a non-dereferencing context. It is fine // for the memref to be used in a non-dereferencing way outside of the // region where this replacement is happening. - if (!isMemRefDereferencingOp(*op)) - // Failure: memref used in a non-dereferencing op (potentially escapes); - // no replacement in these cases. - return failure(); + if (!isMemRefDereferencingOp(*op)) { + // Currently we support the following non-dereferencing types to be a + // candidate for replacement : DeallocOp, CallOp and ReturnOp. + // TODO: Add support for other kinds of ops. + if (!allowNonDereferencingOps) + return failure(); + if (!(isa(*op))) + return failure(); + } // We'll first collect and then replace --- since replacement erases the op // that has the use, and that op could be postDomFilter or domFilter itself! @@ -278,9 +314,9 @@ } for (auto *op : opsToReplace) { - if (failed(replaceAllMemRefUsesWith(oldMemRef, newMemRef, op, extraIndices, - indexRemap, extraOperands, - symbolOperands))) + if (failed(replaceAllMemRefUsesWith( + oldMemRef, newMemRef, op, extraIndices, indexRemap, extraOperands, + symbolOperands, allowNonDereferencingOps))) llvm_unreachable("memref replacement guaranteed to succeed here"); } @@ -385,83 +421,110 @@ // TODO: Currently works for static memrefs with a single layout map. LogicalResult mlir::normalizeMemRef(AllocOp allocOp) { MemRefType memrefType = allocOp.getType(); - unsigned rank = memrefType.getRank(); - if (rank == 0) - return success(); - - auto layoutMaps = memrefType.getAffineMaps(); OpBuilder b(allocOp); - if (layoutMaps.size() != 1) + + // Fetch a new memref type after normalizing the old memref to have an + // identity map layout. + MemRefType newMemRefType = + normalizeMemRefType(memrefType, b, allocOp.getNumSymbolicOperands()); + if (newMemRefType == memrefType) + // Either memrefType already had an identity map or the map couldn't be + // transformed to an identity map. return failure(); - AffineMap layoutMap = layoutMaps.front(); + Value oldMemRef = allocOp.getResult(); + // Check whether all the uses of oldMemRef is either dereferencing uses + // or the op is of type : DeallocOp, CallOp or a ReturnOp. Only if these + // constraints are satisfied will the op become candidate for replacement. + if (llvm::any_of(oldMemRef.getUsers(), [](Operation *op) { + if (isMemRefDereferencingOp(*op)) + return false; + return (!isa(*op)); + })) + return failure(); - // Nothing to do for identity layout maps. - if (layoutMap == b.getMultiDimIdentityMap(rank)) - return success(); + SmallVector symbolOperands(allocOp.getSymbolicOperands()); + AllocOp newAlloc = b.create(allocOp.getLoc(), newMemRefType); + AffineMap layoutMap = memrefType.getAffineMaps().front(); + // Replace all uses of the old memref. + if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newAlloc, + /*extraIndices=*/{}, + /*indexRemap=*/layoutMap, + /*extraOperands=*/{}, + /*symbolOperands=*/symbolOperands, + /*domInstFilter=*/nullptr, + /*postDomInstFilter=*/nullptr, + /*allowDereferencingOps=*/true))) { + // If it failed (due to escapes for example), bail out. + newAlloc.erase(); + return failure(); + } + // Replace any uses of the original alloc op and erase it. All remaining uses + // have to be dealloc's; RAMUW above would've failed otherwise. + assert(llvm::all_of(oldMemRef.getUsers(), + [](Operation *op) { return isa(op); })); + oldMemRef.replaceAllUsesWith(newAlloc); + allocOp.erase(); + return success(); +} + +MemRefType mlir::normalizeMemRefType(MemRefType memrefType, OpBuilder b, + unsigned numSymbolicOperands) { + unsigned rank = memrefType.getRank(); + if (rank == 0) + return memrefType; + + ArrayRef layoutMaps = memrefType.getAffineMaps(); + if (layoutMaps.empty() || + layoutMaps.front() == b.getMultiDimIdentityMap(rank)) { + // Either no maps is associated with this memref or this memref has + // a trivial (identity) map. + return memrefType; + } // We don't do any checks for one-to-one'ness; we assume that it is // one-to-one. // TODO: Only for static memref's for now. if (memrefType.getNumDynamicDims() > 0) - return failure(); + return memrefType; - // We have a single map that is not an identity map. Create a new memref with - // the right shape and an identity layout map. - auto shape = memrefType.getShape(); - FlatAffineConstraints fac(rank, allocOp.getNumSymbolicOperands()); + // We have a single map that is not an identity map. Create a new memref + // with the right shape and an identity layout map. + ArrayRef shape = memrefType.getShape(); + // FlatAffineConstraint may later on use symbolicOperands. + FlatAffineConstraints fac(rank, numSymbolicOperands); for (unsigned d = 0; d < rank; ++d) { fac.addConstantLowerBound(d, 0); fac.addConstantUpperBound(d, shape[d] - 1); } - - // We compose this map with the original index (logical) space to derive the - // upper bounds for the new index space. + // We compose this map with the original index (logical) space to derive + // the upper bounds for the new index space. + AffineMap layoutMap = layoutMaps.front(); unsigned newRank = layoutMap.getNumResults(); if (failed(fac.composeMatchingMap(layoutMap))) - // TODO: semi-affine maps. - return failure(); - + return memrefType; + // TODO: Handle semi-affine maps. // Project out the old data dimensions. fac.projectOut(newRank, fac.getNumIds() - newRank - fac.getNumLocalIds()); SmallVector newShape(newRank); for (unsigned d = 0; d < newRank; ++d) { // The lower bound for the shape is always zero. auto ubConst = fac.getConstantUpperBound(d); - // For a static memref and an affine map with no symbols, this is always - // bounded. + // For a static memref and an affine map with no symbols, this is + // always bounded. assert(ubConst.hasValue() && "should always have an upper bound"); if (ubConst.getValue() < 0) // This is due to an invalid map that maps to a negative space. - return failure(); + return memrefType; newShape[d] = ubConst.getValue() + 1; } - auto oldMemRef = allocOp.getResult(); - SmallVector symbolOperands(allocOp.getSymbolicOperands()); - + // Create the new memref type after trivializing the old layout map. MemRefType newMemRefType = MemRefType::Builder(memrefType) .setShape(newShape) .setAffineMaps(b.getMultiDimIdentityMap(newRank)); - auto newAlloc = b.create(allocOp.getLoc(), newMemRefType); - // Replace all uses of the old memref. - if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newAlloc, - /*extraIndices=*/{}, - /*indexRemap=*/layoutMap, - /*extraOperands=*/{}, - /*symbolOperands=*/symbolOperands))) { - // If it failed (due to escapes for example), bail out. - newAlloc.erase(); - return failure(); - } - // Replace any uses of the original alloc op and erase it. All remaining uses - // have to be dealloc's; RAMUW above would've failed otherwise. - assert(llvm::all_of(oldMemRef.getUsers(), - [](Operation *op) { return isa(op); })); - oldMemRef.replaceAllUsesWith(newAlloc); - allocOp.erase(); - return success(); + return newMemRefType; } diff --git a/mlir/test/Transforms/memref-normalize.mlir b/mlir/test/Transforms/normalize-memrefs.mlir rename from mlir/test/Transforms/memref-normalize.mlir rename to mlir/test/Transforms/normalize-memrefs.mlir --- a/mlir/test/Transforms/memref-normalize.mlir +++ b/mlir/test/Transforms/normalize-memrefs.mlir @@ -1,4 +1,7 @@ -// RUN: mlir-opt -allow-unregistered-dialect -simplify-affine-structures %s | FileCheck %s +// RUN: mlir-opt -normalize-memrefs -allow-unregistered-dialect %s | FileCheck %s + +// This file tests whether the memref type having non-trivial map layouts +// are normalized to trivial (identity) layouts. // CHECK-LABEL: func @permute() func @permute() { @@ -143,3 +146,66 @@ } return } + +#tile = affine_map<(i) -> (i floordiv 4, i mod 4)> + +// Following test cases check the inter-procedural memref normalization. + +// Test case 1: Checks whether memref map layout normalization takes place successfully for self-recursive function calls. +// Also checks cases where function call might return a non-memref type. +// TODO(avarmapml): Add thorough testing of inter-recursive calls. +// CHECK-LABEL: func @recursive +// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<4x4xf64>) -> memref<4x4xf64> +func @recursive(%A: memref<16xf64, #tile>) -> (memref<16xf64, #tile>) { + %M = call @recursive(%A) : (memref<16xf64, #tile>) -> (memref<16xf64, #tile>) + %c = alloc() : memref<8xf64> + %b = constant 4.5 : f64 + %R, %cond = call @argument_type(%A, %b, %c) : (memref<16xf64, #tile>, f64, memref<8xf64>) -> (memref<16xf64, #tile>, i1) + cond_br %cond, ^bb1, ^bb2 + ^bb1: + return %M: memref<16xf64, #tile> + ^bb2: + return %R: memref<16xf64, #tile> +} + +// CHECK: [[M:%[0-9]+]] = call @recursive(%[[A]]) : (memref<4x4xf64>) -> memref<4x4xf64> +// CHECK: [[R:%[0-9]]]:2 = call @argument_type(%[[A]], %cst, %[[c:[0-9]+]]) : (memref<4x4xf64>, f64, memref<8xf64>) -> (memref<4x4xf64>, i1) +// CHECK: cond_br [[R]]#1, ^bb1, ^bb2 +// CHECK: ^bb1: // pred: ^bb0 +// CHECK: return [[M]] : memref<4x4xf64> +// CHECK: ^bb2: // pred: ^bb0 +// CHECK: return [[R]]#0 : memref<4x4xf64> + + +// Test case 2: Checks whether function arguments of memref type having non-trivial map layout get normalized. +// Also checks interprocedural memref layout normalization +// CHECK-LABEL: func @argument_type +// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<4x4xf64>, %[[B:arg[0-9]+]]: f64, %[[C:arg[0-9]+]]: memref<8xf64>) -> (memref<4x4xf64>, i1) +func @argument_type(%A: memref<16xf64, #tile>, %B: f64, %C: memref<8xf64>) -> (memref<16xf64, #tile>, i1) { + %a = affine.load %A[0] : memref<16xf64, #tile> + %p = mulf %a, %a : f64 + affine.store %p, %A[10] : memref<16xf64, #tile> + %M = call @recursive(%A) : (memref<16xf64, #tile>) -> (memref<16xf64, #tile>) + %m = affine.load %M[1] : memref<16xf64, #tile> + %cond = constant 1: i1 + return %A, %cond: memref<16xf64, #tile>, i1 +} + +// CHECK: %[[a:[0-9]+]] = affine.load %[[A]][0, 0] : memref<4x4xf64> +// CHECK: %[[p:[0-9]+]] = mulf %[[a]], %[[a]] : f64 +// CHECK: affine.store %[[p]], %[[A]][2, 2] : memref<4x4xf64> +// CHECK: [[MEM:%[0-9]+]] = call @recursive(%[[A]]) : (memref<4x4xf64>) -> memref<4x4xf64> +// CHECK: %[[b:[0-9]+]] = affine.load [[MEM]][0, 1] : memref<4x4xf64> +// CHECK: %true = constant true +// CHECK: return %[[A]], %true : memref<4x4xf64>, i1 + +// Test case 3: Check normalization for multiple memrefs in a function argument list. +// CHECK-LABEL: func @multiple_arg_normalize +// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<4x4xf64>, %[[B:arg[0-9]+]]: memref<2x4xf64>) -> memref<2x4xf64> +func @multiple_arg_normalize(%A: memref<16xf64, #tile>, %B: memref<8xf64, #tile>) -> (memref<8xf64, #tile>) { + %a = affine.load %A[0] : memref<16xf64, #tile> + return %B: memref<8xf64, #tile> +} + +// CHECK: %[[argA:[0-9]+]] = affine.load %[[A]][0, 0] : memref<4x4xf64> +// CHECK: return %[[B]] : memref<2x4xf64>