Index: mlir/include/mlir/Dialect/Vector/CMakeLists.txt =================================================================== --- mlir/include/mlir/Dialect/Vector/CMakeLists.txt +++ mlir/include/mlir/Dialect/Vector/CMakeLists.txt @@ -6,3 +6,9 @@ mlir_tablegen(VectorOpsEnums.cpp.inc -gen-enum-defs) add_public_tablegen_target(MLIRVectorOpsEnumsIncGen) add_dependencies(mlir-headers MLIRVectorOpsEnumsIncGen) + +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name Vector) +add_public_tablegen_target(MLIRVectorPassIncGen) + +add_mlir_doc(Passes -gen-pass-doc VectorPasses ./) Index: mlir/include/mlir/Dialect/Vector/Passes.h =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Vector/Passes.h @@ -0,0 +1,33 @@ +//===- Passes.h - Vector pass entry points ----------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines prototypes that expose pass constructors of the +// Vector dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_VECTOR_PASSES_H_ +#define MLIR_DIALECT_VECTOR_PASSES_H_ + +#include "mlir/Pass/Pass.h" + +namespace mlir { +// This pass lowers transfer ops to other vector ops. +std::unique_ptr createVectorTransferLoweringPass(); + +//===----------------------------------------------------------------------===// +// Registration +//===----------------------------------------------------------------------===// + +/// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#include "mlir/Dialect/Vector/Passes.h.inc" + +} // end namespace mlir + +#endif // MLIR_DIALECT_VECTOR_PASSES_H_ Index: mlir/include/mlir/Dialect/Vector/Passes.td =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Vector/Passes.td @@ -0,0 +1,44 @@ +//===-- Passes.td - Vector pass definition file ------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_VECTOR_PASSES +#define MLIR_DIALECT_VECTOR_PASSES + +include "mlir/Pass/PassBase.td" + +def VectorTransferLowering : FunctionPass<"vector-transfer-lowering"> { + let summary = "Lower transfer ops to other vector ops"; + let description = [{ + This pass lowers `vector.transfer_read` and `vector.transfer_write` ops to + simpler vector ops like `vector.load`, `vector.store` and + `vector.broadcast`. Note that certain features are not supported, currently + these are: + - Non-default layouts of memrefs. + - Masking. + - Vector element types (unless it coincides with the result type). + - Permutation of dimensions. + - Broadcasting is supported only for `vector.transfer_read`. + More cases will be supported in the future. + + Example: + + ```mlir + %r = vector.transfer_read %m[%i, %j], %c0 {masked = [false]} : memref<8x8xf32>, vector<4xf32> + ``` + + Output: + + ```mlir + %r = vector.load %m[%i, %j] : memref<8x8xf32>, vector<4xf32> + ``` + }]; + let constructor = "mlir::createVectorTransferLoweringPass()"; + let dependentDialects = ["memref::MemRefDialect"]; +} + +#endif // MLIR_DIALECT_VECTOR_PASSES Index: mlir/include/mlir/InitAllPasses.h =================================================================== --- mlir/include/mlir/InitAllPasses.h +++ mlir/include/mlir/InitAllPasses.h @@ -27,6 +27,7 @@ #include "mlir/Dialect/StandardOps/Transforms/Passes.h" #include "mlir/Dialect/Tensor/Transforms/Passes.h" #include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Dialect/Vector/Passes.h" #include "mlir/Transforms/Passes.h" #include @@ -61,6 +62,7 @@ registerStandardPasses(); tensor::registerTensorPasses(); tosa::registerTosaOptPasses(); + registerVectorPasses(); } } // namespace mlir Index: mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt =================================================================== --- mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -44,5 +44,6 @@ MLIRTransforms MLIRTransformUtils MLIRVector + MLIRVectorTransforms MLIRVectorToSCF ) Index: mlir/lib/Dialect/Vector/CMakeLists.txt =================================================================== --- mlir/lib/Dialect/Vector/CMakeLists.txt +++ mlir/lib/Dialect/Vector/CMakeLists.txt @@ -1,8 +1,6 @@ add_mlir_dialect_library(MLIRVector - VectorOps.cpp - VectorTransferOpTransforms.cpp - VectorTransforms.cpp - VectorUtils.cpp + IR/VectorOps.cpp + Utils/VectorUtils.cpp EDSC/Builders.cpp ADDITIONAL_HEADER_DIRS @@ -26,3 +24,5 @@ MLIRSideEffectInterfaces MLIRVectorInterfaces ) + +add_subdirectory(Transforms) Index: mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt @@ -0,0 +1,20 @@ +add_mlir_dialect_library(MLIRVectorTransforms + VectorTransferLowering.cpp + VectorTransferOpTransforms.cpp + VectorTransforms.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Vector + + DEPENDS + MLIRVectorPassIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRMemRef + MLIRPass + MLIRSCF + MLIRVector + MLIRTransforms + MLIRTransformUtils + ) Index: mlir/lib/Dialect/Vector/Transforms/PassDetail.h =================================================================== --- /dev/null +++ mlir/lib/Dialect/Vector/Transforms/PassDetail.h @@ -0,0 +1,25 @@ +//===- PassDetail.h - Vector Pass class details -----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_MLIR_LIB_DIALECT_VECTOR_TRANSFORMS_PASSDETAIL_H +#define LLVM_MLIR_LIB_DIALECT_VECTOR_TRANSFORMS_PASSDETAIL_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { + +namespace memref { +class MemRefDialect; +} // end namespace memref + +#define GEN_PASS_CLASSES +#include "mlir/Dialect/Vector/Passes.h.inc" + +} // end namespace mlir + +#endif // LLVM_MLIR_LIB_DIALECT_VECTOR_TRANSFORMS_PASSDETAIL_H Index: mlir/lib/Dialect/Vector/Transforms/VectorTransferLowering.cpp =================================================================== --- /dev/null +++ mlir/lib/Dialect/Vector/Transforms/VectorTransferLowering.cpp @@ -0,0 +1,39 @@ +//===- VectorTransferLowering.cpp - Progressive Lowering of transfer ops --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the vector transfer lowering pass that performs lowering +// of transfer ops using patterns defined in `VectorTransforms.cpp`. +// +//===----------------------------------------------------------------------===// + +#include + +#include "PassDetail.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/Passes.h" +#include "mlir/Dialect/Vector/VectorTransforms.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +namespace { +struct VectorTransferLoweringPass + : public VectorTransferLoweringBase { + void runOnFunction() override { + OwningRewritePatternList patterns; + mlir::vector::populateVectorTransferLoweringPatterns(patterns, + &getContext()); + (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); + } +}; +} // namespace + +std::unique_ptr mlir::createVectorTransferLoweringPass() { + return std::make_unique(); +} Index: mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp =================================================================== --- mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -2784,7 +2784,7 @@ // If broadcasting is required and the number of loaded elements is 1 then // we can create `memref.load` instead of `vector.load`. loadOp = rewriter.create(read.getLoc(), read.source(), - read.indices()); + read.indices()); } else { // Otherwise create `vector.load`. loadOp = rewriter.create(read.getLoc(), Index: mlir/test/Dialect/Vector/vector-transfer-lowering.mlir =================================================================== --- mlir/test/Dialect/Vector/vector-transfer-lowering.mlir +++ mlir/test/Dialect/Vector/vector-transfer-lowering.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -test-vector-transfer-lowering-patterns -split-input-file | FileCheck %s +// RUN: mlir-opt %s -vector-transfer-lowering -split-input-file | FileCheck %s // transfer_read/write are lowered to vector.load/store // CHECK-LABEL: func @transfer_to_load( Index: mlir/test/lib/Transforms/TestVectorTransforms.cpp =================================================================== --- mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -363,18 +363,6 @@ void runOnFunction() override { transferOpflowOpt(getFunction()); } }; -struct TestVectorTransferLoweringPatterns - : public PassWrapper { - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - void runOnFunction() override { - OwningRewritePatternList patterns; - populateVectorTransferLoweringPatterns(patterns, &getContext()); - (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); - } -}; - } // end anonymous namespace namespace mlir { @@ -417,10 +405,6 @@ PassRegistration transferOpOpt( "test-vector-transferop-opt", "Test optimization transformations for transfer ops"); - - PassRegistration transferOpLoweringPass( - "test-vector-transfer-lowering-patterns", - "Test conversion patterns to lower transfer ops to other vector ops"); } } // namespace test } // namespace mlir