diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -995,19 +995,20 @@ unsigned int operandNumber = operand->getOperandNumber(); auto extractOp = operand->get().getDefiningOp(); VectorType extractSrcType = extractOp.getVectorType(); - bool is0dExtract = extractSrcType.getRank() == 0; + bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1; Type elType = extractSrcType.getElementType(); VectorType distributedVecType; - if (!is0dExtract) { + if (!is0dOrVec1Extract) { assert(extractSrcType.getRank() == 1 && "expected that extractelement src rank is 0 or 1"); + if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0) + return failure(); int64_t elementsPerLane = extractSrcType.getShape()[0] / warpOp.getWarpSize(); distributedVecType = VectorType::get({elementsPerLane}, elType); } else { distributedVecType = extractSrcType; } - // Yield source vector from warp op. Location loc = extractOp.getLoc(); SmallVector newRetIndices; @@ -1019,9 +1020,17 @@ // 0d extract: The new warp op broadcasts the source vector to all lanes. // All lanes extract the scalar. - if (is0dExtract) { - Value newExtract = - rewriter.create(loc, distributedVec); + if (is0dOrVec1Extract) { + Value newExtract; + if (extractSrcType.getRank() == 1) { + newExtract = rewriter.create( + loc, distributedVec, + rewriter.create(loc, 0)); + + } else { + newExtract = + rewriter.create(loc, distributedVec); + } newWarpOp->getResult(operandNumber).replaceAllUsesWith(newExtract); return success(); } diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -761,6 +761,26 @@ // ----- +// CHECK-PROP-LABEL: func.func @vector_extractelement_1element( +// CHECK-PROP: %[[C0:.*]] = arith.constant 0 : index +// CHECK-PROP: %[[R:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>) { +// CHECK-PROP: %[[V:.*]] = "some_def"() : () -> vector<1xf32> +// CHECK-PROP: vector.yield %[[V]] : vector<1xf32> +// CHECK-PROP: } +// CHECK-PROP: %[[E:.*]] = vector.extractelement %[[R]][%[[C0]] : index] : vector<1xf32> +// CHECK-PROP: return %[[E]] : f32 +func.func @vector_extractelement_1element(%laneid: index) -> (f32) { + %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (f32) { + %0 = "some_def"() : () -> (vector<1xf32>) + %c0 = arith.constant 0 : index + %1 = vector.extractelement %0[%c0 : index] : vector<1xf32> + vector.yield %1 : f32 + } + return %r : f32 +} + +// ----- + // CHECK-PROP: #[[$map:.*]] = affine_map<()[s0] -> (s0 ceildiv 3)> // CHECK-PROP: #[[$map1:.*]] = affine_map<()[s0] -> (s0 mod 3)> // CHECK-PROP-LABEL: func.func @vector_extractelement_1d(