diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -14,7 +14,7 @@ #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/Transforms/SideEffectUtils.h" - +#include "llvm/ADT/SetVector.h" #include using namespace mlir; @@ -165,19 +165,34 @@ } /// Helper to create a new WarpExecuteOnLane0Op region with extra outputs. +/// `indices` return the index of each new output. static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns( RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, - ValueRange newYieldedValues, TypeRange newReturnTypes) { + ValueRange newYieldedValues, TypeRange newReturnTypes, + llvm::SmallVector &indices) { SmallVector types(warpOp.getResultTypes().begin(), warpOp.getResultTypes().end()); - types.append(newReturnTypes.begin(), newReturnTypes.end()); auto yield = cast( warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); - SmallVector yieldValues(yield.getOperands().begin(), - yield.getOperands().end()); - yieldValues.append(newYieldedValues.begin(), newYieldedValues.end()); + llvm::SmallSetVector yieldValues(yield.getOperands().begin(), + yield.getOperands().end()); + for (auto newRet : llvm::zip(newYieldedValues, newReturnTypes)) { + if (yieldValues.insert(std::get<0>(newRet))) { + types.push_back(std::get<1>(newRet)); + indices.push_back(yieldValues.size() - 1); + } else { + // If the value already exit the region don't create a new output. + for (auto &yieldOperand : llvm::enumerate(yieldValues.getArrayRef())) { + if (yieldOperand.value() == std::get<0>(newRet)) { + indices.push_back(yieldOperand.index()); + break; + } + } + } + } + yieldValues.insert(newYieldedValues.begin(), newYieldedValues.end()); WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( - rewriter, warpOp, yieldValues, types); + rewriter, warpOp, yieldValues.getArrayRef(), types); rewriter.replaceOp(warpOp, newWarpOp.getResults().take_front(warpOp.getNumResults())); return newWarpOp; @@ -273,14 +288,15 @@ assert(writeOp->getParentOp() == warpOp && "write must be nested immediately under warp"); OpBuilder::InsertionGuard g(rewriter); + SmallVector newRetIndices; WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( rewriter, warpOp, ValueRange{{writeOp.getVector()}}, - TypeRange{targetType}); + TypeRange{targetType}, newRetIndices); rewriter.setInsertionPointAfter(newWarpOp); auto newWriteOp = cast(rewriter.clone(*writeOp.getOperation())); rewriter.eraseOp(writeOp); - newWriteOp.getVectorMutable().assign(newWarpOp.getResults().back()); + newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0])); return newWriteOp; } @@ -387,8 +403,9 @@ SmallVector yieldValues = {writeOp.getVector()}; SmallVector retTypes = {vecType}; + SmallVector newRetIndices; WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( - rewriter, warpOp, yieldValues, retTypes); + rewriter, warpOp, yieldValues, retTypes, newRetIndices); rewriter.setInsertionPointAfter(newWarpOp); // Create a second warp op that contains only writeOp. @@ -398,8 +415,7 @@ rewriter.setInsertionPointToStart(&body); auto newWriteOp = cast(rewriter.clone(*writeOp.getOperation())); - newWriteOp.getVectorMutable().assign( - newWarpOp.getResult(newWarpOp.getNumResults() - 1)); + newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0])); rewriter.eraseOp(writeOp); rewriter.create(newWarpOp.getLoc()); return success(); @@ -489,14 +505,14 @@ retTypes.push_back(targetType); yieldValues.push_back(operand.get()); } - unsigned numResults = warpOp.getNumResults(); + SmallVector newRetIndices; WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( - rewriter, warpOp, yieldValues, retTypes); + rewriter, warpOp, yieldValues, retTypes, newRetIndices); rewriter.setInsertionPointAfter(newWarpOp); SmallVector newOperands(elementWise->getOperands().begin(), elementWise->getOperands().end()); for (unsigned i : llvm::seq(unsigned(0), elementWise->getNumOperands())) { - newOperands[i] = newWarpOp.getResult(i + numResults); + newOperands[i] = newWarpOp.getResult(newRetIndices[i]); } OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPointAfter(newWarpOp); @@ -653,12 +669,13 @@ Location loc = broadcastOp.getLoc(); auto destVecType = warpOp->getResultTypes()[operandNumber].cast(); + SmallVector newRetIndices; WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( rewriter, warpOp, {broadcastOp.getSource()}, - {broadcastOp.getSource().getType()}); + {broadcastOp.getSource().getType()}, newRetIndices); rewriter.setInsertionPointAfter(newWarpOp); Value broadcasted = rewriter.create( - loc, destVecType, newWarpOp->getResults().back()); + loc, destVecType, newWarpOp->getResult(newRetIndices[0])); newWarpOp->getResult(operandNumber).replaceAllUsesWith(broadcasted); return success(); } @@ -814,12 +831,12 @@ SmallVector yieldValues = {reductionOp.getVector()}; SmallVector retTypes = { VectorType::get({numElements}, reductionOp.getType())}; - unsigned numResults = warpOp.getNumResults(); + SmallVector newRetIndices; WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( - rewriter, warpOp, yieldValues, retTypes); + rewriter, warpOp, yieldValues, retTypes, newRetIndices); rewriter.setInsertionPointAfter(newWarpOp); - Value laneValVec = newWarpOp.getResult(numResults); + Value laneValVec = newWarpOp.getResult(newRetIndices[0]); // First reduce on a single thread. Value perLaneReduction = rewriter.create( reductionOp.getLoc(), reductionOp.getKind(), laneValVec); diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -545,3 +545,20 @@ } return %r : f32 } + +// ----- + +// CHECK-PROP-LABEL: func @warp_duplicate_yield( +func.func @warp_duplicate_yield(%laneid: index) -> (vector<1xf32>, vector<1xf32>) { + // CHECK-PROP: %{{.*}}:2 = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>, vector<1xf32>) + %r:2 = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>, vector<1xf32>) { + %2 = "some_def"() : () -> (vector<32xf32>) + %3 = "some_def"() : () -> (vector<32xf32>) + %4 = arith.addf %2, %3 : vector<32xf32> + %5 = arith.addf %2, %2 : vector<32xf32> +// CHECK-PROP-NOT: arith.addf +// CHECK-PROP: vector.yield %{{.*}}, %{{.*}} : vector<32xf32>, vector<32xf32> + vector.yield %4, %5 : vector<32xf32>, vector<32xf32> + } + return %r#0, %r#1 : vector<1xf32>, vector<1xf32> +}