diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h @@ -15,6 +15,7 @@ #define MLIR_DIALECT_AFFINE_IR_AFFINEOPS_H #include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Dialect.h" @@ -28,7 +29,7 @@ class AffineBound; class AffineDimExpr; class AffineValueMap; -class AffineTerminatorOp; +class AffineYieldOp; class FlatAffineConstraints; class OpBuilder; diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -14,6 +14,7 @@ #define AFFINE_OPS include "mlir/Dialect/Affine/IR/AffineOpsBase.td" +include "mlir/Dialect/StandardOps/IR/StandardOpsBase.td" include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td" include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -38,9 +39,9 @@ let parser = [{ return ::parse$cppClass(parser, result); }]; } -// Require regions to have affine terminator. +// Require regions to have affine.yield. def ImplicitAffineTerminator - : SingleBlockImplicitTerminator<"AffineTerminatorOp">; + : SingleBlockImplicitTerminator<"AffineYieldOp">; def AffineApplyOp : Affine_Op<"apply", [NoSideEffect]> { let summary = "affine apply operation"; @@ -125,7 +126,7 @@ The `affine.for` operation represents an affine loop nest. It has one region containing its body. This region must contain one block that terminates with - [`affine.terminator`](#affineterminator-affineterminatorop). *Note:* when + [`affine.yield`](#affineyield-affineyieldop). *Note:* when `affine.for` is printed in custom format, the terminator is omitted. The block has one argument of [`index`](../LangRef.md#index-type) type that represents the induction variable of the loop. @@ -303,7 +304,7 @@ clauses. The latter may be empty (i.e. contain no blocks), meaning the absence of the else clause. When non-empty, both regions must contain exactly one block terminating with - [`affine.terminator`](#affineterminator-affineterminatorop). *Note:* when `affine.if` + [`affine.yield`](#affineyield-operation). *Note:* when `affine.if` is printed in custom format, the terminator is omitted. These blocks must not have any arguments. @@ -328,13 +329,17 @@ ``` }]; let arguments = (ins Variadic); + let results = (outs Variadic:$results); let regions = (region SizedRegion<1>:$thenRegion, AnyRegion:$elseRegion); let skipDefaultBuilders = 1; let builders = [ OpBuilder<"OpBuilder &builder, OperationState &result, " - "IntegerSet set, ValueRange args, bool withElseRegion"> + "IntegerSet set, ValueRange args, bool withElseRegion">, + OpBuilder<"OpBuilder &builder, OperationState &result, " + "TypeRange resultTypes, IntegerSet set, ValueRange args," + "bool withElseRegion">, ]; let extraClassDeclaration = [{ @@ -511,7 +516,9 @@ }]; } -def AffineParallelOp : Affine_Op<"parallel", [ImplicitAffineTerminator]> { +def AffineParallelOp : Affine_Op<"parallel", + [ImplicitAffineTerminator, RecursiveSideEffects, + DeclareOpInterfaceMethods]> { let summary = "multi-index parallel band operation"; let description = [{ The "affine.parallel" operation represents a hyper-rectangular affine @@ -522,17 +529,26 @@ steps, are positive constant integers which defaults to "1" if not present. The lower and upper bounds specify a half-open range: the range includes the lower bound but does not include the upper bound. The body region must - contain exactly one block that terminates with "affine.terminator". + contain exactly one block that terminates with "affine.yield". The lower and upper bounds of a parallel operation are represented as an application of an affine mapping to a list of SSA values passed to the map. The same restrictions hold for these SSA values as for all bindings of SSA values to dimensions and symbols. + Each value yielded by affine.yield will be accumulated by one of the + aggregation ops defined in the AtomicRMWKind enum. The order of + accumulation is unspecified, and lowering may produce any valid ordering. + Loops with a 0 trip count will produce as a result the identity value + associated with each accumulation (i.e. 0.0 for addf, 1.0 for mulf). + 'assign' accumulation for loops with a trip count != 1 produces undefined + results. + Note: Calling AffineParallelOp::build will create the required region and - block, and insert the required terminator. Parsing will also create the - required region, block, and terminator, even when they are missing from the - textual representation. + block, and insert the required terminator if it is trivial (i.e. no + values are yielded). Parsing will also create the required region, + block, and terminator, even when they are missing from the textual + representation. Example: @@ -544,19 +560,26 @@ }]; let arguments = (ins + TypedArrayAttrBase:$aggOps, AffineMapAttr:$lowerBoundsMap, AffineMapAttr:$upperBoundsMap, I64ArrayAttr:$steps, Variadic:$mapOperands); + let results = (outs Variadic:$results); let regions = (region SizedRegion<1>:$region); let builders = [ - OpBuilder<"OpBuilder &builder, OperationState &result," + OpBuilder<"OpBuilder &builder, OperationState& result, " + "ArrayRef resultTypes, ArrayRef aggOps, " "ArrayRef ranges">, - OpBuilder<"OpBuilder &builder, OperationState &result, AffineMap lbMap," - "ValueRange lbArgs, AffineMap ubMap, ValueRange ubArgs">, - OpBuilder<"OpBuilder &builder, OperationState &result, AffineMap lbMap," - "ValueRange lbArgs, AffineMap ubMap, ValueRange ubArgs," + OpBuilder<"OpBuilder &builder, OperationState& result, " + "ArrayRef resultTypes, ArrayRef aggOps, " + "AffineMap lbMap, ValueRange lbArgs, " + "AffineMap ubMap, ValueRange ubArgs">, + OpBuilder<"OpBuilder &builder, OperationState& result, " + "ArrayRef resultTypes, ArrayRef aggOps, " + "AffineMap lbMap, ValueRange lbArgs, " + "AffineMap ubMap, ValueRange ubArgs, " "ArrayRef steps"> ]; @@ -581,10 +604,13 @@ } void setSteps(ArrayRef newSteps); + static StringRef getAggOpsAttrName() { return "aggOps"; } static StringRef getLowerBoundsMapAttrName() { return "lowerBoundsMap"; } static StringRef getUpperBoundsMapAttrName() { return "upperBoundsMap"; } static StringRef getStepsAttrName() { return "steps"; } }]; + + let hasCanonicalizer = 1; } def AffinePrefetchOp : Affine_Op<"prefetch"> { @@ -733,35 +759,28 @@ let hasFolder = 1; } -def AffineTerminatorOp : - Affine_Op<"terminator", [NoSideEffect, Terminator]> { - let summary = "affine terminator operation"; +def AffineYieldOp : Affine_Op<"yield", [NoSideEffect, Terminator]> { + let summary = "Yield values to parent operation"; let description = [{ - Syntax: - - ``` - operation ::= `"affine.terminator"() : () -> ()` + "affine.yield" yields zero or more SSA values from an affine op region and + terminates the region. The semantics of how the values yielded are used + is defined by the parent operation. + If "affine.yield" has any operands, the operands must match the parent + operation's results. + If the parent operation defines no values, then the "affine.yield" may be + left out in the custom syntax and the builders will insert one implicitly. + Otherwise, it has to be present in the syntax to indicate which values are + yielded. ``` - - Affine terminator is a special terminator operation for blocks inside affine - loops ([`affine.for`](#affinefor-affineforop)) and branches - ([`affine.if`](#affineif-affineifop)). It unconditionally transmits the - control flow to the successor of the operation enclosing the region. - - *Rationale*: bodies of affine operations are [blocks](../LangRef.md#blocks) - that must have terminators. Loops and branches represent structured control - flow and should not accept arbitrary branches as terminators. - - This operation does _not_ have a custom syntax. However, affine control - operations omit the terminator in their custom syntax for brevity. }]; - // No custom parsing/printing form. - let parser = ?; - let printer = ?; + let arguments = (ins Variadic:$operands); + + let builders = [OpBuilder< + "OpBuilder &b, OperationState &result", [{ build(b, result, llvm::None); }] + >]; - // Fully specified by traits. - let verifier = ?; + let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; } def AffineVectorLoadOp : AffineLoadOpBase<"vector_load"> { diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -13,6 +13,7 @@ #ifndef STANDARD_OPS #define STANDARD_OPS +include "mlir/Dialect/StandardOps/IR/StandardOpsBase.td" include "mlir/IR/OpAsmInterface.td" include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" @@ -455,31 +456,6 @@ let assemblyFormat = "$memref `,` $alignment attr-dict `:` type($memref)"; } -//===----------------------------------------------------------------------===// -// AtomicRMWOp -//===----------------------------------------------------------------------===// - -def ATOMIC_RMW_KIND_ADDF : I64EnumAttrCase<"addf", 0>; -def ATOMIC_RMW_KIND_ADDI : I64EnumAttrCase<"addi", 1>; -def ATOMIC_RMW_KIND_ASSIGN : I64EnumAttrCase<"assign", 2>; -def ATOMIC_RMW_KIND_MAXF : I64EnumAttrCase<"maxf", 3>; -def ATOMIC_RMW_KIND_MAXS : I64EnumAttrCase<"maxs", 4>; -def ATOMIC_RMW_KIND_MAXU : I64EnumAttrCase<"maxu", 5>; -def ATOMIC_RMW_KIND_MINF : I64EnumAttrCase<"minf", 6>; -def ATOMIC_RMW_KIND_MINS : I64EnumAttrCase<"mins", 7>; -def ATOMIC_RMW_KIND_MINU : I64EnumAttrCase<"minu", 8>; -def ATOMIC_RMW_KIND_MULF : I64EnumAttrCase<"mulf", 9>; -def ATOMIC_RMW_KIND_MULI : I64EnumAttrCase<"muli", 10>; - -def AtomicRMWKindAttr : I64EnumAttr< - "AtomicRMWKind", "", - [ATOMIC_RMW_KIND_ADDF, ATOMIC_RMW_KIND_ADDI, ATOMIC_RMW_KIND_ASSIGN, - ATOMIC_RMW_KIND_MAXF, ATOMIC_RMW_KIND_MAXS, ATOMIC_RMW_KIND_MAXU, - ATOMIC_RMW_KIND_MINF, ATOMIC_RMW_KIND_MINS, ATOMIC_RMW_KIND_MINU, - ATOMIC_RMW_KIND_MULF, ATOMIC_RMW_KIND_MULI]> { - let cppNamespace = "::mlir"; -} - def AtomicRMWOp : Std_Op<"atomic_rmw", [ AllTypesMatch<["value", "result"]>, TypesMatchWith<"value type matches element type of memref", diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/StandardOpsBase.td b/mlir/include/mlir/Dialect/StandardOps/IR/StandardOpsBase.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/StandardOps/IR/StandardOpsBase.td @@ -0,0 +1,39 @@ +//===- StandardOpsBase.td - Standard operation definitions -----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines base support for standard operations. +// +//===----------------------------------------------------------------------===// + +#ifndef STANDARD_OPS_BASE +#define STANDARD_OPS_BASE + +include "mlir/IR/OpBase.td" + +def ATOMIC_RMW_KIND_ADDF : I64EnumAttrCase<"addf", 0>; +def ATOMIC_RMW_KIND_ADDI : I64EnumAttrCase<"addi", 1>; +def ATOMIC_RMW_KIND_ASSIGN : I64EnumAttrCase<"assign", 2>; +def ATOMIC_RMW_KIND_MAXF : I64EnumAttrCase<"maxf", 3>; +def ATOMIC_RMW_KIND_MAXS : I64EnumAttrCase<"maxs", 4>; +def ATOMIC_RMW_KIND_MAXU : I64EnumAttrCase<"maxu", 5>; +def ATOMIC_RMW_KIND_MINF : I64EnumAttrCase<"minf", 6>; +def ATOMIC_RMW_KIND_MINS : I64EnumAttrCase<"mins", 7>; +def ATOMIC_RMW_KIND_MINU : I64EnumAttrCase<"minu", 8>; +def ATOMIC_RMW_KIND_MULF : I64EnumAttrCase<"mulf", 9>; +def ATOMIC_RMW_KIND_MULI : I64EnumAttrCase<"muli", 10>; + +def AtomicRMWKindAttr : I64EnumAttr< + "AtomicRMWKind", "", + [ATOMIC_RMW_KIND_ADDF, ATOMIC_RMW_KIND_ADDI, ATOMIC_RMW_KIND_ASSIGN, + ATOMIC_RMW_KIND_MAXF, ATOMIC_RMW_KIND_MAXS, ATOMIC_RMW_KIND_MAXU, + ATOMIC_RMW_KIND_MINF, ATOMIC_RMW_KIND_MINS, ATOMIC_RMW_KIND_MINU, + ATOMIC_RMW_KIND_MULF, ATOMIC_RMW_KIND_MULI]> { + let cppNamespace = "::mlir"; +} + +#endif // STANDARD_OPS_BASE diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -1020,7 +1020,7 @@ if (isa(opInst) || isa(opInst)) loadAndStoreOpInsts.push_back(opInst); - else if (!isa(opInst) && !isa(opInst) && + else if (!isa(opInst) && !isa(opInst) && !isa(opInst) && !MemoryEffectOpInterface::hasNoEffect(opInst)) return WalkResult::interrupt(); diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp --- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp +++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp @@ -328,11 +328,11 @@ }; /// Affine terminators are removed. -class AffineTerminatorLowering : public OpRewritePattern { +class AffineTerminatorLowering : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AffineTerminatorOp op, + LogicalResult matchAndRewrite(AffineYieldOp op, PatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op); return success(); diff --git a/mlir/lib/Dialect/Affine/EDSC/Builders.cpp b/mlir/lib/Dialect/Affine/EDSC/Builders.cpp --- a/mlir/lib/Dialect/Affine/EDSC/Builders.cpp +++ b/mlir/lib/Dialect/Affine/EDSC/Builders.cpp @@ -54,7 +54,7 @@ OpBuilder::InsertionGuard guard(nestedBuilder); bodyBuilderFn(iv); } - nestedBuilder.create(nestedLoc); + nestedBuilder.create(nestedLoc); }); } diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -1528,7 +1528,7 @@ LogicalResult matchAndRewrite(AffineForOp forOp, PatternRewriter &rewriter) const override { - // Check that the body only contains a terminator. + // Check that the body only contains a yield. if (!llvm::hasSingleElement(*forOp.getBody())) return failure(); rewriter.eraseOp(forOp); @@ -1719,7 +1719,7 @@ OpBuilder::InsertionGuard nestedGuard(nestedBuilder); bodyBuilderFn(nestedBuilder, nestedLoc, ivs); } - nestedBuilder.create(nestedLoc); + nestedBuilder.create(nestedLoc); }; // Delegate actual loop creation to the callback in order to dispatch @@ -1772,14 +1772,14 @@ //===----------------------------------------------------------------------===// namespace { -/// Remove else blocks that have nothing other than the terminator. +/// Remove else blocks that have nothing other than a zero value yield. struct SimplifyDeadElse : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AffineIfOp ifOp, PatternRewriter &rewriter) const override { if (ifOp.elseRegion().empty() || - !llvm::hasSingleElement(*ifOp.getElseBlock())) + !llvm::hasSingleElement(*ifOp.getElseBlock()) || ifOp.getNumResults()) return failure(); rewriter.startRootUpdate(ifOp); @@ -1841,6 +1841,9 @@ parser.getNameLoc(), "symbol operand count and integer set symbol count must match"); + if (parser.parseOptionalArrowTypeList(result.types)) + return failure(); + // Create the regions for 'then' and 'else'. The latter must be created even // if it remains empty for the validity of the operation. result.regions.reserve(2); @@ -1874,9 +1877,10 @@ p << "affine.if " << conditionAttr; printDimAndSymbolList(op.operand_begin(), op.operand_end(), conditionAttr.getValue().getNumDims(), p); + p.printOptionalArrowTypeList(op.getResultTypes()); p.printRegion(op.thenRegion(), /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/false); + /*printBlockTerminators=*/op.getNumResults()); // Print the 'else' regions if it has any blocks. auto &elseRegion = op.elseRegion(); @@ -1884,7 +1888,7 @@ p << " else"; p.printRegion(elseRegion, /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/false); + /*printBlockTerminators=*/op.getNumResults()); } // Print the attribute list. @@ -1905,14 +1909,31 @@ } void AffineIfOp::build(OpBuilder &builder, OperationState &result, - IntegerSet set, ValueRange args, bool withElseRegion) { + TypeRange resultTypes, IntegerSet set, ValueRange args, + bool withElseRegion) { + assert(resultTypes.empty() || withElseRegion); + for (Type type : resultTypes) { + result.addTypes(type); + } result.addOperands(args); result.addAttribute(getConditionAttrName(), IntegerSetAttr::get(set)); + Region *thenRegion = result.addRegion(); + thenRegion->push_back(new Block()); + if (resultTypes.empty()) + AffineIfOp::ensureTerminator(*thenRegion, builder, result.location); + Region *elseRegion = result.addRegion(); - AffineIfOp::ensureTerminator(*thenRegion, builder, result.location); - if (withElseRegion) - AffineIfOp::ensureTerminator(*elseRegion, builder, result.location); + if (withElseRegion) { + elseRegion->push_back(new Block()); + if (resultTypes.empty()) + AffineIfOp::ensureTerminator(*elseRegion, builder, result.location); + } +} + +void AffineIfOp::build(OpBuilder &builder, OperationState &result, + IntegerSet set, ValueRange args, bool withElseRegion) { + AffineIfOp::build(builder, result, {}, set, args, withElseRegion); } /// Canonicalize an affine if op's conditional (integer set + operands). @@ -2370,6 +2391,8 @@ //===----------------------------------------------------------------------===// void AffineParallelOp::build(OpBuilder &builder, OperationState &result, + ArrayRef resultTypes, + ArrayRef aggOps, ArrayRef ranges) { SmallVector lbExprs(ranges.size(), builder.getAffineConstantExpr(0)); @@ -2378,30 +2401,45 @@ for (int64_t range : ranges) ubExprs.push_back(builder.getAffineConstantExpr(range)); auto ubMap = AffineMap::get(0, 0, ubExprs, builder.getContext()); - build(builder, result, lbMap, {}, ubMap, {}); + build(builder, result, resultTypes, aggOps, lbMap, {}, ubMap, {}); } void AffineParallelOp::build(OpBuilder &builder, OperationState &result, - AffineMap lbMap, ValueRange lbArgs, - AffineMap ubMap, ValueRange ubArgs) { + ArrayRef resultTypes, + ArrayRef aggOps, AffineMap lbMap, + ValueRange lbArgs, AffineMap ubMap, + ValueRange ubArgs) { auto numDims = lbMap.getNumResults(); // Verify that the dimensionality of both maps are the same. assert(numDims == ubMap.getNumResults() && "num dims and num results mismatch"); // Make default step sizes of 1. SmallVector steps(numDims, 1); - build(builder, result, lbMap, lbArgs, ubMap, ubArgs, steps); + build(builder, result, resultTypes, aggOps, lbMap, lbArgs, ubMap, ubArgs, + steps); } void AffineParallelOp::build(OpBuilder &builder, OperationState &result, - AffineMap lbMap, ValueRange lbArgs, - AffineMap ubMap, ValueRange ubArgs, - ArrayRef steps) { + ArrayRef resultTypes, + ArrayRef aggOps, AffineMap lbMap, + ValueRange lbArgs, AffineMap ubMap, + ValueRange ubArgs, ArrayRef steps) { auto numDims = lbMap.getNumResults(); // Verify that the dimensionality of the maps matches the number of steps. assert(numDims == ubMap.getNumResults() && "num dims and num results mismatch"); assert(numDims == steps.size() && "num dims and num steps mismatch"); + // Add the types + for (Type type : resultTypes) { + result.addTypes(type); + } + // Convert the aggOps to integer attributes + SmallVector aggOpAttrs; + for (auto agg : aggOps) { + aggOpAttrs.push_back(builder.getI64IntegerAttr(static_cast(agg))); + } + // Add the attributes + result.addAttribute(getAggOpsAttrName(), builder.getArrayAttr(aggOpAttrs)); result.addAttribute(getLowerBoundsMapAttrName(), AffineMapAttr::get(lbMap)); result.addAttribute(getUpperBoundsMapAttrName(), AffineMapAttr::get(ubMap)); result.addAttribute(getStepsAttrName(), builder.getI64ArrayAttr(steps)); @@ -2414,7 +2452,20 @@ for (unsigned i = 0; i < numDims; ++i) body->addArgument(IndexType::get(builder.getContext())); bodyRegion->push_back(body); - ensureTerminator(*bodyRegion, builder, result.location); + if (resultTypes.empty()) + ensureTerminator(*bodyRegion, builder, result.location); +} + +Region &AffineParallelOp::getLoopBody() { return region(); } + +bool AffineParallelOp::isDefinedOutsideOfLoop(Value value) { + return !region().isAncestor(value.getParentRegion()); +} + +LogicalResult AffineParallelOp::moveOutOfLoop(ArrayRef ops) { + for (auto *op : ops) + op->moveBefore(*this); + return success(); } unsigned AffineParallelOp::getNumDims() { return steps().size(); } @@ -2477,6 +2528,17 @@ return op.emitOpError("region argument count and num results of upper " "bounds, lower bounds, and steps must all match"); } + if (op.aggOps().size() != op.getNumResults()) { + return op.emitOpError("aggregation must be specified for each output"); + } + // Verify agg ops are all valid + for (auto attr : op.aggOps()) { + auto intAttr = attr.dyn_cast(); + if (!intAttr || !symbolizeAtomicRMWKind(intAttr.getInt())) { + return op.emitOpError("invalid aggregation attribute"); + } + } + // Verify that the bound operands are valid dimension/symbols. /// Lower bounds. if (failed(verifyDimAndSymbolIdentifiers(op, op.getLowerBoundsOperands(), @@ -2489,6 +2551,92 @@ return success(); } +namespace { +/// This pattern removes affine.parallel ops with no induction variables. +struct AffineParallelRank0LoopRemover + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AffineParallelOp op, + PatternRewriter &rewriter) const override { + // Check that there are no induction variables + if (op.lowerBoundsMap().getNumResults()) + return failure(); + // Remove the affine.parallel wrapper, retain the body in the same location + auto &parentOps = rewriter.getInsertionBlock()->getOperations(); + auto ¶llelBodyOps = op.region().front().getOperations(); + auto yield = mlir::cast(std::prev(parallelBodyOps.end())); + for (auto it : zip(op.getResults(), yield.operands())) { + std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); + } + parentOps.splice(mlir::Block::iterator(op), parallelBodyOps, + parallelBodyOps.begin(), std::prev(parallelBodyOps.end())); + rewriter.eraseOp(op); + return success(); + } +}; + +/// This pattern removes indexes that go over an empty range. +struct AffineParallelTripCount1IndexRemover + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AffineParallelOp op, + PatternRewriter &rewriter) const override { + auto ranges = op.getRangesValueMap(); + auto origNumArgs = op.getBody()->getArguments().size(); + size_t curArgNum = 0; + SmallVector newLowerBounds; + SmallVector newUpperBounds; + SmallVector newSteps; + for (unsigned i = 0; i < origNumArgs; i++) { + // Is the range a constant value of 1? + auto constExpr = ranges.getResult(i).dyn_cast(); + int64_t step = op.steps()[i].template cast().getInt(); + if (constExpr && constExpr.getValue() == step) { + // Remove argument and replace with lower bound + auto curArg = op.getBody()->getArgument(curArgNum); + auto lowerBoundValue = rewriter.create( + op.getLoc(), op.lowerBoundsMap().getSubMap({i}), + op.getLowerBoundsOperands()); + curArg.replaceAllUsesWith(lowerBoundValue); + op.getBody()->eraseArgument(curArgNum); + } else { + // Keep argument + newLowerBounds.push_back(op.lowerBoundsMap().getResult(i)); + newUpperBounds.push_back(op.upperBoundsMap().getResult(i)); + newSteps.push_back(step); + curArgNum++; + } + } + // If no arguments were removed, return failure to match + if (newLowerBounds.size() == op.lowerBoundsMap().getNumResults()) + return failure(); + // Update attributes and return success + auto newLower = AffineMap::get(op.lowerBoundsMap().getNumDims(), + op.lowerBoundsMap().getNumSymbols(), + newLowerBounds, op.getContext()); + auto newUpper = AffineMap::get(op.upperBoundsMap().getNumDims(), + op.upperBoundsMap().getNumSymbols(), + newUpperBounds, op.getContext()); + op.setAttr(AffineParallelOp::getLowerBoundsMapAttrName(), + AffineMapAttr::get(newLower)); + op.setAttr(AffineParallelOp::getUpperBoundsMapAttrName(), + AffineMapAttr::get(newUpper)); + op.setAttr(AffineParallelOp::getStepsAttrName(), + rewriter.getI64ArrayAttr(newSteps)); + return success(); + } +}; + +} // end anonymous namespace + +void AffineParallelOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + static void print(OpAsmPrinter &p, AffineParallelOp op) { p << op.getOperationName() << " (" << op.getBody()->getArguments() << ") = ("; p.printAffineMapOfSSAIds(op.lowerBoundsMapAttr(), @@ -2509,11 +2657,22 @@ llvm::interleaveComma(steps, p); p << ')'; } + if (op.getNumResults()) { + p << " agg ("; + llvm::interleaveComma(op.aggOps(), p, [&](auto &attr) { + auto sym = + *symbolizeAtomicRMWKind(attr.template cast().getInt()); + p << "\"" << stringifyAtomicRMWKind(sym) << "\""; + }); + p << ") -> (" << op.getResultTypes() << ")"; + } + p.printRegion(op.region(), /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/false); + /*printBlockTerminators=*/op.getNumResults()); p.printOptionalAttrDict( op.getAttrs(), - /*elidedAttrs=*/{AffineParallelOp::getLowerBoundsMapAttrName(), + /*elidedAttrs=*/{AffineParallelOp::getAggOpsAttrName(), + AffineParallelOp::getLowerBoundsMapAttrName(), AffineParallelOp::getUpperBoundsMapAttrName(), AffineParallelOp::getStepsAttrName()}); } @@ -2576,6 +2735,30 @@ result.addAttribute(AffineParallelOp::getStepsAttrName(), builder.getI64ArrayAttr(steps)); } + SmallVector aggOps; + if (succeeded(parser.parseOptionalKeyword("agg"))) { + if (parser.parseLParen()) + return failure(); + do { + StringAttr attrVal; + NamedAttrList attrStorage; + auto loc = parser.getCurrentLocation(); + if (parser.parseAttribute(attrVal, builder.getNoneType(), "agg", + attrStorage)) + return failure(); + auto attrOptional = symbolizeAtomicRMWKind(attrVal.getValue()); + if (!attrOptional) + return parser.emitError(loc, "invalid aggOp value: ") << attrVal; + aggOps.push_back(builder.getI64IntegerAttr( + static_cast(attrOptional.getValue()))); + } while (succeeded(parser.parseOptionalComma())); + if (parser.parseRParen()) + return failure(); + } + if (parser.parseOptionalArrowTypeList(result.types)) + return failure(); + result.addAttribute(AffineParallelOp::getAggOpsAttrName(), + builder.getArrayAttr(aggOps)); // Now parse the body. Region *body = result.addRegion(); @@ -2589,6 +2772,32 @@ return success(); } +//===----------------------------------------------------------------------===// +// AffineYieldOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(AffineYieldOp op) { + auto parentOp = op.getParentOp(); + auto results = parentOp->getResults(); + auto operands = op.getOperands(); + + if (!isa(parentOp) && !isa(parentOp) && + !isa(parentOp)) { + return op.emitOpError() + << "affine.terminate only terminates If, For or Parallel regions"; + } + if (parentOp->getNumResults() != op.getNumOperands()) + return op.emitOpError() << "parent of yield must have same number of " + "results as the yield operands"; + for (auto e : llvm::zip(results, operands)) { + if (std::get<0>(e).getType() != std::get<1>(e).getType()) + return op.emitOpError() + << "types mismatch between yield op and its parent"; + } + + return success(); +} + //===----------------------------------------------------------------------===// // AffineVectorLoadOp //===----------------------------------------------------------------------===// @@ -2633,7 +2842,6 @@ if (memrefType.getElementType() != vectorType.getElementType()) return op->emitOpError( "requires memref and vector types of the same elemental type"); - return success(); } diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp --- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp @@ -186,7 +186,7 @@ if (curBegin != block->end()) { // Can't be a terminator because it would have been skipped above. assert(!curBegin->isKnownTerminator() && "can't be a terminator"); - // Exclude the affine terminator - hence, the std::prev. + // Exclude the affine.yield - hence, the std::prev. affineDataCopyGenerate(/*begin=*/curBegin, /*end=*/std::prev(block->end()), copyOptions, /*filterMemRef=*/llvm::None, copyNests); } diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp --- a/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp @@ -114,7 +114,7 @@ // Insert this op in the defined ops list. definedOps.insert(&op); - if (op.getNumOperands() == 0 && !isa(op)) { + if (op.getNumOperands() == 0 && !isa(op)) { LLVM_DEBUG(llvm::dbgs() << "\nNon-constant op with 0 operands\n"); return false; } @@ -199,7 +199,7 @@ for (auto &op : *loopBody) { // We don't hoist for loops. if (!isa(op)) { - if (!isa(op)) { + if (!isa(op)) { if (isOpLoopInvariant(op, indVar, definedOps, opsToHoist)) { opsToMove.push_back(&op); } diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir --- a/mlir/test/Dialect/Affine/canonicalize.mlir +++ b/mlir/test/Dialect/Affine/canonicalize.mlir @@ -604,3 +604,66 @@ } return } + +// ----- +// CHECK-LABEL: func @affine_parallel_rank0 +func @affine_parallel_rank0() { + // CHECK-NEXT: constant + %cst = constant 1.0 : f32 + // CHECK-NEXT: alloc + %0 = alloc() : memref + // CHECK-NEXT: affine.store + affine.parallel () = () to () { + affine.store %cst, %0[] : memref + } + // CHECK-NEXT: return + return +} + +// ----- + +// CHECK-LABEL: func @affine_parallel_range1 +func @affine_parallel_range1() { + // CHECK-NEXT: constant + %cst = constant 1.0 : f32 + // CHECK-NEXT: alloc + %0 = alloc() : memref<2x4xf32> + // CHECK-NEXT: affine.store + affine.parallel (%i, %j) = (0, 1) to (1, 2) { + affine.store %cst, %0[%i, %j] : memref<2x4xf32> + } + // CHECK-NEXT: return + return +} + +// ----- +// CHECK-LABEL: func @affine_parallel_range_step +func @affine_parallel_range_step() { + // CHECK-NEXT: constant + %cst = constant 1.0 : f32 + // CHECK-NEXT: alloc + %0 = alloc() : memref<100xf32> + // CHECK-NEXT: affine.store + affine.parallel (%i) = (0) to (100) step (100) { + affine.store %cst, %0[%i] : memref<100xf32> + } + // CHECK-NEXT: return + return +} + +// ----- + +// CHECK-LABEL: func @affine_parallel_partial_range1 +func @affine_parallel_partial_range1() { + // CHECK-NEXT: constant + %cst = constant 1.0 : f32 + // CHECK-NEXT: alloc + %0 = alloc() : memref<2x4xf32> + // CHECK-NEXT: affine.parallel (%{{.*}}) = (0) to (10) + affine.parallel (%i, %j) = (0, 1) to (10, 2) { + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}, 1] + affine.store %cst, %0[%i, %j] : memref<2x4xf32> + } + // CHECK: return + return +} diff --git a/mlir/test/Dialect/Affine/invalid.mlir b/mlir/test/Dialect/Affine/invalid.mlir --- a/mlir/test/Dialect/Affine/invalid.mlir +++ b/mlir/test/Dialect/Affine/invalid.mlir @@ -266,6 +266,47 @@ // ----- +// CHECK-LABEL: @affine_parallel + +func @affine_parallel(%arg0 : index, %arg1 : index, %arg2 : index) { + %0 = alloc() : memref<100x100xf32> + // expected-error@+1 {{op aggregation must be specified for each output}} + %1 = affine.parallel (%i, %j) = (0, 0) to (100, 100) step (10, 10) -> (f32) { + %2 = affine.load %0[%i, %j] : memref<100x100xf32> + affine.yield %2 : f32 + } + return +} + +// ----- + +// CHECK-LABEL: @affine_parallel + +func @affine_parallel(%arg0 : index, %arg1 : index, %arg2 : index) { + %0 = alloc() : memref<100x100xf32> + // expected-error@+1 {{invalid aggOp value: "bad"}} + %1 = affine.parallel (%i, %j) = (0, 0) to (100, 100) step (10, 10) agg ("bad") -> (f32) { + %2 = affine.load %0[%i, %j] : memref<100x100xf32> + affine.yield %2 : f32 + } + return +} + +// ----- +// CHECK-LABEL: @affine_parallel + +func @affine_parallel(%arg0 : index, %arg1 : index, %arg2 : index) { + %0 = alloc() : memref<100x100xi32> + %1 = affine.parallel (%i, %j) = (0, 0) to (100, 100) step (10, 10) agg ("minf") -> (f32) { + %2 = affine.load %0[%i, %j] : memref<100x100xi32> + // expected-error@+1 {{types mismatch between yield op and its parent}} + affine.yield %2 : i32 + } + return +} + +// ----- + func @vector_load_invalid_vector_type() { %0 = alloc() : memref<100xf32> affine.for %i0 = 0 to 16 step 8 { diff --git a/mlir/test/Dialect/Affine/ops.mlir b/mlir/test/Dialect/Affine/ops.mlir --- a/mlir/test/Dialect/Affine/ops.mlir +++ b/mlir/test/Dialect/Affine/ops.mlir @@ -2,7 +2,7 @@ // RUN: mlir-opt -allow-unregistered-dialect %s -mlir-print-op-generic | FileCheck -check-prefix=GENERIC %s // Check that the attributes for the affine operations are round-tripped. -// Check that `affine.terminator` is visible in the generic form. +// Check that `affine.yield` is visible in the generic form. // CHECK-LABEL: @empty func @empty() { // CHECK: affine.for @@ -10,7 +10,7 @@ // // GENERIC: "affine.for"() // GENERIC-NEXT: ^bb0(%{{.*}}: index): - // GENERIC-NEXT: "affine.terminator"() : () -> () + // GENERIC-NEXT: "affine.yield"() : () -> () // GENERIC-NEXT: }) affine.for %i = 0 to 10 { } {some_attr = true} @@ -19,7 +19,7 @@ // CHECK-NEXT: } {some_attr = true} // // GENERIC: "affine.if"() - // GENERIC-NEXT: "affine.terminator"() : () -> () + // GENERIC-NEXT: "affine.yield"() : () -> () // GENERIC-NEXT: }, { // GENERIC-NEXT: }) affine.if affine_set<() : ()> () { @@ -29,10 +29,10 @@ // CHECK: } {some_attr = true} // // GENERIC: "affine.if"() - // GENERIC-NEXT: "affine.terminator"() : () -> () + // GENERIC-NEXT: "affine.yield"() : () -> () // GENERIC-NEXT: }, { // GENERIC-NEXT: "foo"() : () -> () - // GENERIC-NEXT: "affine.terminator"() : () -> () + // GENERIC-NEXT: "affine.yield"() : () -> () // GENERIC-NEXT: }) affine.if affine_set<() : ()> () { } else { @@ -42,19 +42,19 @@ return } -// Check that an explicit affine terminator is not printed in custom format. +// Check that an explicit affine.yield is not printed in custom format. // Check that no extra terminator is introduced. -// CHECK-LABEL: @affine_terminator -func @affine_terminator() { +// CHECK-LABEL: @affine.yield +func @affine.yield() { // CHECK: affine.for // CHECK-NEXT: } // // GENERIC: "affine.for"() ( { // GENERIC-NEXT: ^bb0(%{{.*}}: index): // no predecessors - // GENERIC-NEXT: "affine.terminator"() : () -> () + // GENERIC-NEXT: "affine.yield"() : () -> () // GENERIC-NEXT: }) {lower_bound = #map0, step = 1 : index, upper_bound = #map1} : () -> () affine.for %i = 0 to 10 { - "affine.terminator"() : () -> () + "affine.yield"() : () -> () } return } @@ -153,14 +153,34 @@ // ----- -// CHECK-LABEL: @parallel -// CHECK-SAME: (%[[N:.*]]: index) -func @parallel(%N : index) { +// CHECK-LABEL: func @parallel +// CHECK-SAME: (%[[A:.*]]: memref<100x100xf32>, %[[N:.*]]: index) +func @parallel(%A : memref<100x100xf32>, %N : index) { // CHECK: affine.parallel (%[[I0:.*]], %[[J0:.*]]) = (0, 0) to (symbol(%[[N]]), 100) step (10, 10) affine.parallel (%i0, %j0) = (0, 0) to (symbol(%N), 100) step (10, 10) { - // CHECK-NEXT: affine.parallel (%{{.*}}, %{{.*}}) = (%[[I0]], %[[J0]]) to (%[[I0]] + 10, %[[J0]] + 10) - affine.parallel (%i1, %j1) = (%i0, %j0) to (%i0 + 10, %j0 + 10) { + // CHECK: affine.parallel (%{{.*}}, %{{.*}}) = (%[[I0]], %[[J0]]) to (%[[I0]] + 10, %[[J0]] + 10) agg ("minf", "maxf") -> (f32, f32) + %0:2 = affine.parallel (%i1, %j1) = (%i0, %j0) to (%i0 + 10, %j0 + 10) agg ("minf", "maxf") -> (f32, f32) { + %2 = affine.load %A[%i0 + %i0, %j0 + %j1] : memref<100x100xf32> + affine.yield %2, %2 : f32, f32 } } return } + +// ----- + +// CHECK-LABEL: func @affine_if +func @affine_if() -> f32 { + // CHECK: %[[ZERO:.*]] = constant {{.*}} : f32 + %zero = constant 0.0 : f32 + // CHECK: %[[OUT:.*]] = affine.if {{.*}}() -> f32 { + %0 = affine.if affine_set<() : ()> () -> f32 { + // CHECK: affine.yield %[[ZERO]] : f32 + affine.yield %zero : f32 + } else { + // CHECK: affine.yield %[[ZERO]] : f32 + affine.yield %zero : f32 + } + // CHECK: return %[[OUT]] : f32 + return %0 : f32 +} diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -474,7 +474,7 @@ func @return_inside_loop() { affine.for %i = 1 to 100 { - // expected-error@-1 {{op expects regions to end with 'affine.terminator', found 'std.return'}} + // expected-error@-1 {{op expects regions to end with 'affine.yield', found 'std.return'}} // expected-note@-2 {{in custom textual format, the absence of terminator implies}} return } diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -840,12 +840,12 @@ "affine.if"(%c, %N, %c) ({ // CHECK-NEXT: "add" %y = "add"(%c, %N) : (index, index) -> index - "affine.terminator"() : () -> () + "affine.yield"() : () -> () // CHECK-NEXT: } else { }, { // The else region. // CHECK-NEXT: "add" %z = "add"(%c, %c) : (index, index) -> index - "affine.terminator"() : () -> () + "affine.yield"() : () -> () }) { condition = #set0 } : (index, index, index) -> () return