diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -613,7 +613,13 @@ The lower and upper bounds of a parallel operation are represented as an application of an affine mapping to a list of SSA values passed to the map. The same restrictions hold for these SSA values as for all bindings of SSA - values to dimensions and symbols. + values to dimensions and symbols. The list of expressions in each map is + interpreted according to the respective bounds group attribute. If a single + expression belongs to the group, then the result of this expression is taken + as a lower(upper) bound of the corresponding loop induction variable. If + multiple expressions belong to the group, then the lower(upper) bound is the + max(min) of these values obtained from these expressions. The loop band has + as many loops as elements in the group bounds attributes. Each value yielded by affine.yield will be accumulated/reduced via one of the reduction methods defined in the AtomicRMWKind enum. The order of @@ -644,12 +650,25 @@ return %O } ``` + + Example (tiling by potentially imperfectly dividing sizes): + + ```mlir + affine.parallel (%ii, %jj) = (0, 0) to (%N, %M) step (32, 32) { + affine.parallel (%i, %j) = (%ii, %jj) + to (min(%ii + 32, %N), min(%jj + 32, %M)) { + call @f(%i, %j) : (index, index) -> () + } + } + ``` }]; let arguments = (ins TypedArrayAttrBase:$reductions, AffineMapAttr:$lowerBoundsMap, + I32ElementsAttr:$lowerBoundsGroups, AffineMapAttr:$upperBoundsMap, + I32ElementsAttr:$upperBoundsGroups, I64ArrayAttr:$steps, Variadic:$mapOperands); let results = (outs Variadic:$results); @@ -659,11 +678,8 @@ OpBuilder<(ins "TypeRange":$resultTypes, "ArrayRef":$reductions, "ArrayRef":$ranges)>, OpBuilder<(ins "TypeRange":$resultTypes, - "ArrayRef":$reductions, "AffineMap":$lbMap, - "ValueRange":$lbArgs, "AffineMap":$ubMap, "ValueRange":$ubArgs)>, - OpBuilder<(ins "TypeRange":$resultTypes, - "ArrayRef":$reductions, "AffineMap":$lbMap, - "ValueRange":$lbArgs, "AffineMap":$ubMap, "ValueRange":$ubArgs, + "ArrayRef":$reductions, "ArrayRef":$lbMaps, + "ValueRange":$lbArgs, "ArrayRef":$ubMaps, "ValueRange":$ubArgs, "ArrayRef":$steps)> ]; @@ -671,8 +687,6 @@ /// Get the number of dimensions. unsigned getNumDims(); - AffineValueMap getRangesValueMap(); - /// Get ranges as constants, may fail in dynamic case. Optional> getConstantRanges(); @@ -682,23 +696,45 @@ return getBody()->getArguments(); } + /// Returns elements of the loop lower bound. + AffineMap getLowerBoundMap(unsigned pos); operand_range getLowerBoundsOperands(); AffineValueMap getLowerBoundsValueMap(); + + /// Sets elements of the loop lower bound. void setLowerBounds(ValueRange operands, AffineMap map); void setLowerBoundsMap(AffineMap map); + /// Returns elements of the loop upper bound. + AffineMap getUpperBoundMap(unsigned pos); operand_range getUpperBoundsOperands(); AffineValueMap getUpperBoundsValueMap(); + + /// Sets elements fo the loop upper bound. void setUpperBounds(ValueRange operands, AffineMap map); void setUpperBoundsMap(AffineMap map); SmallVector getSteps(); void setSteps(ArrayRef newSteps); + /// Returns attribute names to use in op construction. Not expected to be + /// used directly. static StringRef getReductionsAttrName() { return "reductions"; } static StringRef getLowerBoundsMapAttrName() { return "lowerBoundsMap"; } + static StringRef getLowerBoundsGroupsAttrName() { + return "lowerBoundsGroups"; + } static StringRef getUpperBoundsMapAttrName() { return "upperBoundsMap"; } + static StringRef getUpperBoundsGroupsAttrName() { + return "upperBoundsGroups"; + } static StringRef getStepsAttrName() { return "steps"; } + + /// Returns `true` if the loop bounds have min/max expressions. + bool hasMinMaxBounds() { + return lowerBoundsMap().getNumResults() != getNumDims() || + upperBoundsMap().getNumResults() != getNumDims(); + } }]; let hasFolder = 1; diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -262,6 +262,9 @@ /// Returns the map consisting of the `resultPos` subset. AffineMap getSubMap(ArrayRef resultPos) const; + /// Returns the map consisting of `length` expressions starting from `start`. + AffineMap getSliceMap(unsigned start, unsigned length) const; + /// Returns the map consisting of the most major `numResults` results. /// Returns the null AffineMap if `numResults` == 0. /// Returns `*this` if `numResults` >= `this->getNumResults()`. diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -113,6 +113,13 @@ virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr, ValueRange operands) = 0; + /// Prints an affine expression of SSA ids with SSA id names used instead of + /// dims and symbols. + /// Operand values must come from single-result sources, and be valid + /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol. + virtual void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands, + ValueRange symOperands) = 0; + /// Print an optional arrow followed by a type list. template void printOptionalArrowTypeList(TypeRange &&types) { @@ -680,6 +687,14 @@ StringRef attrName, NamedAttrList &attrs, Delimiter delimiter = Delimiter::Square) = 0; + /// Parses an affine expression where dims and symbols are SSA operands. + /// Operand values must come from single-result sources, and be valid + /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol. + virtual ParseResult + parseAffineExprOfSSAIds(SmallVectorImpl &dimOperands, + SmallVectorImpl &symbOperands, + AffineExpr &expr) = 0; + //===--------------------------------------------------------------------===// // Region Parsing //===--------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp --- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp +++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp @@ -423,20 +423,28 @@ SmallVector upperBoundTuple; SmallVector lowerBoundTuple; SmallVector identityVals; - // Finding lower and upper bound by expanding the map expression. - // Checking if expandAffineMap is not giving NULL. - Optional> lowerBound = expandAffineMap( - rewriter, loc, op.lowerBoundsMap(), op.getLowerBoundsOperands()); - Optional> upperBound = expandAffineMap( - rewriter, loc, op.upperBoundsMap(), op.getUpperBoundsOperands()); - if (!lowerBound || !upperBound) - return failure(); - upperBoundTuple = *upperBound; - lowerBoundTuple = *lowerBound; + // Emit IR computing the lower and upper bound by expanding the map + // expression. + lowerBoundTuple.reserve(op.getNumDims()); + upperBoundTuple.reserve(op.getNumDims()); + for (unsigned i = 0, e = op.getNumDims(); i < e; ++i) { + Value lower = lowerAffineMapMax(rewriter, loc, op.getLowerBoundMap(i), + op.getLowerBoundsOperands()); + if (!lower) + return rewriter.notifyMatchFailure(op, "couldn't convert lower bounds"); + lowerBoundTuple.push_back(lower); + + Value upper = lowerAffineMapMin(rewriter, loc, op.getUpperBoundMap(i), + op.getUpperBoundsOperands()); + if (!upper) + return rewriter.notifyMatchFailure(op, "couldn't convert upper bounds"); + upperBoundTuple.push_back(upper); + } steps.reserve(op.steps().size()); for (Attribute step : op.steps()) steps.push_back(rewriter.create( loc, step.cast().getInt())); + // Get the terminator op. Operation *affineParOpTerminator = op.getBody()->getTerminator(); scf::ParallelOp parOp; diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -2604,45 +2604,46 @@ TypeRange resultTypes, ArrayRef reductions, ArrayRef ranges) { - SmallVector lbExprs(ranges.size(), - builder.getAffineConstantExpr(0)); - auto lbMap = AffineMap::get(0, 0, lbExprs, builder.getContext()); - SmallVector ubExprs; - for (int64_t range : ranges) - ubExprs.push_back(builder.getAffineConstantExpr(range)); - auto ubMap = AffineMap::get(0, 0, ubExprs, builder.getContext()); - build(builder, result, resultTypes, reductions, lbMap, /*lbArgs=*/{}, ubMap, - /*ubArgs=*/{}); + SmallVector lbs(ranges.size(), builder.getConstantAffineMap(0)); + auto ubs = llvm::to_vector<4>(llvm::map_range(ranges, [&](int64_t value) { + return builder.getConstantAffineMap(value); + })); + SmallVector steps(ranges.size(), 1); + build(builder, result, resultTypes, reductions, lbs, /*lbArgs=*/{}, ubs, + /*ubArgs=*/{}, steps); } void AffineParallelOp::build(OpBuilder &builder, OperationState &result, TypeRange resultTypes, ArrayRef reductions, - AffineMap lbMap, ValueRange lbArgs, - AffineMap ubMap, ValueRange ubArgs) { - auto numDims = lbMap.getNumResults(); - // Verify that the dimensionality of both maps are the same. - assert(numDims == ubMap.getNumResults() && - "num dims and num results mismatch"); - // Make default step sizes of 1. - SmallVector steps(numDims, 1); - build(builder, result, resultTypes, reductions, lbMap, lbArgs, ubMap, ubArgs, - steps); -} - -void AffineParallelOp::build(OpBuilder &builder, OperationState &result, - TypeRange resultTypes, - ArrayRef reductions, - AffineMap lbMap, ValueRange lbArgs, - AffineMap ubMap, ValueRange ubArgs, + ArrayRef lbMaps, ValueRange lbArgs, + ArrayRef ubMaps, ValueRange ubArgs, ArrayRef steps) { - auto numDims = lbMap.getNumResults(); - // Verify that the dimensionality of the maps matches the number of steps. - assert(numDims == ubMap.getNumResults() && - "num dims and num results mismatch"); - assert(numDims == steps.size() && "num dims and num steps mismatch"); + assert(!lbMaps.empty() && "expected the lower bound map to be non-empty"); + assert(!ubMaps.empty() && "expected the upper bound map to be non-empty"); + assert(llvm::all_of(lbMaps, + [lbMaps](AffineMap m) { + return m.getNumDims() == lbMaps[0].getNumDims() && + m.getNumSymbols() == lbMaps[0].getNumSymbols(); + }) && + "expected all lower bounds maps to have the same number of dimensions " + "and symbols"); + assert(llvm::all_of(ubMaps, + [ubMaps](AffineMap m) { + return m.getNumDims() == ubMaps[0].getNumDims() && + m.getNumSymbols() == ubMaps[0].getNumSymbols(); + }) && + "expected all upper bounds maps to have the same number of dimensions " + "and symbols"); + assert(lbMaps[0].getNumInputs() == lbArgs.size() && + "expected lower bound maps to have as many inputs as lower bound " + "operands"); + assert(ubMaps[0].getNumInputs() == ubArgs.size() && + "expected upper bound maps to have as many inputs as upper bound " + "operands"); result.addTypes(resultTypes); + // Convert the reductions to integer attributes. SmallVector reductionAttrs; for (AtomicRMWKind reduction : reductions) @@ -2650,16 +2651,42 @@ builder.getI64IntegerAttr(static_cast(reduction))); result.addAttribute(getReductionsAttrName(), builder.getArrayAttr(reductionAttrs)); + + // Concatenates maps defined in the same input space (same dimensions and + // symbols), assumes there is at least one map. + auto concatMapsSameInput = [](ArrayRef maps, + SmallVectorImpl &groups) { + SmallVector exprs; + groups.reserve(groups.size() + maps.size()); + exprs.reserve(maps.size()); + for (AffineMap m : maps) { + llvm::append_range(exprs, m.getResults()); + groups.push_back(m.getNumResults()); + } + assert(!maps.empty() && "expected a non-empty list of maps"); + return AffineMap::get(maps[0].getNumDims(), maps[0].getNumSymbols(), exprs, + maps[0].getContext()); + }; + + // Set up the bounds. + SmallVector lbGroups, ubGroups; + AffineMap lbMap = concatMapsSameInput(lbMaps, lbGroups); + AffineMap ubMap = concatMapsSameInput(ubMaps, ubGroups); result.addAttribute(getLowerBoundsMapAttrName(), AffineMapAttr::get(lbMap)); + result.addAttribute(getLowerBoundsGroupsAttrName(), + builder.getI32VectorAttr(lbGroups)); result.addAttribute(getUpperBoundsMapAttrName(), AffineMapAttr::get(ubMap)); + result.addAttribute(getUpperBoundsGroupsAttrName(), + builder.getI32VectorAttr(ubGroups)); result.addAttribute(getStepsAttrName(), builder.getI64ArrayAttr(steps)); result.addOperands(lbArgs); result.addOperands(ubArgs); + // Create a region and a block for the body. auto *bodyRegion = result.addRegion(); auto *body = new Block(); // Add all the block arguments. - for (unsigned i = 0; i < numDims; ++i) + for (unsigned i = 0, e = steps.size(); i < e; ++i) body->addArgument(IndexType::get(builder.getContext())); bodyRegion->push_back(body); if (resultTypes.empty()) @@ -2688,6 +2715,22 @@ return getOperands().drop_front(lowerBoundsMap().getNumInputs()); } +AffineMap AffineParallelOp::getLowerBoundMap(unsigned pos) { + unsigned start = 0; + for (unsigned i = 0; i < pos; ++i) + start += lowerBoundsGroups().getValue(i); + return lowerBoundsMap().getSliceMap( + start, lowerBoundsGroups().getValue(pos)); +} + +AffineMap AffineParallelOp::getUpperBoundMap(unsigned pos) { + unsigned start = 0; + for (unsigned i = 0; i < pos; ++i) + start += upperBoundsGroups().getValue(i); + return upperBoundsMap().getSliceMap( + start, upperBoundsGroups().getValue(pos)); +} + AffineValueMap AffineParallelOp::getLowerBoundsValueMap() { return AffineValueMap(lowerBoundsMap(), getLowerBoundsOperands()); } @@ -2696,17 +2739,15 @@ return AffineValueMap(upperBoundsMap(), getUpperBoundsOperands()); } -AffineValueMap AffineParallelOp::getRangesValueMap() { - AffineValueMap out; - AffineValueMap::difference(getUpperBoundsValueMap(), getLowerBoundsValueMap(), - &out); - return out; -} - Optional> AffineParallelOp::getConstantRanges() { + if (hasMinMaxBounds()) + return llvm::None; + // Try to convert all the ranges to constant expressions. SmallVector out; - AffineValueMap rangesValueMap = getRangesValueMap(); + AffineValueMap rangesValueMap; + AffineValueMap::difference(getUpperBoundsValueMap(), getLowerBoundsValueMap(), + &rangesValueMap); out.reserve(rangesValueMap.getNumResults()); for (unsigned i = 0, e = rangesValueMap.getNumResults(); i < e; ++i) { auto expr = rangesValueMap.getResult(i); @@ -2780,12 +2821,32 @@ static LogicalResult verify(AffineParallelOp op) { auto numDims = op.getNumDims(); - if (op.lowerBoundsMap().getNumResults() != numDims || - op.upperBoundsMap().getNumResults() != numDims || + if (op.lowerBoundsGroups().getNumElements() != numDims || + op.upperBoundsGroups().getNumElements() != numDims || op.steps().size() != numDims || - op.getBody()->getNumArguments() != numDims) - return op.emitOpError("region argument count and num results of upper " - "bounds, lower bounds, and steps must all match"); + op.getBody()->getNumArguments() != numDims) { + return op.emitOpError() + << "the number of region arguments (" + << op.getBody()->getNumArguments() + << ") and the number of map groups for lower (" + << op.lowerBoundsGroups().getNumElements() << ") and upper bound (" + << op.upperBoundsGroups().getNumElements() + << "), and the number of steps (" << op.steps().size() + << ") must all match"; + } + + unsigned expectedNumLBResults = 0; + for (APInt v : op.lowerBoundsGroups()) + expectedNumLBResults += v.getZExtValue(); + if (expectedNumLBResults != op.lowerBoundsMap().getNumResults()) + return op.emitOpError() << "expected lower bounds map to have " + << expectedNumLBResults << " results"; + unsigned expectedNumUBResults = 0; + for (APInt v : op.upperBoundsGroups()) + expectedNumUBResults += v.getZExtValue(); + if (expectedNumUBResults != op.upperBoundsMap().getNumResults()) + return op.emitOpError() << "expected upper bounds map to have " + << expectedNumUBResults << " results"; if (op.reductions().size() != op.getNumResults()) return op.emitOpError("a reduction must be specified for each output"); @@ -2844,13 +2905,44 @@ return canonicalizeLoopBounds(*this); } +/// Prints a lower(upper) bound of an affine parallel loop with max(min) +/// conditions in it. `mapAttr` is a flat list of affine expressions and `group` +/// identifies which of the those expressions form max/min groups. `operands` +/// are the SSA values of dimensions and symbols and `keyword` is either "min" +/// or "max". +static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr, + DenseIntElementsAttr group, ValueRange operands, + StringRef keyword) { + AffineMap map = mapAttr.getValue(); + unsigned numDims = map.getNumDims(); + ValueRange dimOperands = operands.take_front(numDims); + ValueRange symOperands = operands.drop_front(numDims); + unsigned start = 0; + for (llvm::APInt groupSize : group) { + if (start != 0) + p << ", "; + + unsigned size = groupSize.getZExtValue(); + if (size == 1) { + p.printAffineExprOfSSAIds(map.getResult(start), dimOperands, symOperands); + ++start; + } else { + p << keyword << '('; + AffineMap submap = map.getSliceMap(start, size); + p.printAffineMapOfSSAIds(AffineMapAttr::get(submap), operands); + p << ')'; + start += size; + } + } +} + static void print(OpAsmPrinter &p, AffineParallelOp op) { p << op.getOperationName() << " (" << op.getBody()->getArguments() << ") = ("; - p.printAffineMapOfSSAIds(op.lowerBoundsMapAttr(), - op.getLowerBoundsOperands()); + printMinMaxBound(p, op.lowerBoundsMapAttr(), op.lowerBoundsGroupsAttr(), + op.getLowerBoundsOperands(), "max"); p << ") to ("; - p.printAffineMapOfSSAIds(op.upperBoundsMapAttr(), - op.getUpperBoundsOperands()); + printMinMaxBound(p, op.upperBoundsMapAttr(), op.upperBoundsGroupsAttr(), + op.getUpperBoundsOperands(), "min"); p << ')'; SmallVector steps = op.getSteps(); bool elideSteps = llvm::all_of(steps, [](int64_t step) { return step == 1; }); @@ -2875,39 +2967,171 @@ op->getAttrs(), /*elidedAttrs=*/{AffineParallelOp::getReductionsAttrName(), AffineParallelOp::getLowerBoundsMapAttrName(), + AffineParallelOp::getLowerBoundsGroupsAttrName(), AffineParallelOp::getUpperBoundsMapAttrName(), + AffineParallelOp::getUpperBoundsGroupsAttrName(), AffineParallelOp::getStepsAttrName()}); } +/// Given a list of lists of parsed operands, populates `uniqueOperands` with +/// unique operands. Also populates `replacements with affine expressions of +/// `kind` that can be used to update affine maps previously accepting a +/// `operands` to accept `uniqueOperands` instead. +static void deduplicateAndResolveOperands( + OpAsmParser &parser, + ArrayRef> operands, + SmallVectorImpl &uniqueOperands, + SmallVectorImpl &replacements, AffineExprKind kind) { + assert((kind == AffineExprKind::DimId || kind == AffineExprKind::SymbolId) && + "expected operands to be dim or symbol expression"); + + Type indexType = parser.getBuilder().getIndexType(); + for (const auto &list : operands) { + SmallVector valueOperands; + parser.resolveOperands(list, indexType, valueOperands); + for (Value operand : valueOperands) { + unsigned pos = std::distance(uniqueOperands.begin(), + llvm::find(uniqueOperands, operand)); + if (pos == uniqueOperands.size()) + uniqueOperands.push_back(operand); + replacements.push_back( + kind == AffineExprKind::DimId + ? getAffineDimExpr(pos, parser.getBuilder().getContext()) + : getAffineSymbolExpr(pos, parser.getBuilder().getContext())); + } + } +} + +namespace { +enum class MinMaxKind { Min, Max }; +} // namespace + +/// Parses an affine map that can contain a min/max for groups of its results, +/// e.g., max(expr-1, expr-2), expr-3, max(expr-4, expr-5, expr-6). Populates +/// `result` attributes with the map (flat list of expressions) and the grouping +/// (list of integers that specify how many expressions to put into each +/// min/max) attributes. Deduplicates repeated operands. +/// +/// parallel-bound ::= `(` parallel-group-list `)` +/// parallel-group-list ::= parallel-group (`,` parallel-group-list)? +/// parallel-group ::= simple-group | min-max-group +/// simple-group ::= expr-of-ssa-ids +/// min-max-group ::= ( `min` | `max` ) `(` expr-of-ssa-ids-list `)` +/// expr-of-ssa-ids-list ::= expr-of-ssa-ids (`,` expr-of-ssa-id-list)? +/// +/// Examples: +/// (%0, min(%1 + %2, %3), %4, min(%5 floordiv 32, %6)) +/// (%0, max(%1 - 2 * %2)) +static ParseResult parseAffineMapWithMinMax(OpAsmParser &parser, + OperationState &result, + MinMaxKind kind) { + constexpr llvm::StringLiteral tmpAttrName = "__pseudo_bound_map"; + + StringRef mapName = kind == MinMaxKind::Min + ? AffineParallelOp::getUpperBoundsMapAttrName() + : AffineParallelOp::getLowerBoundsMapAttrName(); + StringRef groupsName = kind == MinMaxKind::Min + ? AffineParallelOp::getUpperBoundsGroupsAttrName() + : AffineParallelOp::getLowerBoundsGroupsAttrName(); + + if (failed(parser.parseLParen())) + return failure(); + + if (succeeded(parser.parseOptionalRParen())) { + result.addAttribute( + mapName, AffineMapAttr::get(parser.getBuilder().getEmptyAffineMap())); + result.addAttribute(groupsName, parser.getBuilder().getI32VectorAttr({})); + return success(); + } + + SmallVector flatExprs; + SmallVector> flatDimOperands; + SmallVector> flatSymOperands; + SmallVector numMapsPerGroup; + SmallVector mapOperands; + do { + if (succeeded(parser.parseOptionalKeyword( + kind == MinMaxKind::Min ? "min" : "max"))) { + mapOperands.clear(); + AffineMapAttr map; + if (failed(parser.parseAffineMapOfSSAIds(mapOperands, map, tmpAttrName, + result.attributes, + OpAsmParser::Delimiter::Paren))) + return failure(); + result.attributes.erase(tmpAttrName); + llvm::append_range(flatExprs, map.getValue().getResults()); + auto operandsRef = llvm::makeArrayRef(mapOperands); + auto dimsRef = operandsRef.take_front(map.getValue().getNumDims()); + SmallVector dims(dimsRef.begin(), + dimsRef.end()); + auto symsRef = operandsRef.drop_front(map.getValue().getNumDims()); + SmallVector syms(symsRef.begin(), + symsRef.end()); + flatDimOperands.append(map.getValue().getNumResults(), dims); + flatSymOperands.append(map.getValue().getNumResults(), syms); + numMapsPerGroup.push_back(map.getValue().getNumResults()); + } else { + if (failed(parser.parseAffineExprOfSSAIds(flatDimOperands.emplace_back(), + flatSymOperands.emplace_back(), + flatExprs.emplace_back()))) + return failure(); + numMapsPerGroup.push_back(1); + } + } while (succeeded(parser.parseOptionalComma())); + + if (failed(parser.parseRParen())) + return failure(); + + unsigned totalNumDims = 0; + unsigned totalNumSyms = 0; + for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) { + unsigned numDims = flatDimOperands[i].size(); + unsigned numSyms = flatSymOperands[i].size(); + flatExprs[i] = flatExprs[i] + .shiftDims(numDims, totalNumDims) + .shiftSymbols(numSyms, totalNumSyms); + totalNumDims += numDims; + totalNumSyms += numSyms; + } + + // Deduplicate map operands. + SmallVector dimOperands, symOperands; + SmallVector dimRplacements, symRepacements; + deduplicateAndResolveOperands(parser, flatDimOperands, dimOperands, + dimRplacements, AffineExprKind::DimId); + deduplicateAndResolveOperands(parser, flatSymOperands, symOperands, + symRepacements, AffineExprKind::SymbolId); + + result.operands.append(dimOperands.begin(), dimOperands.end()); + result.operands.append(symOperands.begin(), symOperands.end()); + + Builder &builder = parser.getBuilder(); + auto flatMap = AffineMap::get(totalNumDims, totalNumSyms, flatExprs, + parser.getBuilder().getContext()); + flatMap = flatMap.replaceDimsAndSymbols( + dimRplacements, symRepacements, dimOperands.size(), symOperands.size()); + + result.addAttribute(mapName, AffineMapAttr::get(flatMap)); + result.addAttribute(groupsName, builder.getI32VectorAttr(numMapsPerGroup)); + return success(); +} + // -// operation ::= `affine.parallel` `(` ssa-ids `)` `=` `(` map-of-ssa-ids `)` -// `to` `(` map-of-ssa-ids `)` steps? region attr-dict? +// operation ::= `affine.parallel` `(` ssa-ids `)` `=` parallel-bound +// `to` parallel-bound steps? region attr-dict? // steps ::= `steps` `(` integer-literals `)` // static ParseResult parseAffineParallelOp(OpAsmParser &parser, OperationState &result) { auto &builder = parser.getBuilder(); auto indexType = builder.getIndexType(); - AffineMapAttr lowerBoundsAttr, upperBoundsAttr; SmallVector ivs; - SmallVector lowerBoundsMapOperands; - SmallVector upperBoundsMapOperands; if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1, OpAsmParser::Delimiter::Paren) || parser.parseEqual() || - parser.parseAffineMapOfSSAIds( - lowerBoundsMapOperands, lowerBoundsAttr, - AffineParallelOp::getLowerBoundsMapAttrName(), result.attributes, - OpAsmParser::Delimiter::Paren) || - parser.resolveOperands(lowerBoundsMapOperands, indexType, - result.operands) || + parseAffineMapWithMinMax(parser, result, MinMaxKind::Max) || parser.parseKeyword("to") || - parser.parseAffineMapOfSSAIds( - upperBoundsMapOperands, upperBoundsAttr, - AffineParallelOp::getUpperBoundsMapAttrName(), result.attributes, - OpAsmParser::Delimiter::Paren) || - parser.resolveOperands(upperBoundsMapOperands, indexType, - result.operands)) + parseAffineMapWithMinMax(parser, result, MinMaxKind::Min)) return failure(); AffineMapAttr stepsMapAttr; diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineLoopNormalize.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineLoopNormalize.cpp --- a/mlir/lib/Dialect/Affine/Transforms/AffineLoopNormalize.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineLoopNormalize.cpp @@ -21,6 +21,10 @@ using namespace mlir; void mlir::normalizeAffineParallel(AffineParallelOp op) { + // Loops with min/max in bounds are not normalized at the moment. + if (op.hasMinMaxBounds()) + return; + AffineMap lbMap = op.lowerBoundsMap(); SmallVector steps = op.getSteps(); // No need to do any work if the parallel op is already normalized. @@ -34,7 +38,9 @@ if (isAlreadyNormalized) return; - AffineValueMap ranges = op.getRangesValueMap(); + AffineValueMap ranges; + AffineValueMap::difference(op.getUpperBoundsValueMap(), + op.getLowerBoundsValueMap(), &ranges); auto builder = OpBuilder::atBlockBegin(op.getBody()); auto zeroExpr = builder.getAffineConstantExpr(0); SmallVector lbExprs; diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -145,47 +145,21 @@ Location loc = forOp.getLoc(); OpBuilder outsideBuilder(forOp); - - // If a loop has a 'max' in the lower bound, emit it outside the parallel loop - // as it does not have implicit 'max' behavior. AffineMap lowerBoundMap = forOp.getLowerBoundMap(); ValueRange lowerBoundOperands = forOp.getLowerBoundOperands(); AffineMap upperBoundMap = forOp.getUpperBoundMap(); ValueRange upperBoundOperands = forOp.getUpperBoundOperands(); - bool needsMax = lowerBoundMap.getNumResults() > 1; - bool needsMin = upperBoundMap.getNumResults() > 1; - AffineMap identityMap; - if (needsMax || needsMin) { - if (forOp->getParentOp() && - !forOp->getParentOp()->hasTrait()) - return failure(); - - identityMap = AffineMap::getMultiDimIdentityMap(1, loc->getContext()); - } - if (needsMax) { - auto maxOp = outsideBuilder.create(loc, lowerBoundMap, - lowerBoundOperands); - lowerBoundMap = identityMap; - lowerBoundOperands = maxOp->getResults(); - } - - // Same for the upper bound. - if (needsMin) { - auto minOp = outsideBuilder.create(loc, upperBoundMap, - upperBoundOperands); - upperBoundMap = identityMap; - upperBoundOperands = minOp->getResults(); - } - // Creating empty 1-D affine.parallel op. auto reducedValues = llvm::to_vector<4>(llvm::map_range( parallelReductions, [](const LoopReduction &red) { return red.value; })); auto reductionKinds = llvm::to_vector<4>(llvm::map_range( parallelReductions, [](const LoopReduction &red) { return red.kind; })); AffineParallelOp newPloop = outsideBuilder.create( - loc, ValueRange(reducedValues).getTypes(), reductionKinds, lowerBoundMap, - lowerBoundOperands, upperBoundMap, upperBoundOperands); + loc, ValueRange(reducedValues).getTypes(), reductionKinds, + llvm::makeArrayRef(lowerBoundMap), lowerBoundOperands, + llvm::makeArrayRef(upperBoundMap), upperBoundOperands, + llvm::makeArrayRef(forOp.getStep())); // Steal the body of the old affine for op. newPloop.region().takeBody(forOp.region()); Operation *yieldOp = &newPloop.getBody()->back(); diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -494,6 +494,11 @@ return AffineMap::get(getNumDims(), getNumSymbols(), exprs, getContext()); } +AffineMap AffineMap::getSliceMap(unsigned start, unsigned length) const { + return AffineMap::get(getNumDims(), getNumSymbols(), + getResults().slice(start, length), getContext()); +} + AffineMap AffineMap::getMajorSubMap(unsigned numResults) const { if (numResults == 0) return AffineMap(); diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -469,6 +469,7 @@ /// The following are hooks of `OpAsmPrinter` that are not necessary for /// determining potential aliases. void printAffineMapOfSSAIds(AffineMapAttr, ValueRange) override {} + void printAffineExprOfSSAIds(AffineExpr, ValueRange, ValueRange) override {} void printNewline() override {} void printOperand(Value) override {} void printOperand(Value, raw_ostream &os) override { @@ -2351,6 +2352,11 @@ void printAffineMapOfSSAIds(AffineMapAttr mapAttr, ValueRange operands) override; + /// Print the given affine expression with the symbol and dimension operands + /// printed inline with the expression. + void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands, + ValueRange symOperands) override; + /// Print the given string as a symbol reference. void printSymbolName(StringRef symbolRef) override { ::printSymbolReference(symbolRef, os); @@ -2590,6 +2596,19 @@ }); } +void OperationPrinter::printAffineExprOfSSAIds(AffineExpr expr, + ValueRange dimOperands, + ValueRange symOperands) { + auto printValueName = [&](unsigned pos, bool isSymbol) { + if (!isSymbol) + return printValueID(dimOperands[pos]); + os << "symbol("; + printValueID(symOperands[pos]); + os << ')'; + }; + printAffineExpr(expr, printValueName); +} + //===----------------------------------------------------------------------===// // print and dump methods //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Parser/AffineParser.cpp b/mlir/lib/Parser/AffineParser.cpp --- a/mlir/lib/Parser/AffineParser.cpp +++ b/mlir/lib/Parser/AffineParser.cpp @@ -55,6 +55,7 @@ IntegerSet parseIntegerSetConstraints(unsigned numDims, unsigned numSymbols); ParseResult parseAffineMapOfSSAIds(AffineMap &map, OpAsmParser::Delimiter delimiter); + ParseResult parseAffineExprOfSSAIds(AffineExpr &expr); void getDimsAndSymbolSSAIds(SmallVectorImpl &dimAndSymbolSSAIds, unsigned &numDims); @@ -579,6 +580,12 @@ return success(); } +/// Parse an AffineExpr where the dim and symbol identifiers are SSA ids. +ParseResult AffineParser::parseAffineExprOfSSAIds(AffineExpr &expr) { + expr = parseAffineExpr(); + return success(expr != nullptr); +} + /// Parse the range and sizes affine map definition inline. /// /// affine-map ::= dim-and-symbol-id-lists `->` multi-dim-affine-expr @@ -724,3 +731,12 @@ return AffineParser(state, /*allowParsingSSAIds=*/true, parseElement) .parseAffineMapOfSSAIds(map, delimiter); } + +/// Parse an AffineExpr of SSA ids. The callback `parseElement` is used to parse +/// SSA value uses encountered while parsing. +ParseResult +Parser::parseAffineExprOfSSAIds(AffineExpr &expr, + function_ref parseElement) { + return AffineParser(state, /*allowParsingSSAIds=*/true, parseElement) + .parseAffineExprOfSSAIds(expr); +} diff --git a/mlir/lib/Parser/Parser.h b/mlir/lib/Parser/Parser.h --- a/mlir/lib/Parser/Parser.h +++ b/mlir/lib/Parser/Parser.h @@ -268,6 +268,11 @@ function_ref parseElement, OpAsmParser::Delimiter delimiter); + /// Parse an AffineExpr where dim and symbol identifiers are SSA ids. + ParseResult + parseAffineExprOfSSAIds(AffineExpr &expr, + function_ref parseElement); + protected: /// The Parser is subclassed and reinstantiated. Do not add additional /// non-trivial state here, add it to the ParserState class. diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -1513,6 +1513,25 @@ return success(); } + /// Parse an AffineExpr of SSA ids. + ParseResult + parseAffineExprOfSSAIds(SmallVectorImpl &dimOperands, + SmallVectorImpl &symbOperands, + AffineExpr &expr) override { + auto parseElement = [&](bool isSymbol) -> ParseResult { + OperandType operand; + if (parseOperand(operand)) + return failure(); + if (isSymbol) + symbOperands.push_back(operand); + else + dimOperands.push_back(operand); + return success(); + }; + + return parser.parseAffineExprOfSSAIds(expr, parseElement); + } + //===--------------------------------------------------------------------===// // Region Parsing //===--------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir --- a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir +++ b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir @@ -740,8 +740,8 @@ } // CHECK-LABEL: func @affine_parallel_simple // CHECK: %[[LOWER_1:.*]] = constant 0 : index -// CHECK-NEXT: %[[LOWER_2:.*]] = constant 0 : index // CHECK-NEXT: %[[UPPER_1:.*]] = constant 2 : index +// CHECK-NEXT: %[[LOWER_2:.*]] = constant 0 : index // CHECK-NEXT: %[[UPPER_2:.*]] = constant 2 : index // CHECK-NEXT: %[[STEP_1:.*]] = constant 1 : index // CHECK-NEXT: %[[STEP_2:.*]] = constant 1 : index @@ -800,8 +800,8 @@ } // CHECK-LABEL: func @affine_parallel_with_reductions // CHECK: %[[LOWER_1:.*]] = constant 0 : index -// CHECK-NEXT: %[[LOWER_2:.*]] = constant 0 : index // CHECK-NEXT: %[[UPPER_1:.*]] = constant 2 : index +// CHECK-NEXT: %[[LOWER_2:.*]] = constant 0 : index // CHECK-NEXT: %[[UPPER_2:.*]] = constant 2 : index // CHECK-NEXT: %[[STEP_1:.*]] = constant 1 : index // CHECK-NEXT: %[[STEP_2:.*]] = constant 1 : index @@ -841,8 +841,8 @@ } // CHECK-LABEL: @affine_parallel_with_reductions_f64 // CHECK: %[[LOWER_1:.*]] = constant 0 : index -// CHECK: %[[LOWER_2:.*]] = constant 0 : index // CHECK: %[[UPPER_1:.*]] = constant 2 : index +// CHECK: %[[LOWER_2:.*]] = constant 0 : index // CHECK: %[[UPPER_2:.*]] = constant 2 : index // CHECK: %[[STEP_1:.*]] = constant 1 : index // CHECK: %[[STEP_2:.*]] = constant 1 : index @@ -880,8 +880,8 @@ } // CHECK-LABEL: @affine_parallel_with_reductions_i64 // CHECK: %[[LOWER_1:.*]] = constant 0 : index -// CHECK: %[[LOWER_2:.*]] = constant 0 : index // CHECK: %[[UPPER_1:.*]] = constant 2 : index +// CHECK: %[[LOWER_2:.*]] = constant 0 : index // CHECK: %[[UPPER_2:.*]] = constant 2 : index // CHECK: %[[STEP_1:.*]] = constant 1 : index // CHECK: %[[STEP_2:.*]] = constant 1 : index diff --git a/mlir/test/Dialect/Affine/invalid.mlir b/mlir/test/Dialect/Affine/invalid.mlir --- a/mlir/test/Dialect/Affine/invalid.mlir +++ b/mlir/test/Dialect/Affine/invalid.mlir @@ -197,7 +197,7 @@ // ----- func @affine_parallel(%arg0 : index, %arg1 : index, %arg2 : index) { - // expected-error@+1 {{region argument count and num results of upper bounds, lower bounds, and steps must all match}} + // expected-error@+1 {{the number of region arguments (1) and the number of map groups for lower (2) and upper bound (2), and the number of steps (2) must all match}} affine.parallel (%i) = (0, 0) to (100, 100) step (10, 10) { } } @@ -205,7 +205,7 @@ // ----- func @affine_parallel(%arg0 : index, %arg1 : index, %arg2 : index) { - // expected-error@+1 {{region argument count and num results of upper bounds, lower bounds, and steps must all match}} + // expected-error@+1 {{the number of region arguments (2) and the number of map groups for lower (1) and upper bound (2), and the number of steps (2) must all match}} affine.parallel (%i, %j) = (0) to (100, 100) step (10, 10) { } } @@ -213,7 +213,7 @@ // ----- func @affine_parallel(%arg0 : index, %arg1 : index, %arg2 : index) { - // expected-error@+1 {{region argument count and num results of upper bounds, lower bounds, and steps must all match}} + // expected-error@+1 {{the number of region arguments (2) and the number of map groups for lower (2) and upper bound (1), and the number of steps (2) must all match}} affine.parallel (%i, %j) = (0, 0) to (100) step (10, 10) { } } @@ -221,7 +221,7 @@ // ----- func @affine_parallel(%arg0 : index, %arg1 : index, %arg2 : index) { - // expected-error@+1 {{region argument count and num results of upper bounds, lower bounds, and steps must all match}} + // expected-error@+1 {{the number of region arguments (2) and the number of map groups for lower (2) and upper bound (2), and the number of steps (1) must all match}} affine.parallel (%i, %j) = (0, 0) to (100, 100) step (10) { } } diff --git a/mlir/test/Dialect/Affine/ops.mlir b/mlir/test/Dialect/Affine/ops.mlir --- a/mlir/test/Dialect/Affine/ops.mlir +++ b/mlir/test/Dialect/Affine/ops.mlir @@ -169,6 +169,21 @@ // ----- +// CHECK-LABEL: @parallel_min_max +// CHECK: %[[A:.*]]: index, %[[B:.*]]: index, %[[C:.*]]: index, %[[D:.*]]: index +func @parallel_min_max(%a: index, %b: index, %c: index, %d: index) { + // CHECK: affine.parallel (%{{.*}}, %{{.*}}, %{{.*}}) = + // CHECK: (max(%[[A]], %[[B]]) + // CHECK: to (%[[C]], min(%[[C]], %[[D]]), %[[B]]) + affine.parallel (%i, %j, %k) = (max(%a, %b), %b, max(%a, %c)) + to (%c, min(%c, %d), %b) { + affine.yield + } + return +} + +// ----- + // CHECK-LABEL: func @affine_if func @affine_if() -> f32 { // CHECK: %[[ZERO:.*]] = constant {{.*}} : f32 diff --git a/mlir/test/Dialect/Affine/parallelize.mlir b/mlir/test/Dialect/Affine/parallelize.mlir --- a/mlir/test/Dialect/Affine/parallelize.mlir +++ b/mlir/test/Dialect/Affine/parallelize.mlir @@ -120,9 +120,7 @@ // CHECK-LABEL: for_with_minmax func @for_with_minmax(%m: memref, %lb0: index, %lb1: index, %ub0: index, %ub1: index) { - // CHECK: %[[lb:.*]] = affine.max - // CHECK: %[[ub:.*]] = affine.min - // CHECK: affine.parallel (%{{.*}}) = (%[[lb]]) to (%[[ub]]) + // CHECK: affine.parallel (%{{.*}}) = (max(%{{.*}}, %{{.*}})) to (min(%{{.*}}, %{{.*}})) affine.for %i = max affine_map<(d0, d1) -> (d0, d1)>(%lb0, %lb1) to min affine_map<(d0, d1) -> (d0, d1)>(%ub0, %ub1) { affine.load %m[%i] : memref @@ -133,12 +131,9 @@ // CHECK-LABEL: nested_for_with_minmax func @nested_for_with_minmax(%m: memref, %lb0: index, %ub0: index, %ub1: index) { - // CHECK: affine.parallel + // CHECK: affine.parallel (%[[I:.*]]) = affine.for %j = 0 to 10 { - // Cannot parallelize the inner loop because we would need to compute - // affine.max for its lower bound inside the loop, and that is not (yet) - // considered as a valid affine dimension. - // CHECK: affine.for + // CHECK: affine.parallel (%{{.*}}) = (max(%{{.*}}, %[[I]])) to (min(%{{.*}}, %{{.*}})) affine.for %i = max affine_map<(d0, d1) -> (d0, d1)>(%lb0, %j) to min affine_map<(d0, d1) -> (d0, d1)>(%ub0, %ub1) { affine.load %m[%i] : memref @@ -236,3 +231,20 @@ } return } + +// REDUCE-LABEL: @nested_min_max +// CHECK-LABEL: @nested_min_max +// CHECK: (%{{.*}}, %[[LB0:.*]]: index, %[[UB0:.*]]: index, %[[UB1:.*]]: index) +func @nested_min_max(%m: memref, %lb0: index, + %ub0: index, %ub1: index) { + // CHECK: affine.parallel (%[[J:.*]]) = + affine.for %j = 0 to 10 { + // CHECK: affine.parallel (%{{.*}}) = (max(%[[LB0]], %[[J]])) + // CHECK: to (min(%[[UB0]], %[[UB1]])) + affine.for %i = max affine_map<(d0, d1) -> (d0, d1)>(%lb0, %j) + to min affine_map<(d0, d1) -> (d0, d1)>(%ub0, %ub1) { + affine.load %m[%i] : memref + } + } + return +}