diff --git a/flang/lib/Optimizer/Builder/Runtime/Reduction.cpp b/flang/lib/Optimizer/Builder/Runtime/Reduction.cpp --- a/flang/lib/Optimizer/Builder/Runtime/Reduction.cpp +++ b/flang/lib/Optimizer/Builder/Runtime/Reduction.cpp @@ -545,7 +545,9 @@ auto eleTy = arrTy.cast().getEleTy(); auto dim = builder.createIntegerConstant(loc, builder.getIndexType(), 0); - if (eleTy.isF32()) + if (eleTy.isF16() || eleTy.isBF16()) + TODO(loc, "half-precision MAXVAL"); + else if (eleTy.isF32()) func = fir::runtime::getRuntimeFunc(loc, builder); else if (eleTy.isF64()) func = fir::runtime::getRuntimeFunc(loc, builder); @@ -553,23 +555,18 @@ func = fir::runtime::getRuntimeFunc(loc, builder); else if (eleTy.isF128()) func = fir::runtime::getRuntimeFunc(loc, builder); - else if (eleTy == - builder.getIntegerType(builder.getKindMap().getIntegerBitsize(1))) + else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(1))) func = fir::runtime::getRuntimeFunc(loc, builder); - else if (eleTy == - builder.getIntegerType(builder.getKindMap().getIntegerBitsize(2))) + else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(2))) func = fir::runtime::getRuntimeFunc(loc, builder); - else if (eleTy == - builder.getIntegerType(builder.getKindMap().getIntegerBitsize(4))) + else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(4))) func = fir::runtime::getRuntimeFunc(loc, builder); - else if (eleTy == - builder.getIntegerType(builder.getKindMap().getIntegerBitsize(8))) + else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(8))) func = fir::runtime::getRuntimeFunc(loc, builder); - else if (eleTy == - builder.getIntegerType(builder.getKindMap().getIntegerBitsize(16))) + else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(16))) func = fir::runtime::getRuntimeFunc(loc, builder); else - fir::emitFatalError(loc, "invalid type in Maxval lowering"); + fir::emitFatalError(loc, "invalid type in MAXVAL"); auto fTy = func.getFunctionType(); auto sourceFile = fir::factory::locationToFilename(builder, loc); @@ -664,7 +661,9 @@ auto eleTy = arrTy.cast().getEleTy(); auto dim = builder.createIntegerConstant(loc, builder.getIndexType(), 0); - if (eleTy.isF32()) + if (eleTy.isF16() || eleTy.isBF16()) + TODO(loc, "half-precision MINVAL"); + else if (eleTy.isF32()) func = fir::runtime::getRuntimeFunc(loc, builder); else if (eleTy.isF64()) func = fir::runtime::getRuntimeFunc(loc, builder); @@ -672,23 +671,18 @@ func = fir::runtime::getRuntimeFunc(loc, builder); else if (eleTy.isF128()) func = fir::runtime::getRuntimeFunc(loc, builder); - else if (eleTy == - builder.getIntegerType(builder.getKindMap().getIntegerBitsize(1))) + else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(1))) func = fir::runtime::getRuntimeFunc(loc, builder); - else if (eleTy == - builder.getIntegerType(builder.getKindMap().getIntegerBitsize(2))) + else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(2))) func = fir::runtime::getRuntimeFunc(loc, builder); - else if (eleTy == - builder.getIntegerType(builder.getKindMap().getIntegerBitsize(4))) + else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(4))) func = fir::runtime::getRuntimeFunc(loc, builder); - else if (eleTy == - builder.getIntegerType(builder.getKindMap().getIntegerBitsize(8))) + else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(8))) func = fir::runtime::getRuntimeFunc(loc, builder); - else if (eleTy == - builder.getIntegerType(builder.getKindMap().getIntegerBitsize(16))) + else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(16))) func = fir::runtime::getRuntimeFunc(loc, builder); else - fir::emitFatalError(loc, "invalid type in Minval lowering"); + fir::emitFatalError(loc, "invalid type in MINVAL"); auto fTy = func.getFunctionType(); auto sourceFile = fir::factory::locationToFilename(builder, loc); @@ -721,7 +715,9 @@ auto eleTy = arrTy.cast().getEleTy(); auto dim = builder.createIntegerConstant(loc, builder.getIndexType(), 0); - if (eleTy.isF32()) + if (eleTy.isF16() || eleTy.isBF16()) + TODO(loc, "half-precision PRODUCT"); + else if (eleTy.isF32()) func = fir::runtime::getRuntimeFunc(loc, builder); else if (eleTy.isF64()) func = fir::runtime::getRuntimeFunc(loc, builder); @@ -729,20 +725,15 @@ func = fir::runtime::getRuntimeFunc(loc, builder); else if (eleTy.isF128()) func = fir::runtime::getRuntimeFunc(loc, builder); - else if (eleTy == - builder.getIntegerType(builder.getKindMap().getIntegerBitsize(1))) + else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(1))) func = fir::runtime::getRuntimeFunc(loc, builder); - else if (eleTy == - builder.getIntegerType(builder.getKindMap().getIntegerBitsize(2))) + else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(2))) func = fir::runtime::getRuntimeFunc(loc, builder); - else if (eleTy == - builder.getIntegerType(builder.getKindMap().getIntegerBitsize(4))) + else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(4))) func = fir::runtime::getRuntimeFunc(loc, builder); - else if (eleTy == - builder.getIntegerType(builder.getKindMap().getIntegerBitsize(8))) + else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(8))) func = fir::runtime::getRuntimeFunc(loc, builder); - else if (eleTy == - builder.getIntegerType(builder.getKindMap().getIntegerBitsize(16))) + else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(16))) func = fir::runtime::getRuntimeFunc(loc, builder); else if (eleTy == fir::ComplexType::get(builder.getContext(), 4)) func = @@ -754,8 +745,11 @@ func = fir::runtime::getRuntimeFunc(loc, builder); else if (eleTy == fir::ComplexType::get(builder.getContext(), 16)) func = fir::runtime::getRuntimeFunc(loc, builder); + else if (eleTy == fir::ComplexType::get(builder.getContext(), 2) || + eleTy == fir::ComplexType::get(builder.getContext(), 3)) + TODO(loc, "half-precision PRODUCT"); else - fir::emitFatalError(loc, "invalid type in Product lowering"); + fir::emitFatalError(loc, "invalid type in PRODUCT"); auto fTy = func.getFunctionType(); auto sourceFile = fir::factory::locationToFilename(builder, loc); @@ -788,7 +782,9 @@ auto arrTy = fir::dyn_cast_ptrOrBoxEleTy(ty); auto eleTy = arrTy.cast().getEleTy(); - if (eleTy.isF32()) + if (eleTy.isF16() || eleTy.isBF16()) + TODO(loc, "half-precision DOTPRODUCT"); + else if (eleTy.isF32()) func = fir::runtime::getRuntimeFunc(loc, builder); else if (eleTy.isF64()) func = fir::runtime::getRuntimeFunc(loc, builder); @@ -808,31 +804,29 @@ else if (eleTy == fir::ComplexType::get(builder.getContext(), 16)) func = fir::runtime::getRuntimeFunc(loc, builder); - else if (eleTy == - builder.getIntegerType(builder.getKindMap().getIntegerBitsize(1))) + else if (eleTy == fir::ComplexType::get(builder.getContext(), 2) || + eleTy == fir::ComplexType::get(builder.getContext(), 3)) + TODO(loc, "half-precision DOTPRODUCT"); + else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(1))) func = fir::runtime::getRuntimeFunc(loc, builder); - else if (eleTy == - builder.getIntegerType(builder.getKindMap().getIntegerBitsize(2))) + else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(2))) func = fir::runtime::getRuntimeFunc(loc, builder); - else if (eleTy == - builder.getIntegerType(builder.getKindMap().getIntegerBitsize(4))) + else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(4))) func = fir::runtime::getRuntimeFunc(loc, builder); - else if (eleTy == - builder.getIntegerType(builder.getKindMap().getIntegerBitsize(8))) + else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(8))) func = fir::runtime::getRuntimeFunc(loc, builder); - else if (eleTy == - builder.getIntegerType(builder.getKindMap().getIntegerBitsize(16))) + else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(16))) func = fir::runtime::getRuntimeFunc(loc, builder); else if (eleTy.isa()) func = fir::runtime::getRuntimeFunc(loc, builder); else - fir::emitFatalError(loc, "invalid type in DotProduct lowering"); + fir::emitFatalError(loc, "invalid type in DOTPRODUCT"); auto fTy = func.getFunctionType(); auto sourceFile = fir::factory::locationToFilename(builder, loc); @@ -873,7 +867,9 @@ auto eleTy = arrTy.cast().getEleTy(); auto dim = builder.createIntegerConstant(loc, builder.getIndexType(), 0); - if (eleTy.isF32()) + if (eleTy.isF16() || eleTy.isBF16()) + TODO(loc, "half-precision SUM"); + else if (eleTy.isF32()) func = fir::runtime::getRuntimeFunc(loc, builder); else if (eleTy.isF64()) func = fir::runtime::getRuntimeFunc(loc, builder); @@ -881,20 +877,15 @@ func = fir::runtime::getRuntimeFunc(loc, builder); else if (eleTy.isF128()) func = fir::runtime::getRuntimeFunc(loc, builder); - else if (eleTy == - builder.getIntegerType(builder.getKindMap().getIntegerBitsize(1))) + else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(1))) func = fir::runtime::getRuntimeFunc(loc, builder); - else if (eleTy == - builder.getIntegerType(builder.getKindMap().getIntegerBitsize(2))) + else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(2))) func = fir::runtime::getRuntimeFunc(loc, builder); - else if (eleTy == - builder.getIntegerType(builder.getKindMap().getIntegerBitsize(4))) + else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(4))) func = fir::runtime::getRuntimeFunc(loc, builder); - else if (eleTy == - builder.getIntegerType(builder.getKindMap().getIntegerBitsize(8))) + else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(8))) func = fir::runtime::getRuntimeFunc(loc, builder); - else if (eleTy == - builder.getIntegerType(builder.getKindMap().getIntegerBitsize(16))) + else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(16))) func = fir::runtime::getRuntimeFunc(loc, builder); else if (eleTy == fir::ComplexType::get(builder.getContext(), 4)) func = fir::runtime::getRuntimeFunc(loc, builder); @@ -904,8 +895,11 @@ func = fir::runtime::getRuntimeFunc(loc, builder); else if (eleTy == fir::ComplexType::get(builder.getContext(), 16)) func = fir::runtime::getRuntimeFunc(loc, builder); + else if (eleTy == fir::ComplexType::get(builder.getContext(), 2) || + eleTy == fir::ComplexType::get(builder.getContext(), 3)) + TODO(loc, "half-precision SUM"); else - fir::emitFatalError(loc, "invalid type in Sum lowering"); + fir::emitFatalError(loc, "invalid type in SUM"); auto fTy = func.getFunctionType(); auto sourceFile = fir::factory::locationToFilename(builder, loc);