diff --git a/mlir/include/mlir/Analysis/AffineAnalysis.h b/mlir/include/mlir/Analysis/AffineAnalysis.h --- a/mlir/include/mlir/Analysis/AffineAnalysis.h +++ b/mlir/include/mlir/Analysis/AffineAnalysis.h @@ -42,6 +42,11 @@ LogicalResult getIndexSet(MutableArrayRef forOps, FlatAffineConstraints *domain); +/// A generic version of getIndexSet that handles both AffineForOp and +/// AffineIfOp operations. +LogicalResult getIndexSet(MutableArrayRef ops, + FlatAffineConstraints *domain); + /// Encapsulates a memref load or store access information. struct MemRefAccess { Value memref; diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -21,6 +21,7 @@ class AffineCondition; class AffineForOp; +class AffineIfOp; class AffineMap; class AffineValueMap; class IntegerSet; @@ -215,6 +216,14 @@ // TODO: add support for non-unit strides. LogicalResult addAffineForOpDomain(AffineForOp forOp); + /// Adds constraints imposed by the `affine.if` operation. + /// These constraints are collected from the IntegerSet attached to + /// the given `affine.if` instance argument (ifOp). + /// Since an IntegerSet doesn't bind IDs (variables) in its constraints to + /// specific SSA values, this function helps with that by setting the given + /// ifOp's operands to corresponding IDs. + LogicalResult addAffineIfOpDomain(AffineIfOp ifOp); + /// Adds a lower or an upper bound for the identifier at the specified /// position with constraints being drawn from the specified bound map and /// operands. If `eq` is true, add a single equality equal to the bound map's diff --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h --- a/mlir/include/mlir/Analysis/Utils.h +++ b/mlir/include/mlir/Analysis/Utils.h @@ -39,6 +39,9 @@ // TODO: handle 'affine.if' ops. void getLoopIVs(Operation &op, SmallVectorImpl *loops); +/// Get IVs from both `affine.for` and `affine.if`. +void getIVs(Operation &op, SmallVectorImpl *ops); + /// Returns the nesting depth of this operation, i.e., the number of loops /// surrounding this operation. unsigned getNestingDepth(Operation *op); diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -99,6 +99,32 @@ return success(); } +LogicalResult mlir::getIndexSet(MutableArrayRef ops, + FlatAffineConstraints *domain) { + SmallVector indices; + SmallVector forOps; + + for (auto *op : ops) { + if (auto forOp = dyn_cast(op)) { + forOps.push_back(forOp); + } + } + extractForInductionVars(forOps, &indices); + // Reset while associated Values in 'indices' to the domain. + domain->reset(forOps.size(), /*numSymbols=*/0, /*numLocals=*/0, indices); + for (auto *op : ops) { + // Add constraints from forOp's bounds. + if (auto forOp = dyn_cast(op)) { + if (failed(domain->addAffineForOpDomain(forOp))) + return failure(); + } else if (auto ifOp = dyn_cast(op)) { + if (failed(domain->addAffineIfOpDomain(ifOp))) + return failure(); + } + } + return success(); +} + // Computes the iteration domain for 'opInst' and populates 'indexSet', which // encapsulates the constraints involving loops surrounding 'opInst' and // potentially involving any Function symbols. The dimensional identifiers in @@ -109,9 +135,9 @@ FlatAffineConstraints *indexSet) { // TODO: Extend this to gather enclosing IfInsts and consider // factoring it out into a utility function. - SmallVector loops; - getLoopIVs(*op, &loops); - return getIndexSet(loops, indexSet); + SmallVector ops; + getIVs(*op, &ops); + return getIndexSet(ops, indexSet); } namespace { @@ -209,32 +235,62 @@ const FlatAffineConstraints &dstDomain, const AffineValueMap &srcAccessMap, const AffineValueMap &dstAccessMap, ValuePositionMap *valuePosMap, FlatAffineConstraints *dependenceConstraints) { - auto updateValuePosMap = [&](ArrayRef values, bool isSrc) { + + // IsDimState is a tri-state boolean. It is used to distinguish three + // different cases of the values passed to updateValuePosMap. + // - When it is TRUE, we are certain that all values are dim values. + // - When it is FALSE, we are certain that all values are symbol values. + // - When it is UNKNOWN, we need to further check whether the value is from a + // loop IV to determine its type (dim or symbol). + enum IsDimState { TRUE, FALSE, UNKNOWN }; + + // We need this simply because we cannot determine whether a value that is + // passed to the IntergerSet of `affine.if` is a symbol or not, by the + // information that the value itself holds. + // In another word, srcDomain and dstDomain know whether a Value is dim or + // symbol, so we have to pass this information from the outside of + // updateValuePositionMap. + + auto updateValuePosMap = [&](ArrayRef values, bool isSrc, + IsDimState isDim) { for (unsigned i = 0, e = values.size(); i < e; ++i) { auto value = values[i]; - if (!isForInductionVar(values[i])) { - assert(isValidSymbol(values[i]) && + if (isDim == FALSE || (isDim == UNKNOWN && !isForInductionVar(value))) { + assert(isValidSymbol(value) && "access operand has to be either a loop IV or a symbol"); valuePosMap->addSymbolValue(value); - } else if (isSrc) { - valuePosMap->addSrcValue(value); } else { - valuePosMap->addDstValue(value); + if (isSrc) { + valuePosMap->addSrcValue(value); + } else { + valuePosMap->addDstValue(value); + } } } }; - SmallVector srcValues, destValues; - srcDomain.getIdValues(0, srcDomain.getNumDimAndSymbolIds(), &srcValues); - dstDomain.getIdValues(0, dstDomain.getNumDimAndSymbolIds(), &destValues); - // Update value position map with identifiers from src iteration domain. - updateValuePosMap(srcValues, /*isSrc=*/true); - // Update value position map with identifiers from dst iteration domain. - updateValuePosMap(destValues, /*isSrc=*/false); + SmallVector srcDimValues, dstDimValues, srcSymbolValues, + dstSymbolValues; + srcDomain.getIdValues(0, srcDomain.getNumDimIds(), &srcDimValues); + dstDomain.getIdValues(0, dstDomain.getNumDimIds(), &dstDimValues); + srcDomain.getIdValues(srcDomain.getNumDimIds(), + srcDomain.getNumDimAndSymbolIds(), &srcSymbolValues); + dstDomain.getIdValues(dstDomain.getNumDimIds(), + dstDomain.getNumDimAndSymbolIds(), &dstSymbolValues); + // Update value position map with dim values from src iteration domain. + updateValuePosMap(srcDimValues, /*isSrc=*/true, /*isDim=*/TRUE); + // Update value position map with dim values from dst iteration domain. + updateValuePosMap(dstDimValues, /*isSrc=*/false, /*isDim=*/TRUE); + // Update value position map with symbols from src iteration domain. + updateValuePosMap(srcSymbolValues, /*isSrc=*/true, /*isDim=*/FALSE); + // Update value position map with symbols from dst iteration domain. + updateValuePosMap(dstSymbolValues, /*isSrc=*/false, /*isDim=*/FALSE); // Update value position map with identifiers from src access function. - updateValuePosMap(srcAccessMap.getOperands(), /*isSrc=*/true); + updateValuePosMap(srcAccessMap.getOperands(), /*isSrc=*/true, + /*isDim=*/UNKNOWN); // Update value position map with identifiers from dst access function. - updateValuePosMap(dstAccessMap.getOperands(), /*isSrc=*/false); + updateValuePosMap(dstAccessMap.getOperands(), /*isSrc=*/false, + /*isDim=*/UNKNOWN); } // Sets up dependence constraints columns appropriately, in the format: @@ -271,17 +327,17 @@ srcLoopIVs.size(), srcLoopIVs.size() + dstLoopIVs.size(), dstLoopIVs); // Set values for the symbolic identifier dimensions. - auto setSymbolIds = [&](ArrayRef values) { + auto setSymbolIds = [&](ArrayRef values, + bool isSymbolDetermined = true) { for (auto value : values) { - if (!isForInductionVar(value)) { + if (isSymbolDetermined || !isForInductionVar(value)) { assert(isValidSymbol(value) && "expected symbol"); dependenceConstraints->setIdValue(valuePosMap.getSymPos(value), value); } } }; - - setSymbolIds(srcAccessMap.getOperands()); - setSymbolIds(dstAccessMap.getOperands()); + setSymbolIds(srcAccessMap.getOperands(), /*isSymbolDetermined=*/false); + setSymbolIds(dstAccessMap.getOperands(), /*isSymbolDetermined=*/false); SmallVector srcSymbolValues, dstSymbolValues; srcDomain.getIdValues(srcDomain.getNumDimIds(), @@ -814,7 +870,6 @@ buildDimAndSymbolPositionMaps(srcDomain, dstDomain, srcAccessMap, dstAccessMap, &valuePosMap, dependenceConstraints); - initDependenceConstraints(srcDomain, dstDomain, srcAccessMap, dstAccessMap, valuePosMap, dependenceConstraints); diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -723,6 +723,21 @@ /*eq=*/false, /*lower=*/false); } +LogicalResult FlatAffineConstraints::addAffineIfOpDomain(AffineIfOp ifOp) { + // Create the base constraints from the integer set attached to ifOp. + FlatAffineConstraints cst(ifOp.getIntegerSet()); + + // Bind ids in the constraints to ifOp operands. + SmallVector operands = ifOp.getOperands(); + cst.setIdValues(0, cst.getNumIds(), operands); + + // Merge the constraints from ifOp to the current domain. + mergeAndAlignIdsWithOther(0, &cst); + + // Failure will only be raised by assertions. + return success(); +} + // Searches for a constraint with a non-zero coefficient at 'colIdx' in // equality (isEq=true) or inequality (isEq=false) constraints. // Returns true and sets row found in search in 'rowIdx'. diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -44,6 +44,16 @@ std::reverse(loops->begin(), loops->end()); } +void mlir::getIVs(Operation &op, SmallVectorImpl *ops) { + auto *currOp = op.getParentOp(); + + while (currOp && (isa(currOp) || isa(currOp))) { + ops->push_back(currOp); + currOp = currOp->getParentOp(); + } + std::reverse(ops->begin(), ops->end()); +} + // Populates 'cst' with FlatAffineConstraints which represent slice bounds. LogicalResult ComputationSliceState::getAsConstraints(FlatAffineConstraints *cst) {