diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -174,30 +174,73 @@ return } ``` + `affine.for` can also operate on loop-carried variables and return the final + values after loop termination. The initial values of the variables are + passed as additional SSA operands to the "affine.for" following the 3 loop + control values lower bound, upper bound and step. The operation region has + equivalent arguments for each variable representing the value of the variable + at the current iteration. + + The region must terminate with a `affine.yield` that passes all the current + iteration variables to the next iteration, or to the `affine.for` result, if + at the last iteration. + + `affine.for` results hold the final values after the last iteration. + For example, to sum-reduce a memref: + + ```mlir + func @reduce(%buffer: memref<1024xf32>) -> (f32) { + // Initial sum set to 0. + %sum_0 = constant 0.0 : f32 + // iter_args binds initial values to the loop's region arguments. + %sum = affine.for %i = 0 to 10 step 2 + iter_args(%sum_iter = %sum_0) -> (f32) { + %t = affine.load %buffer[%i] : memref<1024xf32> + %sum_next = addf %sum_iter, %t : f32 + // Yield current iteration sum to next iteration %sum_iter or to %sum + // if final iteration. + affine.yield %sum_next : f32 + } + return %sum : f32 + } + ``` + If the `affine.for` defines any values, a yield terminator must be + explicitly present. The number and types of the "affine.for" results must + match the initial values in the `iter_args` binding and the yield operands. }]; let arguments = (ins Variadic); + let results = (outs Variadic:$results); let regions = (region SizedRegion<1>:$region); let skipDefaultBuilders = 1; let builders = [ OpBuilder<"OpBuilder &builder, OperationState &result, " "int64_t lowerBound, int64_t upperBound, int64_t step = 1, " - "function_ref bodyBuilder " - " = nullptr">, + "ValueRange iterArgs = llvm::None, function_ref bodyBuilder = nullptr">, OpBuilder<"OpBuilder &builder, OperationState &result, " "ValueRange lbOperands, AffineMap lbMap, " "ValueRange ubOperands, AffineMap ubMap, " - "int64_t step = 1, " - "function_ref bodyBuilder " - " = nullptr"> + "int64_t step = 1, ValueRange iterArgs = llvm::None, " + "function_ref " + "bodyBuilder = nullptr"> ]; let extraClassDeclaration = [{ + using BodyBuilderFn = + function_ref; + static StringRef getStepAttrName() { return "step"; } static StringRef getLowerBoundAttrName() { return "lower_bound"; } static StringRef getUpperBoundAttrName() { return "upper_bound"; } Value getInductionVar() { return getBody()->getArgument(0); } + Block::BlockArgListType getRegionIterArgs() { + return getBody()->getArguments().drop_front(); + } + Operation::operand_range getIterOperands() { + return getOperands().drop_front(getNumControlOperands()); + } // TODO: provide iterators for the lower and upper bound operands // if the current access via getLowerBound(), getUpperBound() is too slow. @@ -251,6 +294,16 @@ IntegerAttr::get(IndexType::get(context), step)); } + /// Returns number of region arguments for loop-carried values. + unsigned getNumRegionIterArgs() { + return getBody()->getNumArguments() - 1; + } + /// Number of operands controlling the loop: lb and ub. + int64_t getNumControlOperands() { return getOperation()->getNumOperands() - getNumIterOperands(); } + + /// Get the number of loop-carried values. + int64_t getNumIterOperands(); + /// Returns true if the lower bound is constant. bool hasConstantLowerBound(); /// Returns true if the upper bound is constant. @@ -540,7 +593,7 @@ }]; } -def AffineParallelOp : Affine_Op<"parallel", +def AffineParallelOp : Affine_Op<"parallel", [ImplicitAffineTerminator, RecursiveSideEffects, DeclareOpInterfaceMethods]> { let summary = "multi-index parallel band operation"; @@ -569,7 +622,7 @@ Note: Calling AffineParallelOp::build will create the required region and block, and insert the required terminator if it is trivial (i.e. no values - are yielded). Parsing will also create the required region, block, and + are yielded). Parsing will also create the required region, block, and terminator, even when they are missing from the textual representation. Example (3x3 valid convolution): diff --git a/mlir/lib/Dialect/Affine/EDSC/Builders.cpp b/mlir/lib/Dialect/Affine/EDSC/Builders.cpp --- a/mlir/lib/Dialect/Affine/EDSC/Builders.cpp +++ b/mlir/lib/Dialect/Affine/EDSC/Builders.cpp @@ -47,8 +47,9 @@ // updating the scoped context. builder.create( loc, lbs, builder.getMultiDimIdentityMap(lbs.size()), ubs, - builder.getMultiDimIdentityMap(ubs.size()), step, - [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv) { + builder.getMultiDimIdentityMap(ubs.size()), step, /*iterArgs=*/llvm::None, + [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv, + ValueRange iterArgs) { if (bodyBuilderFn) { ScopedContext nestedContext(nestedBuilder, nestedLoc); OpBuilder::InsertionGuard guard(nestedBuilder); diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -1173,10 +1173,10 @@ // AffineForOp //===----------------------------------------------------------------------===// -void AffineForOp::build( - OpBuilder &builder, OperationState &result, ValueRange lbOperands, - AffineMap lbMap, ValueRange ubOperands, AffineMap ubMap, int64_t step, - function_ref bodyBuilder) { +void AffineForOp::build(OpBuilder &builder, OperationState &result, + ValueRange lbOperands, AffineMap lbMap, + ValueRange ubOperands, AffineMap ubMap, int64_t step, + ValueRange iterArgs, BodyBuilderFn bodyBuilder) { assert(((!lbMap && lbOperands.empty()) || lbOperands.size() == lbMap.getNumInputs()) && "lower bound operand count does not match the affine map"); @@ -1185,6 +1185,10 @@ "upper bound operand count does not match the affine map"); assert(step > 0 && "step has to be a positive integer constant"); + result.addOperands(iterArgs); + for (Value v : iterArgs) + result.addTypes(v.getType()); + // Add an attribute for the step. result.addAttribute(getStepAttrName(), builder.getIntegerAttr(builder.getIndexType(), step)); @@ -1200,53 +1204,71 @@ // Create a region and a block for the body. The argument of the region is // the loop induction variable. Region *bodyRegion = result.addRegion(); - Block *body = new Block; - Value inductionVar = body->addArgument(IndexType::get(builder.getContext())); - bodyRegion->push_back(body); - if (bodyBuilder) { - OpBuilder::InsertionGuard guard(builder); - builder.setInsertionPointToStart(body); - bodyBuilder(builder, result.location, inductionVar); - } else { + bodyRegion->push_back(new Block); + Block &bodyBlock = bodyRegion->front(); + Value inductionVar = bodyBlock.addArgument(builder.getIndexType()); + for (Value v : iterArgs) + bodyBlock.addArgument(v.getType()); + + // Create the default terminator if the builder is not provided and if the + // iteration arguments are not provided. Otherwise, leave this to the caller + // because we don't know which values to return from the loop. + if (iterArgs.empty() && !bodyBuilder) { ensureTerminator(*bodyRegion, builder, result.location); + } else if (bodyBuilder) { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(&bodyBlock); + bodyBuilder(builder, result.location, inductionVar, + bodyBlock.getArguments().drop_front()); } } -void AffineForOp::build( - OpBuilder &builder, OperationState &result, int64_t lb, int64_t ub, - int64_t step, - function_ref bodyBuilder) { +void AffineForOp::build(OpBuilder &builder, OperationState &result, int64_t lb, + int64_t ub, int64_t step, ValueRange iterArgs, + BodyBuilderFn bodyBuilder) { auto lbMap = AffineMap::getConstantMap(lb, builder.getContext()); auto ubMap = AffineMap::getConstantMap(ub, builder.getContext()); - return build(builder, result, {}, lbMap, {}, ubMap, step, bodyBuilder); + return build(builder, result, {}, lbMap, {}, ubMap, step, iterArgs, + bodyBuilder); } static LogicalResult verify(AffineForOp op) { // Check that the body defines as single block argument for the induction // variable. auto *body = op.getBody(); - if (body->getNumArguments() != 1 || !body->getArgument(0).getType().isIndex()) + if (body->getNumArguments() == 0 || !body->getArgument(0).getType().isIndex()) return op.emitOpError( "expected body to have a single index argument for the " "induction variable"); - // Verify that there are enough operands for the bounds. - AffineMap lowerBoundMap = op.getLowerBoundMap(), - upperBoundMap = op.getUpperBoundMap(); - if (op.getNumOperands() != - (lowerBoundMap.getNumInputs() + upperBoundMap.getNumInputs())) - return op.emitOpError( - "operand count must match with affine map dimension and symbol count"); - // Verify that the bound operands are valid dimension/symbols. /// Lower bound. - if (failed(verifyDimAndSymbolIdentifiers(op, op.getLowerBoundOperands(), - op.getLowerBoundMap().getNumDims()))) - return failure(); + if (op.getLowerBoundMap().getNumInputs() != 0) + if (failed( + verifyDimAndSymbolIdentifiers(op, op.getLowerBoundOperands(), + op.getLowerBoundMap().getNumDims()))) + return failure(); /// Upper bound. - if (failed(verifyDimAndSymbolIdentifiers(op, op.getUpperBoundOperands(), - op.getUpperBoundMap().getNumDims()))) - return failure(); + if (op.getUpperBoundMap().getNumInputs() != 0) + if (failed( + verifyDimAndSymbolIdentifiers(op, op.getUpperBoundOperands(), + op.getUpperBoundMap().getNumDims()))) + return failure(); + + unsigned opNumResults = op.getNumResults(); + if (opNumResults == 0) + return success(); + + // If ForOp defines values, check that the number and types of the defined + // values match ForOp initial iter operands and backedge basic block + // arguments. + if (op.getNumIterOperands() != opNumResults) + return op.emitOpError( + "mismatch in number of loop-carried values and defined values"); + if (op.getNumRegionIterArgs() != opNumResults) + return op.emitOpError( + "mismatch in number of basic block args and defined values"); + return success(); } @@ -1375,9 +1397,34 @@ "expected step to be representable as a positive signed integer"); } + // Parse the optional initial iteration arguments. + SmallVector regionArgs, operands; + SmallVector argTypes; + regionArgs.push_back(inductionVariable); + + if (succeeded(parser.parseOptionalKeyword("iter_args"))) { + // Parse assignment list and results type list. + if (parser.parseAssignmentList(regionArgs, operands) || + parser.parseArrowTypeList(result.types)) + return failure(); + // Resolve input operands. + for (auto operand_type : llvm::zip(operands, result.types)) + if (parser.resolveOperand(std::get<0>(operand_type), + std::get<1>(operand_type), result.operands)) + return failure(); + } + // Induction variable. + Type indexType = builder.getIndexType(); + argTypes.push_back(indexType); + // Loop carried variables + argTypes.append(result.types.begin(), result.types.end()); // Parse the body region. Region *body = result.addRegion(); - if (parser.parseRegion(*body, inductionVariable, builder.getIndexType())) + if (regionArgs.size() != argTypes.size()) + return parser.emitError( + parser.getNameLoc(), + "mismatch in number of loop-carried values and defined values"); + if (parser.parseRegion(*body, regionArgs, argTypes)) return failure(); AffineForOp::ensureTerminator(*body, builder, result.location); @@ -1427,6 +1474,16 @@ map.getNumDims(), p); } +int64_t AffineForOp::getNumIterOperands() { + AffineMap lbMap = getLowerBoundMapAttr().getValue(); + AffineMap ubMap = getUpperBoundMapAttr().getValue(); + + int64_t lbOperandNum = lbMap.getNumInputs(); + int64_t ubOperandNum = ubMap.getNumInputs(); + + return getOperation()->getNumOperands() - (lbOperandNum + ubOperandNum); +} + static void print(OpAsmPrinter &p, AffineForOp op) { p << op.getOperationName() << ' '; p.printOperand(op.getBody()->getArgument(0)); @@ -1437,9 +1494,22 @@ if (op.getStep() != 1) p << " step " << op.getStep(); + + bool printBlockTerminators = false; + if (op.getNumIterOperands() > 0) { + p << " iter_args("; + auto regionArgs = op.getRegionIterArgs(); + auto operands = op.getIterOperands(); + + llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) { + p << std::get<0>(it) << " = " << std::get<1>(it); + }); + p << ") -> (" << op.getResultTypes() << ")"; + printBlockTerminators = true; + } + p.printRegion(op.region(), - /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/false); + /*printEntryBlockArgs=*/false, printBlockTerminators); p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{op.getLowerBoundAttrName(), op.getUpperBoundAttrName(), @@ -1710,8 +1780,8 @@ ivs.reserve(lbs.size()); for (unsigned i = 0, e = lbs.size(); i < e; ++i) { // Callback for creating the loop body, always creates the terminator. - auto loopBody = [&](OpBuilder &nestedBuilder, Location nestedLoc, - Value iv) { + auto loopBody = [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv, + ValueRange iterArgs) { ivs.push_back(iv); // In the innermost loop, call the body builder. if (i == e - 1 && bodyBuilderFn) { @@ -1729,16 +1799,19 @@ } /// Creates an affine loop from the bounds known to be constants. -static AffineForOp buildAffineLoopFromConstants( - OpBuilder &builder, Location loc, int64_t lb, int64_t ub, int64_t step, - function_ref bodyBuilderFn) { - return builder.create(loc, lb, ub, step, bodyBuilderFn); +static AffineForOp +buildAffineLoopFromConstants(OpBuilder &builder, Location loc, int64_t lb, + int64_t ub, int64_t step, + AffineForOp::BodyBuilderFn bodyBuilderFn) { + return builder.create(loc, lb, ub, step, /*iterArgs=*/llvm::None, + bodyBuilderFn); } /// Creates an affine loop from the bounds that may or may not be constants. -static AffineForOp buildAffineLoopFromValues( - OpBuilder &builder, Location loc, Value lb, Value ub, int64_t step, - function_ref bodyBuilderFn) { +static AffineForOp +buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub, + int64_t step, + AffineForOp::BodyBuilderFn bodyBuilderFn) { auto lbConst = lb.getDefiningOp(); auto ubConst = ub.getDefiningOp(); if (lbConst && ubConst) @@ -1747,7 +1820,7 @@ bodyBuilderFn); return builder.create(loc, lb, builder.getDimIdentityMap(), ub, builder.getDimIdentityMap(), step, - bodyBuilderFn); + /*iterArgs=*/llvm::None, bodyBuilderFn); } void mlir::buildAffineLoopNest( diff --git a/mlir/test/Dialect/Affine/invalid.mlir b/mlir/test/Dialect/Affine/invalid.mlir --- a/mlir/test/Dialect/Affine/invalid.mlir +++ b/mlir/test/Dialect/Affine/invalid.mlir @@ -379,3 +379,14 @@ return } +// ----- + +func @affine_for(%buffer: memref<1024xf32>) -> f32 { + %sum_0 = constant 0.0 : f32 + // expected-error@+1 {{mismatch in number of loop-carried values and defined values}} + %res = affine.for %i = 0 to 10 step 2 iter_args(%sum_iter = %sum_0) -> (f32, f32) { + %t = affine.load %buffer[%i] : memref<1024xf32> + affine.yield %t : f32 + } + return %res : f32 +} diff --git a/mlir/test/Dialect/Affine/ops.mlir b/mlir/test/Dialect/Affine/ops.mlir --- a/mlir/test/Dialect/Affine/ops.mlir +++ b/mlir/test/Dialect/Affine/ops.mlir @@ -184,3 +184,53 @@ // CHECK: return %[[OUT]] : f32 return %0 : f32 } + +// ----- + +// Test affine.for with yield values. + +#set = affine_set<(d0): (d0 - 10 >= 0)> + +// CHECK-LABEL: func @yield_loop +func @yield_loop(%buffer: memref<1024xf32>) -> f32 { + %sum_init_0 = constant 0.0 : f32 + %res = affine.for %i = 0 to 10 step 2 iter_args(%sum_iter = %sum_init_0) -> f32 { + %t = affine.load %buffer[%i] : memref<1024xf32> + %sum_next = affine.if #set(%i) -> (f32) { + %new_sum = addf %sum_iter, %t : f32 + affine.yield %new_sum : f32 + } else { + affine.yield %sum_iter : f32 + } + affine.yield %sum_next : f32 + } + return %res : f32 +} +// CHECK: %[[const_0:.*]] = constant 0.000000e+00 : f32 +// CHECK-NEXT: %[[output:.*]] = affine.for %{{.*}} = 0 to 10 step 2 iter_args(%{{.*}} = %[[const_0]]) -> (f32) { +// CHECK: affine.if #set0(%{{.*}}) -> f32 { +// CHECK: affine.yield %{{.*}} : f32 +// CHECK-NEXT: } else { +// CHECK-NEXT: affine.yield %{{.*}} : f32 +// CHECK-NEXT: } +// CHECK-NEXT: affine.yield %{{.*}} : f32 +// CHECK-NEXT: } +// CHECK-NEXT: return %[[output]] : f32 + +// CHECK-LABEL: func @multiple_yield +func @multiple_yield(%buffer: memref<1024xf32>) -> (f32, f32) { + %init_0 = constant 0.0 : f32 + %res1, %res2 = affine.for %i = 0 to 10 step 2 iter_args(%iter_arg1 = %init_0, %iter_arg2 = %init_0) -> (f32, f32) { + %t = affine.load %buffer[%i] : memref<1024xf32> + %ret1 = addf %t, %iter_arg1 : f32 + %ret2 = addf %t, %iter_arg2 : f32 + affine.yield %ret1, %ret2 : f32, f32 + } + return %res1, %res2 : f32, f32 +} +// CHECK: %[[const_0:.*]] = constant 0.000000e+00 : f32 +// CHECK-NEXT: %[[output:[0-9]+]]:2 = affine.for %{{.*}} = 0 to 10 step 2 iter_args(%[[iter_arg1:.*]] = %[[const_0]], %[[iter_arg2:.*]] = %[[const_0]]) -> (f32, f32) { +// CHECK: %[[res1:.*]] = addf %{{.*}}, %[[iter_arg1]] : f32 +// CHECK-NEXT: %[[res2:.*]] = addf %{{.*}}, %[[iter_arg2]] : f32 +// CHECK-NEXT: affine.yield %[[res1]], %[[res2]] : f32, f32 +// CHECK-NEXT: }