diff --git a/flang/lib/Lower/IntrinsicCall.cpp b/flang/lib/Lower/IntrinsicCall.cpp --- a/flang/lib/Lower/IntrinsicCall.cpp +++ b/flang/lib/Lower/IntrinsicCall.cpp @@ -190,20 +190,67 @@ int rank = arryTmp.rank(); assert(rank >= 1); + bool absentMask = isStaticallyAbsent(args[2]); // Handle optional mask argument - auto mask = isStaticallyAbsent(args[2]) - ? builder.create( - loc, fir::BoxType::get(builder.getI1Type())) - : builder.createBox(loc, args[2]); + auto mask = absentMask ? builder.create( + loc, fir::BoxType::get(builder.getI1Type())) + : builder.createBox(loc, args[2]); bool absentDim = isStaticallyAbsent(args[1]); + mlir::Type ty = array.getType(); + mlir::Type arrTy = fir::dyn_cast_ptrOrBoxEleTy(ty); + auto eleTy = arrTy.cast().getEleTy(); + + // Inline 1D array sum - when they are the basic SUM(arr) form. + // For now, don't support complex types. + if (rank == 1 && absentDim && absentMask && !fir::isa_complex(eleTy)) { + mlir::IndexType idxTy = builder.getIndexType(); + mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1); + mlir::Value lo = + fir::factory::readLowerBound(builder, loc, arryTmp, 0, one); + mlir::Value up = fir::factory::readExtent(builder, loc, arryTmp, 0); + mlir::Value step = one; + + mlir::Value init; + if (eleTy.isa()) + init = builder.createIntegerConstant(loc, eleTy, 0); + else + init = builder.createRealZeroConstant(loc, eleTy); + + auto savedPt = builder.saveInsertionPoint(); + builder.setInsertionPointToStart(builder.getAllocaBlock()); + auto sum = builder.create(loc, eleTy); + builder.restoreInsertionPoint(savedPt); + builder.create(loc, init, sum); + + auto loop = builder.create(loc, lo, up, step); + + auto insPt = builder.saveInsertionPoint(); + builder.setInsertionPointToStart(loop.getBody()); + + mlir::Type eleRefTy = builder.getRefType(eleTy); + // Get zero-based index by subtracting the lower bound. + mlir::Value index = + builder.create(loc, loop.getInductionVar(), lo); + auto addr = builder.create(loc, eleRefTy, array, index); + mlir::Value elem = builder.create(loc, addr); + mlir::Value sumVal = builder.create(loc, sum); + + mlir::Value res; + if (eleTy.isa()) + res = builder.create(loc, elem, sumVal); + else + res = builder.create(loc, elem, sumVal); + builder.create(loc, res, sum); + + builder.restoreInsertionPoint(insPt); + return builder.create(loc, sum); + } + // We call the type specific versions because the result is scalar // in the case below. if (absentDim || rank == 1) { - mlir::Type ty = array.getType(); - mlir::Type arrTy = fir::dyn_cast_ptrOrBoxEleTy(ty); - auto eleTy = arrTy.cast().getEleTy(); if (fir::isa_complex(eleTy)) { mlir::Value result = builder.createTemporary(loc, eleTy); func(builder, loc, array, mask, result);