diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h @@ -19,6 +19,9 @@ /// Creates an instance of the `vector` dialect bufferization pass. std::unique_ptr createVectorBufferizePass(); +/// Creates an instance of the `vector.mask` lowering pass. +std::unique_ptr createLowerVectorMaskPass(); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// @@ -27,7 +30,6 @@ #define GEN_PASS_REGISTRATION #include "mlir/Dialect/Vector/Transforms/Passes.h.inc" } // namespace vector - } // namespace mlir #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H_ diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td @@ -16,4 +16,9 @@ let constructor = "mlir::vector::createVectorBufferizePass()"; } +def LowerVectorMaskPass : Pass<"lower-vector-mask", "func::FuncOp"> { + let summary = "Lower 'vector.mask' operations"; + let constructor = "mlir::vector::createLowerVectorMaskPass()"; +} + #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRVectorTransforms BufferizableOpInterfaceImpl.cpp Bufferize.cpp + LowerVectorMask.cpp VectorDistribute.cpp VectorDropLeadUnitDim.cpp VectorInsertExtractStridedSliceRewritePatterns.cpp diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp @@ -0,0 +1,147 @@ +//===- LowerVectorMask.cpp - Lower 'vector.mask' operation ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements target-independent rewrites and utilitites to lower the +// 'vector.mask' operation. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/Passes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define DEBUG_TYPE "lower-vector-mask" + +namespace mlir { +namespace vector { +#define GEN_PASS_DEF_LOWERVECTORMASKPASS +#include "mlir/Dialect/Vector/Transforms/Passes.h.inc" +} // namespace vector +} // namespace mlir + +using namespace mlir; +using namespace mlir::vector; + +namespace { + +/// The `MaskOpRewritePattern` implements a pattern that follows a two-fold +/// matching: +/// 1. It matches a `vector.mask` operation. +/// 2. It invokes `matchAndRewriteMaskableOp` on `MaskableOpInterface` nested +/// in the matched `vector.mask` operation. +/// +/// It is required that the replacement op in the pattern replaces the +/// `vector.mask` operation and not the nested `MaskableOpInterface`. This +/// approach allows having patterns that "stop" at every `vector.mask` operation +/// and actually match the traits of its the nested `MaskableOpInterface`. +template +struct MaskOpRewritePattern : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + +private: + LogicalResult + matchAndRewrite(MaskOp maskOp, + PatternRewriter &rewriter) const override final { + MaskableOpInterface maskableOp = maskOp.getMaskableOp(); + SourceOp sourceOp = dyn_cast(maskableOp.getOperation()); + if (!sourceOp) + return failure(); + + return matchAndRewriteMaskableOp(sourceOp, maskOp, rewriter); + } + +protected: + virtual LogicalResult + matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp, + PatternRewriter &rewriter) const = 0; +}; + +/// Lowers a masked `vector.transfer_read` operation. +struct MaskedTransferReadOpPattern + : public MaskOpRewritePattern { +public: + using MaskOpRewritePattern::MaskOpRewritePattern; + + LogicalResult + matchAndRewriteMaskableOp(TransferReadOp readOp, MaskingOpInterface maskingOp, + PatternRewriter &rewriter) const override { + // TODO: The 'vector.mask' passthru is a vector and 'vector.transfer_read' + // expects a scalar. We could only lower one to the other for cases where + // the passthru is a broadcast of a scalar. + if (maskingOp.hasPassthru()) + return rewriter.notifyMatchFailure( + maskingOp, "Can't lower passthru to vector.transfer_read"); + + // Replace the `vector.mask` operation. + rewriter.replaceOpWithNewOp( + maskingOp.getOperation(), readOp.getVectorType(), readOp.getSource(), + readOp.getIndices(), readOp.getPermutationMap(), readOp.getPadding(), + maskingOp.getMask(), readOp.getInBounds().value_or(ArrayAttr())); + return success(); + } +}; + +/// Lowers a masked `vector.transfer_write` operation. +struct MaskedTransferWriteOpPattern + : public MaskOpRewritePattern { +public: + using MaskOpRewritePattern::MaskOpRewritePattern; + + LogicalResult + matchAndRewriteMaskableOp(TransferWriteOp writeOp, + MaskingOpInterface maskingOp, + PatternRewriter &rewriter) const override { + Type resultType = + writeOp.getResult() ? writeOp.getResult().getType() : Type(); + + // Replace the `vector.mask` operation. + rewriter.replaceOpWithNewOp( + maskingOp.getOperation(), resultType, writeOp.getVector(), + writeOp.getSource(), writeOp.getIndices(), writeOp.getPermutationMap(), + maskingOp.getMask(), writeOp.getInBounds().value_or(ArrayAttr())); + return success(); + } +}; + +/// Populates instances of `MaskOpRewritePattern` to lower masked operations +/// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and +/// not its nested `MaskableOpInterface`. +void populateVectorMaskLoweringPatternsForSideEffectingOps( + RewritePatternSet &patterns) { + patterns.add( + patterns.getContext()); +} + +struct LowerVectorMaskPass + : public vector::impl::LowerVectorMaskPassBase { + using Base::Base; + + void runOnOperation() override { + Operation *op = getOperation(); + MLIRContext *context = op->getContext(); + + RewritePatternSet loweringPatterns(context); + populateVectorMaskLoweringPatternsForSideEffectingOps(loweringPatterns); + + if (failed(applyPatternsAndFoldGreedily(op->getRegions(), + std::move(loweringPatterns)))) + signalPassFailure(); + } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } +}; + +} // namespace + +std::unique_ptr mlir::vector::createLowerVectorMaskPass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/Vector/lower-vector-mask.mlir b/mlir/test/Dialect/Vector/lower-vector-mask.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/lower-vector-mask.mlir @@ -0,0 +1,50 @@ +// RUN: mlir-opt -lower-vector-mask -split-input-file %s | FileCheck %s + +func.func @vector_transfer_read(%t0: tensor, %idx: index, %m0: vector<16xi1>) -> vector<16xf32> { + %ft0 = arith.constant 0.0 : f32 + %0 = vector.mask %m0 { vector.transfer_read %t0[%idx], %ft0 : tensor, vector<16xf32> } : vector<16xi1> -> vector<16xf32> + return %0 : vector<16xf32> +} + +// CHECK-LABEL: func.func @vector_transfer_read( +// CHECK-SAME: %[[VAL_0:.*]]: tensor, +// CHECK-SAME: %[[VAL_1:.*]]: index, +// CHECK-SAME: %[[VAL_2:.*]]: vector<16xi1>) -> vector<16xf32> { +// CHECK-NOT: vector.mask +// CHECK: %[[VAL_4:.*]] = vector.transfer_read {{.*}}, %[[VAL_2]] : tensor, vector<16xf32> +// CHECK: return %[[VAL_4]] : vector<16xf32> +// CHECK: } + +// ----- + +func.func @vector_transfer_write_on_memref(%val: vector<16xf32>, %t0: memref, %idx: index, %m0: vector<16xi1>) { + vector.mask %m0 { vector.transfer_write %val, %t0[%idx] : vector<16xf32>, memref } : vector<16xi1> + return +} + +// CHECK-LABEL: func.func @vector_transfer_write_on_memref( +// CHECK-SAME: %[[VAL_0:.*]]: vector<16xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: memref, +// CHECK-SAME: %[[VAL_2:.*]]: index, +// CHECK-SAME: %[[VAL_3:.*]]: vector<16xi1>) { + //CHECK-NOT: vector.mask +// CHECK: vector.transfer_write %[[VAL_0]], {{.*}}, %[[VAL_3]] : vector<16xf32>, memref +// CHECK: return +// CHECK: } + +// ----- + +func.func @vector_transfer_write_on_tensor(%val: vector<16xf32>, %t0: tensor, %idx: index, %m0: vector<16xi1>) -> tensor { + %res = vector.mask %m0 { vector.transfer_write %val, %t0[%idx] : vector<16xf32>, tensor } : vector<16xi1> -> tensor + return %res : tensor +} + +// CHECK-LABEL: func.func @vector_transfer_write_on_tensor( +// CHECK-SAME: %[[VAL_0:.*]]: vector<16xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor, +// CHECK-SAME: %[[VAL_2:.*]]: index, +// CHECK-SAME: %[[VAL_3:.*]]: vector<16xi1>) -> tensor { +// CHECK: %[[VAL_4:.*]] = vector.transfer_write %[[VAL_0]], {{.*}}, %[[VAL_3]] : vector<16xf32>, tensor +// CHECK: return %[[VAL_4]] : tensor +// CHECK: } +