diff --git a/mlir/include/mlir/Dialect/Affine/LoopUtils.h b/mlir/include/mlir/Dialect/Affine/LoopUtils.h --- a/mlir/include/mlir/Dialect/Affine/LoopUtils.h +++ b/mlir/include/mlir/Dialect/Affine/LoopUtils.h @@ -15,6 +15,7 @@ #ifndef MLIR_DIALECT_AFFINE_LOOPUTILS_H #define MLIR_DIALECT_AFFINE_LOOPUTILS_H +#include "mlir/IR/AffineMap.h" #include "mlir/IR/Block.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" @@ -22,7 +23,6 @@ #include namespace mlir { -class AffineMap; class LoopLikeOpInterface; class OpBuilder; class Value; @@ -347,6 +347,56 @@ return result; } +/// Class that encapsulates an Affine tiling analysis to find ideal tile sizes +/// for DMA copies with a fixed-width chunking constraint. +/// This pass assumes: +/// - The loop band is perfect. +/// - The loop bounds are static. +/// - The Affine array acceses are parallel. +/// - The DMA copies are continuous. +/// Refer to Tiling for DMA-Based Hardware Accelerators for more information. +/// doi: https://doi.org/10.1145/3589610.3596283 +class FixedWidthAffineTilingAnalysis { +private: + // Variables provided by the user. + ArrayRef band; + uint64_t chunkSizeBytes; + uint64_t cacheSizeBytes; + // Map of the bound constraints on a memref within the loop band. + llvm::SmallDenseMap, 4> + boundConstraints; + // Variables used for the exhaustive search. + SmallVector bestTileSizes; + uint64_t bestFootprint; + bool foundLegalTileSizes; + /// Populates the bounding constraints on the memrefs within a band based on + /// the Affine accesses on the memrefs. + LogicalResult getBoundingConstraints(); + /// Get the footprint of all the copies that would be generated from the given + /// tile sizes using the bound constraints previously calculated. + uint64_t getFootprintOfCopies(SmallVectorImpl *tileSizes); + /// A recursive search that tests different tile sizes and populates the + /// bestTileSizes array with the best tile sizes. The function will set + /// foundLegalTileSizes to true if a legal tile size is found. + /// NOTE: A recursive alogorithm was used since the rank of the loop band is + /// unknown. + /// NOTE: The time complexity of this search is logarithmic with the number of + /// loops in the band and the size of the largest trip count. + void exhaustiveSearchForTileSizes(SmallVectorImpl *tileSizes, + unsigned id); + +public: + /// Calculates the best tile sizes that maximally fills the cache while + /// maintaining expected DMA copy widths to be multiples of a given fixed- + /// width chunk size. + /// Assumes the memory are copied in a single, continuous DMA copy. + LogicalResult getTileSizes(SmallVectorImpl *tileSizes); + + explicit FixedWidthAffineTilingAnalysis(ArrayRef band, + uint64_t chunkSizeBytes, + uint64_t cacheSizeBytes); +}; + } // namespace affine } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Affine/Passes.td b/mlir/include/mlir/Dialect/Affine/Passes.td --- a/mlir/include/mlir/Dialect/Affine/Passes.td +++ b/mlir/include/mlir/Dialect/Affine/Passes.td @@ -196,6 +196,10 @@ ListOption<"tileSizes", "tile-sizes", "unsigned", "List of tile sizes for each perfect nest " "(overridden by -tile-size)">, + Option<"fpChunkDivisor", "fp-chunk-divisor", "uint64_t", /*default*/"0", + "Specify if the footprint must be a multiple of the given chunk " + "divisor (in bytes). This is useful for generating DMAs with constrained copy " + "widths">, ]; } diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp --- a/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp @@ -212,6 +212,17 @@ return; } + // If a footprint chunk divisor is provided, run the fixed-width tiling + // analysis to choose the tile sizes. + if (fpChunkDivisor != 0) { + FixedWidthAffineTilingAnalysis fwata(band, fpChunkDivisor, cacheSizeBytes); + if (!failed(fwata.getTileSizes(tileSizes))) + return; + LLVM_DEBUG(llvm::dbgs() + << "Failed to find legal tile sizes that do not overflow the " + "cache. Falling back on original implementation.\n"); + } + // Divide all loops equally in an attempt to reduce footprint. // TODO: this is approximate. Ideally, obtain reuse factor / // profitability along each dimension and weight tile sizes based on that as diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp --- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp @@ -2770,3 +2770,326 @@ return success(); } + +/// Linearize the memory locations of the memory access relations produced by +/// the MemRefAccess class. This is so this analysis can handle multi-D array +/// accesses. +static LogicalResult linearizeAffineAccessRelation(FlatAffineRelation &relation, + MemRefAccess &access) { + // Add a new dim to represent the linearized memory location. + relation.appendDomainVar(); + // Set the new dim's constraint as a function of the non-linearized memory + // locations. + unsigned rank = access.getRank(); + if (rank == 1) + return LogicalResult::success(); + MemRefType memrefTy = access.memref.getType().cast(); + assert(memrefTy.hasStaticShape() && + "Cannot delinearize dynamic shape memref while still being affine."); + assert(memrefTy.getLayout().isIdentity() && + "TODO: Use the layout to delinearize instead of assuming row-major."); + SmallVector linearizedConstraint(relation.getNumCols(), 0); + auto shape = memrefTy.getShape(); + unsigned memLocBase = relation.getNumCols() - rank - 1; + for (unsigned i = 0; i < rank; i++) { + int64_t mul = 1; + for (unsigned j = i + 1; j < rank; j++) + mul *= shape[j]; + linearizedConstraint[memLocBase + i] = mul; + } + linearizedConstraint[memLocBase - 1] = -1; + relation.addEquality(linearizedConstraint); + // Project out the non-linearized memory locations. + relation.projectOut(memLocBase, rank); + return LogicalResult::success(); +} + +/// Attempts to find a single constant upper or lower bound within a flat +/// affine relation. +static int64_t findConstBoundIneq(FlatAffineRelation &relation, + unsigned loopIVPos, int64_t constBound, + bool isUB) { + int64_t ineqRow = -1; + unsigned constPos = relation.getNumCols() - 1; + for (unsigned r = 0; r < relation.getNumInequalities(); r++) { + int64_t sign = isUB ? 1 : -1; + if (relation.atIneq(r, loopIVPos) * sign >= 0) + continue; + if (relation.atIneq(r, constPos) * sign != constBound) + continue; + bool allElseZero = true; + for (unsigned c = 0; c < relation.getNumCols() - 1; c++) { + if (c == loopIVPos) + continue; + if (relation.atIneq(r, c) != 0) { + allElseZero = false; + break; + } + } + if (!allElseZero) + continue; + assert(ineqRow == -1 && + "TODO: handle multiple bound constraints for a single loop."); + ineqRow = r; + } + assert(ineqRow != -1 && "findConstBoundIneq: could not find constant bound."); + return ineqRow; +} + +/// Sets the lower bound constraint from a given AffineForOp to 0. +static LogicalResult setLowerBoundOfLoopToZero(AffineForOp forOp, + FlatAffineRelation &relation) { + assert(forOp.hasConstantLowerBound() && + "setLowerBoundofLoopToZero: forOp expected to have constant lower " + "bound."); + // If the lower bound is already zero, no need to do anything. + int64_t constLB = forOp.getConstantLowerBound(); + if (constLB == 0) + return LogicalResult::success(); + // Find the position in the relation of the loops induction variable. + unsigned loopIVPos; + if (!relation.findVar(forOp.getInductionVar(), &loopIVPos)) + return forOp->emitError( + "setLowerBoundOfLoopToZero: cannot find loop in relation\n"); + // Find the constant upper bound constraint of the loop iv. + int64_t lbIneqRow = + findConstBoundIneq(relation, loopIVPos, constLB, /*isUB*/ false); + // Replace the lower bound of the induction variable with 0. + // IV >= 0 + SmallVector lbConstraint(relation.getNumCols(), 0); + lbConstraint[loopIVPos] = 1; + relation.removeInequality(lbIneqRow); + relation.addInequality(lbConstraint); + return LogicalResult::success(); +} + +/// Set the upper bound constraint from a given AffineForOp to a symbolic +/// variable to act as the tile size. +static LogicalResult setUpperBoundOfLoopToSymbol(AffineForOp forOp, + FlatAffineRelation &relation) { + assert(forOp.hasConstantUpperBound() && + "setUpperBoundOfLoopToSymbol: forOp expected to have constant UB."); + // Create a symbolic variable to represent the tile size. + unsigned symPos = relation.appendSymbolVar(); + // Find the position in the relation of the loops induction variable. + unsigned loopIVPos; + if (!relation.findVar(forOp.getInductionVar(), &loopIVPos)) + return forOp->emitError( + "setUpperBoundOfLoopToSymbol: cannot find loop in relation\n"); + // Find the constant upper bound constraint of the loop iv. + int64_t constUB = forOp.getConstantUpperBound() - 1; + int64_t ubIneqRow = + findConstBoundIneq(relation, loopIVPos, constUB, /*isUB*/ true); + // Replace the upperbound of the induction variable with the tile size. + // iv < s ==> iv <= s - 1 ==> s - iv - 1 >= 0 + SmallVector ubConstraint(relation.getNumCols(), 0); + unsigned constPos = relation.getNumCols() - 1; + ubConstraint[symPos] = 1; + ubConstraint[loopIVPos] = -1; + ubConstraint[constPos] = -1; + relation.removeInequality(ubIneqRow); + relation.addInequality(ubConstraint); + // Set the upperbound of the tile size to the trip count. + // s <= TC ==> TC - s >= 0 + std::optional tripCount = getConstantTripCount(forOp); + assert(tripCount.has_value() && "Trip count expected to be constant."); + SmallVector tileUBConstraint(relation.getNumCols(), 0); + tileUBConstraint[constPos] = tripCount.value(); + tileUBConstraint[symPos] = -1; + relation.addInequality(tileUBConstraint); + return LogicalResult::success(); +} + +LogicalResult +mlir::affine::FixedWidthAffineTilingAnalysis::getBoundingConstraints() { + assert(boundConstraints.empty()); + // Collect all of the relations of the memref that is being accessed + llvm::SmallDenseMap relations; + auto result = band.front()->walk([&](Operation *opInst) -> WalkResult { + // If not a load or store op, skip + if (!isa(opInst)) + return WalkResult::advance(); + // Obtain the access relation from the MemRef acess. + MemRefAccess access(opInst); + FlatAffineRelation accessRel; + if (failed(access.getAccessRelation(accessRel))) + return opInst->emitError("Error obtaining access relation\n"); + LLVM_DEBUG(llvm::dbgs() << "Access relation:\n"); + LLVM_DEBUG(accessRel.print(llvm::dbgs())); + // Linearize the memory accesses on the dimensions. This analysis expects + // the input to be 1D. + if (failed(linearizeAffineAccessRelation(accessRel, access))) + return opInst->emitError("Failed to linearize access relation\n"); + LLVM_DEBUG(llvm::dbgs() << "Linearized access relation:\n"); + LLVM_DEBUG(accessRel.print(llvm::dbgs())); + // Simplify the polyhedron by setting the lower bound of the loop + // induction variables to 0, and the upperbound to a symbol. + for (AffineForOp loop : band) { + if (failed(setLowerBoundOfLoopToZero(loop, accessRel))) + return LogicalResult::failure(); + if (failed(setUpperBoundOfLoopToSymbol(loop, accessRel))) + return LogicalResult::failure(); + } + LLVM_DEBUG(llvm::dbgs() << "Simplified access relation:\n"); + LLVM_DEBUG(accessRel.print(llvm::dbgs())); + // Project out the dimensional variables corresponding to the loop + // induction variables. + accessRel.projectOut(0, accessRel.getNumDimVars() - 1); + LLVM_DEBUG(llvm::dbgs() << "Projected access relation:\n"); + LLVM_DEBUG(accessRel.print(llvm::dbgs())); + // Insert into the map of access relations. At the moment these are + // just being appended to the accesses. Ideally this should use some + // form of unionBoundingBox, however have found it to be unusable for + // this case. + auto it = relations.find(access.memref); + if (it == relations.end()) + relations[access.memref] = std::move(accessRel); + else + relations[access.memref].append(accessRel); + return WalkResult::advance(); + }); + if (result.wasInterrupted()) + return LogicalResult::failure(); + // TODO: The relation should be cleaned from clear duplicates to improve + // performance. Cannot use removeTrivialRedundancies. It removes too + // much. This would not be needed if unionBoundingBox would work. + LLVM_DEBUG({ + for (auto relation : relations) { + llvm::dbgs() << "Full Relation for memeref " << relation.first << "\n"; + relation.second.print(llvm::dbgs()); + } + }); + // Collect the lower and upper bounds of the access relations. + // These are what this analysis calls the "bound constraints". + for (auto relation : relations) { + assert(relation.second.getNumSymbolVars() == band.size()); + assert(relation.second.getNumDimVars() == 1); + boundConstraints[relation.first] = relation.second.getLowerAndUpperBound( + 0, 0, 1, 1, {}, band.front()->getContext()); + } + LLVM_DEBUG({ + for (auto mapPair : boundConstraints) { + llvm::dbgs() + << "Found the following bounds on the following memref's accesses: " + << mapPair.first << "\n"; + llvm::dbgs() << "\tLower bound map: " << mapPair.second.first << "\n"; + llvm::dbgs() << "\tUpper bound map: " << mapPair.second.second << "\n"; + } + }); + return LogicalResult::success(); +} + +/// Get the width of the data copy that would be generated from the given tile +/// sizes based on the precalculated bounds. +static int64_t getCopyWidth(SmallVectorImpl *tileSizes, + std::pair &bounds) { + // AffineMaps have a method to fold constant values. Those vales must be + // Attributes. Convert the tile sizes to IntegerAttributes. + SmallVector tileSizeAttrs(tileSizes->size()); + Type intTy = IntegerType::get(bounds.first.getContext(), 64); + for (size_t i = 0; i < tileSizes->size(); i++) + tileSizeAttrs[i] = IntegerAttr::get(intTy, (*tileSizes)[i]); + // Fold the upper and lower bounds. + SmallVector upperBounds; + SmallVector lowerBounds; + LogicalResult res = bounds.second.constantFold(tileSizeAttrs, upperBounds); + assert(res.succeeded()); + res = bounds.first.constantFold(tileSizeAttrs, lowerBounds); + assert(res.succeeded()); + assert(!upperBounds.empty() && !lowerBounds.empty()); + // Find the max upper bound and the min lower bound. + int64_t maxUB = upperBounds[0].cast().getInt(); + int64_t minLB = lowerBounds[0].cast().getInt(); + for (size_t i = 0; i < upperBounds.size(); i++) + maxUB = std::max(maxUB, upperBounds[i].cast().getInt()); + for (size_t i = 0; i < lowerBounds.size(); i++) + minLB = std::min(maxUB, lowerBounds[i].cast().getInt()); + // Return the distance between the max UB and min LB. + return maxUB - minLB; +} + +uint64_t mlir::affine::FixedWidthAffineTilingAnalysis::getFootprintOfCopies( + SmallVectorImpl *tileSizes) { + assert(!boundConstraints.empty()); + // Accumulate the footprint from each unique MemRef in the loop. + uint64_t footprint = 0; + for (auto mapPair : boundConstraints) { + int64_t copyWidth = getCopyWidth(tileSizes, mapPair.second); + assert(copyWidth > 0); + // Check that the copy width is a multiple of the chunk size. + auto memrefTy = mapPair.first.getType().cast(); + uint64_t copyWidthBytes = + copyWidth * (memrefTy.getElementTypeBitWidth() / 8); + if (copyWidthBytes % chunkSizeBytes != 0) + return 0; + // Add its width to the total footprint. + footprint += copyWidthBytes; + } + return footprint; +} + +void mlir::affine::FixedWidthAffineTilingAnalysis::exhaustiveSearchForTileSizes( + SmallVectorImpl *tileSizes, unsigned id) { + // Base condition + if (id == band.size()) { + uint64_t footprint = getFootprintOfCopies(tileSizes); + // If footprint is 0, this signifies that the tile size was not legal. + if (footprint != 0) { + LLVM_DEBUG({ + llvm::dbgs() << "Found legal tile size: "; + for (size_t i = 0; i < tileSizes->size(); i++) { + llvm::dbgs() << (*tileSizes)[i] << " "; + } + llvm::dbgs() << "\n"; + llvm::dbgs() << "\tFootprint: " << footprint << "\n"; + }); + // If the footprint fits in the cache and is better than the last + // best footprint, this is the best tile sizes found so far. + if (footprint <= cacheSizeBytes && footprint > bestFootprint) { + LLVM_DEBUG(llvm::dbgs() << "\t*New best footprint*\n"); + bestFootprint = footprint; + std::copy(tileSizes->begin(), tileSizes->end(), bestTileSizes.begin()); + foundLegalTileSizes = true; + } + } + return; + } + + // Iterate over the possible tile sizes. To avoid max and min bounds, only + // tile sizes that cleanly divide the trip count of the loop are considered. + std::optional tripCount = getConstantTripCount(band[id]); + assert(tripCount.has_value() && "Trip count expected to be constant."); + for (size_t i = 1; i <= tripCount.value(); i++) { + if (tripCount.value() % i != 0) + continue; + (*tileSizes)[id] = i; + exhaustiveSearchForTileSizes(tileSizes, id + 1); + } +} + +LogicalResult mlir::affine::FixedWidthAffineTilingAnalysis::getTileSizes( + SmallVectorImpl *tileSizes) { + // TODO: Before doing anything, check to ensure that all the assumptions + // from the paper are met. Failing if they are not. meetsAssumptions() + // Populate the bounding constraints + if (failed(getBoundingConstraints())) + return LogicalResult::failure(); + // Exhaustively search for the ideal tile sizes based on the constraints. + SmallVector tmpTileSizes(band.size(), 1); + exhaustiveSearchForTileSizes(&tmpTileSizes, 0); + if (!foundLegalTileSizes) + return LogicalResult::failure(); + // Copy the best tile sizes found into the given tileSizes. + tileSizes->resize(bestTileSizes.size()); + std::copy(bestTileSizes.begin(), bestTileSizes.end(), tileSizes->begin()); + return LogicalResult::success(); +} + +mlir::affine::FixedWidthAffineTilingAnalysis::FixedWidthAffineTilingAnalysis( + ArrayRef band, uint64_t chunkSizeBytes, + uint64_t cacheSizeBytes) + : band(band), chunkSizeBytes(chunkSizeBytes), + cacheSizeBytes(cacheSizeBytes), bestFootprint(0), + foundLegalTileSizes(false) { + bestTileSizes.resize(band.size()); +} diff --git a/mlir/test/Dialect/Affine/loop-tiling-fixed-width.mlir b/mlir/test/Dialect/Affine/loop-tiling-fixed-width.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Affine/loop-tiling-fixed-width.mlir @@ -0,0 +1,105 @@ +// RUN: mlir-opt %s -split-input-file -affine-loop-tile="cache-size=256 fp-chunk-divisor=32" | FileCheck %s + +// CHECK-LABEL: func @simple +func.func @simple(%A : memref<256xf32>, %C : memref<232xf32>) { + // CHECK: affine.for %{{arg[0-9]+}} = 8 to 240 { + affine.for %i = 8 to 240 step 1 { + %0 = affine.load %A[%i - 8] : memref<256xf32> + %1 = affine.load %A[%i + 2] : memref<256xf32> + %2 = affine.load %A[%i + 16] : memref<256xf32> + %3 = arith.addf %0, %1 : f32 + %4 = arith.addf %2, %3 : f32 + affine.store %4, %C[%i] : memref<232xf32> + } + return +} + +// ----- + +// CHECK-LABEL: func @complex +func.func @complex(%A : memref<7384xf32, 1>, %B : memref<404416xf32, 1>, %C : memref<9256xf32, 1>) { + // CHECK: step 13 + affine.for %i = 0 to 13 step 1 { + // CHECK-NEXT: step 712 + affine.for %j = 0 to 712 step 1 { + // CHECK-NEXT: step 8 + affine.for %k = 0 to 568 step 1 { + %0 = affine.load %C[712 * %i + %j] : memref<9256xf32, 1> + %1 = affine.load %A[568 * %i + %k] : memref<7384xf32, 1> + %2 = affine.load %B[712 * %k + %j] : memref<404416xf32, 1> + %3 = arith.addf %0, %1 : f32 + %4 = arith.addf %2, %3 : f32 + affine.store %4, %C[712 * %i + %j] : memref<9256xf32, 1> + } + } + } + return +} + +// ----- + +// CHECK-LABEL: func @complex2 +func.func @complex2(%arg0: memref<64x64xf32, 1>, %arg1: memref<32x5x5xf32, 1>, %arg2: memref<32xf32, 1>, %arg3: memref<32x60x60xf32, 1>) { + // CHECK: step 16 + affine.for %arg6 = 0 to 32 { + // CHECK: step 60 + affine.for %arg7 = 0 to 60 { + // CHECK: step 60 + affine.for %arg8 = 0 to 60 { + // CHECK: step 5 + affine.for %arg9 = 0 to 5 { + // CHECK: step 5 + affine.for %arg10 = 0 to 5 { + %0 = affine.load %arg3[%arg6, %arg7, %arg8] : memref<32x60x60xf32, 1> + %1 = affine.load %arg0[%arg9 + %arg7, %arg10 + %arg8] : memref<64x64xf32, 1> + %2 = affine.load %arg1[%arg6, %arg9, %arg10] : memref<32x5x5xf32, 1> + %3 = arith.mulf %1, %2 : f32 + %4 = arith.addf %0, %3 : f32 + affine.store %4, %arg3[%arg6, %arg7, %arg8] : memref<32x60x60xf32, 1> + } + } + } + } + } + // CHECK: step 16 + affine.for %arg6 = 0 to 32 { + // CHECK: step 60 + affine.for %arg7 = 0 to 60 { + // CHECK: step 60 + affine.for %arg8 = 0 to 60 { + %0 = affine.load %arg2[%arg6] : memref<32xf32, 1> + %1 = affine.load %arg3[%arg6, %arg7, %arg8] : memref<32x60x60xf32, 1> + %2 = arith.maxf %0, %1 : f32 + affine.store %2, %arg2[%arg6] : memref<32xf32, 1> + } + } + } + return +} + +// ----- + +// CHECK-LABEL: func @fits +func.func @fits(%arg0: memref<32x64xf32, 1>, %arg1: memref<32xf32, 1>, %arg2: memref<64xf32, 1>) { + %cst = arith.constant 0.000000e+00 : f32 + // CHECK: affine.for %{{arg[0-9]+}} = 0 to 64 { + affine.for %arg3 = 0 to 64 { + // CHECK: affine.for %{{arg[0-9]+}} = 0 to 32 { + affine.for %arg4 = 0 to 32 { + %0 = affine.load %arg2[%arg3] : memref<64xf32, 1> + %1 = affine.load %arg1[%arg4] : memref<32xf32, 1> + %2 = affine.load %arg0[%arg4, %arg3] : memref<32x64xf32, 1> + %3 = arith.addf %1, %2 : f32 + %4 = arith.addf %0, %3 : f32 + affine.store %4, %arg2[%arg3] : memref<64xf32, 1> + } + } + // CHECK: affine.for %{{arg[0-9]+}} = 0 to 64 { + affine.for %arg3 = 0 to 64 { + %0 = affine.load %arg2[%arg3] : memref<64xf32, 1> + %1 = arith.cmpf olt, %0, %cst : f32 + %2 = arith.select %1, %cst, %0 : f32 + affine.store %2, %arg2[%arg3] : memref<64xf32, 1> + } + return +} \ No newline at end of file