diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td @@ -403,6 +403,12 @@ YieldOp thenYield(); Block* elseBlock(); YieldOp elseYield(); + + /// If the condition is a constant, returns 1 for the executed block and 0 + /// for the other. Otherwise, returns `kUnknownNumRegionInvocations` for + /// both successors. + void getNumRegionInvocations(ArrayRef operands, + SmallVectorImpl &countPerRegion); }]; let hasCanonicalizer = 1; diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -474,7 +474,7 @@ // Loop bounds are not known statically. if (!lb || !ub || !step || step.getValue().getSExtValue() == 0) { - countPerRegion[0] = -1; + countPerRegion[0] = kUnknownNumRegionInvocations; return; } @@ -1181,6 +1181,23 @@ regions.push_back(RegionSuccessor(condition ? &thenRegion() : elseRegion)); } +/// If the condition is a constant, returns 1 for the executed block and 0 for +/// the other. Otherwise, returns `kUnknownNumRegionInvocations` for both +/// successors. +void IfOp::getNumRegionInvocations(ArrayRef operands, + SmallVectorImpl &countPerRegion) { + if (auto condAttr = operands.front().dyn_cast_or_null()) { + // If the condition is true, `then` is executed once and `else` zero times, + // and vice-versa. + bool cond = condAttr.getValue().isOneValue(); + countPerRegion.assign(1, cond ? 1 : 0); + countPerRegion.push_back(cond ? 0 : 1); + } else { + // Non-constant condition: unknown invocations for both successors. + countPerRegion.assign(2, kUnknownNumRegionInvocations); + } +} + namespace { // Pattern to remove unused IfOp results. struct RemoveUnusedResults : public OpRewritePattern { diff --git a/mlir/unittests/Dialect/CMakeLists.txt b/mlir/unittests/Dialect/CMakeLists.txt --- a/mlir/unittests/Dialect/CMakeLists.txt +++ b/mlir/unittests/Dialect/CMakeLists.txt @@ -7,6 +7,7 @@ MLIRDialect) add_subdirectory(Quant) +add_subdirectory(SCF) add_subdirectory(SparseTensor) add_subdirectory(SPIRV) add_subdirectory(Utils) diff --git a/mlir/unittests/Dialect/SCF/CMakeLists.txt b/mlir/unittests/Dialect/SCF/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/unittests/Dialect/SCF/CMakeLists.txt @@ -0,0 +1,9 @@ +add_mlir_unittest(MLIRSCFTests + SCFOps.cpp + ) +target_link_libraries(MLIRSPIRVImportExportTests + PRIVATE + MLIRIR + MLIRSCF + MLIRStandard + ) diff --git a/mlir/unittests/Dialect/SCF/SCFOps.cpp b/mlir/unittests/Dialect/SCF/SCFOps.cpp new file mode 100644 --- /dev/null +++ b/mlir/unittests/Dialect/SCF/SCFOps.cpp @@ -0,0 +1,67 @@ +//===- SCFOps.cpp - SCF Op Unit Tests -------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Parser.h" +#include "gtest/gtest.h" + +using namespace mlir; + +namespace { +class SCFOpsTest : public testing::Test { +public: + SCFOpsTest() { + context.getOrLoadDialect(); + context.getOrLoadDialect(); + } + +protected: + MLIRContext context; +}; + +TEST_F(SCFOpsTest, IfOpNumRegionInvocations) { + const char *const code = R"( +func @test(%cond : i1) -> () { + scf.if %cond { + scf.yield + } else { + scf.yield + } + return +} +)"; + Builder builder(&context); + + auto module = parseSourceString(code, &context); + ASSERT_TRUE(module); + scf::IfOp op; + module->walk([&](scf::IfOp ifOp) { op = ifOp; }); + ASSERT_TRUE(op); + + SmallVector countPerRegion; + op.getNumRegionInvocations({Attribute()}, countPerRegion); + ASSERT_EQ(countPerRegion.size(), 2); + ASSERT_EQ(countPerRegion[0], kUnknownNumRegionInvocations); + ASSERT_EQ(countPerRegion[1], kUnknownNumRegionInvocations); + + countPerRegion.clear(); + op.getNumRegionInvocations( + {builder.getIntegerAttr(builder.getI1Type(), true)}, countPerRegion); + ASSERT_EQ(countPerRegion.size(), 2); + ASSERT_EQ(countPerRegion[0], 1); + ASSERT_EQ(countPerRegion[1], 0); + + countPerRegion.clear(); + op.getNumRegionInvocations( + {builder.getIntegerAttr(builder.getI1Type(), false)}, countPerRegion); + ASSERT_EQ(countPerRegion.size(), 2); + ASSERT_EQ(countPerRegion[0], 0); + ASSERT_EQ(countPerRegion[1], 1); +} +} // end anonymous namespace