diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp --- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp +++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp @@ -368,17 +368,19 @@ }; /// Returns the identity value associated with an AtomicRMWKind op. -static Value getIdentityValue(AtomicRMWKind op, OpBuilder &builder, - Location loc) { +static Value getIdentityValue(AtomicRMWKind op, Type resultType, + OpBuilder &builder, Location loc) { switch (op) { case AtomicRMWKind::addf: - return builder.create(loc, builder.getF32FloatAttr(0)); + return builder.create(loc, builder.getFloatAttr(resultType, 0)); case AtomicRMWKind::addi: - return builder.create(loc, builder.getI32IntegerAttr(0)); + return builder.create(loc, + builder.getIntegerAttr(resultType, 0)); case AtomicRMWKind::mulf: - return builder.create(loc, builder.getF32FloatAttr(1)); + return builder.create(loc, builder.getFloatAttr(resultType, 1)); case AtomicRMWKind::muli: - return builder.create(loc, builder.getI32IntegerAttr(1)); + return builder.create(loc, + builder.getIntegerAttr(resultType, 1)); // TODO: Add remaining reduction operations. default: (void)emitOptionalError(loc, "Reduction operation type not supported"); @@ -453,15 +455,18 @@ // scf.parallel handles the reduction operation differently unlike // affine.parallel. ArrayRef reductions = op.reductions().getValue(); - for (Attribute reduction : reductions) { + for (auto pair : llvm::zip(reductions, op.getResultTypes())) { // For each of the reduction operations get the identity values for // initialization of the result values. + Attribute reduction = std::get<0>(pair); + Type resultType = std::get<1>(pair); Optional reductionOp = symbolizeAtomicRMWKind( static_cast(reduction.cast().getInt())); assert(reductionOp.hasValue() && "Reduction operation cannot be of None Type"); AtomicRMWKind reductionOpValue = reductionOp.getValue(); - identityVals.push_back(getIdentityValue(reductionOpValue, rewriter, loc)); + identityVals.push_back( + getIdentityValue(reductionOpValue, resultType, rewriter, loc)); } parOp = rewriter.create( loc, lowerBoundTuple, upperBoundTuple, steps, identityVals, diff --git a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir --- a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir +++ b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir @@ -826,3 +826,81 @@ // CHECK-NEXT: } // CHECK-NEXT: return // CHECK-NEXT: } + +///////////////////////////////////////////////////////////////////// + +func @affine_parallel_with_reductions_f64(%arg0: memref<3x3xf64>, %arg1: memref<3x3xf64>) -> (f64, f64) { + %0:2 = affine.parallel (%kx, %ky) = (0, 0) to (2, 2) reduce ("addf", "mulf") -> (f64, f64) { + %1 = affine.load %arg0[%kx, %ky] : memref<3x3xf64> + %2 = affine.load %arg1[%kx, %ky] : memref<3x3xf64> + %3 = mulf %1, %2 : f64 + %4 = addf %1, %2 : f64 + affine.yield %3, %4 : f64, f64 + } + return %0#0, %0#1 : f64, f64 +} +// CHECK-LABEL: @affine_parallel_with_reductions_f64 +// CHECK: %[[LOWER_1:.*]] = constant 0 : index +// CHECK: %[[LOWER_2:.*]] = constant 0 : index +// CHECK: %[[UPPER_1:.*]] = constant 2 : index +// CHECK: %[[UPPER_2:.*]] = constant 2 : index +// CHECK: %[[STEP_1:.*]] = constant 1 : index +// CHECK: %[[STEP_2:.*]] = constant 1 : index +// CHECK: %[[INIT_1:.*]] = constant 0.000000e+00 : f64 +// CHECK: %[[INIT_2:.*]] = constant 1.000000e+00 : f64 +// CHECK: %[[RES:.*]] = scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[LOWER_1]], %[[LOWER_2]]) to (%[[UPPER_1]], %[[UPPER_2]]) step (%[[STEP_1]], %[[STEP_2]]) init (%[[INIT_1]], %[[INIT_2]]) -> (f64, f64) { +// CHECK: %[[VAL_1:.*]] = memref.load +// CHECK: %[[VAL_2:.*]] = memref.load +// CHECK: %[[PRODUCT:.*]] = mulf +// CHECK: %[[SUM:.*]] = addf +// CHECK: scf.reduce(%[[PRODUCT]]) : f64 { +// CHECK: ^bb0(%[[LHS:.*]]: f64, %[[RHS:.*]]: f64): +// CHECK: %[[RES:.*]] = addf +// CHECK: scf.reduce.return %[[RES]] : f64 +// CHECK: } +// CHECK: scf.reduce(%[[SUM]]) : f64 { +// CHECK: ^bb0(%[[LHS:.*]]: f64, %[[RHS:.*]]: f64): +// CHECK: %[[RES:.*]] = mulf +// CHECK: scf.reduce.return %[[RES]] : f64 +// CHECK: } +// CHECK: scf.yield +// CHECK: } + +///////////////////////////////////////////////////////////////////// + +func @affine_parallel_with_reductions_i64(%arg0: memref<3x3xi64>, %arg1: memref<3x3xi64>) -> (i64, i64) { + %0:2 = affine.parallel (%kx, %ky) = (0, 0) to (2, 2) reduce ("addi", "muli") -> (i64, i64) { + %1 = affine.load %arg0[%kx, %ky] : memref<3x3xi64> + %2 = affine.load %arg1[%kx, %ky] : memref<3x3xi64> + %3 = muli %1, %2 : i64 + %4 = addi %1, %2 : i64 + affine.yield %3, %4 : i64, i64 + } + return %0#0, %0#1 : i64, i64 +} +// CHECK-LABEL: @affine_parallel_with_reductions_i64 +// CHECK: %[[LOWER_1:.*]] = constant 0 : index +// CHECK: %[[LOWER_2:.*]] = constant 0 : index +// CHECK: %[[UPPER_1:.*]] = constant 2 : index +// CHECK: %[[UPPER_2:.*]] = constant 2 : index +// CHECK: %[[STEP_1:.*]] = constant 1 : index +// CHECK: %[[STEP_2:.*]] = constant 1 : index +// CHECK: %[[INIT_1:.*]] = constant 0 : i64 +// CHECK: %[[INIT_2:.*]] = constant 1 : i64 +// CHECK: %[[RES:.*]] = scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[LOWER_1]], %[[LOWER_2]]) to (%[[UPPER_1]], %[[UPPER_2]]) step (%[[STEP_1]], %[[STEP_2]]) init (%[[INIT_1]], %[[INIT_2]]) -> (i64, i64) { +// CHECK: %[[VAL_1:.*]] = memref.load +// CHECK: %[[VAL_2:.*]] = memref.load +// CHECK: %[[PRODUCT:.*]] = muli +// CHECK: %[[SUM:.*]] = addi +// CHECK: scf.reduce(%[[PRODUCT]]) : i64 { +// CHECK: ^bb0(%[[LHS:.*]]: i64, %[[RHS:.*]]: i64): +// CHECK: %[[RES:.*]] = addi +// CHECK: scf.reduce.return %[[RES]] : i64 +// CHECK: } +// CHECK: scf.reduce(%[[SUM]]) : i64 { +// CHECK: ^bb0(%[[LHS:.*]]: i64, %[[RHS:.*]]: i64): +// CHECK: %[[RES:.*]] = muli +// CHECK: scf.reduce.return %[[RES]] : i64 +// CHECK: } +// CHECK: scf.yield +// CHECK: }