diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp --- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp +++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp @@ -334,7 +334,13 @@ LogicalResult matchAndRewrite(AffineYieldOp op, PatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op); + if (isa(op.getParentOp())) { + // scf.parallel does not yield any values via its terminator scf.yield but + // models reductions differently using additional ops in its region. + rewriter.replaceOpWithNewOp(op); + return success(); + } + rewriter.replaceOpWithNewOp(op, op.operands()); return success(); } }; @@ -349,14 +355,55 @@ Value lowerBound = lowerAffineLowerBound(op, rewriter); Value upperBound = lowerAffineUpperBound(op, rewriter); Value step = rewriter.create(loc, op.getStep()); - auto f = rewriter.create(loc, lowerBound, upperBound, step); - rewriter.eraseBlock(f.getBody()); - rewriter.inlineRegionBefore(op.region(), f.region(), f.region().end()); - rewriter.eraseOp(op); + auto scfForOp = rewriter.create(loc, lowerBound, upperBound, + step, op.getIterOperands()); + rewriter.eraseBlock(scfForOp.getBody()); + rewriter.inlineRegionBefore(op.region(), scfForOp.region(), + scfForOp.region().end()); + rewriter.replaceOp(op, scfForOp.results()); return success(); } }; +/// Returns the identity value associated with an AtomicRMWKind op. +static Value getIdentityValue(AtomicRMWKind op, OpBuilder &builder, + Location loc) { + switch (op) { + case AtomicRMWKind::addf: + return builder.create(loc, builder.getF32FloatAttr(0)); + case AtomicRMWKind::addi: + return builder.create(loc, builder.getI32IntegerAttr(0)); + case AtomicRMWKind::mulf: + return builder.create(loc, builder.getF32FloatAttr(1)); + case AtomicRMWKind::muli: + return builder.create(loc, builder.getI32IntegerAttr(1)); + // TODO: Add remaining reduction operations. + default: + emitOptionalError(loc, "Reduction operation type not supported"); + } + return nullptr; +} + +/// Returns the value of reduction operation associated with an AtomicRMWKind +/// op. +static Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, + Value lhs, Value rhs) { + switch (op) { + case AtomicRMWKind::addf: + return builder.create(loc, lhs, rhs); + case AtomicRMWKind::addi: + return builder.create(loc, lhs, rhs); + case AtomicRMWKind::mulf: + return builder.create(loc, lhs, rhs); + case AtomicRMWKind::muli: + return builder.create(loc, lhs, rhs); + // TODO: Add remaining reduction operations. + default: + emitOptionalError(loc, "Reduction operation type not supported"); + } + return nullptr; +} + /// Convert an `affine.parallel` (loop nest) operation into a `scf.parallel` /// operation. class AffineParallelLowering : public OpRewritePattern { @@ -369,12 +416,13 @@ SmallVector steps; SmallVector upperBoundTuple; SmallVector lowerBoundTuple; + SmallVector identityVals; // Finding lower and upper bound by expanding the map expression. // Checking if expandAffineMap is not giving NULL. - Optional> upperBound = expandAffineMap( - rewriter, loc, op.upperBoundsMap(), op.getUpperBoundsOperands()); Optional> lowerBound = expandAffineMap( rewriter, loc, op.lowerBoundsMap(), op.getLowerBoundsOperands()); + Optional> upperBound = expandAffineMap( + rewriter, loc, op.upperBoundsMap(), op.getUpperBoundsOperands()); if (!lowerBound || !upperBound) return failure(); upperBoundTuple = *upperBound; @@ -383,13 +431,60 @@ for (Attribute step : op.steps()) steps.push_back(rewriter.create( loc, step.cast().getInt())); - // Creating empty scf.parallel op body with appropriate bounds. - auto parallelOp = rewriter.create(loc, lowerBoundTuple, - upperBoundTuple, steps); - rewriter.eraseBlock(parallelOp.getBody()); - rewriter.inlineRegionBefore(op.region(), parallelOp.region(), - parallelOp.region().end()); - rewriter.eraseOp(op); + // Get the terminator op. + Operation *affineParOpTerminator = op.getBody()->getTerminator(); + scf::ParallelOp parOp; + if (op.results().empty()) { + // Case with no reduction operations/return values. + parOp = rewriter.create(loc, lowerBoundTuple, + upperBoundTuple, steps, nullptr); + rewriter.eraseBlock(parOp.getBody()); + rewriter.inlineRegionBefore(op.region(), parOp.region(), + parOp.region().end()); + rewriter.replaceOp(op, parOp.results()); + return success(); + } + // Case with affine.parallel having reduction operations/return values. + // scf.parallel handles the reduction operation differently unlike + // affine.parallel. + ArrayRef reductions = op.reductions().getValue(); + for (Attribute reduction : reductions) { + // For each of the reduction operations get the identity values for + // initialization of the result values. + Optional reductionOp = symbolizeAtomicRMWKind( + static_cast(reduction.cast().getInt())); + assert(reductionOp.hasValue() && + "Reduction Operation cannot be of None Type"); + AtomicRMWKind reductionOpValue = reductionOp.getValue(); + identityVals.push_back(getIdentityValue(reductionOpValue, rewriter, loc)); + } + parOp = rewriter.create( + loc, lowerBoundTuple, upperBoundTuple, steps, identityVals, nullptr); + + // Copy the same body as of the AffineParallel operation. + rewriter.eraseBlock(parOp.getBody()); + rewriter.inlineRegionBefore(op.region(), parOp.region(), + parOp.region().end()); + assert(reductions.size() == affineParOpTerminator->getNumOperands() && + "Unequal number of reductions and operands."); + for (unsigned i = 0, end = reductions.size(); i < end; i++) { + // For each of the reduction operations get the respective mlir::Value. + Optional reductionOp = + symbolizeAtomicRMWKind(reductions[i].cast().getInt()); + assert(reductionOp.hasValue() && + "Reduction Operation cannot be of None Type"); + AtomicRMWKind reductionOpValue = reductionOp.getValue(); + rewriter.setInsertionPoint(&parOp.getBody()->back()); + auto reduceOp = rewriter.create( + loc, affineParOpTerminator->getOperand(i)); + rewriter.setInsertionPointToEnd(&reduceOp.reductionOperator().front()); + Value resultReduction = + getReductionOp(reductionOpValue, rewriter, loc, + reduceOp.reductionOperator().front().getArgument(0), + reduceOp.reductionOperator().front().getArgument(1)); + rewriter.create(loc, resultReduction); + } + rewriter.replaceOp(op, parOp.results()); return success(); } }; diff --git a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir --- a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir +++ b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir @@ -26,6 +26,30 @@ ///////////////////////////////////////////////////////////////////// +func @for_with_yield(%buffer: memref<1024xf32>) -> (f32) { + %sum_0 = constant 0.0 : f32 + %sum = affine.for %i = 0 to 10 step 2 iter_args(%sum_iter = %sum_0) -> (f32) { + %t = affine.load %buffer[%i] : memref<1024xf32> + %sum_next = addf %sum_iter, %t : f32 + affine.yield %sum_next : f32 + } + return %sum : f32 +} + +// CHECK-LABEL: func @for_with_yield +// CHECK: %[[INIT_SUM:.*]] = constant 0.000000e+00 : f32 +// CHECK-NEXT: %[[LOWER:.*]] = constant 0 : index +// CHECK-NEXT: %[[UPPER:.*]] = constant 10 : index +// CHECK-NEXT: %[[STEP:.*]] = constant 2 : index +// CHECK-NEXT: %[[SUM:.*]] = scf.for %[[IV:.*]] = %[[LOWER]] to %[[UPPER]] step %[[STEP]] iter_args(%[[SUM_ITER:.*]] = %[[INIT_SUM]]) -> (f32) { +// CHECK-NEXT: load +// CHECK-NEXT: %[[SUM_NEXT:.*]] = addf +// CHECK-NEXT: scf.yield %[[SUM_NEXT]] : f32 +// CHECK-NEXT: } +// CHECK-NEXT: return %[[SUM]] : f32 + +///////////////////////////////////////////////////////////////////// + func private @pre(index) -> () func private @body2(index, index) -> () func private @post(index) -> () @@ -674,3 +698,104 @@ // CHECK: %[[A4:.*]] = load %[[ARG2]][%[[arg8]], %[[arg7]]] : memref<100x100xf32> // CHECK: mulf %[[A3]], %[[A4]] : f32 // CHECK: scf.yield + +///////////////////////////////////////////////////////////////////// + +func @affine_parallel_simple(%arg0: memref<3x3xf32>, %arg1: memref<3x3xf32>) -> (memref<3x3xf32>) { + %O = alloc() : memref<3x3xf32> + affine.parallel (%kx, %ky) = (0, 0) to (2, 2) { + %1 = affine.load %arg0[%kx, %ky] : memref<3x3xf32> + %2 = affine.load %arg1[%kx, %ky] : memref<3x3xf32> + %3 = mulf %1, %2 : f32 + affine.store %3, %O[%kx, %ky] : memref<3x3xf32> + } + return %O : memref<3x3xf32> +} +// CHECK-LABEL: func @affine_parallel_simple +// CHECK: %[[LOWER_1:.*]] = constant 0 : index +// CHECK-NEXT: %[[LOWER_2:.*]] = constant 0 : index +// CHECK-NEXT: %[[UPPER_1:.*]] = constant 2 : index +// CHECK-NEXT: %[[UPPER_2:.*]] = constant 2 : index +// CHECK-NEXT: %[[STEP_1:.*]] = constant 1 : index +// CHECK-NEXT: %[[STEP_2:.*]] = constant 1 : index +// CHECK-NEXT: scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[LOWER_1]], %[[LOWER_2]]) to (%[[UPPER_1]], %[[UPPER_2]]) step (%[[STEP_1]], %[[STEP_2]]) { +// CHECK-NEXT: %[[VAL_1:.*]] = load +// CHECK-NEXT: %[[VAL_2:.*]] = load +// CHECK-NEXT: %[[PRODUCT:.*]] = mulf +// CHECK-NEXT: store +// CHECK-NEXT: scf.yield +// CHECK-NEXT: } +// CHECK-NEXT: return +// CHECK-NEXT: } + +///////////////////////////////////////////////////////////////////// + +func @affine_parallel_simple_dynamic_bounds(%arg0: memref, %arg1: memref, %arg2: memref) { + %c_0 = constant 0 : index + %output_dim = dim %arg0, %c_0 : memref + affine.parallel (%kx, %ky) = (%c_0, %c_0) to (%output_dim, %output_dim) { + %1 = affine.load %arg0[%kx, %ky] : memref + %2 = affine.load %arg1[%kx, %ky] : memref + %3 = mulf %1, %2 : f32 + affine.store %3, %arg2[%kx, %ky] : memref + } + return +} +// CHECK-LABEL: func @affine_parallel_simple_dynamic_bounds +// CHECK-SAME: %[[ARG_0:.*]]: memref, %[[ARG_1:.*]]: memref, %[[ARG_2:.*]]: memref +// CHECK: %[[DIM_INDEX:.*]] = constant 0 : index +// CHECK-NEXT: %[[UPPER:.*]] = dim %[[ARG_0]], %[[DIM_INDEX]] : memref +// CHECK-NEXT: %[[LOWER_1:.*]] = constant 0 : index +// CHECK-NEXT: %[[LOWER_2:.*]] = constant 0 : index +// CHECK-NEXT: %[[STEP_1:.*]] = constant 1 : index +// CHECK-NEXT: %[[STEP_2:.*]] = constant 1 : index +// CHECK-NEXT: scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[LOWER_1]], %[[LOWER_2]]) to (%[[UPPER]], %[[UPPER]]) step (%[[STEP_1]], %[[STEP_2]]) { +// CHECK-NEXT: %[[VAL_1:.*]] = load +// CHECK-NEXT: %[[VAL_2:.*]] = load +// CHECK-NEXT: %[[PRODUCT:.*]] = mulf +// CHECK-NEXT: store +// CHECK-NEXT: scf.yield +// CHECK-NEXT: } +// CHECK-NEXT: return +// CHECK-NEXT: } + +///////////////////////////////////////////////////////////////////// + +func @affine_parallel_with_reductions(%arg0: memref<3x3xf32>, %arg1: memref<3x3xf32>) -> (f32, f32) { + %0:2 = affine.parallel (%kx, %ky) = (0, 0) to (2, 2) reduce ("addf", "mulf") -> (f32, f32) { + %1 = affine.load %arg0[%kx, %ky] : memref<3x3xf32> + %2 = affine.load %arg1[%kx, %ky] : memref<3x3xf32> + %3 = mulf %1, %2 : f32 + %4 = addf %1, %2 : f32 + affine.yield %3, %4 : f32, f32 + } + return %0#0, %0#1 : f32, f32 +} +// CHECK-LABEL: func @affine_parallel_with_reductions +// CHECK: %[[LOWER_1:.*]] = constant 0 : index +// CHECK-NEXT: %[[LOWER_2:.*]] = constant 0 : index +// CHECK-NEXT: %[[UPPER_1:.*]] = constant 2 : index +// CHECK-NEXT: %[[UPPER_2:.*]] = constant 2 : index +// CHECK-NEXT: %[[STEP_1:.*]] = constant 1 : index +// CHECK-NEXT: %[[STEP_2:.*]] = constant 1 : index +// CHECK-NEXT: %[[INIT_1:.*]] = constant 0.000000e+00 : f32 +// CHECK-NEXT: %[[INIT_2:.*]] = constant 1.000000e+00 : f32 +// CHECK-NEXT: %[[RES:.*]] = scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[LOWER_1]], %[[LOWER_2]]) to (%[[UPPER_1]], %[[UPPER_2]]) step (%[[STEP_1]], %[[STEP_2]]) init (%[[INIT_1]], %[[INIT_2]]) -> (f32, f32) { +// CHECK-NEXT: %[[VAL_1:.*]] = load +// CHECK-NEXT: %[[VAL_2:.*]] = load +// CHECK-NEXT: %[[PRODUCT:.*]] = mulf +// CHECK-NEXT: %[[SUM:.*]] = addf +// CHECK-NEXT: scf.reduce(%[[PRODUCT]]) : f32 { +// CHECK-NEXT: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32): +// CHECK-NEXT: %[[RES:.*]] = addf +// CHECK-NEXT: scf.reduce.return %[[RES]] : f32 +// CHECK-NEXT: } +// CHECK-NEXT: scf.reduce(%[[SUM]]) : f32 { +// CHECK-NEXT: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32): +// CHECK-NEXT: %[[RES:.*]] = mulf +// CHECK-NEXT: scf.reduce.return %[[RES]] : f32 +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield +// CHECK-NEXT: } +// CHECK-NEXT: return +// CHECK-NEXT: }