diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -24,7 +24,9 @@ class FuncOp; class ModuleOp; class Pass; -template class OperationPass; + +template +class OperationPass; /// Creates an instance of the BufferPlacement pass. std::unique_ptr createBufferPlacementPass(); @@ -89,6 +91,10 @@ /// Creates a pass which delete symbol operations that are unreachable. This /// pass may *only* be scheduled on an operation that defines a SymbolTable. std::unique_ptr createSymbolDCEPass(); + +/// Creates an interprocedural pass to normalize memrefs to have a trivial +/// (identity) layout map. +std::unique_ptr> createNormalizeMemRefsPass(); } // end namespace mlir #endif // MLIR_TRANSFORMS_PASSES_H diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -309,6 +309,11 @@ let constructor = "mlir::createMemRefDataFlowOptPass()"; } +def NormalizeMemRefs : Pass<"normalize-memrefs", "ModuleOp"> { + let summary = "Normalize memrefs"; + let constructor = "mlir::createNormalizeMemRefsPass()"; +} + def ParallelLoopCollapsing : Pass<"parallel-loop-collapsing"> { let summary = "Collapse parallel loops to use less induction variables"; let constructor = "mlir::createParallelLoopCollapsingPass()"; @@ -405,5 +410,4 @@ }]; let constructor = "mlir::createSymbolDCEPass()"; } - #endif // MLIR_TRANSFORMS_PASSES diff --git a/mlir/include/mlir/Transforms/Utils.h b/mlir/include/mlir/Transforms/Utils.h --- a/mlir/include/mlir/Transforms/Utils.h +++ b/mlir/include/mlir/Transforms/Utils.h @@ -45,10 +45,19 @@ /// operations that are dominated by the former; similarly, `postDomInstFilter` /// restricts replacement to only those operations that are postdominated by it. /// +/// 'allowNonDereferencingOps', if set, allows replacement of non-dereferencing +/// uses of a memref without any requirement for access index rewrites. The +/// default value of this flag variable is false. +/// +/// 'replaceInDeallocOp', if set, lets DeallocOp, a non-dereferencing user, to +/// also be a candidate for replacement. The default value of this flag is +/// false. +/// /// Returns true on success and false if the replacement is not possible, -/// whenever a memref is used as an operand in a non-dereferencing context, -/// except for dealloc's on the memref which are left untouched. See comments at -/// function definition for an example. +/// whenever a memref is used as an operand in a non-dereferencing context and +/// 'allowNonDereferencingOps' is false, except for dealloc's on the memref +/// which are left untouched. See comments at function definition for an +/// example. // // Ex: to replace load %A[%i, %j] with load %Abuf[%t mod 2, %ii - %i, %j]: // The SSA value corresponding to '%t mod 2' should be in 'extraIndices', and @@ -57,28 +66,38 @@ // extra operands, note that 'indexRemap' would just be applied to existing // indices (%i, %j). // TODO: allow extraIndices to be added at any position. -LogicalResult replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef, - ArrayRef extraIndices = {}, - AffineMap indexRemap = AffineMap(), - ArrayRef extraOperands = {}, - ArrayRef symbolOperands = {}, - Operation *domInstFilter = nullptr, - Operation *postDomInstFilter = nullptr); +LogicalResult replaceAllMemRefUsesWith( + Value oldMemRef, Value newMemRef, ArrayRef extraIndices = {}, + AffineMap indexRemap = AffineMap(), ArrayRef extraOperands = {}, + ArrayRef symbolOperands = {}, Operation *domInstFilter = nullptr, + Operation *postDomInstFilter = nullptr, + bool allowNonDereferencingOps = false, bool replaceInDeallocOp = false); /// Performs the same replacement as the other version above but only for the -/// dereferencing uses of `oldMemRef` in `op`. +/// dereferencing uses of `oldMemRef` in `op`, except in cases where +/// 'allowNonDereferencingOps' is set to true and we provide replacement for the +/// non-dereferencing uses as well. LogicalResult replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef, Operation *op, ArrayRef extraIndices = {}, AffineMap indexRemap = AffineMap(), ArrayRef extraOperands = {}, - ArrayRef symbolOperands = {}); + ArrayRef symbolOperands = {}, + bool allowNonDereferencingOps = false); /// Rewrites the memref defined by this alloc op to have an identity layout map /// and updates all its indexing uses. Returns failure if any of its uses /// escape (while leaving the IR in a valid state). LogicalResult normalizeMemRef(AllocOp op); +/// Uses the old memref type map layout and computes the new memref type to have +/// a new shape and a layout map, where the old layout map has been normalized +/// to an identity layout map. It returns the old memref in case no +/// normalization was needed or a failure occurs while transforming the old map +/// layout to an identity layout map. +MemRefType normalizeMemRefType(MemRefType memrefType, OpBuilder builder, + unsigned numSymbolicOperands); + /// Creates and inserts into 'builder' a new AffineApplyOp, with the number of /// its results equal to the number of operands, as a composition /// of all other AffineApplyOps reachable from input parameter 'operands'. If diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp --- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp @@ -96,13 +96,4 @@ if (isa(op)) applyOpPatternsAndFold(op, patterns); }); - - // Turn memrefs' non-identity layouts maps into ones with identity. Collect - // alloc ops first and then process since normalizeMemRef replaces/erases ops - // during memref rewriting. - SmallVector allocOps; - func.walk([&](AllocOp op) { allocOps.push_back(op); }); - for (auto allocOp : allocOps) { - normalizeMemRef(allocOp); - } } diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt --- a/mlir/lib/Transforms/CMakeLists.txt +++ b/mlir/lib/Transforms/CMakeLists.txt @@ -12,6 +12,7 @@ LoopFusion.cpp LoopInvariantCodeMotion.cpp MemRefDataFlowOpt.cpp + NormalizeMemRefs.cpp OpStats.cpp ParallelLoopCollapsing.cpp PipelineDataTransfer.cpp diff --git a/mlir/lib/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Transforms/NormalizeMemRefs.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Transforms/NormalizeMemRefs.cpp @@ -0,0 +1,220 @@ +//===- NormalizeMemRefs.cpp -----------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements an interprocedural pass to normalize memrefs to have +// identity layout maps. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/Utils.h" + +#define DEBUG_TYPE "normalize-memrefs" + +using namespace mlir; + +namespace { + +/// All interprocedural memrefs with non-trivial layout maps are converted to +/// ones with trivial identity layout ones. + +// Input :- +// #tile = affine_map<(i) -> (i floordiv 4, i mod 4)> +// func @matmul(%A: memref<16xf64, #tile>, %B: index, %C: memref<16xf64>) -> +// (memref<16xf64, #tile>) { +// affine.for %arg3 = 0 to 16 { +// %a = affine.load %A[%arg3] : memref<16xf64, #tile> +// %p = mulf %a, %a : f64 +// affine.store %p, %A[%arg3] : memref<16xf64, #tile> +// } +// %c = alloc() : memref<16xf64, #tile> +// %d = affine.load %c[0] : memref<16xf64, #tile> +// return %A: memref<16xf64, #tile> +// } + +// Output :- +// module { +// func @matmul(%arg0: memref<4x4xf64>, %arg1: index, %arg2: memref<16xf64>) +// -> memref<4x4xf64> { +// affine.for %arg3 = 0 to 16 { +// %2 = affine.load %arg0[%arg3 floordiv 4, %arg3 mod 4] : memref<4x4xf64> +// %3 = mulf %2, %2 : f64 +// affine.store %3, %arg0[%arg3 floordiv 4, %arg3 mod 4] : memref<4x4xf64> +// } +// %0 = alloc() : memref<16xf64, #map0> +// %1 = affine.load %0[0] : memref<16xf64, #map0> +// return %arg0 : memref<4x4xf64> +// } +// } + +struct NormalizeMemRefs : public NormalizeMemRefsBase { + void runOnOperation() override; + void runOnFunction(FuncOp funcOp); + bool areMemRefsNormalizable(FuncOp funcOp); + void updateFunctionSignature(FuncOp funcOp); +}; + +} // end anonymous namespace + +std::unique_ptr> mlir::createNormalizeMemRefsPass() { + return std::make_unique(); +} + +void NormalizeMemRefs::runOnOperation() { + ModuleOp moduleOp = getOperation(); + // We traverse each function within the module in order to normalize the + // memref type arguments. + // TODO: Handle external functions. + moduleOp.walk([&](FuncOp funcOp) { + if (areMemRefsNormalizable(funcOp)) + runOnFunction(funcOp); + }); +} + +// Return true if this operation dereferences one or more memref's. +// TODO: Temporary utility, will be replaced when this is modeled through +// side-effects/op traits. +static bool isMemRefDereferencingOp(Operation &op) { + return isa(op); +} + +// Check whether all the uses of oldMemRef are either dereferencing uses or the +// op is of type : DeallocOp, CallOp. Only if these constraints are satisfied +// will the value become a candidate for replacement. +static bool isMemRefNormalizable(Value::user_range opUsers) { + if (llvm::any_of(opUsers, [](Operation *op) { + if (isMemRefDereferencingOp(*op)) + return false; + return !isa(*op); + })) + return false; + return true; +} + +// Check whether all the uses of AllocOps, CallOps and function arguments of a +// function are either of dereferencing type or of type: DeallocOp, CallOp. Only +// if these constraints are satisfied will the function become a candidate for +// normalization. +bool NormalizeMemRefs::areMemRefsNormalizable(FuncOp funcOp) { + if (funcOp + .walk([&](AllocOp allocOp) -> WalkResult { + Value oldMemRef = allocOp.getResult(); + if (!isMemRefNormalizable(oldMemRef.getUsers())) + return WalkResult::interrupt(); + return WalkResult::advance(); + }) + .wasInterrupted()) + return false; + + if (funcOp + .walk([&](CallOp callOp) { + for (unsigned resIndex : + llvm::seq(0, callOp.getNumResults())) { + Value oldMemRef = callOp.getResult(resIndex); + if (oldMemRef.getType().isa()) + if (!isMemRefNormalizable(oldMemRef.getUsers())) + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }) + .wasInterrupted()) + return false; + + for (unsigned argIndex : llvm::seq(0, funcOp.getNumArguments())) { + BlockArgument oldMemRef = funcOp.getArgument(argIndex); + if (oldMemRef.getType().isa()) + if (!isMemRefNormalizable(oldMemRef.getUsers())) + return false; + } + + return true; +} + +// Fetch the updated argument list and result of the function and update the +// function signature. +void NormalizeMemRefs::updateFunctionSignature(FuncOp funcOp) { + FunctionType functionType = funcOp.getType(); + SmallVector argTypes; + SmallVector resultTypes; + + for (const auto &arg : llvm::enumerate(funcOp.getArguments())) + argTypes.push_back(arg.value().getType()); + + resultTypes = llvm::to_vector<4>(functionType.getResults()); + // We create a new function type and modify the function signature with this + // new type. + FunctionType newFuncType = FunctionType::get(/*inputs=*/argTypes, + /*results=*/resultTypes, + /*context=*/&getContext()); + + // TODO: Handle ReturnOps to update function results the caller site. + funcOp.setType(newFuncType); +} + +void NormalizeMemRefs::runOnFunction(FuncOp funcOp) { + // Turn memrefs' non-identity layouts maps into ones with identity. Collect + // alloc ops first and then process since normalizeMemRef replaces/erases ops + // during memref rewriting. + SmallVector allocOps; + funcOp.walk([&](AllocOp op) { allocOps.push_back(op); }); + for (AllocOp allocOp : allocOps) + normalizeMemRef(allocOp); + + // We use this OpBuilder to create new memref layout later. + OpBuilder b(funcOp); + + // Walk over each argument of a function to perform memref normalization (if + // any). + for (unsigned argIndex : llvm::seq(0, funcOp.getNumArguments())) { + Type argType = funcOp.getArgument(argIndex).getType(); + MemRefType memrefType = argType.dyn_cast(); + // Check whether argument is of MemRef type. Any other argument type can + // simply be part of the final function signature. + if (!memrefType) + continue; + // Fetch a new memref type after normalizing the old memref to have an + // identity map layout. + MemRefType newMemRefType = normalizeMemRefType(memrefType, b, + /*numSymbolicOperands=*/0); + if (newMemRefType == memrefType) { + // Either memrefType already had an identity map or the map couldn't be + // transformed to an identity map. + continue; + } + + // Insert a new temporary argument with the new memref type. + BlockArgument newMemRef = + funcOp.front().insertArgument(argIndex, newMemRefType); + BlockArgument oldMemRef = funcOp.getArgument(argIndex + 1); + AffineMap layoutMap = memrefType.getAffineMaps().front(); + // Replace all uses of the old memref. + if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newMemRef, + /*extraIndices=*/{}, + /*indexRemap=*/layoutMap, + /*extraOperands=*/{}, + /*symbolOperands=*/{}, + /*domInstFilter=*/nullptr, + /*postDomInstFilter=*/nullptr, + /*allowNonDereferencingOps=*/true, + /*handleDeallocOp=*/true))) { + // If it failed (due to escapes for example), bail out. Removing the + // temporary argument inserted previously. + funcOp.front().eraseArgument(argIndex); + continue; + } + + // All uses for the argument with old memref type were replaced + // successfully. So we remove the old argument now. + funcOp.front().eraseArgument(argIndex + 1); + } + + updateFunctionSignature(funcOp); +} diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -48,7 +48,8 @@ ArrayRef extraIndices, AffineMap indexRemap, ArrayRef extraOperands, - ArrayRef symbolOperands) { + ArrayRef symbolOperands, + bool allowNonDereferencingOps) { unsigned newMemRefRank = newMemRef.getType().cast().getRank(); (void)newMemRefRank; // unused in opt mode unsigned oldMemRefRank = oldMemRef.getType().cast().getRank(); @@ -67,11 +68,6 @@ assert(oldMemRef.getType().cast().getElementType() == newMemRef.getType().cast().getElementType()); - if (!isMemRefDereferencingOp(*op)) - // Failure: memref used in a non-dereferencing context (potentially - // escapes); no replacement in these cases. - return failure(); - SmallVector usePositions; for (const auto &opEntry : llvm::enumerate(op->getOperands())) { if (opEntry.value() == oldMemRef) @@ -91,6 +87,18 @@ unsigned memRefOperandPos = usePositions.front(); OpBuilder builder(op); + // The following checks if op is dereferencing memref and performs the access + // index rewrites. + if (!isMemRefDereferencingOp(*op)) { + if (!allowNonDereferencingOps) + // Failure: memref used in a non-dereferencing context (potentially + // escapes); no replacement in these cases unless allowNonDereferencingOps + // is set. + return failure(); + op->setOperand(memRefOperandPos, newMemRef); + return success(); + } + // Perform index rewrites for the dereferencing op and then replace the op NamedAttribute oldMapAttrPair = getAffineMapAttrForMemRef(op, oldMemRef); AffineMap oldMap = oldMapAttrPair.second.cast().getValue(); unsigned oldMapNumInputs = oldMap.getNumInputs(); @@ -112,7 +120,7 @@ affineApplyOps.push_back(afOp); } } else { - oldMemRefOperands.append(oldMapOperands.begin(), oldMapOperands.end()); + oldMemRefOperands.assign(oldMapOperands.begin(), oldMapOperands.end()); } // Construct new indices as a remap of the old ones if a remapping has been @@ -141,14 +149,14 @@ } } else { // No remapping specified. - remapOutputs.append(remapOperands.begin(), remapOperands.end()); + remapOutputs.assign(remapOperands.begin(), remapOperands.end()); } SmallVector newMapOperands; newMapOperands.reserve(newMemRefRank); // Prepend 'extraIndices' in 'newMapOperands'. - for (auto extraIndex : extraIndices) { + for (Value extraIndex : extraIndices) { assert(extraIndex.getDefiningOp()->getNumResults() == 1 && "single result op's expected to generate these indices"); assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) && @@ -167,12 +175,12 @@ newMap = simplifyAffineMap(newMap); canonicalizeMapAndOperands(&newMap, &newMapOperands); // Remove any affine.apply's that became dead as a result of composition. - for (auto value : affineApplyOps) + for (Value value : affineApplyOps) if (value.use_empty()) value.getDefiningOp()->erase(); - // Construct the new operation using this memref. OperationState state(op->getLoc(), op->getName()); + // Construct the new operation using this memref. state.operands.reserve(op->getNumOperands() + extraIndices.size()); // Insert the non-memref operands. state.operands.append(op->operand_begin(), @@ -196,11 +204,10 @@ // Add attribute for 'newMap', other Attributes do not change. auto newMapAttr = AffineMapAttr::get(newMap); for (auto namedAttr : op->getAttrs()) { - if (namedAttr.first == oldMapAttrPair.first) { + if (namedAttr.first == oldMapAttrPair.first) state.attributes.push_back({namedAttr.first, newMapAttr}); - } else { + else state.attributes.push_back(namedAttr); - } } // Create the new operation. @@ -211,13 +218,12 @@ return success(); } -LogicalResult mlir::replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef, - ArrayRef extraIndices, - AffineMap indexRemap, - ArrayRef extraOperands, - ArrayRef symbolOperands, - Operation *domInstFilter, - Operation *postDomInstFilter) { +LogicalResult mlir::replaceAllMemRefUsesWith( + Value oldMemRef, Value newMemRef, ArrayRef extraIndices, + AffineMap indexRemap, ArrayRef extraOperands, + ArrayRef symbolOperands, Operation *domInstFilter, + Operation *postDomInstFilter, bool allowNonDereferencingOps, + bool replaceInDeallocOp) { unsigned newMemRefRank = newMemRef.getType().cast().getRank(); (void)newMemRefRank; // unused in opt mode unsigned oldMemRefRank = oldMemRef.getType().cast().getRank(); @@ -261,16 +267,21 @@ // Skip dealloc's - no replacement is necessary, and a memref replacement // at other uses doesn't hurt these dealloc's. - if (isa(op)) + if (isa(op) && !replaceInDeallocOp) continue; // Check if the memref was used in a non-dereferencing context. It is fine // for the memref to be used in a non-dereferencing way outside of the // region where this replacement is happening. - if (!isMemRefDereferencingOp(*op)) - // Failure: memref used in a non-dereferencing op (potentially escapes); - // no replacement in these cases. - return failure(); + if (!isMemRefDereferencingOp(*op)) { + // Currently we support the following non-dereferencing types to be a + // candidate for replacement: Dealloc and CallOp. + // TODO: Add support for other kinds of ops. + if (!allowNonDereferencingOps) + return failure(); + if (!(isa(*op))) + return failure(); + } // We'll first collect and then replace --- since replacement erases the op // that has the use, and that op could be postDomFilter or domFilter itself! @@ -278,9 +289,9 @@ } for (auto *op : opsToReplace) { - if (failed(replaceAllMemRefUsesWith(oldMemRef, newMemRef, op, extraIndices, - indexRemap, extraOperands, - symbolOperands))) + if (failed(replaceAllMemRefUsesWith( + oldMemRef, newMemRef, op, extraIndices, indexRemap, extraOperands, + symbolOperands, allowNonDereferencingOps))) llvm_unreachable("memref replacement guaranteed to succeed here"); } @@ -385,83 +396,101 @@ // TODO: Currently works for static memrefs with a single layout map. LogicalResult mlir::normalizeMemRef(AllocOp allocOp) { MemRefType memrefType = allocOp.getType(); - unsigned rank = memrefType.getRank(); - if (rank == 0) - return success(); - - auto layoutMaps = memrefType.getAffineMaps(); OpBuilder b(allocOp); - if (layoutMaps.size() != 1) + + // Fetch a new memref type after normalizing the old memref to have an + // identity map layout. + MemRefType newMemRefType = + normalizeMemRefType(memrefType, b, allocOp.getNumSymbolicOperands()); + if (newMemRefType == memrefType) + // Either memrefType already had an identity map or the map couldn't be + // transformed to an identity map. return failure(); - AffineMap layoutMap = layoutMaps.front(); + Value oldMemRef = allocOp.getResult(); - // Nothing to do for identity layout maps. - if (layoutMap == b.getMultiDimIdentityMap(rank)) - return success(); + SmallVector symbolOperands(allocOp.getSymbolicOperands()); + AllocOp newAlloc = b.create(allocOp.getLoc(), newMemRefType); + AffineMap layoutMap = memrefType.getAffineMaps().front(); + // Replace all uses of the old memref. + if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newAlloc, + /*extraIndices=*/{}, + /*indexRemap=*/layoutMap, + /*extraOperands=*/{}, + /*symbolOperands=*/symbolOperands, + /*domInstFilter=*/nullptr, + /*postDomInstFilter=*/nullptr, + /*allowDereferencingOps=*/true))) { + // If it failed (due to escapes for example), bail out. + newAlloc.erase(); + return failure(); + } + // Replace any uses of the original alloc op and erase it. All remaining uses + // have to be dealloc's; RAMUW above would've failed otherwise. + assert(llvm::all_of(oldMemRef.getUsers(), + [](Operation *op) { return isa(op); })); + oldMemRef.replaceAllUsesWith(newAlloc); + allocOp.erase(); + return success(); +} + +MemRefType mlir::normalizeMemRefType(MemRefType memrefType, OpBuilder b, + unsigned numSymbolicOperands) { + unsigned rank = memrefType.getRank(); + if (rank == 0) + return memrefType; + + ArrayRef layoutMaps = memrefType.getAffineMaps(); + if (layoutMaps.empty() || + layoutMaps.front() == b.getMultiDimIdentityMap(rank)) { + // Either no maps is associated with this memref or this memref has + // a trivial (identity) map. + return memrefType; + } // We don't do any checks for one-to-one'ness; we assume that it is // one-to-one. // TODO: Only for static memref's for now. if (memrefType.getNumDynamicDims() > 0) - return failure(); + return memrefType; - // We have a single map that is not an identity map. Create a new memref with - // the right shape and an identity layout map. - auto shape = memrefType.getShape(); - FlatAffineConstraints fac(rank, allocOp.getNumSymbolicOperands()); + // We have a single map that is not an identity map. Create a new memref + // with the right shape and an identity layout map. + ArrayRef shape = memrefType.getShape(); + // FlatAffineConstraint may later on use symbolicOperands. + FlatAffineConstraints fac(rank, numSymbolicOperands); for (unsigned d = 0; d < rank; ++d) { fac.addConstantLowerBound(d, 0); fac.addConstantUpperBound(d, shape[d] - 1); } - - // We compose this map with the original index (logical) space to derive the - // upper bounds for the new index space. + // We compose this map with the original index (logical) space to derive + // the upper bounds for the new index space. + AffineMap layoutMap = layoutMaps.front(); unsigned newRank = layoutMap.getNumResults(); if (failed(fac.composeMatchingMap(layoutMap))) - // TODO: semi-affine maps. - return failure(); - + return memrefType; + // TODO: Handle semi-affine maps. // Project out the old data dimensions. fac.projectOut(newRank, fac.getNumIds() - newRank - fac.getNumLocalIds()); SmallVector newShape(newRank); for (unsigned d = 0; d < newRank; ++d) { // The lower bound for the shape is always zero. auto ubConst = fac.getConstantUpperBound(d); - // For a static memref and an affine map with no symbols, this is always - // bounded. + // For a static memref and an affine map with no symbols, this is + // always bounded. assert(ubConst.hasValue() && "should always have an upper bound"); if (ubConst.getValue() < 0) // This is due to an invalid map that maps to a negative space. - return failure(); + return memrefType; newShape[d] = ubConst.getValue() + 1; } - auto oldMemRef = allocOp.getResult(); - SmallVector symbolOperands(allocOp.getSymbolicOperands()); - + // Create the new memref type after trivializing the old layout map. MemRefType newMemRefType = MemRefType::Builder(memrefType) .setShape(newShape) .setAffineMaps(b.getMultiDimIdentityMap(newRank)); - auto newAlloc = b.create(allocOp.getLoc(), newMemRefType); - // Replace all uses of the old memref. - if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newAlloc, - /*extraIndices=*/{}, - /*indexRemap=*/layoutMap, - /*extraOperands=*/{}, - /*symbolOperands=*/symbolOperands))) { - // If it failed (due to escapes for example), bail out. - newAlloc.erase(); - return failure(); - } - // Replace any uses of the original alloc op and erase it. All remaining uses - // have to be dealloc's; RAMUW above would've failed otherwise. - assert(llvm::all_of(oldMemRef.getUsers(), - [](Operation *op) { return isa(op); })); - oldMemRef.replaceAllUsesWith(newAlloc); - allocOp.erase(); - return success(); + return newMemRefType; } diff --git a/mlir/test/Transforms/memref-normalize.mlir b/mlir/test/Transforms/normalize-memrefs.mlir rename from mlir/test/Transforms/memref-normalize.mlir rename to mlir/test/Transforms/normalize-memrefs.mlir --- a/mlir/test/Transforms/memref-normalize.mlir +++ b/mlir/test/Transforms/normalize-memrefs.mlir @@ -1,4 +1,7 @@ -// RUN: mlir-opt -allow-unregistered-dialect -simplify-affine-structures %s | FileCheck %s +// RUN: mlir-opt -normalize-memrefs -allow-unregistered-dialect %s | FileCheck %s + +// This file tests whether the memref type having non-trivial map layouts +// are normalized to trivial (identity) layouts. // CHECK-LABEL: func @permute() func @permute() { @@ -143,3 +146,61 @@ } return } + +#tile = affine_map < (i)->(i floordiv 4, i mod 4) > + +// Following test cases check the inter-procedural memref normalization. + +// Test case 1: Check normalization for multiple memrefs in a function argument list. +// CHECK-LABEL: func @multiple_argument_type +// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<4x4xf64>, %[[B:arg[0-9]+]]: f64, %[[C:arg[0-9]+]]: memref<2x4xf64>, %[[D:arg[0-9]+]]: memref<24xf64>) -> f64 +func @multiple_argument_type(%A: memref<16xf64, #tile>, %B: f64, %C: memref<8xf64, #tile>, %D: memref<24xf64>) -> f64 { + %a = affine.load %A[0] : memref<16xf64, #tile> + %p = mulf %a, %a : f64 + affine.store %p, %A[10] : memref<16xf64, #tile> + call @single_argument_type(%C): (memref<8xf64, #tile>) -> () + return %B : f64 +} + +// CHECK: %[[a:[0-9]+]] = affine.load %[[A]][0, 0] : memref<4x4xf64> +// CHECK: %[[p:[0-9]+]] = mulf %[[a]], %[[a]] : f64 +// CHECK: affine.store %[[p]], %[[A]][2, 2] : memref<4x4xf64> +// CHECK: call @single_argument_type(%[[C]]) : (memref<2x4xf64>) -> () +// CHECK: return %[[B]] : f64 + +// Test case 2: Check normalization for single memref argument in a function. +// CHECK-LABEL: func @single_argument_type +// CHECK-SAME: (%[[C:arg[0-9]+]]: memref<2x4xf64>) +func @single_argument_type(%C : memref<8xf64, #tile>) { + %a = alloc(): memref<8xf64, #tile> + %b = alloc(): memref<16xf64, #tile> + %d = constant 23.0 : f64 + %e = alloc(): memref<24xf64> + call @single_argument_type(%a): (memref<8xf64, #tile>) -> () + call @single_argument_type(%C): (memref<8xf64, #tile>) -> () + call @multiple_argument_type(%b, %d, %a, %e): (memref<16xf64, #tile>, f64, memref<8xf64, #tile>, memref<24xf64>) -> f64 + return +} + +// CHECK: %[[a:[0-9]+]] = alloc() : memref<2x4xf64> +// CHECK: %[[b:[0-9]+]] = alloc() : memref<4x4xf64> +// CHECK: %cst = constant 2.300000e+01 : f64 +// CHECK: %[[e:[0-9]+]] = alloc() : memref<24xf64> +// CHECK: call @single_argument_type(%[[a]]) : (memref<2x4xf64>) -> () +// CHECK: call @single_argument_type(%[[C]]) : (memref<2x4xf64>) -> () +// CHECK: call @multiple_argument_type(%[[b]], %cst, %[[a]], %[[e]]) : (memref<4x4xf64>, f64, memref<2x4xf64>, memref<24xf64>) -> f64 + +// Test case 3: Check function returning any other type except memref. +// CHECK-LABEL: func @non_memref_ret +// CHECK-SAME: (%[[C:arg[0-9]+]]: memref<2x4xf64>) -> i1 +func @non_memref_ret(%A: memref<8xf64, #tile>) -> i1 { + %d = constant 1 : i1 + return %d : i1 +} + +// Test case 4: No normalization should take place because the function is returning the memref. +// CHECK-LABEL: func @memref_used_in_return +// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<8xf64, #map{{[0-9]+}}>) -> memref<8xf64, #map{{[0-9]+}}> +func @memref_used_in_return(%A: memref<8xf64, #tile>) -> (memref<8xf64, #tile>) { + return %A : memref<8xf64, #tile> +}