diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -46,8 +46,8 @@ Value bbArg = warpOpBody->getArgument(it.index()); rewriter.setInsertionPoint(ifOp); - Value buffer = options.warpAllocationFn(warpOp->getLoc(), rewriter, warpOp, - bbArg.getType()); + Value buffer = + options.warpAllocationFn(loc, rewriter, warpOp, bbArg.getType()); // Store arg vector into buffer. rewriter.setInsertionPoint(ifOp); @@ -68,7 +68,7 @@ // Insert sync after all the stores and before all the loads. if (!warpOp.getArgs().empty()) { rewriter.setInsertionPoint(ifOp); - options.warpSyncronizationFn(warpOp->getLoc(), rewriter, warpOp); + options.warpSyncronizationFn(loc, rewriter, warpOp); } // Move body of warpOp to ifOp. @@ -82,8 +82,8 @@ Value val = it.value(); Type resultType = warpOp->getResultTypes()[it.index()]; rewriter.setInsertionPoint(ifOp); - Value buffer = options.warpAllocationFn(warpOp->getLoc(), rewriter, warpOp, - val.getType()); + Value buffer = + options.warpAllocationFn(loc, rewriter, warpOp, val.getType()); // Store yielded value into buffer. rewriter.setInsertionPoint(yieldOp); @@ -121,7 +121,7 @@ // Insert sync after all the stores and before all the loads. if (!yieldOp.operands().empty()) { rewriter.setInsertionPointAfter(ifOp); - options.warpSyncronizationFn(warpOp->getLoc(), rewriter, warpOp); + options.warpSyncronizationFn(loc, rewriter, warpOp); } // Delete terminator and add empty scf.yield. @@ -148,7 +148,12 @@ Region &opBody = warpOp.getBodyRegion(); Region &newOpBody = newWarpOp.getBodyRegion(); + Block &newOpFirstBlock = newOpBody.front(); rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin()); + rewriter.eraseBlock(&newOpFirstBlock); + assert(newWarpOp.getWarpRegion().hasOneBlock() && + "expected WarpOp with single block"); + auto yield = cast(newOpBody.getBlocks().begin()->getTerminator());