diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -457,8 +457,8 @@ return rewriter.notifyMatchFailure(op, "unsupported integer bitwidth"); VectorType inVecTy = op.getSourceVectorType(); - if (inVecTy.getNumElements() != 4 || inVecTy.getShape().size() != 1 || - inVecTy.isScalable()) + if (!llvm::is_contained({4, 3}, inVecTy.getNumElements()) || + inVecTy.getShape().size() != 1 || inVecTy.isScalable()) return rewriter.notifyMatchFailure(op, "unsupported vector shape"); auto mul = op.getVector().getDefiningOp(); @@ -491,15 +491,31 @@ static LogicalResult handleCase(vector::ReductionOp op, arith::MulIOp mul, PatternRewriter &rewriter) { auto lhs = mul.getLhs().getDefiningOp(); - if (!lhs || !getElementTypeOrSelf(lhs.getIn().getType()).isInteger(8)) + if (!lhs) + return failure(); + Value lhsIn = lhs.getIn(); + auto lhsInType = cast(lhsIn.getType()); + if (!lhsInType.getElementType().isInteger(8)) return failure(); auto rhs = mul.getRhs().getDefiningOp(); - if (!rhs || !getElementTypeOrSelf(rhs.getIn().getType()).isInteger(8)) + if (!rhs) return failure(); - - Value lhsIn = lhs.getIn(); Value rhsIn = rhs.getIn(); + auto rhsInType = cast(rhsIn.getType()); + if (!rhsInType.getElementType().isInteger(8)) + return failure(); + + if (op.getSourceVectorType().getNumElements() == 3) { + IntegerType i8Type = rewriter.getI8Type(); + auto v4i8Type = VectorType::get({4}, i8Type); + Location loc = op.getLoc(); + Value zero = spirv::ConstantOp::getZero(i8Type, loc, rewriter); + lhsIn = rewriter.create( + loc, v4i8Type, ValueRange{lhsIn, zero}); + rhsIn = rewriter.create( + loc, v4i8Type, ValueRange{rhsIn, zero}); + } // There's no variant of dot prod ops for unsigned LHS and signed RHS, so // we have to swap operands instead in that case. diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-reduction-to-spirv-dot-prod.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-reduction-to-spirv-dot-prod.mlir --- a/mlir/test/Conversion/VectorToSPIRV/vector-reduction-to-spirv-dot-prod.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-reduction-to-spirv-dot-prod.mlir @@ -123,18 +123,34 @@ return %red : i32 } +// CHECK-LABEL: func.func @to_sdot_vector3 +// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi8>, %[[ARG1:.+]]: vector<3xi8>) +// CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i8 +// CHECK: %[[LHS:.+]] = spirv.CompositeConstruct %[[ARG0]], %[[ZERO]] : (vector<3xi8>, i8) -> vector<4xi8> +// CHECK: %[[RHS:.+]] = spirv.CompositeConstruct %[[ARG1]], %[[ZERO]] : (vector<3xi8>, i8) -> vector<4xi8> +// CHECK: %[[SDOT:.+]] = spirv.SDot %[[LHS]], %[[RHS]] : (vector<4xi8>, vector<4xi8>) -> i32 +// CHECK: return %[[SDOT]] +func.func @to_sdot_vector3(%arg0: vector<3xi8>, %arg1: vector<3xi8>) -> i32 { + %lhs = arith.extsi %arg0 : vector<3xi8> to vector<3xi32> + %rhs = arith.extsi %arg1 : vector<3xi8> to vector<3xi32> + %mul = arith.muli %lhs, %rhs : vector<3xi32> + %red = vector.reduction , %mul : vector<3xi32> into i32 + return %red : i32 +} + // ----- + // Negative tests. // CHECK-LABEL: func.func @too_short -// CHECK-SAME: ([[ARG0:%.+]]: vector<3xi8>, [[ARG1:%.+]]: vector<3xi8>) +// CHECK-SAME: ([[ARG0:%.+]]: vector<2xi8>, [[ARG1:%.+]]: vector<2xi8>) // CHECK: [[RED:%.+]] = vector.reduction // CHECK-NEXT: return [[RED]] : i32 -func.func @too_short(%arg0: vector<3xi8>, %arg1: vector<3xi8>) -> i32 { - %lhs = arith.extsi %arg0 : vector<3xi8> to vector<3xi32> - %rhs = arith.extsi %arg1 : vector<3xi8> to vector<3xi32> - %mul = arith.muli %lhs, %rhs : vector<3xi32> - %red = vector.reduction , %mul : vector<3xi32> into i32 +func.func @too_short(%arg0: vector<2xi8>, %arg1: vector<2xi8>) -> i32 { + %lhs = arith.extsi %arg0 : vector<2xi8> to vector<2xi32> + %rhs = arith.extsi %arg1 : vector<2xi8> to vector<2xi32> + %mul = arith.muli %lhs, %rhs : vector<2xi32> + %red = vector.reduction , %mul : vector<2xi32> into i32 return %red : i32 }