diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -24,7 +24,9 @@ class FuncOp; class ModuleOp; class Pass; -template class OperationPass; + +template +class OperationPass; /// Creates an instance of the BufferPlacement pass. std::unique_ptr createBufferPlacementPass(); @@ -89,6 +91,10 @@ /// Creates a pass which delete symbol operations that are unreachable. This /// pass may *only* be scheduled on an operation that defines a SymbolTable. std::unique_ptr createSymbolDCEPass(); + +/// Creates an interprocedural pass to normalize memrefs to have a trivial +/// (identity) layout map. +std::unique_ptr> createNormalizeMemRefsPass(); } // end namespace mlir #endif // MLIR_TRANSFORMS_PASSES_H diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -309,6 +309,11 @@ let constructor = "mlir::createMemRefDataFlowOptPass()"; } +def NormalizeMemRefs : Pass<"normalize-memrefs", "ModuleOp"> { + let summary = "Interprocedural normalize memrefs"; + let constructor = "mlir::createNormalizeMemRefsPass()"; +} + def ParallelLoopCollapsing : Pass<"parallel-loop-collapsing"> { let summary = "Collapse parallel loops to use less induction variables"; let constructor = "mlir::createParallelLoopCollapsingPass()"; @@ -405,5 +410,4 @@ }]; let constructor = "mlir::createSymbolDCEPass()"; } - #endif // MLIR_TRANSFORMS_PASSES diff --git a/mlir/include/mlir/Transforms/Utils.h b/mlir/include/mlir/Transforms/Utils.h --- a/mlir/include/mlir/Transforms/Utils.h +++ b/mlir/include/mlir/Transforms/Utils.h @@ -45,10 +45,18 @@ /// operations that are dominated by the former; similarly, `postDomInstFilter` /// restricts replacement to only those operations that are postdominated by it. /// +/// 'allowNonDereferencingOps', if set, allows replacement of non-dereferencing +/// uses of a memref without any requirement for access index rewrites. The +/// default value of this flag variable is false. +/// +/// 'handleDeallocOp', if set, lets DeallocOp, a non-dereferencing type, to also +/// be a candidate for replacement. The default value of this flag is false. +/// /// Returns true on success and false if the replacement is not possible, -/// whenever a memref is used as an operand in a non-dereferencing context, -/// except for dealloc's on the memref which are left untouched. See comments at -/// function definition for an example. +/// whenever a memref is used as an operand in a non-dereferencing context and +/// 'allowNonDereferencingOps' is false, except for dealloc's on the memref +/// which are left untouched. See comments at function definition for an +/// example. // // Ex: to replace load %A[%i, %j] with load %Abuf[%t mod 2, %ii - %i, %j]: // The SSA value corresponding to '%t mod 2' should be in 'extraIndices', and @@ -57,28 +65,38 @@ // extra operands, note that 'indexRemap' would just be applied to existing // indices (%i, %j). // TODO: allow extraIndices to be added at any position. -LogicalResult replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef, - ArrayRef extraIndices = {}, - AffineMap indexRemap = AffineMap(), - ArrayRef extraOperands = {}, - ArrayRef symbolOperands = {}, - Operation *domInstFilter = nullptr, - Operation *postDomInstFilter = nullptr); +LogicalResult replaceAllMemRefUsesWith( + Value oldMemRef, Value newMemRef, ArrayRef extraIndices = {}, + AffineMap indexRemap = AffineMap(), ArrayRef extraOperands = {}, + ArrayRef symbolOperands = {}, Operation *domInstFilter = nullptr, + Operation *postDomInstFilter = nullptr, + bool allowNonDereferencingOps = false, bool handleDeallocOp = false); /// Performs the same replacement as the other version above but only for the -/// dereferencing uses of `oldMemRef` in `op`. +/// dereferencing uses of `oldMemRef` in `op`, except in cases where +/// 'allowNonDereferencingOps' is set to true and we provide replacement for the +/// non-dereferencing uses as well. LogicalResult replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef, Operation *op, ArrayRef extraIndices = {}, AffineMap indexRemap = AffineMap(), ArrayRef extraOperands = {}, - ArrayRef symbolOperands = {}); + ArrayRef symbolOperands = {}, + bool allowNonDereferencingOps = false); /// Rewrites the memref defined by this alloc op to have an identity layout map /// and updates all its indexing uses. Returns failure if any of its uses /// escape (while leaving the IR in a valid state). LogicalResult normalizeMemRef(AllocOp op); +/// Uses the old memref type map layout and computes the new memref type to have +/// a new shape and a layout map, where the old layout map has been normalized +/// to an identity layout map. It returns the old memref in case no +/// normalization was needed or a failure occurs while transforming the old map +/// layout to an identity layout map. +MemRefType normalizeMemRefType(MemRefType memrefType, OpBuilder builder, + unsigned numSymbolicOperands); + /// Creates and inserts into 'builder' a new AffineApplyOp, with the number of /// its results equal to the number of operands, as a composition /// of all other AffineApplyOps reachable from input parameter 'operands'. If diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp --- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp @@ -96,13 +96,4 @@ if (isa(op)) applyOpPatternsAndFold(op, patterns); }); - - // Turn memrefs' non-identity layouts maps into ones with identity. Collect - // alloc ops first and then process since normalizeMemRef replaces/erases ops - // during memref rewriting. - SmallVector allocOps; - func.walk([&](AllocOp op) { allocOps.push_back(op); }); - for (auto allocOp : allocOps) { - normalizeMemRef(allocOp); - } } diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt --- a/mlir/lib/Transforms/CMakeLists.txt +++ b/mlir/lib/Transforms/CMakeLists.txt @@ -12,6 +12,7 @@ LoopFusion.cpp LoopInvariantCodeMotion.cpp MemRefDataFlowOpt.cpp + NormalizeMemRefs.cpp OpStats.cpp ParallelLoopCollapsing.cpp PipelineDataTransfer.cpp diff --git a/mlir/lib/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Transforms/NormalizeMemRefs.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Transforms/NormalizeMemRefs.cpp @@ -0,0 +1,219 @@ +//===- NormalizeMemRefs.cpp -----------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements an interprocedural pass to normalize memrefs to have +// identity layout maps. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/Utils.h" + +#define DEBUG_TYPE "normalize-memrefs" + +using namespace mlir; + +namespace { + +/// All interprocedural memrefs with non-trivial layout maps are converted to +/// ones with trivial identity layout ones. + +// Input :- +// #tile = affine_map<(i) -> (i floordiv 4, i mod 4)> +// func @matmul(%A: memref<16xf64, #tile>, %B: index, %C: memref<16xf64>) -> +// (memref<16xf64, #tile>) { +// affine.for %arg3 = 0 to 16 { +// %a = affine.load %A[%arg3] : memref<16xf64, #tile> +// %p = mulf %a, %a : f64 +// affine.store %p, %A[%arg3] : memref<16xf64, #tile> +// } +// %c = alloc() : memref<16xf64, #tile> +// %d = affine.load %c[0] : memref<16xf64, #tile> +// return %A: memref<16xf64, #tile> +// } + +// Output :- +// module { +// func @matmul(%arg0: memref<4x4xf64>, %arg1: index, %arg2: memref<16xf64>) +// -> memref<4x4xf64> { +// affine.for %arg3 = 0 to 16 { +// %2 = affine.load %arg0[%arg3 floordiv 4, %arg3 mod 4] : memref<4x4xf64> +// %3 = mulf %2, %2 : f64 +// affine.store %3, %arg0[%arg3 floordiv 4, %arg3 mod 4] : memref<4x4xf64> +// } +// %0 = alloc() : memref<16xf64, #map0> +// %1 = affine.load %0[0] : memref<16xf64, #map0> +// return %arg0 : memref<4x4xf64> +// } +// } + +struct NormalizeMemRefs : public NormalizeMemRefsBase { + void runOnOperation() override; + void runOnFunction(FuncOp funcOp); + bool canNormalizeMemRefs(FuncOp funcOp); + void updateFunctionSignature(FuncOp funcOp); +}; + +} // end anonymous namespace + +std::unique_ptr> mlir::createNormalizeMemRefsPass() { + return std::make_unique(); +} + +void NormalizeMemRefs::runOnOperation() { + ModuleOp moduleOp = getOperation(); + // We traverse each function within the module in order to normalize the + // memref type arguments. + // TODO: Handle external functions. + moduleOp.walk([&](FuncOp funcOp) { + if (canNormalizeMemRefs(funcOp)) + runOnFunction(funcOp); + }); +} + +// Return true if this operation dereferences one or more memref's. +// Temporary utility: will be replaced when this is modeled through +// side-effects/op traits. TODO +static bool isMemRefDereferencingOp(Operation &op) { + return isa(op); +} + +// Check whether all the uses of oldMemRef are either dereferencing uses or the +// op is of type : DeallocOp, CallOp. Only if these constraints are satisfied +// will the value become a candidate for replacement. +static bool +canNormalizeMemRefValue(mlir::detail::InLineOpResult::user_range opUsers) { + if (llvm::any_of(opUsers, [](Operation *op) { + if (isMemRefDereferencingOp(*op)) + return false; + return !isa(*op); + })) + return false; + return true; +} + +// Check whether all the uses of AllocOps, CallOps and function arguments of a +// function are either of dereferencing type or of type: DeallocOp, CallOp. Only +// if these constraints are satisfied will the function become a candidate for +// normalization. +bool NormalizeMemRefs::canNormalizeMemRefs(FuncOp funcOp) { + SmallVector allocOps; + funcOp.walk([&](AllocOp op) { allocOps.push_back(op); }); + for (AllocOp allocOp : allocOps) { + Value oldMemRef = allocOp.getResult(); + if (!canNormalizeMemRefValue(oldMemRef.getUsers())) + return false; + } + + SmallVector callOps; + funcOp.walk([&](CallOp op) { callOps.push_back(op); }); + for (CallOp callOp : callOps) { + for (unsigned resIndex : llvm::seq(0, callOp.getNumResults())) { + Value oldMemRef = callOp.getResult(resIndex); + if (oldMemRef.getType().getKind() == StandardTypes::MemRef) + if (!canNormalizeMemRefValue(oldMemRef.getUsers())) + return false; + } + } + + for (unsigned argIndex : llvm::seq(0, funcOp.getNumArguments())) { + BlockArgument oldMemRef = funcOp.getArgument(argIndex); + if (oldMemRef.getType().getKind() == StandardTypes::MemRef) + if (!canNormalizeMemRefValue(oldMemRef.getUsers())) + return false; + } + + return true; +} + +// Fetch the updated argument list and result of the function and update the +// function signature. +void NormalizeMemRefs::updateFunctionSignature(FuncOp funcOp) { + FunctionType functionType = funcOp.getType(); + SmallVector argTypes; + SmallVector resultTypes; + + for (const auto &arg : llvm::enumerate(funcOp.getArguments())) { + argTypes.push_back(arg.value().getType()); + } + + resultTypes = llvm::to_vector<4>(functionType.getResults()); + // We create a new function type and modify the function signature with this + // new type. + FunctionType newFuncType = FunctionType::get(/*inputs=*/argTypes, + /*results=*/resultTypes, + /*context=*/&getContext()); + + // TODO: Handle ReturnOps to update function results the caller site. + funcOp.setType(newFuncType); +} + +void NormalizeMemRefs::runOnFunction(FuncOp funcOp) { + // Turn memrefs' non-identity layouts maps into ones with identity. Collect + // alloc ops first and then process since normalizeMemRef replaces/erases ops + // during memref rewriting. + SmallVector allocOps; + funcOp.walk([&](AllocOp op) { allocOps.push_back(op); }); + for (AllocOp allocOp : allocOps) + normalizeMemRef(allocOp); + + // We use this OpBuilder to create new memref layout later. + OpBuilder b(funcOp); + + // Walk over each argument of a function to perform memref normalization (if + // any). + for (const auto &arg : llvm::enumerate(funcOp.getArguments())) { + Type argType = arg.value().getType(); + MemRefType memrefType = argType.dyn_cast(); + // Check whether argument is of MemRef type. Any other argument type can + // simply be part of the final function signature. + if (!memrefType) { + continue; + } + unsigned argIndex = arg.index(); + // Fetch a new memref type after normalizing the old memref to have an + // identity map layout. + MemRefType newMemRefType = normalizeMemRefType(memrefType, b, + /*numSymbolicOperands=*/0); + if (newMemRefType == memrefType) { + // Either memrefType already had an identity map or the map couldn't be + // transformed to an identity map. + continue; + } + + // Insert a new temporary argument with the new memref type. + BlockArgument newMemRef = + funcOp.front().insertArgument(argIndex, newMemRefType); + BlockArgument oldMemRef = funcOp.getArgument(argIndex + 1); + AffineMap layoutMap = memrefType.getAffineMaps().front(); + // Replace all uses of the old memref. + if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newMemRef, + /*extraIndices=*/{}, + /*indexRemap=*/layoutMap, + /*extraOperands=*/{}, + /*symbolOperands=*/{}, + /*domInstFilter=*/nullptr, + /*postDomInstFilter=*/nullptr, + /*allowNonDereferencingOps=*/true, + /*handleDeallocOp=*/true))) { + // If it failed (due to escapes for example), bail out. Removing the + // temporary argument inserted previously. + funcOp.front().eraseArgument(argIndex); + continue; + } + + // All uses for the argument with old memref type were replaced + // successfully. So we remove the old argument now. + funcOp.front().eraseArgument(argIndex + 1); + } + + updateFunctionSignature(funcOp); +} diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -48,7 +48,8 @@ ArrayRef extraIndices, AffineMap indexRemap, ArrayRef extraOperands, - ArrayRef symbolOperands) { + ArrayRef symbolOperands, + bool allowNonDereferencingOps) { unsigned newMemRefRank = newMemRef.getType().cast().getRank(); (void)newMemRefRank; // unused in opt mode unsigned oldMemRefRank = oldMemRef.getType().cast().getRank(); @@ -67,11 +68,6 @@ assert(oldMemRef.getType().cast().getElementType() == newMemRef.getType().cast().getElementType()); - if (!isMemRefDereferencingOp(*op)) - // Failure: memref used in a non-dereferencing context (potentially - // escapes); no replacement in these cases. - return failure(); - SmallVector usePositions; for (const auto &opEntry : llvm::enumerate(op->getOperands())) { if (opEntry.value() == oldMemRef) @@ -91,6 +87,18 @@ unsigned memRefOperandPos = usePositions.front(); OpBuilder builder(op); + // The following checks if op is dereferencing memref and performs the access + // index rewrites. + if (!isMemRefDereferencingOp(*op)) { + if (!allowNonDereferencingOps) + // Failure: memref used in a non-dereferencing context (potentially + // escapes); no replacement in these cases unless allowNonDereferencingOps + // is set. + return failure(); + op->setOperand(memRefOperandPos, newMemRef); + return success(); + } + // Perform index rewrites for the dereferencing op and then replace the op NamedAttribute oldMapAttrPair = getAffineMapAttrForMemRef(op, oldMemRef); AffineMap oldMap = oldMapAttrPair.second.cast().getValue(); unsigned oldMapNumInputs = oldMap.getNumInputs(); @@ -112,7 +120,7 @@ affineApplyOps.push_back(afOp); } } else { - oldMemRefOperands.append(oldMapOperands.begin(), oldMapOperands.end()); + oldMemRefOperands.assign(oldMapOperands.begin(), oldMapOperands.end()); } // Construct new indices as a remap of the old ones if a remapping has been @@ -141,14 +149,14 @@ } } else { // No remapping specified. - remapOutputs.append(remapOperands.begin(), remapOperands.end()); + remapOutputs.assign(remapOperands.begin(), remapOperands.end()); } SmallVector newMapOperands; newMapOperands.reserve(newMemRefRank); // Prepend 'extraIndices' in 'newMapOperands'. - for (auto extraIndex : extraIndices) { + for (Value extraIndex : extraIndices) { assert(extraIndex.getDefiningOp()->getNumResults() == 1 && "single result op's expected to generate these indices"); assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) && @@ -167,12 +175,12 @@ newMap = simplifyAffineMap(newMap); canonicalizeMapAndOperands(&newMap, &newMapOperands); // Remove any affine.apply's that became dead as a result of composition. - for (auto value : affineApplyOps) + for (Value value : affineApplyOps) if (value.use_empty()) value.getDefiningOp()->erase(); - // Construct the new operation using this memref. OperationState state(op->getLoc(), op->getName()); + // Construct the new operation using this memref. state.operands.reserve(op->getNumOperands() + extraIndices.size()); // Insert the non-memref operands. state.operands.append(op->operand_begin(), @@ -196,11 +204,10 @@ // Add attribute for 'newMap', other Attributes do not change. auto newMapAttr = AffineMapAttr::get(newMap); for (auto namedAttr : op->getAttrs()) { - if (namedAttr.first == oldMapAttrPair.first) { + if (namedAttr.first == oldMapAttrPair.first) state.attributes.push_back({namedAttr.first, newMapAttr}); - } else { + else state.attributes.push_back(namedAttr); - } } // Create the new operation. @@ -211,13 +218,12 @@ return success(); } -LogicalResult mlir::replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef, - ArrayRef extraIndices, - AffineMap indexRemap, - ArrayRef extraOperands, - ArrayRef symbolOperands, - Operation *domInstFilter, - Operation *postDomInstFilter) { +LogicalResult mlir::replaceAllMemRefUsesWith( + Value oldMemRef, Value newMemRef, ArrayRef extraIndices, + AffineMap indexRemap, ArrayRef extraOperands, + ArrayRef symbolOperands, Operation *domInstFilter, + Operation *postDomInstFilter, bool allowNonDereferencingOps, + bool handleDeallocOp) { unsigned newMemRefRank = newMemRef.getType().cast().getRank(); (void)newMemRefRank; // unused in opt mode unsigned oldMemRefRank = oldMemRef.getType().cast().getRank(); @@ -261,16 +267,21 @@ // Skip dealloc's - no replacement is necessary, and a memref replacement // at other uses doesn't hurt these dealloc's. - if (isa(op)) + if (isa(op) && !handleDeallocOp) continue; // Check if the memref was used in a non-dereferencing context. It is fine // for the memref to be used in a non-dereferencing way outside of the // region where this replacement is happening. - if (!isMemRefDereferencingOp(*op)) - // Failure: memref used in a non-dereferencing op (potentially escapes); - // no replacement in these cases. - return failure(); + if (!isMemRefDereferencingOp(*op)) { + // Currently we support the following non-dereferencing types to be a + // candidate for replacement : Dealloc and CallOp.. + // TODO: Add support for other kinds of ops. + if (!allowNonDereferencingOps) + return failure(); + if (!(isa(*op))) + return failure(); + } // We'll first collect and then replace --- since replacement erases the op // that has the use, and that op could be postDomFilter or domFilter itself! @@ -278,9 +289,9 @@ } for (auto *op : opsToReplace) { - if (failed(replaceAllMemRefUsesWith(oldMemRef, newMemRef, op, extraIndices, - indexRemap, extraOperands, - symbolOperands))) + if (failed(replaceAllMemRefUsesWith( + oldMemRef, newMemRef, op, extraIndices, indexRemap, extraOperands, + symbolOperands, allowNonDereferencingOps))) llvm_unreachable("memref replacement guaranteed to succeed here"); } @@ -385,83 +396,101 @@ // TODO: Currently works for static memrefs with a single layout map. LogicalResult mlir::normalizeMemRef(AllocOp allocOp) { MemRefType memrefType = allocOp.getType(); - unsigned rank = memrefType.getRank(); - if (rank == 0) - return success(); - - auto layoutMaps = memrefType.getAffineMaps(); OpBuilder b(allocOp); - if (layoutMaps.size() != 1) + + // Fetch a new memref type after normalizing the old memref to have an + // identity map layout. + MemRefType newMemRefType = + normalizeMemRefType(memrefType, b, allocOp.getNumSymbolicOperands()); + if (newMemRefType == memrefType) + // Either memrefType already had an identity map or the map couldn't be + // transformed to an identity map. return failure(); - AffineMap layoutMap = layoutMaps.front(); + Value oldMemRef = allocOp.getResult(); - // Nothing to do for identity layout maps. - if (layoutMap == b.getMultiDimIdentityMap(rank)) - return success(); + SmallVector symbolOperands(allocOp.getSymbolicOperands()); + AllocOp newAlloc = b.create(allocOp.getLoc(), newMemRefType); + AffineMap layoutMap = memrefType.getAffineMaps().front(); + // Replace all uses of the old memref. + if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newAlloc, + /*extraIndices=*/{}, + /*indexRemap=*/layoutMap, + /*extraOperands=*/{}, + /*symbolOperands=*/symbolOperands, + /*domInstFilter=*/nullptr, + /*postDomInstFilter=*/nullptr, + /*allowDereferencingOps=*/true))) { + // If it failed (due to escapes for example), bail out. + newAlloc.erase(); + return failure(); + } + // Replace any uses of the original alloc op and erase it. All remaining uses + // have to be dealloc's; RAMUW above would've failed otherwise. + assert(llvm::all_of(oldMemRef.getUsers(), + [](Operation *op) { return isa(op); })); + oldMemRef.replaceAllUsesWith(newAlloc); + allocOp.erase(); + return success(); +} + +MemRefType mlir::normalizeMemRefType(MemRefType memrefType, OpBuilder b, + unsigned numSymbolicOperands) { + unsigned rank = memrefType.getRank(); + if (rank == 0) + return memrefType; + + ArrayRef layoutMaps = memrefType.getAffineMaps(); + if (layoutMaps.empty() || + layoutMaps.front() == b.getMultiDimIdentityMap(rank)) { + // Either no maps is associated with this memref or this memref has + // a trivial (identity) map. + return memrefType; + } // We don't do any checks for one-to-one'ness; we assume that it is // one-to-one. // TODO: Only for static memref's for now. if (memrefType.getNumDynamicDims() > 0) - return failure(); + return memrefType; - // We have a single map that is not an identity map. Create a new memref with - // the right shape and an identity layout map. - auto shape = memrefType.getShape(); - FlatAffineConstraints fac(rank, allocOp.getNumSymbolicOperands()); + // We have a single map that is not an identity map. Create a new memref + // with the right shape and an identity layout map. + ArrayRef shape = memrefType.getShape(); + // FlatAffineConstraint may later on use symbolicOperands. + FlatAffineConstraints fac(rank, numSymbolicOperands); for (unsigned d = 0; d < rank; ++d) { fac.addConstantLowerBound(d, 0); fac.addConstantUpperBound(d, shape[d] - 1); } - - // We compose this map with the original index (logical) space to derive the - // upper bounds for the new index space. + // We compose this map with the original index (logical) space to derive + // the upper bounds for the new index space. + AffineMap layoutMap = layoutMaps.front(); unsigned newRank = layoutMap.getNumResults(); if (failed(fac.composeMatchingMap(layoutMap))) - // TODO: semi-affine maps. - return failure(); - + return memrefType; + // TODO: Handle semi-affine maps. // Project out the old data dimensions. fac.projectOut(newRank, fac.getNumIds() - newRank - fac.getNumLocalIds()); SmallVector newShape(newRank); for (unsigned d = 0; d < newRank; ++d) { // The lower bound for the shape is always zero. auto ubConst = fac.getConstantUpperBound(d); - // For a static memref and an affine map with no symbols, this is always - // bounded. + // For a static memref and an affine map with no symbols, this is + // always bounded. assert(ubConst.hasValue() && "should always have an upper bound"); if (ubConst.getValue() < 0) // This is due to an invalid map that maps to a negative space. - return failure(); + return memrefType; newShape[d] = ubConst.getValue() + 1; } - auto oldMemRef = allocOp.getResult(); - SmallVector symbolOperands(allocOp.getSymbolicOperands()); - + // Create the new memref type after trivializing the old layout map. MemRefType newMemRefType = MemRefType::Builder(memrefType) .setShape(newShape) .setAffineMaps(b.getMultiDimIdentityMap(newRank)); - auto newAlloc = b.create(allocOp.getLoc(), newMemRefType); - // Replace all uses of the old memref. - if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newAlloc, - /*extraIndices=*/{}, - /*indexRemap=*/layoutMap, - /*extraOperands=*/{}, - /*symbolOperands=*/symbolOperands))) { - // If it failed (due to escapes for example), bail out. - newAlloc.erase(); - return failure(); - } - // Replace any uses of the original alloc op and erase it. All remaining uses - // have to be dealloc's; RAMUW above would've failed otherwise. - assert(llvm::all_of(oldMemRef.getUsers(), - [](Operation *op) { return isa(op); })); - oldMemRef.replaceAllUsesWith(newAlloc); - allocOp.erase(); - return success(); + return newMemRefType; } diff --git a/mlir/test/Transforms/memref-normalize.mlir b/mlir/test/Transforms/memref-normalize.mlir deleted file mode 100644 --- a/mlir/test/Transforms/memref-normalize.mlir +++ /dev/null @@ -1,145 +0,0 @@ -// RUN: mlir-opt -allow-unregistered-dialect -simplify-affine-structures %s | FileCheck %s - -// CHECK-LABEL: func @permute() -func @permute() { - %A = alloc() : memref<64x256xf32, affine_map<(d0, d1) -> (d1, d0)>> - affine.for %i = 0 to 64 { - affine.for %j = 0 to 256 { - %1 = affine.load %A[%i, %j] : memref<64x256xf32, affine_map<(d0, d1) -> (d1, d0)>> - "prevent.dce"(%1) : (f32) -> () - } - } - dealloc %A : memref<64x256xf32, affine_map<(d0, d1) -> (d1, d0)>> - return -} -// The old memref alloc should disappear. -// CHECK-NOT: memref<64x256xf32> -// CHECK: [[MEM:%[0-9]+]] = alloc() : memref<256x64xf32> -// CHECK-NEXT: affine.for %[[I:arg[0-9]+]] = 0 to 64 { -// CHECK-NEXT: affine.for %[[J:arg[0-9]+]] = 0 to 256 { -// CHECK-NEXT: affine.load [[MEM]][%[[J]], %[[I]]] : memref<256x64xf32> -// CHECK-NEXT: "prevent.dce" -// CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: dealloc [[MEM]] -// CHECK-NEXT: return - -// CHECK-LABEL: func @shift -func @shift(%idx : index) { - // CHECK-NEXT: alloc() : memref<65xf32> - %A = alloc() : memref<64xf32, affine_map<(d0) -> (d0 + 1)>> - // CHECK-NEXT: affine.load %{{.*}}[symbol(%arg0) + 1] : memref<65xf32> - affine.load %A[%idx] : memref<64xf32, affine_map<(d0) -> (d0 + 1)>> - affine.for %i = 0 to 64 { - %1 = affine.load %A[%i] : memref<64xf32, affine_map<(d0) -> (d0 + 1)>> - "prevent.dce"(%1) : (f32) -> () - // CHECK: %{{.*}} = affine.load %{{.*}}[%arg{{.*}} + 1] : memref<65xf32> - } - return -} - -// CHECK-LABEL: func @high_dim_permute() -func @high_dim_permute() { - // CHECK-NOT: memref<64x128x256xf32, - %A = alloc() : memref<64x128x256xf32, affine_map<(d0, d1, d2) -> (d2, d0, d1)>> - // CHECK: %[[I:arg[0-9]+]] - affine.for %i = 0 to 64 { - // CHECK: %[[J:arg[0-9]+]] - affine.for %j = 0 to 128 { - // CHECK: %[[K:arg[0-9]+]] - affine.for %k = 0 to 256 { - %1 = affine.load %A[%i, %j, %k] : memref<64x128x256xf32, affine_map<(d0, d1, d2) -> (d2, d0, d1)>> - // CHECK: %{{.*}} = affine.load %{{.*}}[%[[K]], %[[I]], %[[J]]] : memref<256x64x128xf32> - "prevent.dce"(%1) : (f32) -> () - } - } - } - return -} - -// CHECK-LABEL: func @invalid_map -func @invalid_map() { - %A = alloc() : memref<64x128xf32, affine_map<(d0, d1) -> (d0, -d1 - 10)>> - // CHECK: %{{.*}} = alloc() : memref<64x128xf32, - return -} - -// A tiled layout. -// CHECK-LABEL: func @data_tiling -func @data_tiling(%idx : index) { - // CHECK: alloc() : memref<8x32x8x16xf32> - %A = alloc() : memref<64x512xf32, affine_map<(d0, d1) -> (d0 floordiv 8, d1 floordiv 16, d0 mod 8, d1 mod 16)>> - // CHECK: affine.load %{{.*}}[symbol(%arg0) floordiv 8, symbol(%arg0) floordiv 16, symbol(%arg0) mod 8, symbol(%arg0) mod 16] - %1 = affine.load %A[%idx, %idx] : memref<64x512xf32, affine_map<(d0, d1) -> (d0 floordiv 8, d1 floordiv 16, d0 mod 8, d1 mod 16)>> - "prevent.dce"(%1) : (f32) -> () - return -} - -// Strides 2 and 4 along respective dimensions. -// CHECK-LABEL: func @strided -func @strided() { - %A = alloc() : memref<64x128xf32, affine_map<(d0, d1) -> (2*d0, 4*d1)>> - // CHECK: affine.for %[[IV0:.*]] = - affine.for %i = 0 to 64 { - // CHECK: affine.for %[[IV1:.*]] = - affine.for %j = 0 to 128 { - // CHECK: affine.load %{{.*}}[%[[IV0]] * 2, %[[IV1]] * 4] : memref<127x509xf32> - %1 = affine.load %A[%i, %j] : memref<64x128xf32, affine_map<(d0, d1) -> (2*d0, 4*d1)>> - "prevent.dce"(%1) : (f32) -> () - } - } - return -} - -// Strided, but the strides are in the linearized space. -// CHECK-LABEL: func @strided_cumulative -func @strided_cumulative() { - %A = alloc() : memref<2x5xf32, affine_map<(d0, d1) -> (3*d0 + 17*d1)>> - // CHECK: affine.for %[[IV0:.*]] = - affine.for %i = 0 to 2 { - // CHECK: affine.for %[[IV1:.*]] = - affine.for %j = 0 to 5 { - // CHECK: affine.load %{{.*}}[%[[IV0]] * 3 + %[[IV1]] * 17] : memref<72xf32> - %1 = affine.load %A[%i, %j] : memref<2x5xf32, affine_map<(d0, d1) -> (3*d0 + 17*d1)>> - "prevent.dce"(%1) : (f32) -> () - } - } - return -} - -// Symbolic operand for alloc, although unused. Tests replaceAllMemRefUsesWith -// when the index remap has symbols. -// CHECK-LABEL: func @symbolic_operands -func @symbolic_operands(%s : index) { - // CHECK: alloc() : memref<100xf32> - %A = alloc()[%s] : memref<10x10xf32, affine_map<(d0,d1)[s0] -> (10*d0 + d1)>> - affine.for %i = 0 to 10 { - affine.for %j = 0 to 10 { - // CHECK: affine.load %{{.*}}[%{{.*}} * 10 + %{{.*}}] : memref<100xf32> - %1 = affine.load %A[%i, %j] : memref<10x10xf32, affine_map<(d0,d1)[s0] -> (10*d0 + d1)>> - "prevent.dce"(%1) : (f32) -> () - } - } - return -} - -// Memref escapes; no normalization. -// CHECK-LABEL: func @escaping() -> memref<64xf32, #map{{[0-9]+}}> -func @escaping() -> memref<64xf32, affine_map<(d0) -> (d0 + 2)>> { - // CHECK: %{{.*}} = alloc() : memref<64xf32, #map{{[0-9]+}}> - %A = alloc() : memref<64xf32, affine_map<(d0) -> (d0 + 2)>> - return %A : memref<64xf32, affine_map<(d0) -> (d0 + 2)>> -} - -// Semi-affine maps, normalization not implemented yet. -// CHECK-LABEL: func @semi_affine_layout_map -func @semi_affine_layout_map(%s0: index, %s1: index) { - %A = alloc()[%s0, %s1] : memref<256x1024xf32, affine_map<(d0, d1)[s0, s1] -> (d0*s0 + d1*s1)>> - affine.for %i = 0 to 256 { - affine.for %j = 0 to 1024 { - // CHECK: memref<256x1024xf32, #map{{[0-9]+}}> - affine.load %A[%i, %j] : memref<256x1024xf32, affine_map<(d0, d1)[s0, s1] -> (d0*s0 + d1*s1)>> - } - } - return -} diff --git a/mlir/test/Transforms/normalize-memrefs.mlir b/mlir/test/Transforms/normalize-memrefs.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/normalize-memrefs.mlir @@ -0,0 +1,243 @@ +// RUN: mlir-opt -normalize-memrefs -allow-unregistered-dialect %s | FileCheck +// %s + +// This file tests whether the memref type having non-trivial map layouts +// are normalized to trivial (identity) layouts. + +// CHECK-LABEL: func @permute() +func @permute() { + %A = alloc() : memref<64x256xf32, affine_map<(d0, d1) -> (d1, d0)>> + affine.for %i = 0 to 64 { + affine.for %j = 0 to 256 { + % 1 = affine.load % A[% i, % j] + : memref<64x256xf32, affine_map<(d0, d1)->(d1, d0)>> "prevent.dce"(% + 1) + : (f32)->() + } + } + dealloc % A : memref<64x256xf32, affine_map<(d0, d1)->(d1, d0)>> return +} +// The old memref alloc should disappear. +// CHECK-NOT: memref<64x256xf32> +// CHECK: [[MEM:%[0-9]+]] = alloc() : memref<256x64xf32> +// CHECK-NEXT: affine.for %[[I:arg[0-9]+]] = 0 to 64 { +// CHECK-NEXT: affine.for %[[J:arg[0-9]+]] = 0 to 256 { +// CHECK-NEXT: affine.load [[MEM]][%[[J]], %[[I]]] : memref<256x64xf32> +// CHECK-NEXT: "prevent.dce" +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: dealloc [[MEM]] +// CHECK-NEXT: return + +// CHECK-LABEL: func @shift +func @shift(% idx : index) { + // CHECK-NEXT: alloc() : memref<65xf32> + %A = alloc() : memref<64xf32, affine_map<(d0) -> (d0 + 1)>> + // CHECK-NEXT: affine.load %{{.*}}[symbol(%arg0) + 1] : memref<65xf32> + affine.load %A[%idx] : memref<64xf32, affine_map<(d0) -> (d0 + 1)>> + affine.for %i = 0 to 64 { + % 1 = affine.load % A[% i] + : memref<64xf32, affine_map<(d0)->(d0 + 1)>> "prevent.dce"(% 1) + : (f32)->() + // CHECK: %{{.*}} = affine.load %{{.*}}[%arg{{.*}} + 1] : memref<65xf32> + } + return +} + +// CHECK-LABEL: func @high_dim_permute() +func @high_dim_permute() { + // CHECK-NOT: memref<64x128x256xf32, + %A = alloc() : memref<64x128x256xf32, affine_map<(d0, d1, d2) -> (d2, d0, d1)>> + // CHECK: %[[I:arg[0-9]+]] + affine.for %i = 0 to 64 { + // CHECK: %[[J:arg[0-9]+]] + affine.for %j = 0 to 128 { + // CHECK: %[[K:arg[0-9]+]] + affine.for %k = 0 to 256 { + % 1 = affine.load % A[% i, % j, % k] + : memref<64x128x256xf32, affine_map<(d0, d1, d2)->(d2, d0, d1)>> + // CHECK: %{{.*}} = affine.load %{{.*}}[%[[K]], %[[I]], %[[J]]] : + // memref<256x64x128xf32> + "prevent.dce"(% 1) + : (f32)->() + } + } + } + return +} + +// CHECK-LABEL: func @invalid_map +func @invalid_map(){ + % A = alloc() : memref<64x128xf32, affine_map<(d0, d1)->(d0, -d1 - 10)>> + // CHECK: %{{.*}} = alloc() : memref<64x128xf32, + return +} + +// A tiled layout. +// CHECK-LABEL: func @data_tiling +func @data_tiling(% idx + : index){ + // CHECK: alloc() : memref<8x32x8x16xf32> + % A = alloc() : memref<64x512xf32, + affine_map<(d0, d1)->(d0 floordiv 8, d1 floordiv 16, + d0 mod 8, d1 mod 16)>> + // CHECK: affine.load %{{.*}}[symbol(%arg0) floordiv 8, symbol(%arg0) + // floordiv 16, symbol(%arg0) mod 8, symbol(%arg0) mod 16] + % 1 = affine.load % + A[% idx, % idx] : memref< + 64x512xf32, + affine_map<(d0, d1)->(d0 floordiv 8, d1 floordiv 16, d0 mod 8, + d1 mod 16)>> "prevent.dce"(% 1) : (f32) + ->() return +} + +// Strides 2 and 4 along respective dimensions. +// CHECK-LABEL: func @strided +func @strided() { + %A = alloc() : memref<64x128xf32, affine_map<(d0, d1) -> (2*d0, 4*d1)>> + // CHECK: affine.for %[[IV0:.*]] = + affine.for %i = 0 to 64 { + // CHECK: affine.for %[[IV1:.*]] = + affine.for %j = 0 to 128 { + // CHECK: affine.load %{{.*}}[%[[IV0]] * 2, %[[IV1]] * 4] : + // memref<127x509xf32> + % 1 = affine.load % A[% i, % j] + : memref<64x128xf32, + affine_map<(d0, d1)->(2 * d0, 4 * d1)>> "prevent.dce"(% 1) + : (f32)->() + } + } + return +} + +// Strided, but the strides are in the linearized space. +// CHECK-LABEL: func @strided_cumulative +func @strided_cumulative() { + %A = alloc() : memref<2x5xf32, affine_map<(d0, d1) -> (3*d0 + 17*d1)>> + // CHECK: affine.for %[[IV0:.*]] = + affine.for %i = 0 to 2 { + // CHECK: affine.for %[[IV1:.*]] = + affine.for %j = 0 to 5 { + // CHECK: affine.load %{{.*}}[%[[IV0]] * 3 + %[[IV1]] * 17] : + // memref<72xf32> + % 1 = affine.load % A[% i, % j] + : memref<2x5xf32, + affine_map<(d0, d1)->(3 * d0 + 17 * d1)>> "prevent.dce"(% 1) + : (f32)->() + } + } + return +} + +// Symbolic operand for alloc, although unused. Tests replaceAllMemRefUsesWith +// when the index remap has symbols. +// CHECK-LABEL: func @symbolic_operands +func @symbolic_operands(% s : index) { + // CHECK: alloc() : memref<100xf32> + %A = alloc()[%s] : memref<10x10xf32, affine_map<(d0,d1)[s0] -> (10*d0 + d1)>> + affine.for %i = 0 to 10 { + affine.for %j = 0 to 10 { + // CHECK: affine.load %{{.*}}[%{{.*}} * 10 + %{{.*}}] : memref<100xf32> + % 1 = affine.load % A[% i, % j] + : memref<10x10xf32, + affine_map<(d0, d1)[s0]->(10 * d0 + d1)>> "prevent.dce"(% 1) + : (f32)->() + } + } + return +} + +// Memref escapes; no normalization. +// CHECK-LABEL: func @escaping() -> memref<64xf32, #map{{[0-9]+}}> +func @escaping()->memref<64xf32, affine_map<(d0)->(d0 + 2)>>{ + // CHECK: %{{.*}} = alloc() : memref<64xf32, #map{{[0-9]+}}> + % A = alloc() : memref<64xf32, affine_map<(d0)->(d0 + 2)>> return % + A : memref<64xf32, affine_map<(d0)->(d0 + 2)>> +} + +// Semi-affine maps, normalization not implemented yet. +// CHECK-LABEL: func @semi_affine_layout_map +func @semi_affine_layout_map(% s0 + : index, % s1 + : index) { + %A = alloc()[%s0, %s1] : memref<256x1024xf32, affine_map<(d0, d1)[s0, s1] -> (d0*s0 + d1*s1)>> + affine.for %i = 0 to 256 { + affine.for %j = 0 to 1024 { + // CHECK: memref<256x1024xf32, #map{{[0-9]+}}> + affine.load % A[% i, % j] + : memref<256x1024xf32, + affine_map<(d0, d1)[s0, s1]->(d0 * s0 + d1 * s1)>> + } + } + return +} + +#tile = affine_map < (i)->(i floordiv 4, i mod 4) > + +// Following test cases check the inter-procedural memref normalization. + +// Test case 1: Check normalization for multiple memrefs in a function argument +// list. CHECK-LABEL: func @multiple_argument_type CHECK-SAME: +// (%[[A:arg[0-9]+]]: memref<4x4xf64>, %[[B:arg[0-9]+]]: f64, %[[C:arg[0-9]+]]: +// memref<2x4xf64>, %[[D:arg[0-9]+]]: memref<24xf64>) -> f64 +func @multiple_argument_type(% A + : memref<16xf64, #tile>, % B + : f64, % C + : memref<8xf64, #tile>, % D + : memref<24xf64>) + ->f64 { + % a = affine.load % A[0] : memref<16xf64, #tile> % p = mulf % a, + % a : f64 affine.store % p, + % + A[10] : memref<16xf64, #tile> call @single_argument_type(% C) + : (memref<8xf64, #tile>)->() return % + B : f64 +} + +// CHECK: %[[a:[0-9]+]] = affine.load %[[A]][0, 0] : memref<4x4xf64> +// CHECK: %[[p:[0-9]+]] = mulf %[[a]], %[[a]] : f64 +// CHECK: affine.store %[[p]], %[[A]][2, 2] : memref<4x4xf64> +// CHECK: call @single_argument_type(%[[C]]) : (memref<2x4xf64>) -> () +// CHECK: return %[[B]] : f64 + +// Test case 2: Check normalization for single memref argument in a function. +// CHECK-LABEL: func @single_argument_type +// CHECK-SAME: (%[[C:arg[0-9]+]]: memref<2x4xf64>) +func @single_argument_type(% C : memref<8xf64, #tile>) { + % a = alloc() + : memref<8xf64, #tile> % b = alloc() + : memref<16xf64, #tile> % d = constant 23.0 : f64 % e = alloc() + : memref<24xf64> call @single_argument_type(% a) + : (memref<8xf64, #tile>) + ->() call @single_argument_type(% C) + : (memref<8xf64, #tile>) + ->() call @multiple_argument_type(% b, % d, % a, % e) + : (memref<16xf64, #tile>, f64, memref<8xf64, #tile>, memref<24xf64>) + ->f64 return +} + +// CHECK: %[[a:[0-9]+]] = alloc() : memref<2x4xf64> +// CHECK: %[[b:[0-9]+]] = alloc() : memref<4x4xf64> +// CHECK: %cst = constant 2.300000e+01 : f64 +// CHECK: %[[e:[0-9]+]] = alloc() : memref<24xf64> +// CHECK: call @single_argument_type(%[[a]]) : (memref<2x4xf64>) -> () +// CHECK: call @single_argument_type(%[[C]]) : (memref<2x4xf64>) -> () +// CHECK: call @multiple_argument_type(%[[b]], %cst, %[[a]], %[[e]]) : +// (memref<4x4xf64>, f64, memref<2x4xf64>, memref<24xf64>) -> f64 + +// Test case 3: Check function returning any other type except memref. +// CHECK-LABEL: func @non_memref_ret +// CHECK-SAME: (%[[C:arg[0-9]+]]: memref<2x4xf64>) -> i1 +func @non_memref_ret(% A + : memref<8xf64, #tile>) + ->i1{ % d = constant 1 : i1 return % d : i1} + +// Test case 4: No normalization should take place because function has a memref +// arg being used in return. CHECK-LABEL: func @memref_used_in_return +// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<8xf64, #map{{[0-9]+}}>) -> +// memref<8xf64, #map{{[0-9]+}}> +func @memref_used_in_return(% A + : memref<8xf64, #tile>) + ->(memref<8xf64, #tile>) { + return % A : memref<8xf64, #tile> +}