diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -667,21 +667,15 @@ let skipDefaultBuilders = 1; let builders = [ + OpBuilder<(ins "TypeRange":$resultTypes, "Value":$cond)>, OpBuilder<(ins "Value":$cond, "bool":$withElseRegion)>, OpBuilder<(ins "TypeRange":$resultTypes, "Value":$cond, "bool":$withElseRegion)>, - // TODO: Remove builder when it is no longer used to create invalid `if` ops - // (with a type mispatch between the op and it's inner `yield` op). - OpBuilder<(ins "TypeRange":$resultTypes, "Value":$cond, - CArg<"function_ref", - "buildTerminatedBody">:$thenBuilder, - CArg<"function_ref", - "nullptr">:$elseBuilder)>, OpBuilder<(ins "Value":$cond, CArg<"function_ref", "buildTerminatedBody">:$thenBuilder, CArg<"function_ref", - "nullptr">:$elseBuilder)> + "nullptr">:$elseBuilder)>, ]; let extraClassDeclaration = [{ diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -92,7 +92,7 @@ Type indexTy = lb.getIndexType(); broadcastedDim = lb.create( - TypeRange{indexTy}, outOfBounds, + outOfBounds, [&](OpBuilder &b, Location loc) { b.create(loc, broadcastedDim); }, @@ -293,7 +293,7 @@ loc, arith::CmpIPredicate::ult, iv, rankDiff); broadcastable = b.create( - loc, TypeRange{i1Ty}, outOfBounds, + loc, outOfBounds, [&](OpBuilder &b, Location loc) { // Non existent dimensions are always broadcastable b.create(loc, broadcastable); @@ -522,7 +522,7 @@ Value eqRank = rewriter.create(loc, arith::CmpIPredicate::eq, firstRank, rank); auto same = rewriter.create( - loc, i1Ty, eqRank, + loc, eqRank, [&](OpBuilder &b, Location loc) { Value one = b.create(loc, 1); Value init = diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -192,7 +192,7 @@ // If the condition is non-empty, generate an SCF::IfOp. if (cond) { auto check = lb.create( - resultTypes, cond, + cond, /*thenBuilder=*/ [&](OpBuilder &b, Location loc) { maybeYieldValue(b, loc, hasRetVal, inBoundsCase(b, loc)); diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp @@ -645,7 +645,7 @@ }; // Dispatch either single block compute function, or launch async dispatch. - b.create(TypeRange(), isSingleBlock, syncDispatch, asyncDispatch); + b.create(isSingleBlock, syncDispatch, asyncDispatch); } // Dispatch parallel compute functions by submitting all async compute tasks @@ -910,8 +910,8 @@ Value useBlockAlignedComputeFn = b.create( arith::CmpIPredicate::sge, blockSize, numIters); - b.create(TypeRange(), useBlockAlignedComputeFn, - dispatchBlockAligned, dispatchDefault); + b.create(useBlockAlignedComputeFn, dispatchBlockAligned, + dispatchDefault); b.create(); } else { dispatchDefault(b, loc); @@ -919,7 +919,7 @@ }; // Replace the `scf.parallel` operation with the parallel compute function. - b.create(TypeRange(), isZeroIterations, noOp, dispatch); + b.create(isZeroIterations, noOp, dispatch); // Parallel operation was replaced with a block iteration loop. rewriter.eraseOp(op); diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -1485,44 +1485,41 @@ return success(); } +void IfOp::build(OpBuilder &builder, OperationState &result, + TypeRange resultTypes, Value cond) { + result.addTypes(resultTypes); + result.addOperands(cond); + + // Build regions. + OpBuilder::InsertionGuard guard(builder); + result.addRegion(); + result.addRegion(); +} + void IfOp::build(OpBuilder &builder, OperationState &result, Value cond, bool withElseRegion) { - build(builder, result, /*resultTypes=*/std::nullopt, cond, withElseRegion); + build(builder, result, TypeRange{}, cond, withElseRegion); } void IfOp::build(OpBuilder &builder, OperationState &result, TypeRange resultTypes, Value cond, bool withElseRegion) { - auto addTerminator = [&](OpBuilder &nested, Location loc) { - if (resultTypes.empty()) - IfOp::ensureTerminator(*nested.getInsertionBlock()->getParent(), nested, - loc); - }; - - build(builder, result, resultTypes, cond, addTerminator, - withElseRegion ? addTerminator - : function_ref()); -} - -void IfOp::build(OpBuilder &builder, OperationState &result, - TypeRange resultTypes, Value cond, - function_ref thenBuilder, - function_ref elseBuilder) { - assert(thenBuilder && "the builder callback for 'then' must be present"); - result.addOperands(cond); result.addTypes(resultTypes); + result.addOperands(cond); // Build then region. OpBuilder::InsertionGuard guard(builder); Region *thenRegion = result.addRegion(); builder.createBlock(thenRegion); - thenBuilder(builder, result.location); + if (resultTypes.empty()) + IfOp::ensureTerminator(*thenRegion, builder, result.location); // Build else region. Region *elseRegion = result.addRegion(); - if (!elseBuilder) - return; - builder.createBlock(elseRegion); - elseBuilder(builder, result.location); + if (withElseRegion) { + builder.createBlock(elseRegion); + if (resultTypes.empty()) + IfOp::ensureTerminator(*elseRegion, builder, result.location); + } } void IfOp::build(OpBuilder &builder, OperationState &result, Value cond, @@ -1730,9 +1727,10 @@ [](OpResult result) { return result.getType(); }); // Create a replacement operation with empty then and else regions. - auto emptyBuilder = [](OpBuilder &, Location) {}; - auto newOp = rewriter.create(op.getLoc(), newTypes, op.getCondition(), - emptyBuilder, emptyBuilder); + auto newOp = + rewriter.create(op.getLoc(), newTypes, op.getCondition()); + rewriter.createBlock(&newOp.getThenRegion()); + rewriter.createBlock(&newOp.getElseRegion()); // Move the bodies and replace the terminators (note there is a then and // an else region since the operation returns results). @@ -1796,7 +1794,8 @@ if (nonHoistable.size() == op->getNumResults()) return failure(); - IfOp replacement = rewriter.create(op.getLoc(), nonHoistable, cond); + IfOp replacement = rewriter.create(op.getLoc(), nonHoistable, cond, + /*withElseRegion=*/false); if (replacement.thenBlock()) rewriter.eraseBlock(replacement.thenBlock()); replacement.getThenRegion().takeBody(op.getThenRegion()); @@ -2249,6 +2248,7 @@ Value newCondition = rewriter.create( loc, op.getCondition(), nestedIf.getCondition()); auto newIf = rewriter.create(loc, op.getResultTypes(), newCondition); + Block *newIfBlock = rewriter.createBlock(&newIf.getThenRegion()); SmallVector results; llvm::append_range(results, newIf.getResults()); @@ -2258,11 +2258,6 @@ results[idx] = rewriter.create( op.getLoc(), op.getCondition(), thenYield[idx], elseYield[idx]); - Block *newIfBlock = newIf.thenBlock(); - if (newIfBlock) - rewriter.eraseOp(newIfBlock->getTerminator()); - else - newIfBlock = rewriter.createBlock(&newIf.getThenRegion()); rewriter.mergeBlocks(nestedIf.thenBlock(), newIfBlock); rewriter.setInsertionPointToEnd(newIf.thenBlock()); rewriter.replaceOpWithNewOp(newIf.thenYield(), thenYield); diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp @@ -632,7 +632,7 @@ // creating SliceOps with result dimensions of size 0 at runtime. if (generateZeroSliceGuard && dynHasZeroLenCond) { auto result = b.create( - loc, resultType, dynHasZeroLenCond, + loc, dynHasZeroLenCond, /*thenBuilder=*/ [&](OpBuilder &b, Location loc) { b.create(loc, createGenerateOp()->getResult(0)); diff --git a/mlir/lib/Dialect/Tensor/Transforms/SplitPaddingPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/SplitPaddingPatterns.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/SplitPaddingPatterns.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/SplitPaddingPatterns.cpp @@ -81,8 +81,8 @@ Operation *newOp = builder.clone(*padOp); builder.create(loc, newOp->getResults()); }; - rewriter.replaceOpWithNewOp(padOp, padOp.getType(), ifCond, - thenBuilder, elseBuilder); + rewriter.replaceOpWithNewOp(padOp, ifCond, thenBuilder, + elseBuilder); return success(); } }; diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -1126,7 +1126,7 @@ Value newResult = rewriter .create( - loc, distrType, isInsertingLane, + loc, isInsertingLane, /*thenBuilder=*/ [&](OpBuilder &builder, Location loc) { Value newInsert = builder.create( @@ -1257,7 +1257,7 @@ builder.create(loc, distributedDest); }; newResult = rewriter - .create(loc, distrDestType, isInsertingLane, + .create(loc, isInsertingLane, /*thenBuilder=*/insertingBuilder, /*elseBuilder=*/nonInsertingBuilder) .getResult(0); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp @@ -252,7 +252,7 @@ Value zero = b.create(loc, 0); Value memref = xferOp.getSource(); return b.create( - loc, returnTypes, inBoundsCond, + loc, inBoundsCond, [&](OpBuilder &b, Location loc) { Value res = memref; if (compatibleMemRefType != xferOp.getShapedType()) @@ -307,7 +307,7 @@ Value zero = b.create(loc, 0); Value memref = xferOp.getSource(); return b.create( - loc, returnTypes, inBoundsCond, + loc, inBoundsCond, [&](OpBuilder &b, Location loc) { Value res = memref; if (compatibleMemRefType != xferOp.getShapedType()) @@ -358,7 +358,7 @@ Value memref = xferOp.getSource(); return b .create( - loc, returnTypes, inBoundsCond, + loc, inBoundsCond, [&](OpBuilder &b, Location loc) { Value res = memref; if (compatibleMemRefType != xferOp.getShapedType())