Index: mlir/include/mlir/Dialect/Arith/IR/Arith.h =================================================================== --- mlir/include/mlir/Dialect/Arith/IR/Arith.h +++ mlir/include/mlir/Dialect/Arith/IR/Arith.h @@ -128,6 +128,9 @@ Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc); +/// Checks if `value` is an identity value assocated with an AtomicRMWKind op. +bool isIdentityValue(AtomicRMWKind op, Value value); + /// Returns the value obtained by applying the reduction operation kind /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`. Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Index: mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp =================================================================== --- mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp +++ mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp @@ -1205,6 +1205,38 @@ } }; +/// Checks every iterOperand of `forOp` from `reductions`. If an iterOperand is +/// not a neutral element for the corresponding reduction kind, replaces the +/// iterOperand with a neutral element and adds an op after `forOp` to combine +/// the original iterOperand and the related result of `forOp`. +static void +replaceIterOperandWithNeutralElement(AffineForOp forOp, + ArrayRef reductions) { + unsigned numIterArgs = forOp.getNumIterOperands(); + if (numIterArgs == 0) + return; + assert(reductions.size() == numIterArgs); + OpBuilder builder(forOp); + auto loc = forOp.getLoc(); + auto iterOperands = forOp.getIterOperands(); + for (const LoopReduction &reduction : reductions) { + unsigned pos = reduction.iterArgPosition; + arith::AtomicRMWKind kind = reduction.kind; + Value iterOperand = iterOperands[pos]; + if (isIdentityValue(kind, iterOperand)) + continue; + builder.setInsertionPoint(forOp); + Value neutralElem = + getIdentityValue(kind, iterOperand.getType(), builder, loc); + forOp.setOperand(pos + forOp.getNumControlOperands(), neutralElem); + builder.setInsertionPointAfter(forOp); + auto res = forOp.getResult(pos); + // Combine the original initial value with the result. + Value newRes = getReductionOp(kind, builder, loc, res, iterOperand); + res.replaceAllUsesExcept(newRes, newRes.getDefiningOp()); + } +} + /// Unrolls and jams this loop by the specified factor. LogicalResult mlir::loopUnrollJamByFactor(AffineForOp forOp, uint64_t unrollJamFactor) { @@ -1247,8 +1279,12 @@ // Get supported reductions to be used for creating reduction ops at the end. SmallVector reductions; - if (forOp.getNumIterOperands() > 0) + if (forOp.getNumIterOperands() > 0) { getSupportedReductions(forOp, reductions); + // Each iterOperand of `forOp` from `reductions` should be a neutral element + // for the corresponding reduction kind. + replaceIterOperandWithNeutralElement(forOp, reductions); + } // Generate the cleanup loop if trip count isn't a multiple of // unrollJamFactor. Index: mlir/lib/Dialect/Arith/IR/ArithOps.cpp =================================================================== --- mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -2338,10 +2338,29 @@ /// Returns the identity value associated with an AtomicRMWKind op. Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc) { - Attribute attr = getIdentityValueAttr(op, resultType, builder, loc); + Type scalarTy = getElementTypeOrSelf(resultType); + Attribute attr = getIdentityValueAttr(op, scalarTy, builder, loc); + if (scalarTy != resultType) { + attr = DenseElementsAttr::get(resultType, attr); + } return builder.create(loc, attr); } +/// Checks if `value` is an identity value assocated with an AtomicRMWKind op. +bool mlir::arith::isIdentityValue(AtomicRMWKind op, Value value) { + auto constOp = dyn_cast_or_null(value.getDefiningOp()); + if (!constOp) + return false; + OpBuilder builder(value.getContext()); + Type scalarTy = getElementTypeOrSelf(value); + Attribute valueAttr = + getIdentityValueAttr(op, scalarTy, builder, value.getLoc()); + if (constOp.getValue().isa()) + return constOp.getValue() == + DenseElementsAttr::get(value.getType(), valueAttr); + return constOp.getValue() == valueAttr; +} + /// Return the value obtained by applying the reduction operation kind /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`. Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder, Index: mlir/test/Dialect/Affine/unroll-jam.mlir =================================================================== --- mlir/test/Dialect/Affine/unroll-jam.mlir +++ mlir/test/Dialect/Affine/unroll-jam.mlir @@ -458,7 +458,8 @@ } // CHECK: %[[CONST0:[a-zA-Z0-9_]*]] = arith.constant 20 : index -// CHECK-NEXT: [[RES:%[0-9]+]]:2 = affine.for %[[IV0:arg[0-9]+]] = 0 to 20 step 2 iter_args([[ACC0:%arg[0-9]+]] = [[INIT0]], [[ACC1:%arg[0-9]+]] = [[INIT0]]) -> (f32, f32) { +// CHECK-NEXT: [[CST1:%[a-zA-Z0-9_]*]] = arith.constant 1.000000e+00 : f32 +// CHECK-NEXT: [[RES:%[0-9]+]]:2 = affine.for %[[IV0:arg[0-9]+]] = 0 to 20 step 2 iter_args([[ACC0:%arg[0-9]+]] = [[CST1]], [[ACC1:%arg[0-9]+]] = [[CST1]]) -> (f32, f32) { // CHECK-NEXT: [[RES1:%[0-9]+]]:2 = affine.for %[[IV1:arg[0-9]+]] = 0 to 30 iter_args([[ACC2:%arg[0-9]+]] = [[INIT1]], [[ACC3:%arg[0-9]+]] = [[INIT1]]) -> (f32, f32) { // CHECK-NEXT: [[LOAD1:%[0-9]+]] = affine.load {{.*}}[%[[IV0]], %[[IV1]]] // CHECK-NEXT: [[ADD1:%[0-9]+]] = arith.addf [[ACC2]], [[LOAD1]] : f32 @@ -481,6 +482,7 @@ // CHECK-NEXT: affine.yield [[ADD3]] : f32 // CHECK-NEXT: } // CHECK-NEXT: [[MUL4:%[0-9]+]] = arith.mulf [[MUL3]], [[RES2]] : f32 +// CHECK-NEXT: [[MUL5:%[0-9]+]] = arith.mulf [[MUL4]], [[INIT0]] : f32 // CHECK-NEXT: return // CHECK-LABEL: func @unroll_jam_iter_args_addi @@ -495,7 +497,8 @@ } // CHECK: %[[CONST0:[a-zA-Z0-9_]*]] = arith.constant 20 : index -// CHECK-NEXT: [[RES:%[0-9]+]]:2 = affine.for %[[IV0:arg[0-9]+]] = 0 to 20 step 2 iter_args([[ACC0:%arg[0-9]+]] = [[INIT0]], [[ACC1:%arg[0-9]+]] = [[INIT0]]) -> (i32, i32) { +// CHECK-NEXT: [[CST0:%[a-zA-Z0-9_]*]] = arith.constant 0 : i32 +// CHECK-NEXT: [[RES:%[0-9]+]]:2 = affine.for %[[IV0:arg[0-9]+]] = 0 to 20 step 2 iter_args([[ACC0:%arg[0-9]+]] = [[CST0]], [[ACC1:%arg[0-9]+]] = [[CST0]]) -> (i32, i32) { // CHECK-NEXT: [[LOAD1:%[0-9]+]] = affine.load {{.*}}[%[[IV0]]] // CHECK-NEXT: [[ADD1:%[0-9]+]] = arith.addi [[ACC0]], [[LOAD1]] : i32 // CHECK-NEXT: %[[INC1:[0-9]+]] = affine.apply [[$MAP_PLUS_1]](%[[IV0]]) @@ -508,4 +511,5 @@ // Cleanup loop (single iteration). // CHECK-NEXT: [[LOAD3:%[0-9]+]] = affine.load {{.*}}[%[[CONST0]]] // CHECK-NEXT: [[ADD4:%[0-9]+]] = arith.addi [[ADD3]], [[LOAD3]] : i32 +// CHECK-NEXT: [[ADD5:%[0-9]+]] = arith.addi [[ADD4]], [[INIT0]] : i32 // CHECK-NEXT: return