diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -13,6 +13,8 @@ #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" #include "mlir/IR/AffineExpr.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/Support/LLVM.h" #include "mlir/Transforms/SideEffectUtils.h" #include "llvm/ADT/SetVector.h" #include @@ -753,25 +755,43 @@ using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { - SmallVector resultTypes; - SmallVector yieldValues; + SmallVector newResultTypes; + SmallVector newYieldValues; + DenseMap dedupYieldOperandPositionMap; + DenseMap dedupResultPositionMap; auto yield = cast( warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); + + // Some values may be yielded multiple times and correspond to multiple + // results. Deduplicating occurs by taking each result with its matching + // yielded value, and: + // 1. recording the unique first position at which the value is yielded. + // 2. recording for the result, the first position at which the dedup'ed + // value is yielded. + // 3. skipping from the new result types / new yielded values any result + // that has no use or whose yielded value has already been seen. for (OpResult result : warpOp.getResults()) { - if (result.use_empty()) + Value yieldOperand = yield.getOperand(result.getResultNumber()); + auto it = dedupYieldOperandPositionMap.insert( + std::make_pair(yieldOperand, newResultTypes.size())); + dedupResultPositionMap.insert(std::make_pair(result, it.first->second)); + if (result.use_empty() || !it.second) continue; - resultTypes.push_back(result.getType()); - yieldValues.push_back(yield.getOperand(result.getResultNumber())); + newResultTypes.push_back(result.getType()); + newYieldValues.push_back(yieldOperand); } - if (yield.getNumOperands() == yieldValues.size()) + // No modification, exit early. + if (yield.getNumOperands() == newYieldValues.size()) return failure(); + // Move the body of the old warpOp to a new warpOp. WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( - rewriter, warpOp, yieldValues, resultTypes); - unsigned resultIndex = 0; + rewriter, warpOp, newYieldValues, newResultTypes); + // Replace results of the old warpOp by the new, deduplicated results. for (OpResult result : warpOp.getResults()) { if (result.use_empty()) continue; - result.replaceAllUsesWith(newWarpOp.getResult(resultIndex++)); + result.replaceAllUsesWith( + newWarpOp.getResult(dedupResultPositionMap.lookup(result))); } rewriter.eraseOp(warpOp); return success(); diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -669,3 +669,25 @@ } return %r : vector<1xf32> } + +// ----- + +// CHECK-PROP: func @dedup +func.func @dedup(%laneid: index, %v0: vector<4xf32>, %v1: vector<4xf32>) + -> (vector<1xf32>, vector<1xf32>) { + + // CHECK-PROP: %[[SINGLE_RES:.*]] = vector.warp_execute_on_lane_0{{.*}} -> (vector<1xf32>) { + %r:2 = vector.warp_execute_on_lane_0(%laneid)[32] + args(%v0, %v1 : vector<4xf32>, vector<4xf32>) -> (vector<1xf32>, vector<1xf32>) { + ^bb0(%arg0: vector<128xf32>, %arg1: vector<128xf32>): + + // CHECK-PROP: %[[SINGLE_VAL:.*]] = "some_def"(%{{.*}}) : (vector<128xf32>) -> vector<32xf32> + %2 = "some_def"(%arg0) : (vector<128xf32>) -> vector<32xf32> + + // CHECK-PROP: vector.yield %[[SINGLE_VAL]] : vector<32xf32> + vector.yield %2, %2 : vector<32xf32>, vector<32xf32> + } + + // CHECK-PROP: return %[[SINGLE_RES]], %[[SINGLE_RES]] : vector<1xf32>, vector<1xf32> + return %r#0, %r#1 : vector<1xf32>, vector<1xf32> +}