diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h @@ -15,6 +15,7 @@ #define MLIR_DIALECT_AFFINE_IR_AFFINEOPS_H #include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Dialect.h" @@ -28,7 +29,7 @@ class AffineBound; class AffineDimExpr; class AffineValueMap; -class AffineTerminatorOp; +class AffineYieldOp; class FlatAffineConstraints; class OpBuilder; diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -14,7 +14,9 @@ #define AFFINE_OPS include "mlir/Dialect/Affine/IR/AffineOpsBase.td" +include "mlir/Dialect/StandardOps/IR/StandardOpsBase.td" include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td" +include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -38,9 +40,9 @@ let parser = [{ return ::parse$cppClass(parser, result); }]; } -// Require regions to have affine terminator. +// Require regions to have affine.yield. def ImplicitAffineTerminator - : SingleBlockImplicitTerminator<"AffineTerminatorOp">; + : SingleBlockImplicitTerminator<"AffineYieldOp">; def AffineApplyOp : Affine_Op<"apply", [NoSideEffect]> { let summary = "affine apply operation"; @@ -125,7 +127,7 @@ The `affine.for` operation represents an affine loop nest. It has one region containing its body. This region must contain one block that terminates with - [`affine.terminator`](#affineterminator-affineterminatorop). *Note:* when + [`affine.yield`](#affineyield-affineyieldop). *Note:* when `affine.for` is printed in custom format, the terminator is omitted. The block has one argument of [`index`](../LangRef.md#index-type) type that represents the induction variable of the loop. @@ -301,12 +303,13 @@ symbols. The `affine.if` operation contains two regions for the "then" and "else" - clauses. The latter may be empty (i.e. contain no blocks), meaning the - absence of the else clause. When non-empty, both regions must contain - exactly one block terminating with - [`affine.terminator`](#affineterminator-affineterminatorop). *Note:* when `affine.if` - is printed in custom format, the terminator is omitted. These blocks must - not have any arguments. + clauses. `affine.if` may return results that are defined in its regions. + The values defined are determined by which execution path is taken. Each + region of the `affine.if` must contain a single block with no arguments, + and be terminated by `affine.yield`. If `affine.if` defines no values, + the `affine.yield` can be left out, and will be inserted implicitly. + Otherwise, it must be explicit. If no values are defined, the else block + may be empty (i.e. contain no blocks). Example: @@ -327,15 +330,39 @@ return } ``` + + Example with an explicit yield (initialization with edge padding): + + ```mlir + #interior = affine_set<(i, j) : (i - 1 >= 0, j - 1 >= 0, 10 - i >= 0, 10 - j >= 0)> (%i, %j) + func @pad_edges(%I : memref<10x10xf32>) -> (memref<12x12xf32) { + %O = alloc memref<12x12xf32> + affine.parallel (%i, %j) = (0, 0) to (12, 12) { + %1 = affine.if #interior (%i, %j) { + %2 = load %I[%i - 1, %j - 1] : memref<10x10xf32> + affine.yield %2 + } else { + %2 = constant 0.0 : f32 + affine.yield %2 : f32 + } + affine.store %1, %O[%i, %j] : memref<12x12xf32> + } + return %O + } + ``` }]; let arguments = (ins Variadic); + let results = (outs Variadic:$results); let regions = (region SizedRegion<1>:$thenRegion, AnyRegion:$elseRegion); let skipDefaultBuilders = 1; let builders = [ OpBuilder<"OpBuilder &builder, OperationState &result, " - "IntegerSet set, ValueRange args, bool withElseRegion"> + "IntegerSet set, ValueRange args, bool withElseRegion">, + OpBuilder<"OpBuilder &builder, OperationState &result, " + "TypeRange resultTypes, IntegerSet set, ValueRange args," + "bool withElseRegion">, ]; let extraClassDeclaration = [{ @@ -512,7 +539,9 @@ }]; } -def AffineParallelOp : Affine_Op<"parallel", [ImplicitAffineTerminator]> { +def AffineParallelOp : Affine_Op<"parallel", + [ImplicitAffineTerminator, RecursiveSideEffects, + DeclareOpInterfaceMethods]> { let summary = "multi-index parallel band operation"; let description = [{ The "affine.parallel" operation represents a hyper-rectangular affine @@ -523,41 +552,68 @@ steps, are positive constant integers which defaults to "1" if not present. The lower and upper bounds specify a half-open range: the range includes the lower bound but does not include the upper bound. The body region must - contain exactly one block that terminates with "affine.terminator". + contain exactly one block that terminates with "affine.yield". The lower and upper bounds of a parallel operation are represented as an application of an affine mapping to a list of SSA values passed to the map. The same restrictions hold for these SSA values as for all bindings of SSA values to dimensions and symbols. + Each value yielded by affine.yield will be accumulated/reduced via one of + the reduction methods defined in the AtomicRMWKind enum. The order of + reduction is unspecified, and lowering may produce any valid ordering. + Loops with a 0 trip count will produce as a result the identity value + associated with each reduction (i.e. 0.0 for addf, 1.0 for mulf). Assign + reductions for loops with a trip count != 1 produces undefined results. + Note: Calling AffineParallelOp::build will create the required region and - block, and insert the required terminator. Parsing will also create the - required region, block, and terminator, even when they are missing from the - textual representation. + block, and insert the required terminator if it is trivial (i.e. no values + are yielded). Parsing will also create the required region, block, and + terminator, even when they are missing from the textual representation. - Example: + Example (3x3 valid convolution): ```mlir - affine.parallel (%i, %j) = (0, 0) to (10, 10) step (1, 1) { - ... + fuction @conv_2d(%D : memref<100x100xf32>, %K : memref<3x3xf32>) -> (memref<98x98xf32>) { + %O = alloc memref<98x98xf32> + affine.parallel (%x, %y) = (0, 0) to (98, 98) { + %0 = affine.parallel (%kx, %ky) = (0, 0) to (2, 2) reduce ("addf") { + %1 = affine.load %D[%x + %kx, %y + %ky] : memref<100x100xf32> + %2 = affine.load %K[%kx, %ky] : memref<3x3xf32> + %3 = mulf %1, %2 : f32 + affine.yield %3 : f32 + } + affine.store %0, O[%x, %y] : memref<98x98xf32> + } + return %O } ``` }]; let arguments = (ins + TypedArrayAttrBase:$reductions, AffineMapAttr:$lowerBoundsMap, AffineMapAttr:$upperBoundsMap, I64ArrayAttr:$steps, Variadic:$mapOperands); + let results = (outs Variadic:$results); let regions = (region SizedRegion<1>:$region); let builders = [ - OpBuilder<"OpBuilder &builder, OperationState &result," + OpBuilder<"OpBuilder &builder, OperationState &result, " + "ArrayRef resultTypes, " + "ArrayRef reductions, " "ArrayRef ranges">, - OpBuilder<"OpBuilder &builder, OperationState &result, AffineMap lbMap," - "ValueRange lbArgs, AffineMap ubMap, ValueRange ubArgs">, - OpBuilder<"OpBuilder &builder, OperationState &result, AffineMap lbMap," - "ValueRange lbArgs, AffineMap ubMap, ValueRange ubArgs," + OpBuilder<"OpBuilder &builder, OperationState &result, " + "ArrayRef resultTypes, " + "ArrayRef reductions, " + "AffineMap lbMap, ValueRange lbArgs, " + "AffineMap ubMap, ValueRange ubArgs">, + OpBuilder<"OpBuilder &builder, OperationState &result, " + "ArrayRef resultTypes, " + "ArrayRef reductions, " + "AffineMap lbMap, ValueRange lbArgs, " + "AffineMap ubMap, ValueRange ubArgs, " "ArrayRef steps"> ]; @@ -582,6 +638,7 @@ } void setSteps(ArrayRef newSteps); + static StringRef getReductionsAttrName() { return "reductions"; } static StringRef getLowerBoundsMapAttrName() { return "lowerBoundsMap"; } static StringRef getUpperBoundsMapAttrName() { return "upperBoundsMap"; } static StringRef getStepsAttrName() { return "steps"; } @@ -734,35 +791,28 @@ let hasFolder = 1; } -def AffineTerminatorOp : - Affine_Op<"terminator", [NoSideEffect, Terminator]> { - let summary = "affine terminator operation"; +def AffineYieldOp : Affine_Op<"yield", [NoSideEffect, Terminator, ReturnLike]> { + let summary = "Yield values to parent operation"; let description = [{ - Syntax: - + "affine.yield" yields zero or more SSA values from an affine op region and + terminates the region. The semantics of how the values yielded are used + is defined by the parent operation. + If "affine.yield" has any operands, the operands must match the parent + operation's results. + If the parent operation defines no values, then the "affine.yield" may be + left out in the custom syntax and the builders will insert one implicitly. + Otherwise, it has to be present in the syntax to indicate which values are + yielded. ``` - operation ::= `"affine.terminator"() : () -> ()` - ``` - - Affine terminator is a special terminator operation for blocks inside affine - loops ([`affine.for`](#affinefor-affineforop)) and branches - ([`affine.if`](#affineif-affineifop)). It unconditionally transmits the - control flow to the successor of the operation enclosing the region. - - *Rationale*: bodies of affine operations are [blocks](../LangRef.md#blocks) - that must have terminators. Loops and branches represent structured control - flow and should not accept arbitrary branches as terminators. - - This operation does _not_ have a custom syntax. However, affine control - operations omit the terminator in their custom syntax for brevity. }]; - // No custom parsing/printing form. - let parser = ?; - let printer = ?; + let arguments = (ins Variadic:$operands); + + let builders = [OpBuilder< + "OpBuilder &b, OperationState &result", [{ build(b, result, llvm::None); }] + >]; - // Fully specified by traits. - let verifier = ?; + let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; } def AffineVectorLoadOp : AffineLoadOpBase<"vector_load"> { diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -13,6 +13,7 @@ #ifndef STANDARD_OPS #define STANDARD_OPS +include "mlir/Dialect/StandardOps/IR/StandardOpsBase.td" include "mlir/IR/OpAsmInterface.td" include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" @@ -462,31 +463,6 @@ let assemblyFormat = "$memref `,` $alignment attr-dict `:` type($memref)"; } -//===----------------------------------------------------------------------===// -// AtomicRMWOp -//===----------------------------------------------------------------------===// - -def ATOMIC_RMW_KIND_ADDF : I64EnumAttrCase<"addf", 0>; -def ATOMIC_RMW_KIND_ADDI : I64EnumAttrCase<"addi", 1>; -def ATOMIC_RMW_KIND_ASSIGN : I64EnumAttrCase<"assign", 2>; -def ATOMIC_RMW_KIND_MAXF : I64EnumAttrCase<"maxf", 3>; -def ATOMIC_RMW_KIND_MAXS : I64EnumAttrCase<"maxs", 4>; -def ATOMIC_RMW_KIND_MAXU : I64EnumAttrCase<"maxu", 5>; -def ATOMIC_RMW_KIND_MINF : I64EnumAttrCase<"minf", 6>; -def ATOMIC_RMW_KIND_MINS : I64EnumAttrCase<"mins", 7>; -def ATOMIC_RMW_KIND_MINU : I64EnumAttrCase<"minu", 8>; -def ATOMIC_RMW_KIND_MULF : I64EnumAttrCase<"mulf", 9>; -def ATOMIC_RMW_KIND_MULI : I64EnumAttrCase<"muli", 10>; - -def AtomicRMWKindAttr : I64EnumAttr< - "AtomicRMWKind", "", - [ATOMIC_RMW_KIND_ADDF, ATOMIC_RMW_KIND_ADDI, ATOMIC_RMW_KIND_ASSIGN, - ATOMIC_RMW_KIND_MAXF, ATOMIC_RMW_KIND_MAXS, ATOMIC_RMW_KIND_MAXU, - ATOMIC_RMW_KIND_MINF, ATOMIC_RMW_KIND_MINS, ATOMIC_RMW_KIND_MINU, - ATOMIC_RMW_KIND_MULF, ATOMIC_RMW_KIND_MULI]> { - let cppNamespace = "::mlir"; -} - def AtomicRMWOp : Std_Op<"atomic_rmw", [ AllTypesMatch<["value", "result"]>, TypesMatchWith<"value type matches element type of memref", diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/StandardOpsBase.td b/mlir/include/mlir/Dialect/StandardOps/IR/StandardOpsBase.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/StandardOps/IR/StandardOpsBase.td @@ -0,0 +1,39 @@ +//===- StandardOpsBase.td - Standard ops definitions -------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines base support for standard operations. +// +//===----------------------------------------------------------------------===// + +#ifndef STANDARD_OPS_BASE +#define STANDARD_OPS_BASE + +include "mlir/IR/OpBase.td" + +def ATOMIC_RMW_KIND_ADDF : I64EnumAttrCase<"addf", 0>; +def ATOMIC_RMW_KIND_ADDI : I64EnumAttrCase<"addi", 1>; +def ATOMIC_RMW_KIND_ASSIGN : I64EnumAttrCase<"assign", 2>; +def ATOMIC_RMW_KIND_MAXF : I64EnumAttrCase<"maxf", 3>; +def ATOMIC_RMW_KIND_MAXS : I64EnumAttrCase<"maxs", 4>; +def ATOMIC_RMW_KIND_MAXU : I64EnumAttrCase<"maxu", 5>; +def ATOMIC_RMW_KIND_MINF : I64EnumAttrCase<"minf", 6>; +def ATOMIC_RMW_KIND_MINS : I64EnumAttrCase<"mins", 7>; +def ATOMIC_RMW_KIND_MINU : I64EnumAttrCase<"minu", 8>; +def ATOMIC_RMW_KIND_MULF : I64EnumAttrCase<"mulf", 9>; +def ATOMIC_RMW_KIND_MULI : I64EnumAttrCase<"muli", 10>; + +def AtomicRMWKindAttr : I64EnumAttr< + "AtomicRMWKind", "", + [ATOMIC_RMW_KIND_ADDF, ATOMIC_RMW_KIND_ADDI, ATOMIC_RMW_KIND_ASSIGN, + ATOMIC_RMW_KIND_MAXF, ATOMIC_RMW_KIND_MAXS, ATOMIC_RMW_KIND_MAXU, + ATOMIC_RMW_KIND_MINF, ATOMIC_RMW_KIND_MINS, ATOMIC_RMW_KIND_MINU, + ATOMIC_RMW_KIND_MULF, ATOMIC_RMW_KIND_MULI]> { + let cppNamespace = "::mlir"; +} + +#endif // STANDARD_OPS_BASE diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -1018,7 +1018,7 @@ auto walkResult = forOp.walk([&](Operation *opInst) -> WalkResult { if (isa(opInst)) loadAndStoreOpInsts.push_back(opInst); - else if (!isa(opInst) && + else if (!isa(opInst) && !MemoryEffectOpInterface::hasNoEffect(opInst)) return WalkResult::interrupt(); diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp --- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp +++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp @@ -327,12 +327,12 @@ } }; -/// Affine terminators are removed. -class AffineTerminatorLowering : public OpRewritePattern { +/// Affine yields ops are removed. +class AffineYieldOpLowering : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AffineTerminatorOp op, + LogicalResult matchAndRewrite(AffineYieldOp op, PatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op); return success(); @@ -619,7 +619,7 @@ AffineStoreLowering, AffineForLowering, AffineIfLowering, - AffineTerminatorLowering>(ctx); + AffineYieldOpLowering>(ctx); // clang-format on } diff --git a/mlir/lib/Dialect/Affine/EDSC/Builders.cpp b/mlir/lib/Dialect/Affine/EDSC/Builders.cpp --- a/mlir/lib/Dialect/Affine/EDSC/Builders.cpp +++ b/mlir/lib/Dialect/Affine/EDSC/Builders.cpp @@ -54,7 +54,7 @@ OpBuilder::InsertionGuard guard(nestedBuilder); bodyBuilderFn(iv); } - nestedBuilder.create(nestedLoc); + nestedBuilder.create(nestedLoc); }); } diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -1528,7 +1528,7 @@ LogicalResult matchAndRewrite(AffineForOp forOp, PatternRewriter &rewriter) const override { - // Check that the body only contains a terminator. + // Check that the body only contains a yield. if (!llvm::hasSingleElement(*forOp.getBody())) return failure(); rewriter.eraseOp(forOp); @@ -1719,7 +1719,7 @@ OpBuilder::InsertionGuard nestedGuard(nestedBuilder); bodyBuilderFn(nestedBuilder, nestedLoc, ivs); } - nestedBuilder.create(nestedLoc); + nestedBuilder.create(nestedLoc); }; // Delegate actual loop creation to the callback in order to dispatch @@ -1772,14 +1772,14 @@ //===----------------------------------------------------------------------===// namespace { -/// Remove else blocks that have nothing other than the terminator. +/// Remove else blocks that have nothing other than a zero value yield. struct SimplifyDeadElse : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AffineIfOp ifOp, PatternRewriter &rewriter) const override { if (ifOp.elseRegion().empty() || - !llvm::hasSingleElement(*ifOp.getElseBlock())) + !llvm::hasSingleElement(*ifOp.getElseBlock()) || ifOp.getNumResults()) return failure(); rewriter.startRootUpdate(ifOp); @@ -1834,6 +1834,9 @@ parser.getNameLoc(), "symbol operand count and integer set symbol count must match"); + if (parser.parseOptionalArrowTypeList(result.types)) + return failure(); + // Create the regions for 'then' and 'else'. The latter must be created even // if it remains empty for the validity of the operation. result.regions.reserve(2); @@ -1867,9 +1870,10 @@ p << "affine.if " << conditionAttr; printDimAndSymbolList(op.operand_begin(), op.operand_end(), conditionAttr.getValue().getNumDims(), p); + p.printOptionalArrowTypeList(op.getResultTypes()); p.printRegion(op.thenRegion(), /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/false); + /*printBlockTerminators=*/op.getNumResults()); // Print the 'else' regions if it has any blocks. auto &elseRegion = op.elseRegion(); @@ -1877,7 +1881,7 @@ p << " else"; p.printRegion(elseRegion, /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/false); + /*printBlockTerminators=*/op.getNumResults()); } // Print the attribute list. @@ -1898,14 +1902,30 @@ } void AffineIfOp::build(OpBuilder &builder, OperationState &result, - IntegerSet set, ValueRange args, bool withElseRegion) { + TypeRange resultTypes, IntegerSet set, ValueRange args, + bool withElseRegion) { + assert(resultTypes.empty() || withElseRegion); + result.addTypes(resultTypes); result.addOperands(args); result.addAttribute(getConditionAttrName(), IntegerSetAttr::get(set)); + Region *thenRegion = result.addRegion(); + thenRegion->push_back(new Block()); + if (resultTypes.empty()) + AffineIfOp::ensureTerminator(*thenRegion, builder, result.location); + Region *elseRegion = result.addRegion(); - AffineIfOp::ensureTerminator(*thenRegion, builder, result.location); - if (withElseRegion) - AffineIfOp::ensureTerminator(*elseRegion, builder, result.location); + if (withElseRegion) { + elseRegion->push_back(new Block()); + if (resultTypes.empty()) + AffineIfOp::ensureTerminator(*elseRegion, builder, result.location); + } +} + +void AffineIfOp::build(OpBuilder &builder, OperationState &result, + IntegerSet set, ValueRange args, bool withElseRegion) { + AffineIfOp::build(builder, result, /*resultTypes=*/{}, set, args, + withElseRegion); } /// Canonicalize an affine if op's conditional (integer set + operands). @@ -2363,6 +2383,8 @@ //===----------------------------------------------------------------------===// void AffineParallelOp::build(OpBuilder &builder, OperationState &result, + ArrayRef resultTypes, + ArrayRef reductions, ArrayRef ranges) { SmallVector lbExprs(ranges.size(), builder.getAffineConstantExpr(0)); @@ -2371,10 +2393,13 @@ for (int64_t range : ranges) ubExprs.push_back(builder.getAffineConstantExpr(range)); auto ubMap = AffineMap::get(0, 0, ubExprs, builder.getContext()); - build(builder, result, lbMap, {}, ubMap, {}); + build(builder, result, resultTypes, reductions, lbMap, /*lbArgs=*/{}, ubMap, + /*ubArgs=*/{}); } void AffineParallelOp::build(OpBuilder &builder, OperationState &result, + ArrayRef resultTypes, + ArrayRef reductions, AffineMap lbMap, ValueRange lbArgs, AffineMap ubMap, ValueRange ubArgs) { auto numDims = lbMap.getNumResults(); @@ -2383,10 +2408,13 @@ "num dims and num results mismatch"); // Make default step sizes of 1. SmallVector steps(numDims, 1); - build(builder, result, lbMap, lbArgs, ubMap, ubArgs, steps); + build(builder, result, resultTypes, reductions, lbMap, lbArgs, ubMap, ubArgs, + steps); } void AffineParallelOp::build(OpBuilder &builder, OperationState &result, + ArrayRef resultTypes, + ArrayRef reductions, AffineMap lbMap, ValueRange lbArgs, AffineMap ubMap, ValueRange ubArgs, ArrayRef steps) { @@ -2395,6 +2423,15 @@ assert(numDims == ubMap.getNumResults() && "num dims and num results mismatch"); assert(numDims == steps.size() && "num dims and num steps mismatch"); + + result.addTypes(resultTypes); + // Convert the reductions to integer attributes. + SmallVector reductionAttrs; + for (AtomicRMWKind reduction : reductions) + reductionAttrs.push_back( + builder.getI64IntegerAttr(static_cast(reduction))); + result.addAttribute(getReductionsAttrName(), + builder.getArrayAttr(reductionAttrs)); result.addAttribute(getLowerBoundsMapAttrName(), AffineMapAttr::get(lbMap)); result.addAttribute(getUpperBoundsMapAttrName(), AffineMapAttr::get(ubMap)); result.addAttribute(getStepsAttrName(), builder.getI64ArrayAttr(steps)); @@ -2407,7 +2444,20 @@ for (unsigned i = 0; i < numDims; ++i) body->addArgument(IndexType::get(builder.getContext())); bodyRegion->push_back(body); - ensureTerminator(*bodyRegion, builder, result.location); + if (resultTypes.empty()) + ensureTerminator(*bodyRegion, builder, result.location); +} + +Region &AffineParallelOp::getLoopBody() { return region(); } + +bool AffineParallelOp::isDefinedOutsideOfLoop(Value value) { + return !region().isAncestor(value.getParentRegion()); +} + +LogicalResult AffineParallelOp::moveOutOfLoop(ArrayRef ops) { + for (Operation *op : ops) + op->moveBefore(*this); + return success(); } unsigned AffineParallelOp::getNumDims() { return steps().size(); } @@ -2466,10 +2516,20 @@ if (op.lowerBoundsMap().getNumResults() != numDims || op.upperBoundsMap().getNumResults() != numDims || op.steps().size() != numDims || - op.getBody()->getNumArguments() != numDims) { + op.getBody()->getNumArguments() != numDims) return op.emitOpError("region argument count and num results of upper " "bounds, lower bounds, and steps must all match"); + + if (op.reductions().size() != op.getNumResults()) + return op.emitOpError("a reduction must be specified for each output"); + + // Verify reduction ops are all valid + for (Attribute attr : op.reductions()) { + auto intAttr = attr.dyn_cast(); + if (!intAttr || !symbolizeAtomicRMWKind(intAttr.getInt())) + return op.emitOpError("invalid reduction attribute"); } + // Verify that the bound operands are valid dimension/symbols. /// Lower bounds. if (failed(verifyDimAndSymbolIdentifiers(op, op.getLowerBoundsOperands(), @@ -2502,11 +2562,22 @@ llvm::interleaveComma(steps, p); p << ')'; } + if (op.getNumResults()) { + p << " reduce ("; + llvm::interleaveComma(op.reductions(), p, [&](auto &attr) { + AtomicRMWKind sym = + *symbolizeAtomicRMWKind(attr.template cast().getInt()); + p << "\"" << stringifyAtomicRMWKind(sym) << "\""; + }); + p << ") -> (" << op.getResultTypes() << ")"; + } + p.printRegion(op.region(), /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/false); + /*printBlockTerminators=*/op.getNumResults()); p.printOptionalAttrDict( op.getAttrs(), - /*elidedAttrs=*/{AffineParallelOp::getLowerBoundsMapAttrName(), + /*elidedAttrs=*/{AffineParallelOp::getReductionsAttrName(), + AffineParallelOp::getLowerBoundsMapAttrName(), AffineParallelOp::getUpperBoundsMapAttrName(), AffineParallelOp::getStepsAttrName()}); } @@ -2570,6 +2641,40 @@ builder.getI64ArrayAttr(steps)); } + // Parse optional clause of the form: `reduce ("addf", "maxf")`, where the + // quoted strings a member of the enum AtomicRMWKind. + SmallVector reductions; + if (succeeded(parser.parseOptionalKeyword("reduce"))) { + if (parser.parseLParen()) + return failure(); + do { + // Parse a single quoted string via the attribute parsing, and then + // verify it is a member of the enum and convert to it's integer + // representation. + StringAttr attrVal; + NamedAttrList attrStorage; + auto loc = parser.getCurrentLocation(); + if (parser.parseAttribute(attrVal, builder.getNoneType(), "reduce", + attrStorage)) + return failure(); + llvm::Optional reduction = + symbolizeAtomicRMWKind(attrVal.getValue()); + if (!reduction) + return parser.emitError(loc, "invalid reduction value: ") << attrVal; + reductions.push_back(builder.getI64IntegerAttr( + static_cast(reduction.getValue()))); + // While we keep getting commas, keep parsing. + } while (succeeded(parser.parseOptionalComma())); + if (parser.parseRParen()) + return failure(); + } + result.addAttribute(AffineParallelOp::getReductionsAttrName(), + builder.getArrayAttr(reductions)); + + // Parse return types of reductions (if any) + if (parser.parseOptionalArrowTypeList(result.types)) + return failure(); + // Now parse the body. Region *body = result.addRegion(); SmallVector types(ivs.size(), indexType); @@ -2582,6 +2687,30 @@ return success(); } +//===----------------------------------------------------------------------===// +// AffineYieldOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(AffineYieldOp op) { + auto parentOp = op.getParentOp(); + auto results = parentOp->getResults(); + auto operands = op.getOperands(); + + if (!isa(parentOp)) + return op.emitOpError() + << "affine.terminate only terminates If, For or Parallel regions"; + if (parentOp->getNumResults() != op.getNumOperands()) + return op.emitOpError() << "parent of yield must have same number of " + "results as the yield operands"; + for (auto it : llvm::zip(results, operands)) { + if (std::get<0>(it).getType() != std::get<1>(it).getType()) + return op.emitOpError() + << "types mismatch between yield op and its parent"; + } + + return success(); +} + //===----------------------------------------------------------------------===// // AffineVectorLoadOp //===----------------------------------------------------------------------===// @@ -2626,7 +2755,6 @@ if (memrefType.getElementType() != vectorType.getElementType()) return op->emitOpError( "requires memref and vector types of the same elemental type"); - return success(); } diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp --- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp @@ -186,7 +186,7 @@ if (curBegin != block->end()) { // Can't be a terminator because it would have been skipped above. assert(!curBegin->isKnownTerminator() && "can't be a terminator"); - // Exclude the affine terminator - hence, the std::prev. + // Exclude the affine.yield - hence, the std::prev. affineDataCopyGenerate(/*begin=*/curBegin, /*end=*/std::prev(block->end()), copyOptions, /*filterMemRef=*/llvm::None, copyNests); } diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp --- a/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp @@ -114,7 +114,7 @@ // Insert this op in the defined ops list. definedOps.insert(&op); - if (op.getNumOperands() == 0 && !isa(op)) { + if (op.getNumOperands() == 0 && !isa(op)) { LLVM_DEBUG(llvm::dbgs() << "\nNon-constant op with 0 operands\n"); return false; } @@ -199,7 +199,7 @@ for (auto &op : *loopBody) { // We don't hoist for loops. if (!isa(op)) { - if (!isa(op)) { + if (!isa(op)) { if (isOpLoopInvariant(op, indVar, definedOps, opsToHoist)) { opsToMove.push_back(&op); } diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -131,6 +131,11 @@ // Returns success if any hoisting happened. LogicalResult mlir::hoistAffineIfOp(AffineIfOp ifOp, bool *folded) { + // Bail out early if the ifOp returns a result. TODO: Consider how to + // properly support this case. + if (ifOp.getNumResults() != 0) + return failure(); + // Apply canonicalization patterns and folding - this is necessary for the // hoisting check to be correct (operands should be composed), and to be more // effective (no unused operands). Since the pattern rewriter's folding is diff --git a/mlir/test/Dialect/Affine/invalid.mlir b/mlir/test/Dialect/Affine/invalid.mlir --- a/mlir/test/Dialect/Affine/invalid.mlir +++ b/mlir/test/Dialect/Affine/invalid.mlir @@ -266,6 +266,47 @@ // ----- +// CHECK-LABEL: @affine_parallel + +func @affine_parallel(%arg0 : index, %arg1 : index, %arg2 : index) { + %0 = alloc() : memref<100x100xf32> + // expected-error@+1 {{reduction must be specified for each output}} + %1 = affine.parallel (%i, %j) = (0, 0) to (100, 100) step (10, 10) -> (f32) { + %2 = affine.load %0[%i, %j] : memref<100x100xf32> + affine.yield %2 : f32 + } + return +} + +// ----- + +// CHECK-LABEL: @affine_parallel + +func @affine_parallel(%arg0 : index, %arg1 : index, %arg2 : index) { + %0 = alloc() : memref<100x100xf32> + // expected-error@+1 {{invalid reduction value: "bad"}} + %1 = affine.parallel (%i, %j) = (0, 0) to (100, 100) step (10, 10) reduce ("bad") -> (f32) { + %2 = affine.load %0[%i, %j] : memref<100x100xf32> + affine.yield %2 : f32 + } + return +} + +// ----- +// CHECK-LABEL: @affine_parallel + +func @affine_parallel(%arg0 : index, %arg1 : index, %arg2 : index) { + %0 = alloc() : memref<100x100xi32> + %1 = affine.parallel (%i, %j) = (0, 0) to (100, 100) step (10, 10) reduce ("minf") -> (f32) { + %2 = affine.load %0[%i, %j] : memref<100x100xi32> + // expected-error@+1 {{types mismatch between yield op and its parent}} + affine.yield %2 : i32 + } + return +} + +// ----- + func @vector_load_invalid_vector_type() { %0 = alloc() : memref<100xf32> affine.for %i0 = 0 to 16 step 8 { diff --git a/mlir/test/Dialect/Affine/ops.mlir b/mlir/test/Dialect/Affine/ops.mlir --- a/mlir/test/Dialect/Affine/ops.mlir +++ b/mlir/test/Dialect/Affine/ops.mlir @@ -2,7 +2,7 @@ // RUN: mlir-opt -allow-unregistered-dialect %s -mlir-print-op-generic | FileCheck -check-prefix=GENERIC %s // Check that the attributes for the affine operations are round-tripped. -// Check that `affine.terminator` is visible in the generic form. +// Check that `affine.yield` is visible in the generic form. // CHECK-LABEL: @empty func @empty() { // CHECK: affine.for @@ -10,7 +10,7 @@ // // GENERIC: "affine.for"() // GENERIC-NEXT: ^bb0(%{{.*}}: index): - // GENERIC-NEXT: "affine.terminator"() : () -> () + // GENERIC-NEXT: "affine.yield"() : () -> () // GENERIC-NEXT: }) affine.for %i = 0 to 10 { } {some_attr = true} @@ -19,7 +19,7 @@ // CHECK-NEXT: } {some_attr = true} // // GENERIC: "affine.if"() - // GENERIC-NEXT: "affine.terminator"() : () -> () + // GENERIC-NEXT: "affine.yield"() : () -> () // GENERIC-NEXT: }, { // GENERIC-NEXT: }) affine.if affine_set<() : ()> () { @@ -29,10 +29,10 @@ // CHECK: } {some_attr = true} // // GENERIC: "affine.if"() - // GENERIC-NEXT: "affine.terminator"() : () -> () + // GENERIC-NEXT: "affine.yield"() : () -> () // GENERIC-NEXT: }, { // GENERIC-NEXT: "foo"() : () -> () - // GENERIC-NEXT: "affine.terminator"() : () -> () + // GENERIC-NEXT: "affine.yield"() : () -> () // GENERIC-NEXT: }) affine.if affine_set<() : ()> () { } else { @@ -42,19 +42,19 @@ return } -// Check that an explicit affine terminator is not printed in custom format. +// Check that an explicit affine.yield is not printed in custom format. // Check that no extra terminator is introduced. -// CHECK-LABEL: @affine_terminator -func @affine_terminator() { +// CHECK-LABEL: @affine.yield +func @affine.yield() { // CHECK: affine.for // CHECK-NEXT: } // // GENERIC: "affine.for"() ( { // GENERIC-NEXT: ^bb0(%{{.*}}: index): // no predecessors - // GENERIC-NEXT: "affine.terminator"() : () -> () + // GENERIC-NEXT: "affine.yield"() : () -> () // GENERIC-NEXT: }) {lower_bound = #map0, step = 1 : index, upper_bound = #map1} : () -> () affine.for %i = 0 to 10 { - "affine.terminator"() : () -> () + "affine.yield"() : () -> () } return } @@ -153,14 +153,34 @@ // ----- -// CHECK-LABEL: @parallel -// CHECK-SAME: (%[[N:.*]]: index) -func @parallel(%N : index) { +// CHECK-LABEL: func @parallel +// CHECK-SAME: (%[[A:.*]]: memref<100x100xf32>, %[[N:.*]]: index) +func @parallel(%A : memref<100x100xf32>, %N : index) { // CHECK: affine.parallel (%[[I0:.*]], %[[J0:.*]]) = (0, 0) to (symbol(%[[N]]), 100) step (10, 10) affine.parallel (%i0, %j0) = (0, 0) to (symbol(%N), 100) step (10, 10) { - // CHECK-NEXT: affine.parallel (%{{.*}}, %{{.*}}) = (%[[I0]], %[[J0]]) to (%[[I0]] + 10, %[[J0]] + 10) - affine.parallel (%i1, %j1) = (%i0, %j0) to (%i0 + 10, %j0 + 10) { + // CHECK: affine.parallel (%{{.*}}, %{{.*}}) = (%[[I0]], %[[J0]]) to (%[[I0]] + 10, %[[J0]] + 10) reduce ("minf", "maxf") -> (f32, f32) + %0:2 = affine.parallel (%i1, %j1) = (%i0, %j0) to (%i0 + 10, %j0 + 10) reduce ("minf", "maxf") -> (f32, f32) { + %2 = affine.load %A[%i0 + %i0, %j0 + %j1] : memref<100x100xf32> + affine.yield %2, %2 : f32, f32 } } return } + +// ----- + +// CHECK-LABEL: func @affine_if +func @affine_if() -> f32 { + // CHECK: %[[ZERO:.*]] = constant {{.*}} : f32 + %zero = constant 0.0 : f32 + // CHECK: %[[OUT:.*]] = affine.if {{.*}}() -> f32 { + %0 = affine.if affine_set<() : ()> () -> f32 { + // CHECK: affine.yield %[[ZERO]] : f32 + affine.yield %zero : f32 + } else { + // CHECK: affine.yield %[[ZERO]] : f32 + affine.yield %zero : f32 + } + // CHECK: return %[[OUT]] : f32 + return %0 : f32 +} diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -474,7 +474,7 @@ func @return_inside_loop() { affine.for %i = 1 to 100 { - // expected-error@-1 {{op expects regions to end with 'affine.terminator', found 'std.return'}} + // expected-error@-1 {{op expects regions to end with 'affine.yield', found 'std.return'}} // expected-note@-2 {{in custom textual format, the absence of terminator implies}} return } diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -842,12 +842,12 @@ "affine.if"(%c, %N, %c) ({ // CHECK-NEXT: "add" %y = "add"(%c, %N) : (index, index) -> index - "affine.terminator"() : () -> () + "affine.yield"() : () -> () // CHECK-NEXT: } else { }, { // The else region. // CHECK-NEXT: "add" %z = "add"(%c, %c) : (index, index) -> index - "affine.terminator"() : () -> () + "affine.yield"() : () -> () }) { condition = #set0 } : (index, index, index) -> () return