diff --git a/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h b/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h --- a/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h +++ b/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h @@ -23,6 +23,15 @@ void populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns); +/// Appends patterns to convert vector reduction of the form: +/// ``` +/// vector.reduction , (muli (ext %lhs), (ext %rhs)), [%acc] +/// ``` +/// +/// to SPIR-V integer dot product ops. +void populateVectorReductionToSPIRVDotProductPatterns( + RewritePatternSet &patterns); + } // namespace mlir #endif // MLIR_CONVERSION_VECTORTOSPIRV_VECTORTOSPIRV_H diff --git a/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt --- a/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt @@ -10,6 +10,7 @@ intrinsics_gen LINK_LIBS PUBLIC + MLIRArithDialect MLIRSPIRVDialect MLIRSPIRVConversion MLIRVectorDialect diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -12,6 +12,7 @@ #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" @@ -20,6 +21,9 @@ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" @@ -436,6 +440,84 @@ } }; +struct VectorReductionToDotProd final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ReductionOp op, + PatternRewriter &rewriter) const override { + if (op.getKind() != vector::CombiningKind::ADD) + return rewriter.notifyMatchFailure(op, "combining kind is not 'add'"); + + auto resultType = dyn_cast(op.getType()); + if (!resultType) + return rewriter.notifyMatchFailure(op, "result is not an integer"); + + int64_t resultBitwidth = resultType.getIntOrFloatBitWidth(); + if (!llvm::is_contained({32, 64}, resultBitwidth)) + return rewriter.notifyMatchFailure(op, "unsupported integer bitwidth"); + + VectorType inVecTy = op.getSourceVectorType(); + if (inVecTy.getNumElements() != 4 || inVecTy.getShape().size() != 1 || + inVecTy.isScalable()) + return rewriter.notifyMatchFailure(op, "unsupported vector shape"); + + auto mul = op.getVector().getDefiningOp(); + if (!mul) + return rewriter.notifyMatchFailure( + op, "reduction operand is not 'arith.muli'"); + + if (succeeded(handleCase(op, mul, rewriter))) + return success(); + + if (succeeded(handleCase(op, mul, rewriter))) + return success(); + + if (succeeded(handleCase(op, mul, rewriter))) + return success(); + + if (succeeded(handleCase(op, mul, rewriter))) + return success(); + + return failure(); + } + +private: + template + static LogicalResult handleCase(vector::ReductionOp op, arith::MulIOp mul, + PatternRewriter &rewriter) { + auto lhs = mul.getLhs().getDefiningOp(); + if (!lhs || !getElementTypeOrSelf(lhs.getIn().getType()).isInteger(8)) + return failure(); + + auto rhs = mul.getRhs().getDefiningOp(); + if (!rhs || !getElementTypeOrSelf(rhs.getIn().getType()).isInteger(8)) + return failure(); + + Value lhsIn = lhs.getIn(); + Value rhsIn = rhs.getIn(); + + // There's no variant of dot prod ops for unsigned LHS and signed RHS, so + // we have to swap operands instead in that case. + if (SwapOperands) + std::swap(lhsIn, rhsIn); + + if (Value acc = op.getAcc()) { + rewriter.replaceOpWithNewOp(op, op.getType(), lhsIn, rhsIn, acc, + nullptr); + } else { + rewriter.replaceOpWithNewOp(op, op.getType(), lhsIn, rhsIn, + nullptr); + } + + return success(); + } +}; + } // namespace #define CL_MAX_MIN_OPS \ spirv::CLFMaxOp, spirv::CLFMinOp, spirv::CLUMaxOp, spirv::CLUMinOp, \ @@ -457,3 +539,8 @@ VectorShuffleOpConvert, VectorSplatPattern>(typeConverter, patterns.getContext()); } + +void mlir::populateVectorReductionToSPIRVDotProductPatterns( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-reduction-to-spirv-dot-prod.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-reduction-to-spirv-dot-prod.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/VectorToSPIRV/vector-reduction-to-spirv-dot-prod.mlir @@ -0,0 +1,176 @@ +// RUN: mlir-opt --split-input-file --verify-diagnostics \ +// RUN: --test-vector-reduction-to-spirv-dot-prod %s -o - | FileCheck %s + +// Positive tests. + +// CHECK-LABEL: func.func @to_sdot +// CHECK-SAME: ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>) +// CHECK-NEXT: [[DOT:%.+]] = spirv.SDot [[ARG0]], [[ARG1]] : (vector<4xi8>, vector<4xi8>) -> i32 +// CHECK-NEXT: return [[DOT]] : i32 +func.func @to_sdot(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i32 { + %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi32> + %rhs = arith.extsi %arg1 : vector<4xi8> to vector<4xi32> + %mul = arith.muli %lhs, %rhs : vector<4xi32> + %red = vector.reduction , %mul : vector<4xi32> into i32 + return %red : i32 +} + +// CHECK-LABEL: func.func @to_sdot_acc +// CHECK-SAME: ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>, [[ACC:%.+]]: i32) +// CHECK-NEXT: [[DOT:%.+]] = spirv.SDotAccSat [[ARG0]], [[ARG1]], [[ACC]] : (vector<4xi8>, vector<4xi8>, i32) -> i32 +// CHECK-NEXT: return [[DOT]] : i32 +func.func @to_sdot_acc(%arg0: vector<4xi8>, %arg1: vector<4xi8>, %acc: i32) -> i32 { + %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi32> + %rhs = arith.extsi %arg1 : vector<4xi8> to vector<4xi32> + %mul = arith.muli %lhs, %rhs : vector<4xi32> + %red = vector.reduction , %mul, %acc : vector<4xi32> into i32 + return %red : i32 +} + +// CHECK-LABEL: func.func @to_sdot_i64 +// CHECK-SAME: ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>) +// CHECK-NEXT: [[DOT:%.+]] = spirv.SDot [[ARG0]], [[ARG1]] : (vector<4xi8>, vector<4xi8>) -> i64 +// CHECK-NEXT: return [[DOT]] : i64 +func.func @to_sdot_i64(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i64 { + %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi64> + %rhs = arith.extsi %arg1 : vector<4xi8> to vector<4xi64> + %mul = arith.muli %lhs, %rhs : vector<4xi64> + %red = vector.reduction , %mul : vector<4xi64> into i64 + return %red : i64 +} + +// CHECK-LABEL: func.func @to_sdot_acc_i64 +// CHECK-SAME: ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>, [[ACC:%.+]]: i64) +// CHECK-NEXT: [[DOT:%.+]] = spirv.SDotAccSat [[ARG0]], [[ARG1]], [[ACC]] : (vector<4xi8>, vector<4xi8>, i64) -> i64 +// CHECK-NEXT: return [[DOT]] : i64 +func.func @to_sdot_acc_i64(%arg0: vector<4xi8>, %arg1: vector<4xi8>, %acc: i64) -> i64 { + %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi64> + %rhs = arith.extsi %arg1 : vector<4xi8> to vector<4xi64> + %mul = arith.muli %lhs, %rhs : vector<4xi64> + %red = vector.reduction , %mul, %acc : vector<4xi64> into i64 + return %red : i64 +} + +// CHECK-LABEL: func.func @to_udot +// CHECK-SAME: ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>) +// CHECK-NEXT: [[DOT:%.+]] = spirv.UDot [[ARG0]], [[ARG1]] : (vector<4xi8>, vector<4xi8>) -> i32 +// CHECK-NEXT: return [[DOT]] : i32 +func.func @to_udot(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i32 { + %lhs = arith.extui %arg0 : vector<4xi8> to vector<4xi32> + %rhs = arith.extui %arg1 : vector<4xi8> to vector<4xi32> + %mul = arith.muli %lhs, %rhs : vector<4xi32> + %red = vector.reduction , %mul : vector<4xi32> into i32 + return %red : i32 +} + +// CHECK-LABEL: func.func @to_udot_acc +// CHECK-SAME: ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>, [[ACC:%.+]]: i32) +// CHECK-NEXT: [[DOT:%.+]] = spirv.UDotAccSat [[ARG0]], [[ARG1]], [[ACC]] : (vector<4xi8>, vector<4xi8>, i32) -> i32 +// CHECK-NEXT: return [[DOT]] : i32 +func.func @to_udot_acc(%arg0: vector<4xi8>, %arg1: vector<4xi8>, %acc: i32) -> i32 { + %lhs = arith.extui %arg0 : vector<4xi8> to vector<4xi32> + %rhs = arith.extui %arg1 : vector<4xi8> to vector<4xi32> + %mul = arith.muli %lhs, %rhs : vector<4xi32> + %red = vector.reduction , %mul, %acc : vector<4xi32> into i32 + return %red : i32 +} + +// CHECK-LABEL: func.func @to_signed_unsigned_dot +// CHECK-SAME: ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>) +// CHECK-NEXT: [[DOT:%.+]] = spirv.SUDot [[ARG0]], [[ARG1]] : (vector<4xi8>, vector<4xi8>) -> i32 +// CHECK-NEXT: return [[DOT]] : i32 +func.func @to_signed_unsigned_dot(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i32 { + %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi32> + %rhs = arith.extui %arg1 : vector<4xi8> to vector<4xi32> + %mul = arith.muli %lhs, %rhs : vector<4xi32> + %red = vector.reduction , %mul : vector<4xi32> into i32 + return %red : i32 +} + +// CHECK-LABEL: func.func @to_signed_unsigned_dot_acc +// CHECK-SAME: ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>, [[ACC:%.+]]: i32) +// CHECK-NEXT: [[DOT:%.+]] = spirv.SUDotAccSat [[ARG0]], [[ARG1]], [[ACC]] : (vector<4xi8>, vector<4xi8>, i32) -> i32 +// CHECK-NEXT: return [[DOT]] : i32 +func.func @to_signed_unsigned_dot_acc(%arg0: vector<4xi8>, %arg1: vector<4xi8>, %acc: i32) -> i32 { + %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi32> + %rhs = arith.extui %arg1 : vector<4xi8> to vector<4xi32> + %mul = arith.muli %lhs, %rhs : vector<4xi32> + %red = vector.reduction , %mul, %acc : vector<4xi32> into i32 + return %red : i32 +} + +// CHECK-LABEL: func.func @to_unsigned_signed_dot +// CHECK-SAME: ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>) +// CHECK-NEXT: [[DOT:%.+]] = spirv.SUDot [[ARG1]], [[ARG0]] : (vector<4xi8>, vector<4xi8>) -> i32 +// CHECK-NEXT: return [[DOT]] : i32 +func.func @to_unsigned_signed_dot(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i32 { + %lhs = arith.extui %arg0 : vector<4xi8> to vector<4xi32> + %rhs = arith.extsi %arg1 : vector<4xi8> to vector<4xi32> + %mul = arith.muli %lhs, %rhs : vector<4xi32> + %red = vector.reduction , %mul : vector<4xi32> into i32 + return %red : i32 +} + +// CHECK-LABEL: func.func @to_unsigned_signed_dot_acc +// CHECK-SAME: ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>, [[ACC:%.+]]: i32) +// CHECK-NEXT: [[DOT:%.+]] = spirv.SUDotAccSat [[ARG1]], [[ARG0]], [[ACC]] : (vector<4xi8>, vector<4xi8>, i32) -> i32 +// CHECK-NEXT: return [[DOT]] : i32 +func.func @to_unsigned_signed_dot_acc(%arg0: vector<4xi8>, %arg1: vector<4xi8>, %acc: i32) -> i32 { + %lhs = arith.extui %arg0 : vector<4xi8> to vector<4xi32> + %rhs = arith.extsi %arg1 : vector<4xi8> to vector<4xi32> + %mul = arith.muli %lhs, %rhs : vector<4xi32> + %red = vector.reduction , %mul, %acc : vector<4xi32> into i32 + return %red : i32 +} + +// ----- +// Negative tests. + +// CHECK-LABEL: func.func @too_short +// CHECK-SAME: ([[ARG0:%.+]]: vector<3xi8>, [[ARG1:%.+]]: vector<3xi8>) +// CHECK: [[RED:%.+]] = vector.reduction +// CHECK-NEXT: return [[RED]] : i32 +func.func @too_short(%arg0: vector<3xi8>, %arg1: vector<3xi8>) -> i32 { + %lhs = arith.extsi %arg0 : vector<3xi8> to vector<3xi32> + %rhs = arith.extsi %arg1 : vector<3xi8> to vector<3xi32> + %mul = arith.muli %lhs, %rhs : vector<3xi32> + %red = vector.reduction , %mul : vector<3xi32> into i32 + return %red : i32 +} + +// CHECK-LABEL: func.func @too_long +// CHECK-SAME: ([[ARG0:%.+]]: vector<6xi8>, [[ARG1:%.+]]: vector<6xi8>) +// CHECK: [[RED:%.+]] = vector.reduction +// CHECK-NEXT: return [[RED]] : i32 +func.func @too_long(%arg0: vector<6xi8>, %arg1: vector<6xi8>) -> i32 { + %lhs = arith.extsi %arg0 : vector<6xi8> to vector<6xi32> + %rhs = arith.extsi %arg1 : vector<6xi8> to vector<6xi32> + %mul = arith.muli %lhs, %rhs : vector<6xi32> + %red = vector.reduction , %mul : vector<6xi32> into i32 + return %red : i32 +} + +// CHECK-LABEL: func.func @wrong_reduction_kind +// CHECK-SAME: ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>) +// CHECK: [[RED:%.+]] = vector.reduction +// CHECK-NEXT: return [[RED]] : i32 +func.func @wrong_reduction_kind(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i32 { + %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi32> + %rhs = arith.extsi %arg1 : vector<4xi8> to vector<4xi32> + %mul = arith.muli %lhs, %rhs : vector<4xi32> + %red = vector.reduction , %mul : vector<4xi32> into i32 + return %red : i32 +} + +// CHECK-LABEL: func.func @wrong_arith_op +// CHECK-SAME: ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>) +// CHECK: [[ADD:%.+]] = arith.addi +// CHECK: [[RED:%.+]] = vector.reduction , [[ADD]] +// CHECK-NEXT: return [[RED]] : i32 +func.func @wrong_arith_op(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i32 { + %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi32> + %rhs = arith.extsi %arg1 : vector<4xi8> to vector<4xi32> + %add = arith.addi %lhs, %rhs : vector<4xi32> + %red = vector.reduction , %add : vector<4xi32> into i32 + return %red : i32 +} diff --git a/mlir/test/lib/Conversion/CMakeLists.txt b/mlir/test/lib/Conversion/CMakeLists.txt --- a/mlir/test/lib/Conversion/CMakeLists.txt +++ b/mlir/test/lib/Conversion/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(FuncToLLVM) +add_subdirectory(VectorToSPIRV) diff --git a/mlir/test/lib/Conversion/VectorToSPIRV/CMakeLists.txt b/mlir/test/lib/Conversion/VectorToSPIRV/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Conversion/VectorToSPIRV/CMakeLists.txt @@ -0,0 +1,15 @@ +# Exclude tests from libMLIR.so +add_mlir_library(MLIRTestVectorToSPIRV + TestVectorReductionToSPIRVDotProd.cpp + + EXCLUDE_FROM_LIBMLIR + + LINK_LIBS PUBLIC + MLIRVectorToSPIRV + MLIRArithDialect + MLIRFuncDialect + MLIRSPIRVDialect + MLIRVectorDialect + MLIRPass + MLIRTransforms + ) diff --git a/mlir/test/lib/Conversion/VectorToSPIRV/TestVectorReductionToSPIRVDotProd.cpp b/mlir/test/lib/Conversion/VectorToSPIRV/TestVectorReductionToSPIRVDotProd.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Conversion/VectorToSPIRV/TestVectorReductionToSPIRVDotProd.cpp @@ -0,0 +1,55 @@ +//===- TestVectorReductionToSPIRVDotProd.cpp - Test reduction to dot prod -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +namespace { + +struct TestVectorReductionToSPIRVDotProd + : PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + TestVectorReductionToSPIRVDotProd) + + StringRef getArgument() const final { + return "test-vector-reduction-to-spirv-dot-prod"; + } + + StringRef getDescription() const final { + return "Test lowering patterns that converts vector.reduction to SPIR-V " + "integer dot product ops"; + } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateVectorReductionToSPIRVDotProductPatterns(patterns); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; + +} // namespace + +namespace test { +void registerTestVectorReductionToSPIRVDotProd() { + PassRegistration(); +} +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt --- a/mlir/tools/mlir-opt/CMakeLists.txt +++ b/mlir/tools/mlir-opt/CMakeLists.txt @@ -41,6 +41,7 @@ MLIRTestTransforms MLIRTilingInterfaceTestPasses MLIRVectorTestPasses + MLIRTestVectorToSPIRV MLIRLLVMTestPasses ) endif() diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -124,6 +124,7 @@ void registerTestTransformDialectInterpreterPass(); void registerTestWrittenToPass(); void registerTestVectorLowerings(); +void registerTestVectorReductionToSPIRVDotProd(); void registerTestNvgpuLowerings(); } // namespace test } // namespace mlir @@ -231,6 +232,7 @@ mlir::test::registerTestTransformDialectEraseSchedulePass(); mlir::test::registerTestTransformDialectInterpreterPass(); mlir::test::registerTestVectorLowerings(); + mlir::test::registerTestVectorReductionToSPIRVDotProd(); mlir::test::registerTestNvgpuLowerings(); mlir::test::registerTestWrittenToPass(); } diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -4263,9 +4263,11 @@ ]), includes = ["include"], deps = [ + ":ArithDialect", ":ConversionPassIncGen", ":IR", ":Pass", + ":Support", ":SPIRVConversion", ":SPIRVDialect", ":Transforms", @@ -7077,6 +7079,7 @@ "//mlir/test:TestTransforms", "//mlir/test:TestTypeDialect", "//mlir/test:TestVector", + "//mlir/test:TestVectorToSPIRV", ], ) diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -488,7 +488,6 @@ "//mlir:SCFDialect", "//mlir:SPIRVDialect", "//mlir:Support", - "//mlir:TransformUtils", ], ) @@ -507,6 +506,20 @@ ], ) +cc_library( + name = "TestVectorToSPIRV", + srcs = glob(["lib/Conversion/VectorToSPIRV/*.cpp"]), + deps = [ + "//mlir:ArithDialect", + "//mlir:FuncDialect", + "//mlir:Pass", + "//mlir:SPIRVDialect", + "//mlir:Transforms", + "//mlir:VectorDialect", + "//mlir:VectorToSPIRV", + ], +) + cc_library( name = "TestAffine", srcs = glob([