diff --git a/mlir/include/mlir/Dialect/Affine/Analysis/AffineAnalysis.h b/mlir/include/mlir/Dialect/Affine/Analysis/AffineAnalysis.h --- a/mlir/include/mlir/Dialect/Affine/Analysis/AffineAnalysis.h +++ b/mlir/include/mlir/Dialect/Affine/Analysis/AffineAnalysis.h @@ -167,8 +167,9 @@ DependenceResult checkMemrefAccessDependence( const MemRefAccess &srcAccess, const MemRefAccess &dstAccess, - unsigned loopDepth, FlatAffineValueConstraints *dependenceConstraints, - SmallVector *dependenceComponents, + unsigned loopDepth, + FlatAffineValueConstraints *dependenceConstraints = nullptr, + SmallVector *dependenceComponents = nullptr, bool allowRAR = false); /// Utility function that returns true if the provided DependenceResult diff --git a/mlir/include/mlir/Dialect/Affine/Analysis/AffineStructures.h b/mlir/include/mlir/Dialect/Affine/Analysis/AffineStructures.h --- a/mlir/include/mlir/Dialect/Affine/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Dialect/Affine/Analysis/AffineStructures.h @@ -112,19 +112,6 @@ return cst->getKind() == Kind::FlatAffineValueConstraints; } - /// Clears any existing data and reserves memory for the specified - /// constraints. - void reset(unsigned numReservedInequalities, unsigned numReservedEqualities, - unsigned numReservedCols, unsigned numDims, unsigned numSymbols, - unsigned numLocals = 0); - void reset(unsigned numDims = 0, unsigned numSymbols = 0, - unsigned numLocals = 0); - void reset(unsigned numReservedInequalities, unsigned numReservedEqualities, - unsigned numReservedCols, unsigned numDims, unsigned numSymbols, - unsigned numLocals, ArrayRef valArgs); - void reset(unsigned numDims, unsigned numSymbols, unsigned numLocals, - ArrayRef valArgs); - /// Clones this object. std::unique_ptr clone() const; diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp --- a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp @@ -172,10 +172,8 @@ MemRefAccess srcAccess(srcOp); for (auto *dstOp : loadAndStoreOps) { MemRefAccess dstAccess(dstOp); - FlatAffineValueConstraints dependenceConstraints; - DependenceResult result = checkMemrefAccessDependence( - srcAccess, dstAccess, depth, &dependenceConstraints, - /*dependenceComponents=*/nullptr); + DependenceResult result = + checkMemrefAccessDependence(srcAccess, dstAccess, depth); if (result.value != DependenceResult::NoDependence) return false; } @@ -230,6 +228,12 @@ } } +static SmallVector> +wrapInOptional(ArrayRef values) { + return llvm::to_vector(llvm::map_range( + values, [](Value v) -> std::optional { return v; })); +} + // Builds a system of constraints with dimensional variables corresponding to // the loop IVs of the forOps appearing in that order. Any symbols founds in // the bound operands are added as symbols in the system. Returns failure for @@ -260,7 +264,9 @@ } extractInductionVars(loopOps, indices); // Reset while associating Values in 'indices' to the domain. - domain->reset(numDims, /*numSymbols=*/0, /*numLocals=*/0, indices); + *domain = + FlatAffineValueConstraints(numDims, /*numSymbols=*/0, + /*numLocals=*/0, wrapInOptional(indices)); for (Operation *op : ops) { // Add constraints from forOp's bounds. if (AffineForOp forOp = dyn_cast(op)) { @@ -659,23 +665,27 @@ // memory locations. dstRel.inverse(); dstRel.compose(srcRel); - *dependenceConstraints = dstRel; // Add 'src' happens before 'dst' ordering constraints. - addOrderingConstraints(srcDomain, dstDomain, loopDepth, - dependenceConstraints); + addOrderingConstraints(srcDomain, dstDomain, loopDepth, &dstRel); // Return 'NoDependence' if the solution space is empty: no dependence. - if (dependenceConstraints->isEmpty()) + if (dstRel.isEmpty()) { + if (dependenceConstraints) + *dependenceConstraints = dstRel; return DependenceResult::NoDependence; + } // Compute dependence direction vector and return true. if (dependenceComponents != nullptr) - computeDirectionVector(srcDomain, dstDomain, loopDepth, - dependenceConstraints, dependenceComponents); + computeDirectionVector(srcDomain, dstDomain, loopDepth, &dstRel, + dependenceComponents); LLVM_DEBUG(llvm::dbgs() << "Dependence polyhedron:\n"); - LLVM_DEBUG(dependenceConstraints->dump()); + LLVM_DEBUG(dstRel.dump()); + + if (dependenceConstraints) + *dependenceConstraints = dstRel; return DependenceResult::HasDependence; } @@ -700,12 +710,12 @@ auto *dstOp = loadAndStoreOps[j]; MemRefAccess dstAccess(dstOp); - FlatAffineValueConstraints dependenceConstraints; SmallVector depComps; // TODO: Explore whether it would be profitable to pre-compute and store // deps instead of repeatedly checking. DependenceResult result = checkMemrefAccessDependence( - srcAccess, dstAccess, d, &dependenceConstraints, &depComps); + srcAccess, dstAccess, d, /*dependenceConstraints=*/nullptr, + &depComps); if (hasDependence(result)) depCompsVec->push_back(depComps); } diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp --- a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp @@ -75,7 +75,8 @@ std::vector> *flattenedExprs, FlatAffineValueConstraints *localVarCst) { if (exprs.empty()) { - localVarCst->reset(numDims, numSymbols); + if (localVarCst) + *localVarCst = FlatAffineValueConstraints(numDims, numSymbols); return success(); } @@ -120,7 +121,9 @@ AffineMap map, std::vector> *flattenedExprs, FlatAffineValueConstraints *localVarCst) { if (map.getNumResults() == 0) { - localVarCst->reset(map.getNumDims(), map.getNumSymbols()); + if (localVarCst) + *localVarCst = + FlatAffineValueConstraints(map.getNumDims(), map.getNumSymbols()); return success(); } return ::getFlattenedAffineExprs(map.getResults(), map.getNumDims(), @@ -132,7 +135,9 @@ IntegerSet set, std::vector> *flattenedExprs, FlatAffineValueConstraints *localVarCst) { if (set.getNumConstraints() == 0) { - localVarCst->reset(set.getNumDims(), set.getNumSymbols()); + if (localVarCst) + *localVarCst = + FlatAffineValueConstraints(set.getNumDims(), set.getNumSymbols()); return success(); } return ::getFlattenedAffineExprs(set.getConstraints(), set.getNumDims(), @@ -231,50 +236,6 @@ return res; } -void FlatAffineValueConstraints::reset(unsigned numReservedInequalities, - unsigned numReservedEqualities, - unsigned newNumReservedCols, - unsigned newNumDims, - unsigned newNumSymbols, - unsigned newNumLocals) { - assert(newNumReservedCols >= newNumDims + newNumSymbols + newNumLocals + 1 && - "minimum 1 column"); - *this = FlatAffineValueConstraints(numReservedInequalities, - numReservedEqualities, newNumReservedCols, - newNumDims, newNumSymbols, newNumLocals); -} - -void FlatAffineValueConstraints::reset(unsigned newNumDims, - unsigned newNumSymbols, - unsigned newNumLocals) { - reset(/*numReservedInequalities=*/0, /*numReservedEqualities=*/0, - /*numReservedCols=*/newNumDims + newNumSymbols + newNumLocals + 1, - newNumDims, newNumSymbols, newNumLocals); -} - -void FlatAffineValueConstraints::reset( - unsigned numReservedInequalities, unsigned numReservedEqualities, - unsigned newNumReservedCols, unsigned newNumDims, unsigned newNumSymbols, - unsigned newNumLocals, ArrayRef valArgs) { - assert(newNumReservedCols >= newNumDims + newNumSymbols + newNumLocals + 1 && - "minimum 1 column"); - SmallVector, 8> newVals; - if (!valArgs.empty()) - newVals.assign(valArgs.begin(), valArgs.end()); - - *this = FlatAffineValueConstraints( - numReservedInequalities, numReservedEqualities, newNumReservedCols, - newNumDims, newNumSymbols, newNumLocals, newVals); -} - -void FlatAffineValueConstraints::reset(unsigned newNumDims, - unsigned newNumSymbols, - unsigned newNumLocals, - ArrayRef valArgs) { - reset(0, 0, newNumDims + newNumSymbols + newNumLocals + 1, newNumDims, - newNumSymbols, newNumLocals, valArgs); -} - unsigned FlatAffineValueConstraints::appendDimVar(ValueRange vals) { unsigned pos = getNumDimVars(); return insertVar(VarKind::SetDim, pos, vals); diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp --- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp @@ -60,12 +60,19 @@ std::reverse(ops->begin(), ops->end()); } +static SmallVector> +wrapInOptional(ArrayRef values) { + return llvm::to_vector(llvm::map_range( + values, [](Value v) -> std::optional { return v; })); +} + // Populates 'cst' with FlatAffineValueConstraints which represent original // domain of the loop bounds that define 'ivs'. LogicalResult ComputationSliceState::getSourceAsConstraints(FlatAffineValueConstraints &cst) { assert(!ivs.empty() && "Cannot have a slice without its IVs"); - cst.reset(/*numDims=*/ivs.size(), /*numSymbols=*/0, /*numLocals=*/0, ivs); + cst = FlatAffineValueConstraints(/*numDims=*/ivs.size(), /*numSymbols=*/0, + /*numLocals=*/0, wrapInOptional(ivs)); for (Value iv : ivs) { AffineForOp loop = getForInductionVarOwner(iv); assert(loop && "Expected affine for"); @@ -87,7 +94,8 @@ SmallVector values(ivs); // Append 'ivs' then 'operands' to 'values'. values.append(lbOperands[0].begin(), lbOperands[0].end()); - cst->reset(numDims, numSymbols, 0, values); + *cst = FlatAffineValueConstraints(numDims, numSymbols, 0, + wrapInOptional(values)); // Add loop bound constraints for values which are loop IVs of the destination // of fusion and equality constraints for symbols which are constants. @@ -293,9 +301,10 @@ return isMaximalFastCheck; // Create constraints for the src loop nest being sliced. - FlatAffineValueConstraints srcConstraints; - srcConstraints.reset(/*numDims=*/ivs.size(), /*numSymbols=*/0, - /*numLocals=*/0, ivs); + FlatAffineValueConstraints srcConstraints(/*numDims=*/ivs.size(), + /*numSymbols=*/0, + /*numLocals=*/0, + wrapInOptional(ivs)); for (Value iv : ivs) { AffineForOp loop = getForInductionVarOwner(iv); assert(loop && "Expected affine for"); @@ -305,7 +314,7 @@ // Create constraints for the slice using the dst loop nest information. We // retrieve existing dst loops from the lbOperands. - SmallVector consumerIVs; + SmallVector consumerIVs; for (Value lbOp : lbOperands[0]) if (getForInductionVarOwner(lbOp)) consumerIVs.push_back(lbOp); @@ -315,9 +324,10 @@ for (int i = consumerIVs.size(), end = ivs.size(); i < end; ++i) consumerIVs.push_back(Value()); - FlatAffineValueConstraints sliceConstraints; - sliceConstraints.reset(/*numDims=*/consumerIVs.size(), /*numSymbols=*/0, - /*numLocals=*/0, consumerIVs); + FlatAffineValueConstraints sliceConstraints(/*numDims=*/consumerIVs.size(), + /*numSymbols=*/0, + /*numLocals=*/0, + wrapInOptional(consumerIVs)); if (failed(sliceConstraints.addDomainFromSliceMaps(lbs, ubs, lbOperands[0]))) return std::nullopt; @@ -463,10 +473,11 @@ assert(loopDepth <= ivs.size() && "invalid 'loopDepth'"); // The first 'loopDepth' IVs are symbols for this region. ivs.resize(loopDepth); - SmallVector regionSymbols; + SmallVector regionSymbols; extractForInductionVars(ivs, ®ionSymbols); // A 0-d memref has a 0-d region. - cst.reset(rank, loopDepth, /*numLocals=*/0, regionSymbols); + cst = FlatAffineValueConstraints(rank, loopDepth, /*numLocals=*/0, + wrapInOptional(regionSymbols)); return success(); } @@ -497,7 +508,8 @@ // We'll first associate the dims and symbols of the access map to the dims // and symbols resp. of cst. This will change below once cst is // fully constructed out. - cst.reset(numDims, numSymbols, 0, operands); + cst = FlatAffineValueConstraints(numDims, numSymbols, 0, + wrapInOptional(operands)); // Add equality constraints. // Add inequalities for loop lower/upper bounds. diff --git a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp --- a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp @@ -225,11 +225,9 @@ unsigned numCommonLoops = getNumCommonSurroundingLoops(*srcOpInst, *dstOpInst); for (unsigned d = 1; d <= numCommonLoops + 1; ++d) { - FlatAffineValueConstraints dependenceConstraints; // TODO: Cache dependence analysis results, check cache here. - DependenceResult result = checkMemrefAccessDependence( - srcAccess, dstAccess, d, &dependenceConstraints, - /*dependenceComponents=*/nullptr); + DependenceResult result = + checkMemrefAccessDependence(srcAccess, dstAccess, d); if (hasDependence(result)) { // Store minimum loop depth and break because we want the min 'd' at // which there is a dependence. diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp --- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp @@ -378,7 +378,6 @@ unsigned numOps = loadAndStoreOps.size(); unsigned numLoops = origLoops.size(); - FlatAffineValueConstraints dependenceConstraints; for (unsigned d = 1; d <= numLoops + 1; ++d) { for (unsigned i = 0; i < numOps; ++i) { Operation *srcOp = loadAndStoreOps[i]; @@ -388,9 +387,9 @@ MemRefAccess dstAccess(dstOp); SmallVector depComps; - dependenceConstraints.reset(); DependenceResult result = checkMemrefAccessDependence( - srcAccess, dstAccess, d, &dependenceConstraints, &depComps); + srcAccess, dstAccess, d, /*dependenceConstraints=*/nullptr, + &depComps); // Skip if there is no dependence in this case. if (!hasDependence(result)) @@ -2362,7 +2361,7 @@ ivs.resize(numParamLoopIVs); SmallVector symbols; extractForInductionVars(ivs, &symbols); - regionCst->reset(rank, numParamLoopIVs, 0); + *regionCst = FlatAffineValueConstraints(rank, numParamLoopIVs, 0); regionCst->setValues(rank, rank + numParamLoopIVs, symbols); // Memref dim sizes provide the bounds. diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -660,10 +660,8 @@ unsigned nsLoops = getNumCommonSurroundingLoops(*srcAccess.opInst, *destAccess.opInst); - FlatAffineValueConstraints dependenceConstraints; - DependenceResult result = checkMemrefAccessDependence( - srcAccess, destAccess, nsLoops + 1, &dependenceConstraints, - /*dependenceComponents=*/nullptr); + DependenceResult result = + checkMemrefAccessDependence(srcAccess, destAccess, nsLoops + 1); return hasDependence(result); } diff --git a/mlir/test/lib/Analysis/TestMemRefDependenceCheck.cpp b/mlir/test/lib/Analysis/TestMemRefDependenceCheck.cpp --- a/mlir/test/lib/Analysis/TestMemRefDependenceCheck.cpp +++ b/mlir/test/lib/Analysis/TestMemRefDependenceCheck.cpp @@ -81,10 +81,9 @@ unsigned numCommonLoops = getNumCommonSurroundingLoops(*srcOpInst, *dstOpInst); for (unsigned d = 1; d <= numCommonLoops + 1; ++d) { - FlatAffineValueConstraints dependenceConstraints; SmallVector dependenceComponents; DependenceResult result = checkMemrefAccessDependence( - srcAccess, dstAccess, d, &dependenceConstraints, + srcAccess, dstAccess, d, /*dependenceConstraints=*/nullptr, &dependenceComponents); if (result.value == DependenceResult::Failure) { srcOpInst->emitError("dependence check failed");