diff --git a/flang/include/flang/Optimizer/Builder/Runtime/Transformational.h b/flang/include/flang/Optimizer/Builder/Runtime/Transformational.h new file mode 100644 --- /dev/null +++ b/flang/include/flang/Optimizer/Builder/Runtime/Transformational.h @@ -0,0 +1,63 @@ +//===-- Transformational.h --------------------------------------*- C++ -*-===// +// Generate transformational intrinsic runtime API calls. +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef FORTRAN_OPTIMIZER_BUILDER_RUNTIME_TRANSFORMATIONAL_H +#define FORTRAN_OPTIMIZER_BUILDER_RUNTIME_TRANSFORMATIONAL_H + +#include "mlir/Dialect/StandardOps/IR/Ops.h" + +namespace fir { +class ExtendedValue; +class FirOpBuilder; +} // namespace fir + +namespace fir::runtime { + +void genCshift(fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Value resultBox, mlir::Value arrayBox, + mlir::Value shiftBox, mlir::Value dimBox); + +void genCshiftVector(fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Value resultBox, mlir::Value arrayBox, + mlir::Value shiftBox); + +void genEoshift(fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Value resultBox, mlir::Value arrayBox, + mlir::Value shiftBox, mlir::Value boundBox, mlir::Value dimBox); + +void genEoshiftVector(fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Value resultBox, mlir::Value arrayBox, + mlir::Value shiftBox, mlir::Value boundBox); + +void genMatmul(fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Value matrixABox, mlir::Value matrixBBox, + mlir::Value resultBox); + +void genPack(fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Value resultBox, mlir::Value arrayBox, mlir::Value maskBox, + mlir::Value vectorBox); + +void genReshape(fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Value resultBox, mlir::Value sourceBox, + mlir::Value shapeBox, mlir::Value padBox, mlir::Value orderBox); + +void genSpread(fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Value resultBox, mlir::Value sourceBox, mlir::Value dim, + mlir::Value ncopies); + +void genTranspose(fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Value resultBox, mlir::Value sourceBox); + +void genUnpack(fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Value resultBox, mlir::Value vectorBox, + mlir::Value maskBox, mlir::Value fieldBox); + +} // namespace fir::runtime + +#endif // FORTRAN_OPTIMIZER_BUILDER_RUNTIME_TRANSFORMATIONAL_H diff --git a/flang/lib/Optimizer/Builder/CMakeLists.txt b/flang/lib/Optimizer/Builder/CMakeLists.txt --- a/flang/lib/Optimizer/Builder/CMakeLists.txt +++ b/flang/lib/Optimizer/Builder/CMakeLists.txt @@ -8,6 +8,7 @@ FIRBuilder.cpp MutableBox.cpp Runtime/Reduction.cpp + Runtime/Transformational.cpp DEPENDS FIRDialect diff --git a/flang/lib/Optimizer/Builder/Runtime/Transformational.cpp b/flang/lib/Optimizer/Builder/Runtime/Transformational.cpp new file mode 100644 --- /dev/null +++ b/flang/lib/Optimizer/Builder/Runtime/Transformational.cpp @@ -0,0 +1,176 @@ +//===-- Transformational.cpp ------------------------------------*- C++ -*-===// +// Generate transformational intrinsic runtime API calls. +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/Builder/Runtime/Transformational.h" +#include "flang/Lower/Todo.h" +#include "flang/Optimizer/Builder/BoxValue.h" +#include "flang/Optimizer/Builder/Character.h" +#include "flang/Optimizer/Builder/FIRBuilder.h" +#include "flang/Optimizer/Builder/Runtime/RTBuilder.h" +#include "flang/Runtime/matmul.h" +#include "flang/Runtime/transformational.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" + +using namespace Fortran::runtime; + +/// Generate call to Cshift intrinsic +void fir::runtime::genCshift(fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Value resultBox, mlir::Value arrayBox, + mlir::Value shiftBox, mlir::Value dimBox) { + auto cshiftFunc = fir::runtime::getRuntimeFunc(loc, builder); + auto fTy = cshiftFunc.getType(); + auto sourceFile = fir::factory::locationToFilename(builder, loc); + auto sourceLine = + fir::factory::locationToLineNo(builder, loc, fTy.getInput(5)); + auto args = + fir::runtime::createArguments(builder, loc, fTy, resultBox, arrayBox, + shiftBox, dimBox, sourceFile, sourceLine); + builder.create(loc, cshiftFunc, args); +} + +/// Generate call to the vector version of the Cshift intrinsic +void fir::runtime::genCshiftVector(fir::FirOpBuilder &builder, + mlir::Location loc, mlir::Value resultBox, + mlir::Value arrayBox, mlir::Value shiftBox) { + auto cshiftFunc = + fir::runtime::getRuntimeFunc(loc, builder); + auto fTy = cshiftFunc.getType(); + + auto sourceFile = fir::factory::locationToFilename(builder, loc); + auto sourceLine = + fir::factory::locationToLineNo(builder, loc, fTy.getInput(4)); + auto args = fir::runtime::createArguments( + builder, loc, fTy, resultBox, arrayBox, shiftBox, sourceFile, sourceLine); + builder.create(loc, cshiftFunc, args); +} + +/// Generate call to Eoshift intrinsic +void fir::runtime::genEoshift(fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Value resultBox, mlir::Value arrayBox, + mlir::Value shiftBox, mlir::Value boundBox, + mlir::Value dimBox) { + auto eoshiftFunc = + fir::runtime::getRuntimeFunc(loc, builder); + auto fTy = eoshiftFunc.getType(); + auto sourceFile = fir::factory::locationToFilename(builder, loc); + auto sourceLine = + fir::factory::locationToLineNo(builder, loc, fTy.getInput(6)); + auto args = fir::runtime::createArguments(builder, loc, fTy, resultBox, + arrayBox, shiftBox, boundBox, + dimBox, sourceFile, sourceLine); + builder.create(loc, eoshiftFunc, args); +} + +/// Generate call to the vector version of the Eoshift intrinsic +void fir::runtime::genEoshiftVector(fir::FirOpBuilder &builder, + mlir::Location loc, mlir::Value resultBox, + mlir::Value arrayBox, mlir::Value shiftBox, + mlir::Value boundBox) { + auto eoshiftFunc = + fir::runtime::getRuntimeFunc(loc, builder); + auto fTy = eoshiftFunc.getType(); + + auto sourceFile = fir::factory::locationToFilename(builder, loc); + auto sourceLine = + fir::factory::locationToLineNo(builder, loc, fTy.getInput(5)); + + auto args = + fir::runtime::createArguments(builder, loc, fTy, resultBox, arrayBox, + shiftBox, boundBox, sourceFile, sourceLine); + builder.create(loc, eoshiftFunc, args); +} + +/// Generate call to Matmul intrinsic runtime routine. +void fir::runtime::genMatmul(fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Value resultBox, mlir::Value matrixABox, + mlir::Value matrixBBox) { + auto func = fir::runtime::getRuntimeFunc(loc, builder); + auto fTy = func.getType(); + auto sourceFile = fir::factory::locationToFilename(builder, loc); + auto sourceLine = + fir::factory::locationToLineNo(builder, loc, fTy.getInput(4)); + auto args = + fir::runtime::createArguments(builder, loc, fTy, resultBox, matrixABox, + matrixBBox, sourceFile, sourceLine); + builder.create(loc, func, args); +} + +/// Generate call to Pack intrinsic runtime routine. +void fir::runtime::genPack(fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Value resultBox, mlir::Value arrayBox, + mlir::Value maskBox, mlir::Value vectorBox) { + auto packFunc = fir::runtime::getRuntimeFunc(loc, builder); + auto fTy = packFunc.getType(); + auto sourceFile = fir::factory::locationToFilename(builder, loc); + auto sourceLine = + fir::factory::locationToLineNo(builder, loc, fTy.getInput(5)); + auto args = + fir::runtime::createArguments(builder, loc, fTy, resultBox, arrayBox, + maskBox, vectorBox, sourceFile, sourceLine); + builder.create(loc, packFunc, args); +} + +/// Generate call to Reshape intrinsic runtime routine. +void fir::runtime::genReshape(fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Value resultBox, mlir::Value sourceBox, + mlir::Value shapeBox, mlir::Value padBox, + mlir::Value orderBox) { + auto func = fir::runtime::getRuntimeFunc(loc, builder); + auto fTy = func.getType(); + auto sourceFile = fir::factory::locationToFilename(builder, loc); + auto sourceLine = + fir::factory::locationToLineNo(builder, loc, fTy.getInput(6)); + auto args = fir::runtime::createArguments(builder, loc, fTy, resultBox, + sourceBox, shapeBox, padBox, + orderBox, sourceFile, sourceLine); + builder.create(loc, func, args); +} + +/// Generate call to Spread intrinsic runtime routine. +void fir::runtime::genSpread(fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Value resultBox, mlir::Value sourceBox, + mlir::Value dim, mlir::Value ncopies) { + auto func = fir::runtime::getRuntimeFunc(loc, builder); + auto fTy = func.getType(); + auto sourceFile = fir::factory::locationToFilename(builder, loc); + auto sourceLine = + fir::factory::locationToLineNo(builder, loc, fTy.getInput(5)); + auto args = + fir::runtime::createArguments(builder, loc, fTy, resultBox, sourceBox, + dim, ncopies, sourceFile, sourceLine); + builder.create(loc, func, args); +} + +/// Generate call to Transpose intrinsic runtime routine. +void fir::runtime::genTranspose(fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Value resultBox, mlir::Value sourceBox) { + auto func = fir::runtime::getRuntimeFunc(loc, builder); + auto fTy = func.getType(); + auto sourceFile = fir::factory::locationToFilename(builder, loc); + auto sourceLine = + fir::factory::locationToLineNo(builder, loc, fTy.getInput(3)); + auto args = fir::runtime::createArguments(builder, loc, fTy, resultBox, + sourceBox, sourceFile, sourceLine); + builder.create(loc, func, args); +} + +/// Generate call to Unpack intrinsic runtime routine. +void fir::runtime::genUnpack(fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Value resultBox, mlir::Value vectorBox, + mlir::Value maskBox, mlir::Value fieldBox) { + auto unpackFunc = fir::runtime::getRuntimeFunc(loc, builder); + auto fTy = unpackFunc.getType(); + auto sourceFile = fir::factory::locationToFilename(builder, loc); + auto sourceLine = + fir::factory::locationToLineNo(builder, loc, fTy.getInput(5)); + auto args = + fir::runtime::createArguments(builder, loc, fTy, resultBox, vectorBox, + maskBox, fieldBox, sourceFile, sourceLine); + builder.create(loc, unpackFunc, args); +} diff --git a/flang/unittests/Optimizer/Builder/Runtime/ReductionTest.cpp b/flang/unittests/Optimizer/Builder/Runtime/ReductionTest.cpp --- a/flang/unittests/Optimizer/Builder/Runtime/ReductionTest.cpp +++ b/flang/unittests/Optimizer/Builder/Runtime/ReductionTest.cpp @@ -7,92 +7,10 @@ //===----------------------------------------------------------------------===// #include "flang/Optimizer/Builder/Runtime/Reduction.h" +#include "RuntimeCallTestBase.h" #include "gtest/gtest.h" -#include "flang/Optimizer/Builder/FIRBuilder.h" -#include "flang/Optimizer/Support/InitFIR.h" -#include "flang/Optimizer/Support/KindMapping.h" - -struct ReductionTest : public testing::Test { -public: - void SetUp() override { - mlir::OpBuilder builder(&context); - auto loc = builder.getUnknownLoc(); - - // Set up a Module with a dummy function operation inside. - // Set the insertion point in the function entry block. - mlir::ModuleOp mod = builder.create(loc); - mlir::FuncOp func = mlir::FuncOp::create( - loc, "func1", builder.getFunctionType(llvm::None, llvm::None)); - auto *entryBlock = func.addEntryBlock(); - mod.push_back(mod); - builder.setInsertionPointToStart(entryBlock); - - fir::support::loadDialects(context); - kindMap = std::make_unique(&context); - firBuilder = std::make_unique(mod, *kindMap); - - i8Ty = firBuilder->getI8Type(); - i16Ty = firBuilder->getIntegerType(16); - i32Ty = firBuilder->getI32Type(); - i64Ty = firBuilder->getI64Type(); - i128Ty = firBuilder->getIntegerType(128); - - f32Ty = firBuilder->getF32Type(); - f64Ty = firBuilder->getF64Type(); - f80Ty = firBuilder->getF80Type(); - f128Ty = firBuilder->getF128Type(); - - c4Ty = fir::ComplexType::get(firBuilder->getContext(), 4); - c8Ty = fir::ComplexType::get(firBuilder->getContext(), 8); - c10Ty = fir::ComplexType::get(firBuilder->getContext(), 10); - c16Ty = fir::ComplexType::get(firBuilder->getContext(), 16); - } - - mlir::MLIRContext context; - std::unique_ptr kindMap; - std::unique_ptr firBuilder; - - // Commnly used type - mlir::Type i8Ty; - mlir::Type i16Ty; - mlir::Type i32Ty; - mlir::Type i64Ty; - mlir::Type i128Ty; - mlir::Type f32Ty; - mlir::Type f64Ty; - mlir::Type f80Ty; - mlir::Type f128Ty; - mlir::Type c4Ty; - mlir::Type c8Ty; - mlir::Type c10Ty; - mlir::Type c16Ty; -}; - -void checkCallOp( - mlir::Operation *op, llvm::StringRef fctName, unsigned nbArgs) { - EXPECT_TRUE(mlir::isa(*op)); - auto callOp = mlir::dyn_cast(*op); - EXPECT_TRUE(callOp.callee().hasValue()); - mlir::SymbolRefAttr callee = *callOp.callee(); - EXPECT_EQ(fctName, callee.getRootReference().getValue()); - // sourceFile and sourceLine are added arguments. - EXPECT_EQ(nbArgs + 2, callOp.args().size()); -} - -void checkCallOpFromResultBox( - mlir::Value result, llvm::StringRef fctName, unsigned nbArgs) { - EXPECT_TRUE(result.hasOneUse()); - for (auto &u : result.getUses()) { - if (mlir::isa(*u.getOwner())) { - checkCallOp(u.getOwner(), fctName, nbArgs); - } else { - auto convOp = mlir::dyn_cast(*u.getOwner()); - checkCallOpFromResultBox(convOp.getResult(), fctName, nbArgs); - } - } -} -TEST_F(ReductionTest, genAllTest) { +TEST_F(RuntimeCallTest, genAllTest) { auto loc = firBuilder->getUnknownLoc(); mlir::Type seqTy = fir::SequenceType::get(fir::SequenceType::Shape(1, 10), i32Ty); @@ -102,7 +20,7 @@ checkCallOp(all.getDefiningOp(), "_FortranAAll", 2); } -TEST_F(ReductionTest, genAllDescriptorTest) { +TEST_F(RuntimeCallTest, genAllDescriptorTest) { auto loc = firBuilder->getUnknownLoc(); mlir::Type seqTy = fir::SequenceType::get(fir::SequenceType::Shape(1, 10), i32Ty); @@ -113,7 +31,7 @@ checkCallOpFromResultBox(result, "_FortranAAllDim", 3); } -TEST_F(ReductionTest, genAnyTest) { +TEST_F(RuntimeCallTest, genAnyTest) { auto loc = firBuilder->getUnknownLoc(); mlir::Type seqTy = fir::SequenceType::get(fir::SequenceType::Shape(1, 10), i32Ty); @@ -123,7 +41,7 @@ checkCallOp(any.getDefiningOp(), "_FortranAAny", 2); } -TEST_F(ReductionTest, genAnyDescriptorTest) { +TEST_F(RuntimeCallTest, genAnyDescriptorTest) { auto loc = firBuilder->getUnknownLoc(); mlir::Type seqTy = fir::SequenceType::get(fir::SequenceType::Shape(1, 10), i32Ty); @@ -134,7 +52,7 @@ checkCallOpFromResultBox(result, "_FortranAAnyDim", 3); } -TEST_F(ReductionTest, genCountTest) { +TEST_F(RuntimeCallTest, genCountTest) { auto loc = firBuilder->getUnknownLoc(); mlir::Type seqTy = fir::SequenceType::get(fir::SequenceType::Shape(1, 10), i32Ty); @@ -144,7 +62,7 @@ checkCallOp(count.getDefiningOp(), "_FortranACount", 2); } -TEST_F(ReductionTest, genCountDimTest) { +TEST_F(RuntimeCallTest, genCountDimTest) { auto loc = firBuilder->getUnknownLoc(); mlir::Type seqTy = fir::SequenceType::get(fir::SequenceType::Shape(1, 10), i32Ty); @@ -168,7 +86,7 @@ checkCallOp(max.getDefiningOp(), fctName, 3); } -TEST_F(ReductionTest, genMaxValTest) { +TEST_F(RuntimeCallTest, genMaxValTest) { testGenMaxVal(*firBuilder, f32Ty, "_FortranAMaxvalReal4"); testGenMaxVal(*firBuilder, f64Ty, "_FortranAMaxvalReal8"); testGenMaxVal(*firBuilder, f80Ty, "_FortranAMaxvalReal10"); @@ -193,7 +111,7 @@ checkCallOp(min.getDefiningOp(), fctName, 3); } -TEST_F(ReductionTest, genMinValTest) { +TEST_F(RuntimeCallTest, genMinValTest) { testGenMinVal(*firBuilder, f32Ty, "_FortranAMinvalReal4"); testGenMinVal(*firBuilder, f64Ty, "_FortranAMinvalReal8"); testGenMinVal(*firBuilder, f80Ty, "_FortranAMinvalReal10"); @@ -222,7 +140,7 @@ checkCallOp(sum.getDefiningOp(), fctName, 3); } -TEST_F(ReductionTest, genSumTest) { +TEST_F(RuntimeCallTest, genSumTest) { testGenSum(*firBuilder, f32Ty, "_FortranASumReal4"); testGenSum(*firBuilder, f64Ty, "_FortranASumReal8"); testGenSum(*firBuilder, f80Ty, "_FortranASumReal10"); @@ -255,7 +173,7 @@ checkCallOp(prod.getDefiningOp(), fctName, 3); } -TEST_F(ReductionTest, genProduct) { +TEST_F(RuntimeCallTest, genProduct) { testGenProduct(*firBuilder, f32Ty, "_FortranAProductReal4"); testGenProduct(*firBuilder, f64Ty, "_FortranAProductReal8"); testGenProduct(*firBuilder, f80Ty, "_FortranAProductReal10"); @@ -287,7 +205,7 @@ checkCallOp(prod.getDefiningOp(), fctName, 2); } -TEST_F(ReductionTest, genDotProduct) { +TEST_F(RuntimeCallTest, genDotProduct) { testGenDotProduct(*firBuilder, f32Ty, "_FortranADotProductReal4"); testGenDotProduct(*firBuilder, f64Ty, "_FortranADotProductReal8"); testGenDotProduct(*firBuilder, f80Ty, "_FortranADotProductReal10"); @@ -321,11 +239,11 @@ checkCallOpFromResultBox(result, fctName, nbArgs); } -TEST_F(ReductionTest, genMaxlocTest) { +TEST_F(RuntimeCallTest, genMaxlocTest) { checkGenMxxloc(*firBuilder, fir::runtime::genMaxloc, "_FortranAMaxloc", 5); } -TEST_F(ReductionTest, genMinlocTest) { +TEST_F(RuntimeCallTest, genMinlocTest) { checkGenMxxloc(*firBuilder, fir::runtime::genMinloc, "_FortranAMinloc", 5); } @@ -348,12 +266,12 @@ checkCallOpFromResultBox(result, fctName, nbArgs); } -TEST_F(ReductionTest, genMaxlocDimTest) { +TEST_F(RuntimeCallTest, genMaxlocDimTest) { checkGenMxxlocDim( *firBuilder, fir::runtime::genMaxlocDim, "_FortranAMaxlocDim", 6); } -TEST_F(ReductionTest, genMinlocDimTest) { +TEST_F(RuntimeCallTest, genMinlocDimTest) { checkGenMxxlocDim( *firBuilder, fir::runtime::genMinlocDim, "_FortranAMinlocDim", 6); } @@ -374,12 +292,12 @@ checkCallOpFromResultBox(result, fctName, nbArgs); } -TEST_F(ReductionTest, genMaxvalCharTest) { +TEST_F(RuntimeCallTest, genMaxvalCharTest) { checkGenMxxvalChar( *firBuilder, fir::runtime::genMaxvalChar, "_FortranAMaxvalCharacter", 3); } -TEST_F(ReductionTest, genMinvalCharTest) { +TEST_F(RuntimeCallTest, genMinvalCharTest) { checkGenMxxvalChar( *firBuilder, fir::runtime::genMinvalChar, "_FortranAMinvalCharacter", 3); } @@ -401,21 +319,21 @@ checkCallOpFromResultBox(result, fctName, nbArgs); } -TEST_F(ReductionTest, genMaxvalDimTest) { +TEST_F(RuntimeCallTest, genMaxvalDimTest) { checkGen4argsDim( *firBuilder, fir::runtime::genMaxvalDim, "_FortranAMaxvalDim", 4); } -TEST_F(ReductionTest, genMinvalDimTest) { +TEST_F(RuntimeCallTest, genMinvalDimTest) { checkGen4argsDim( *firBuilder, fir::runtime::genMinvalDim, "_FortranAMinvalDim", 4); } -TEST_F(ReductionTest, genProductDimTest) { +TEST_F(RuntimeCallTest, genProductDimTest) { checkGen4argsDim( *firBuilder, fir::runtime::genProductDim, "_FortranAProductDim", 4); } -TEST_F(ReductionTest, genSumDimTest) { +TEST_F(RuntimeCallTest, genSumDimTest) { checkGen4argsDim(*firBuilder, fir::runtime::genSumDim, "_FortranASumDim", 4); } diff --git a/flang/unittests/Optimizer/Builder/Runtime/RuntimeCallTestBase.h b/flang/unittests/Optimizer/Builder/Runtime/RuntimeCallTestBase.h new file mode 100644 --- /dev/null +++ b/flang/unittests/Optimizer/Builder/Runtime/RuntimeCallTestBase.h @@ -0,0 +1,97 @@ +//===- RuntimeCalltestBase.cpp -- Base for runtime call generation tests --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef FORTRAN_OPTIMIZER_BUILDER_RUNTIME_RUNTIMECALLTESTBASE_H +#define FORTRAN_OPTIMIZER_BUILDER_RUNTIME_RUNTIMECALLTESTBASE_H + +#include "gtest/gtest.h" +#include "flang/Optimizer/Builder/FIRBuilder.h" +#include "flang/Optimizer/Support/InitFIR.h" +#include "flang/Optimizer/Support/KindMapping.h" + +struct RuntimeCallTest : public testing::Test { +public: + void SetUp() override { + mlir::OpBuilder builder(&context); + auto loc = builder.getUnknownLoc(); + + // Set up a Module with a dummy function operation inside. + // Set the insertion point in the function entry block. + mlir::ModuleOp mod = builder.create(loc); + mlir::FuncOp func = mlir::FuncOp::create( + loc, "func1", builder.getFunctionType(llvm::None, llvm::None)); + auto *entryBlock = func.addEntryBlock(); + mod.push_back(mod); + builder.setInsertionPointToStart(entryBlock); + + fir::support::loadDialects(context); + kindMap = std::make_unique(&context); + firBuilder = std::make_unique(mod, *kindMap); + + i8Ty = firBuilder->getI8Type(); + i16Ty = firBuilder->getIntegerType(16); + i32Ty = firBuilder->getI32Type(); + i64Ty = firBuilder->getI64Type(); + i128Ty = firBuilder->getIntegerType(128); + + f32Ty = firBuilder->getF32Type(); + f64Ty = firBuilder->getF64Type(); + f80Ty = firBuilder->getF80Type(); + f128Ty = firBuilder->getF128Type(); + + c4Ty = fir::ComplexType::get(firBuilder->getContext(), 4); + c8Ty = fir::ComplexType::get(firBuilder->getContext(), 8); + c10Ty = fir::ComplexType::get(firBuilder->getContext(), 10); + c16Ty = fir::ComplexType::get(firBuilder->getContext(), 16); + } + + mlir::MLIRContext context; + std::unique_ptr kindMap; + std::unique_ptr firBuilder; + + // Commnly used type + mlir::Type i8Ty; + mlir::Type i16Ty; + mlir::Type i32Ty; + mlir::Type i64Ty; + mlir::Type i128Ty; + mlir::Type f32Ty; + mlir::Type f64Ty; + mlir::Type f80Ty; + mlir::Type f128Ty; + mlir::Type c4Ty; + mlir::Type c8Ty; + mlir::Type c10Ty; + mlir::Type c16Ty; +}; + +static void checkCallOp( + mlir::Operation *op, llvm::StringRef fctName, unsigned nbArgs) { + EXPECT_TRUE(mlir::isa(*op)); + auto callOp = mlir::dyn_cast(*op); + EXPECT_TRUE(callOp.callee().hasValue()); + mlir::SymbolRefAttr callee = *callOp.callee(); + EXPECT_EQ(fctName, callee.getRootReference().getValue()); + // sourceFile and sourceLine are added arguments. + EXPECT_EQ(nbArgs + 2, callOp.args().size()); +} + +static void checkCallOpFromResultBox( + mlir::Value result, llvm::StringRef fctName, unsigned nbArgs) { + EXPECT_TRUE(result.hasOneUse()); + for (auto &u : result.getUses()) { + if (mlir::isa(*u.getOwner())) { + checkCallOp(u.getOwner(), fctName, nbArgs); + } else { + auto convOp = mlir::dyn_cast(*u.getOwner()); + checkCallOpFromResultBox(convOp.getResult(), fctName, nbArgs); + } + } +} + +#endif // FORTRAN_OPTIMIZER_BUILDER_RUNTIME_RUNTIMECALLTESTBASE_H diff --git a/flang/unittests/Optimizer/Builder/Runtime/TransformationalTest.cpp b/flang/unittests/Optimizer/Builder/Runtime/TransformationalTest.cpp new file mode 100644 --- /dev/null +++ b/flang/unittests/Optimizer/Builder/Runtime/TransformationalTest.cpp @@ -0,0 +1,129 @@ +//===- TransformationalTest.cpp -- Transformational intrinsic generation --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/Builder/Runtime/Transformational.h" +#include "RuntimeCallTestBase.h" +#include "gtest/gtest.h" + +TEST_F(RuntimeCallTest, genCshiftTest) { + auto loc = firBuilder->getUnknownLoc(); + mlir::Type seqTy = + fir::SequenceType::get(fir::SequenceType::Shape(1, 10), i32Ty); + mlir::Value result = firBuilder->create(loc, seqTy); + mlir::Value array = firBuilder->create(loc, seqTy); + mlir::Value shift = firBuilder->create(loc, seqTy); + mlir::Value dim = firBuilder->create(loc, seqTy); + fir::runtime::genCshift(*firBuilder, loc, result, array, shift, dim); + checkCallOpFromResultBox(result, "_FortranACshift", 4); +} + +TEST_F(RuntimeCallTest, genCshiftVectorTest) { + auto loc = firBuilder->getUnknownLoc(); + mlir::Type seqTy = + fir::SequenceType::get(fir::SequenceType::Shape(1, 10), i32Ty); + mlir::Value result = firBuilder->create(loc, seqTy); + mlir::Value array = firBuilder->create(loc, seqTy); + mlir::Value shift = firBuilder->create(loc, seqTy); + fir::runtime::genCshiftVector(*firBuilder, loc, result, array, shift); + checkCallOpFromResultBox(result, "_FortranACshiftVector", 3); +} + +TEST_F(RuntimeCallTest, genEoshiftTest) { + auto loc = firBuilder->getUnknownLoc(); + mlir::Type seqTy = + fir::SequenceType::get(fir::SequenceType::Shape(1, 10), i32Ty); + mlir::Value result = firBuilder->create(loc, seqTy); + mlir::Value array = firBuilder->create(loc, seqTy); + mlir::Value shift = firBuilder->create(loc, seqTy); + mlir::Value bound = firBuilder->create(loc, seqTy); + mlir::Value dim = firBuilder->create(loc, seqTy); + fir::runtime::genEoshift(*firBuilder, loc, result, array, shift, bound, dim); + checkCallOpFromResultBox(result, "_FortranAEoshift", 5); +} + +TEST_F(RuntimeCallTest, genEoshiftVectorTest) { + auto loc = firBuilder->getUnknownLoc(); + mlir::Type seqTy = + fir::SequenceType::get(fir::SequenceType::Shape(1, 10), i32Ty); + mlir::Value result = firBuilder->create(loc, seqTy); + mlir::Value array = firBuilder->create(loc, seqTy); + mlir::Value shift = firBuilder->create(loc, seqTy); + mlir::Value bound = firBuilder->create(loc, seqTy); + fir::runtime::genEoshiftVector(*firBuilder, loc, result, array, shift, bound); + checkCallOpFromResultBox(result, "_FortranAEoshiftVector", 4); +} + +TEST_F(RuntimeCallTest, genMatmulTest) { + auto loc = firBuilder->getUnknownLoc(); + mlir::Type seqTy = + fir::SequenceType::get(fir::SequenceType::Shape(1, 10), i32Ty); + mlir::Value result = firBuilder->create(loc, seqTy); + mlir::Value matrixA = firBuilder->create(loc, seqTy); + mlir::Value matrixB = firBuilder->create(loc, seqTy); + fir::runtime::genMatmul(*firBuilder, loc, matrixA, matrixB, result); + checkCallOpFromResultBox(result, "_FortranAMatmul", 3); +} + +TEST_F(RuntimeCallTest, genPackTest) { + auto loc = firBuilder->getUnknownLoc(); + mlir::Type seqTy = + fir::SequenceType::get(fir::SequenceType::Shape(1, 10), i32Ty); + mlir::Value result = firBuilder->create(loc, seqTy); + mlir::Value array = firBuilder->create(loc, seqTy); + mlir::Value mask = firBuilder->create(loc, seqTy); + mlir::Value vector = firBuilder->create(loc, seqTy); + fir::runtime::genPack(*firBuilder, loc, result, array, mask, vector); + checkCallOpFromResultBox(result, "_FortranAPack", 4); +} + +TEST_F(RuntimeCallTest, genReshapeTest) { + auto loc = firBuilder->getUnknownLoc(); + mlir::Type seqTy = + fir::SequenceType::get(fir::SequenceType::Shape(1, 10), i32Ty); + mlir::Value result = firBuilder->create(loc, seqTy); + mlir::Value source = firBuilder->create(loc, seqTy); + mlir::Value shape = firBuilder->create(loc, seqTy); + mlir::Value pad = firBuilder->create(loc, seqTy); + mlir::Value order = firBuilder->create(loc, seqTy); + fir::runtime::genReshape(*firBuilder, loc, result, source, shape, pad, order); + checkCallOpFromResultBox(result, "_FortranAReshape", 5); +} + +TEST_F(RuntimeCallTest, genSpreadTest) { + auto loc = firBuilder->getUnknownLoc(); + mlir::Type seqTy = + fir::SequenceType::get(fir::SequenceType::Shape(1, 10), i32Ty); + mlir::Value result = firBuilder->create(loc, seqTy); + mlir::Value source = firBuilder->create(loc, seqTy); + mlir::Value dim = firBuilder->create(loc, seqTy); + mlir::Value ncopies = firBuilder->create(loc, seqTy); + fir::runtime::genSpread(*firBuilder, loc, result, source, dim, ncopies); + checkCallOpFromResultBox(result, "_FortranASpread", 4); +} + +TEST_F(RuntimeCallTest, genTransposeTest) { + auto loc = firBuilder->getUnknownLoc(); + mlir::Type seqTy = + fir::SequenceType::get(fir::SequenceType::Shape(1, 10), i32Ty); + mlir::Value result = firBuilder->create(loc, seqTy); + mlir::Value source = firBuilder->create(loc, seqTy); + fir::runtime::genTranspose(*firBuilder, loc, result, source); + checkCallOpFromResultBox(result, "_FortranATranspose", 2); +} + +TEST_F(RuntimeCallTest, genUnpack) { + auto loc = firBuilder->getUnknownLoc(); + mlir::Type seqTy = + fir::SequenceType::get(fir::SequenceType::Shape(1, 10), i32Ty); + mlir::Value result = firBuilder->create(loc, seqTy); + mlir::Value vector = firBuilder->create(loc, seqTy); + mlir::Value mask = firBuilder->create(loc, seqTy); + mlir::Value field = firBuilder->create(loc, seqTy); + fir::runtime::genUnpack(*firBuilder, loc, result, vector, mask, field); + checkCallOpFromResultBox(result, "_FortranAUnpack", 4); +} diff --git a/flang/unittests/Optimizer/CMakeLists.txt b/flang/unittests/Optimizer/CMakeLists.txt --- a/flang/unittests/Optimizer/CMakeLists.txt +++ b/flang/unittests/Optimizer/CMakeLists.txt @@ -14,6 +14,7 @@ Builder/DoLoopHelperTest.cpp Builder/FIRBuilderTest.cpp Builder/Runtime/ReductionTest.cpp + Builder/Runtime/TransformationalTest.cpp FIRContextTest.cpp InternalNamesTest.cpp KindMappingTest.cpp