diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -2764,19 +2764,24 @@ OpBuilder::InsertionGuard guard(odsBuilder); - SmallVector blockArgLocs; + // Build before region. + SmallVector beforeArgLocs; + beforeArgLocs.reserve(operands.size()); for (Value operand : operands) { - blockArgLocs.push_back(operand.getLoc()); + beforeArgLocs.push_back(operand.getLoc()); } Region *beforeRegion = odsState.addRegion(); - Block *beforeBlock = odsBuilder.createBlock(beforeRegion, /*insertPt=*/{}, - resultTypes, blockArgLocs); + Block *beforeBlock = odsBuilder.createBlock( + beforeRegion, /*insertPt=*/{}, operands.getTypes(), beforeArgLocs); beforeBuilder(odsBuilder, odsState.location, beforeBlock->getArguments()); + // Build after region. + SmallVector afterArgLocs(resultTypes.size(), odsState.location); + Region *afterRegion = odsState.addRegion(); Block *afterBlock = odsBuilder.createBlock(afterRegion, /*insertPt=*/{}, - resultTypes, blockArgLocs); + resultTypes, afterArgLocs); afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments()); } diff --git a/mlir/test/Dialect/SCF/while-op-builder.mlir b/mlir/test/Dialect/SCF/while-op-builder.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SCF/while-op-builder.mlir @@ -0,0 +1,37 @@ +// RUN: mlir-opt %s -test-scf-while-op-builder | FileCheck %s + +// CHECK-LABEL: @testMatchingTypes +func.func @testMatchingTypes(%arg0 : i32) { + %0 = scf.while (%arg1 = %arg0) : (i32) -> (i32) { + %c10 = arith.constant 10 : i32 + %1 = arith.cmpi slt, %arg1, %c10 : i32 + scf.condition(%1) %arg1 : i32 + } do { + ^bb0(%arg1: i32): + scf.yield %arg1 : i32 + } + // Expect the same loop twice (the dummy added by the test pass and the + // original one). + // CHECK: %[[V0:.*]] = scf.while (%[[arg1:.*]] = %[[arg0:.*]]) : (i32) -> i32 { + // CHECK: %[[V1:.*]] = scf.while (%[[arg2:.*]] = %[[arg0]]) : (i32) -> i32 { + return +} + +// CHECK-LABEL: @testNonMatchingTypes +func.func @testNonMatchingTypes(%arg0 : i32) { + %c1 = arith.constant 1 : i32 + %c10 = arith.constant 10 : i32 + %0:2 = scf.while (%arg1 = %arg0) : (i32) -> (i32, i32) { + %1 = arith.cmpi slt, %arg1, %c10 : i32 + scf.condition(%1) %arg1, %c1 : i32, i32 + } do { + ^bb0(%arg1: i32, %arg2: i32): + %1 = arith.addi %arg1, %arg2 : i32 + scf.yield %1 : i32 + } + // Expect the same loop twice (the dummy added by the test pass and the + // original one). + // CHECK: %[[V0:.*]] = scf.while (%[[arg1:.*]] = %[[arg0:.*]]) : (i32) -> (i32, i32) { + // CHECK: %[[V1:.*]] = scf.while (%[[arg2:.*]] = %[[arg0]]) : (i32) -> (i32, i32) { + return +} diff --git a/mlir/test/lib/Dialect/SCF/CMakeLists.txt b/mlir/test/lib/Dialect/SCF/CMakeLists.txt --- a/mlir/test/lib/Dialect/SCF/CMakeLists.txt +++ b/mlir/test/lib/Dialect/SCF/CMakeLists.txt @@ -3,6 +3,7 @@ TestLoopParametricTiling.cpp TestLoopUnrolling.cpp TestSCFUtils.cpp + TestWhileOpBuilder.cpp EXCLUDE_FROM_LIBMLIR diff --git a/mlir/test/lib/Dialect/SCF/TestWhileOpBuilder.cpp b/mlir/test/lib/Dialect/SCF/TestWhileOpBuilder.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/SCF/TestWhileOpBuilder.cpp @@ -0,0 +1,82 @@ +//===- TestWhileOpBuilder.cpp - Pass to test WhileOp::build ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to test some builder functions of WhileOp. It +// tests the regression explained in https://reviews.llvm.org/D142952, where +// a WhileOp::build overload crashed when fed with operands of different types +// than the result types. +// +// To test the build function, the pass copies each WhileOp found in the body +// of a FuncOp and adds an additional WhileOp with the same operands and result +// types (but dummy computations) using the builder in question. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; +using namespace mlir::arith; +using namespace mlir::scf; + +namespace { +struct TestSCFWhileOpBuilderPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFWhileOpBuilderPass) + + StringRef getArgument() const final { return "test-scf-while-op-builder"; } + StringRef getDescription() const final { + return "test build functions of scf.while"; + } + explicit TestSCFWhileOpBuilderPass() = default; + TestSCFWhileOpBuilderPass(const TestSCFWhileOpBuilderPass &pass) = default; + + void runOnOperation() override { + func::FuncOp func = getOperation(); + func.walk([&](WhileOp whileOp) { + Location loc = whileOp->getLoc(); + ImplicitLocOpBuilder builder(loc, whileOp); + + // Create a WhileOp with the same operands and result types. + TypeRange resultTypes = whileOp->getResultTypes(); + ValueRange operands = whileOp->getOperands(); + builder.create( + loc, resultTypes, operands, /*beforeBuilder=*/ + [&](OpBuilder &b, Location loc, ValueRange args) { + // Just cast the before args into the right types for condition. + ImplicitLocOpBuilder builder(loc, b); + auto castOp = + builder.create(resultTypes, args); + auto cmp = builder.create(/*value=*/1, /*width=*/1); + builder.create(cmp, castOp->getResults()); + }, + /*afterBuilder=*/ + [&](OpBuilder &b, Location loc, ValueRange args) { + // Just cast the after args into the right types for yield. + ImplicitLocOpBuilder builder(loc, b); + auto castOp = builder.create( + operands.getTypes(), args); + builder.create(castOp->getResults()); + }); + }); + } +}; +} // namespace + +namespace mlir { +namespace test { +void registerTestSCFWhileOpBuilderPass() { + PassRegistration(); +} +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -113,6 +113,7 @@ void registerTestPreparationPassWithAllowedMemrefResults(); void registerTestRecursiveTypesPass(); void registerTestSCFUtilsPass(); +void registerTestSCFWhileOpBuilderPass(); void registerTestShapeMappingPass(); void registerTestSliceAnalysisPass(); void registerTestTensorCopyInsertionPass(); @@ -220,6 +221,7 @@ mlir::test::registerTestPDLLPasses(); mlir::test::registerTestRecursiveTypesPass(); mlir::test::registerTestSCFUtilsPass(); + mlir::test::registerTestSCFWhileOpBuilderPass(); mlir::test::registerTestShapeMappingPass(); mlir::test::registerTestSliceAnalysisPass(); mlir::test::registerTestTensorCopyInsertionPass();