diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -897,16 +897,81 @@ return failure(); unsigned int operandNumber = operand->getOperandNumber(); auto extractOp = operand->get().getDefiningOp(); - if (extractOp.getVectorType().getNumElements() != 1) - return failure(); + VectorType extractSrcType = extractOp.getVectorType(); Location loc = extractOp.getLoc(); + + // "vector.extract %v[] : vector" is an invalid op. + assert(extractSrcType.getRank() > 0 && + "vector.extract does not support rank 0 sources"); + + // "vector.extract %v[] : vector<...xf32>" can be canonicalized to %v. + if (extractOp.getPosition().empty()) + return failure(); + + // Rewrite vector.extract with 1d source to vector.extractelement. + if (extractSrcType.getRank() == 1) { + assert(extractOp.getPosition().size() == 1 && "expected 1 index"); + int64_t pos = extractOp.getPosition()[0].cast().getInt(); + rewriter.setInsertionPoint(extractOp); + rewriter.replaceOpWithNewOp( + extractOp, extractOp.getVector(), + rewriter.create(loc, pos)); + return success(); + } + + // All following cases are 2d or higher dimensional source vectors. + + if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) { + // There is no distribution, this is a broadcast. Simply move the extract + // out of the warp op. + // TODO: This could be optimized. E.g., in case of a scalar result, let + // one lane extract and shuffle the result to all other lanes (same as + // the 1d case). + SmallVector newRetIndices; + WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, {extractOp.getVector()}, + {extractOp.getVectorType()}, newRetIndices); + rewriter.setInsertionPointAfter(newWarpOp); + Value distributedVec = newWarpOp->getResult(newRetIndices[0]); + // Extract from distributed vector. + Value newExtract = rewriter.create( + loc, distributedVec, extractOp.getPosition()); + newWarpOp->getResult(operandNumber).replaceAllUsesWith(newExtract); + return success(); + } + + // Find the distributed dimension. There should be exactly one. + auto distributedType = + warpOp.getResult(operandNumber).getType().cast(); + auto yieldedType = operand->get().getType().cast(); + int64_t distributedDim = -1; + for (int64_t i = 0; i < yieldedType.getRank(); ++i) { + if (distributedType.getDimSize(i) != yieldedType.getDimSize(i)) { + // Keep this assert here in case WarpExecuteOnLane0Op gets extended to + // support distributing multiple dimensions in the future. + assert(distributedDim == -1 && "found multiple distributed dims"); + distributedDim = i; + } + } + assert(distributedDim != -1 && "could not find distributed dimension"); + + // Yield source vector from warp op. + SmallVector newDistributedShape(extractSrcType.getShape().begin(), + extractSrcType.getShape().end()); + for (int i = 0; i < distributedType.getRank(); ++i) + newDistributedShape[i + extractOp.getPosition().size()] = + distributedType.getDimSize(i); + auto newDistributedType = + VectorType::get(newDistributedShape, distributedType.getElementType()); SmallVector newRetIndices; WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( - rewriter, warpOp, {extractOp.getVector()}, {extractOp.getVectorType()}, + rewriter, warpOp, {extractOp.getVector()}, {newDistributedType}, newRetIndices); rewriter.setInsertionPointAfter(newWarpOp); + Value distributedVec = newWarpOp->getResult(newRetIndices[0]); + // Extract from distributed vector. Value newExtract = rewriter.create( - loc, newWarpOp->getResult(newRetIndices[0]), extractOp.getPosition()); + loc, distributedVec, extractOp.getPosition()); newWarpOp->getResult(operandNumber).replaceAllUsesWith(newExtract); return success(); } diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -648,17 +648,58 @@ // ----- -// CHECK-PROP-LABEL: func.func @vector_extract_simple( -// CHECK-PROP: %[[R:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>) { -// CHECK-PROP: %[[V:.*]] = "some_def"() : () -> vector<1xf32> -// CHECK-PROP: vector.yield %[[V]] : vector<1xf32> +// TODO: We could use warp shuffles instead of broadcasting the entire vector. + +// CHECK-PROP-LABEL: func.func @vector_extract_1d( +// CHECK-PROP-DAG: %[[C5_I32:.*]] = arith.constant 5 : i32 +// CHECK-PROP-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-PROP: %[[R:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<2xf32>) { +// CHECK-PROP: %[[V:.*]] = "some_def"() : () -> vector<64xf32> +// CHECK-PROP: vector.yield %[[V]] : vector<64xf32> // CHECK-PROP: } -// CHECK-PROP: %[[E:.*]] = vector.extract %[[R]][0] : vector<1xf32> -// CHECK-PROP: return %[[E]] : f32 -func.func @vector_extract_simple(%laneid: index) -> (f32) { +// CHECK-PROP: %[[E:.*]] = vector.extractelement %[[R]][%[[C1]] : index] : vector<2xf32> +// CHECK-PROP: %[[SHUFFLED:.*]], %{{.*}} = gpu.shuffle idx %[[E]], %[[C5_I32]] +// CHECK-PROP: return %[[SHUFFLED]] : f32 +func.func @vector_extract_1d(%laneid: index) -> (f32) { %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (f32) { - %0 = "some_def"() : () -> (vector<1xf32>) - %1 = vector.extract %0[0] : vector<1xf32> + %0 = "some_def"() : () -> (vector<64xf32>) + %1 = vector.extract %0[9] : vector<64xf32> + vector.yield %1 : f32 + } + return %r : f32 +} + +// ----- + +// CHECK-PROP-LABEL: func.func @vector_extract_2d( +// CHECK-PROP: %[[W:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<5x3xf32>) { +// CHECK-PROP: %[[V:.*]] = "some_def" +// CHECK-PROP: vector.yield %[[V]] : vector<5x96xf32> +// CHECK-PROP: } +// CHECK-PROP: %[[E:.*]] = vector.extract %[[W]][2] : vector<5x3xf32> +// CHECK-PROP: return %[[E]] +func.func @vector_extract_2d(%laneid: index) -> (vector<3xf32>) { + %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<3xf32>) { + %0 = "some_def"() : () -> (vector<5x96xf32>) + %1 = vector.extract %0[2] : vector<5x96xf32> + vector.yield %1 : vector<96xf32> + } + return %r : vector<3xf32> +} + +// ----- + +// CHECK-PROP-LABEL: func.func @vector_extract_2d_broadcast_scalar( +// CHECK-PROP: %[[W:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<5x96xf32>) { +// CHECK-PROP: %[[V:.*]] = "some_def" +// CHECK-PROP: vector.yield %[[V]] : vector<5x96xf32> +// CHECK-PROP: } +// CHECK-PROP: %[[E:.*]] = vector.extract %[[W]][1, 2] : vector<5x96xf32> +// CHECK-PROP: return %[[E]] +func.func @vector_extract_2d_broadcast_scalar(%laneid: index) -> (f32) { + %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (f32) { + %0 = "some_def"() : () -> (vector<5x96xf32>) + %1 = vector.extract %0[1, 2] : vector<5x96xf32> vector.yield %1 : f32 } return %r : f32 @@ -666,6 +707,42 @@ // ----- +// CHECK-PROP-LABEL: func.func @vector_extract_2d_broadcast( +// CHECK-PROP: %[[W:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<5x96xf32>) { +// CHECK-PROP: %[[V:.*]] = "some_def" +// CHECK-PROP: vector.yield %[[V]] : vector<5x96xf32> +// CHECK-PROP: } +// CHECK-PROP: %[[E:.*]] = vector.extract %[[W]][2] : vector<5x96xf32> +// CHECK-PROP: return %[[E]] +func.func @vector_extract_2d_broadcast(%laneid: index) -> (vector<96xf32>) { + %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<96xf32>) { + %0 = "some_def"() : () -> (vector<5x96xf32>) + %1 = vector.extract %0[2] : vector<5x96xf32> + vector.yield %1 : vector<96xf32> + } + return %r : vector<96xf32> +} + +// ----- + +// CHECK-PROP-LABEL: func.func @vector_extract_3d( +// CHECK-PROP: %[[W:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<8x4x96xf32>) { +// CHECK-PROP: %[[V:.*]] = "some_def" +// CHECK-PROP: vector.yield %[[V]] : vector<8x128x96xf32> +// CHECK-PROP: } +// CHECK-PROP: %[[E:.*]] = vector.extract %[[W]][2] : vector<8x4x96xf32> +// CHECK-PROP: return %[[E]] +func.func @vector_extract_3d(%laneid: index) -> (vector<4x96xf32>) { + %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<4x96xf32>) { + %0 = "some_def"() : () -> (vector<8x128x96xf32>) + %1 = vector.extract %0[2] : vector<8x128x96xf32> + vector.yield %1 : vector<128x96xf32> + } + return %r : vector<4x96xf32> +} + +// ----- + // CHECK-PROP-LABEL: func.func @vector_extractelement_0d( // CHECK-PROP: %[[R:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector) { // CHECK-PROP: %[[V:.*]] = "some_def"() : () -> vector