diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCF.h b/mlir/include/mlir/Dialect/SCF/IR/SCF.h --- a/mlir/include/mlir/Dialect/SCF/IR/SCF.h +++ b/mlir/include/mlir/Dialect/SCF/IR/SCF.h @@ -18,6 +18,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/RegionKindInterface.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Interfaces/ParallelCombiningOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -17,6 +17,7 @@ include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/IR/RegionKindInterface.td" include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td" +include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/ParallelCombiningOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" @@ -610,12 +611,11 @@ // IfOp //===----------------------------------------------------------------------===// -def IfOp : SCF_Op<"if", - [DeclareOpInterfaceMethods, - SingleBlockImplicitTerminator<"scf::YieldOp">, RecursiveMemoryEffects, - NoRegionArguments]> { +def IfOp : SCF_Op<"if", [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + SingleBlockImplicitTerminator<"scf::YieldOp">, RecursiveMemoryEffects, + NoRegionArguments]> { let summary = "if-then-else operation"; let description = [{ The `scf.if` operation represents an if-then-else construct for diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -1467,6 +1467,23 @@ return false; } +LogicalResult +IfOp::inferReturnTypes(MLIRContext *ctx, std::optional loc, + ValueRange operands, DictionaryAttr attrs, + RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + if (regions.empty()) + return failure(); + Region *r = regions.front(); + assert(!r->empty()); + Block &b = r->front(); + auto yieldOp = llvm::dyn_cast(b.getTerminator()); + TypeRange types = yieldOp.getOperandTypes(); + inferredReturnTypes.insert(inferredReturnTypes.end(), types.begin(), + types.end()); + return success(); +} + void IfOp::build(OpBuilder &builder, OperationState &result, Value cond, bool withElseRegion) { build(builder, result, /*resultTypes=*/std::nullopt, cond, withElseRegion); @@ -1516,19 +1533,24 @@ // Build then region. OpBuilder::InsertionGuard guard(builder); Region *thenRegion = result.addRegion(); - Block *thenBlock = builder.createBlock(thenRegion); + builder.createBlock(thenRegion); thenBuilder(builder, result.location); - // Infer types if there are any. - if (auto yieldOp = llvm::dyn_cast(thenBlock->getTerminator())) - result.addTypes(yieldOp.getOperandTypes()); - // Build else region. Region *elseRegion = result.addRegion(); - if (!elseBuilder) - return; - builder.createBlock(elseRegion); - elseBuilder(builder, result.location); + if (elseBuilder) { + builder.createBlock(elseRegion); + elseBuilder(builder, result.location); + } + + // Infer result types. + SmallVector inferredReturnTypes; + MLIRContext *ctx = builder.getContext(); + auto attrDict = DictionaryAttr::get(ctx, result.attributes); + if (succeeded(inferReturnTypes(ctx, std::nullopt, result.operands, attrDict, + result.regions, inferredReturnTypes))) { + result.addTypes(inferredReturnTypes); + } } LogicalResult IfOp::verify() { diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -1900,6 +1900,7 @@ includes = ["include"], deps = [ ":ControlFlowInterfacesTdFiles", + ":InferTypeOpInterfaceTdFiles", ":LoopLikeInterfaceTdFiles", ":ParallelCombiningOpInterfaceTdFiles", ":SideEffectInterfacesTdFiles", @@ -2929,6 +2930,7 @@ ":ControlFlowInterfaces", ":FuncDialect", ":IR", + ":InferTypeOpInterface", ":LoopLikeInterface", ":MemRefDialect", ":ParallelCombiningOpInterface",