diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -2764,19 +2764,24 @@ OpBuilder::InsertionGuard guard(odsBuilder); - SmallVector blockArgLocs; + // Build before region. + SmallVector beforeArgLocs; + beforeArgLocs.reserve(operands.size()); for (Value operand : operands) { - blockArgLocs.push_back(operand.getLoc()); + beforeArgLocs.push_back(operand.getLoc()); } Region *beforeRegion = odsState.addRegion(); - Block *beforeBlock = odsBuilder.createBlock(beforeRegion, /*insertPt=*/{}, - resultTypes, blockArgLocs); + Block *beforeBlock = odsBuilder.createBlock( + beforeRegion, /*insertPt=*/{}, operands.getTypes(), beforeArgLocs); beforeBuilder(odsBuilder, odsState.location, beforeBlock->getArguments()); + // Build after region. + SmallVector afterArgLocs(resultTypes.size(), odsState.location); + Region *afterRegion = odsState.addRegion(); Block *afterBlock = odsBuilder.createBlock(afterRegion, /*insertPt=*/{}, - resultTypes, blockArgLocs); + resultTypes, afterArgLocs); afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments()); } diff --git a/mlir/unittests/Dialect/CMakeLists.txt b/mlir/unittests/Dialect/CMakeLists.txt --- a/mlir/unittests/Dialect/CMakeLists.txt +++ b/mlir/unittests/Dialect/CMakeLists.txt @@ -8,6 +8,7 @@ add_subdirectory(LLVMIR) add_subdirectory(MemRef) +add_subdirectory(SCF) add_subdirectory(SparseTensor) add_subdirectory(SPIRV) add_subdirectory(Transform) diff --git a/mlir/unittests/Dialect/SCF/BuildersTest.cpp b/mlir/unittests/Dialect/SCF/BuildersTest.cpp new file mode 100644 --- /dev/null +++ b/mlir/unittests/Dialect/SCF/BuildersTest.cpp @@ -0,0 +1,74 @@ +//===- BuildersTest.cpp - unit tests for scf op builders ------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinOps.h" +#include "gtest/gtest.h" + +using namespace mlir; +using namespace mlir::scf; +using namespace mlir::arith; + +// Ensure that WhileOp::biuld with builder functions for its regions compiles +// and produces IR that verifies. +TEST(SCFBuildersTest, whileOpBuilderWithRegionBuilderArgs) { + MLIRContext ctx; + ctx.getOrLoadDialect(); + ctx.getOrLoadDialect(); + + OpBuilder b(&ctx); + Location loc = UnknownLoc::get(&ctx); + Type indexType = b.getIndexType(); + + // Build the following IR: + // + // module { + // %c0 = arith.constant 0 : index + // %c1 = arith.constant 1 : index + // %c10 = arith.constant 10 : index + // %0:2 = scf.while (%arg0 = %c0) : (index) -> (index, index) { + // %1 = arith.cmpi slt, %arg0, %c10 : index + // scf.condition(%1) %arg0, %c1 : index, index + // } do { + // ^bb0(%arg0: index, %arg1: index): + // %1 = arith.addi %arg0, %arg1 : index + // scf.yield %1 : index + // } + // } + + OwningOpRef module = b.create(loc); + b.setInsertionPointToStart(&module->getBodyRegion().front()); + + Value zero = b.create(loc, 0); + Value one = b.create(loc, 1); + Value ten = b.create(loc, 10); + + // Make argument types of regions differnt to catch some possible bugs. + SmallVector operandTypes = {indexType}; + SmallVector resultTypes = {indexType, indexType}; + + // Make sure the builder runs without crash. + auto whileOp = b.create( + loc, resultTypes, /*operands=*/zero, /*beforeBuilder=*/ + [&](OpBuilder &builder, Location loc, ValueRange args) { + Value cmp = + builder.create(loc, CmpIPredicate::slt, args[0], ten); + EXPECT_EQ(args.size(), operandTypes.size()); + builder.create(loc, cmp, ValueRange{args[0], one}); + }, + /*afterBuilder=*/ + [&](OpBuilder &builder, Location loc, ValueRange args) { + EXPECT_EQ(args.size(), resultTypes.size()); + Value updatedIndex = builder.create(loc, args[0], args[1]); + builder.create(loc, updatedIndex); + }); + + // Make sure the op we built verifies. + EXPECT_TRUE(whileOp.verify().succeeded()); +} diff --git a/mlir/unittests/Dialect/SCF/CMakeLists.txt b/mlir/unittests/Dialect/SCF/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/unittests/Dialect/SCF/CMakeLists.txt @@ -0,0 +1,7 @@ +add_mlir_unittest(MLIRSCFTests + BuildersTest.cpp +) +target_link_libraries(MLIRSCFTests + PRIVATE + MLIRSCFDialect + )