diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -41,11 +41,15 @@ //===----------------------------------------------------------------------===// using LinalgLoops = SmallVector; -/// Populates patterns for vectorization of all ConvN-D ops. +/// [DEPRECATED] Populates patterns for vectorization of all ConvN-D ops. void populateConvVectorizationPatterns( MLIRContext *context, SmallVectorImpl &patterns, ArrayRef tileSizes); +/// Populates patterns for vectorizing convolution ops. +void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); + /// Populate patterns that convert `ElementwiseMappable` ops to linalg /// parallel loops. void populateElementwiseToLinalgConversionPatterns(RewritePatternSet &patterns); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -20,6 +20,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/Dialect/Vector/VectorTransforms.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" @@ -1436,3 +1437,160 @@ return success(); } + +//===----------------------------------------------------------------------===// +// Convolution vectorization patterns +//===----------------------------------------------------------------------===// +namespace { +/// Generate a vector implementation for: +/// ``` +/// Op def: ( n, w, c, kw, f ) +/// Iters: ({Par(), Par(), Par(), Red(), Red()}) +/// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}} +/// ``` +/// w and kw are unrolled. +/// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is > 1. +struct Conv1D_NWC_WCF_Generator : public StructuredGenerator { + Conv1D_NWC_WCF_Generator(PatternRewriter &rewriter, LinalgOp linalgOp, + int strideW, int dilationW) + : StructuredGenerator(rewriter, linalgOp), valid(false), + strideW(strideW), dilationW(dilationW) { + // Determine whether `linalgOp` can be generated with this generator + if (linalgOp.getNumInputs() != 2 || linalgOp.getNumOutputs() != 1) + return; + lhsShaped = linalgOp.inputs()[0]; + rhsShaped = linalgOp.inputs()[1]; + resShaped = linalgOp.outputs()[0]; + lhsShapedType = lhsShaped.getType().dyn_cast(); + rhsShapedType = rhsShaped.getType().dyn_cast(); + resShapedType = resShaped.getType().dyn_cast(); + if (!lhsShapedType || !rhsShapedType || !resShapedType) + return; + if (lhsShapedType.getRank() != 3 || rhsShapedType.getRank() != 3 || + resShapedType.getRank() != 3) + return; + // The op is now known to be valid. + valid = true; + } + + /// Generate a vector implementation for: + /// ``` + /// Op def: ( n, w, c, kw, f ) + /// Iters: ({Par(), Par(), Par(), Red(), Red()}) + /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}} + /// ``` + /// w and kw are unrolled. + /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is > 1. + LogicalResult conv() { + if (!valid) + return failure(); + + int nSize = lhsShapedType.getShape()[0]; + int wSize = resShapedType.getShape()[1]; + int cSize = lhsShapedType.getShape()[2]; + int kwSize = rhsShapedType.getShape()[0]; + int fSize = rhsShapedType.getShape()[2]; + + vector::TransferWriteOp write; + Value zero = rewriter.create(loc, 0); + + // Unroll along kw and read slices of lhs and rhs. + // Alternatively we could preload both 3-d slices and extract smaller slices + // iteratively without touching memory. But this will quickly spill. + for (int64_t kw = 0; kw < kwSize; ++kw) { + // Read rhs slice of size {1, c, f} @ [kw, 0, 0]. + Value kwVal = rewriter.create(loc, kw); + VectorType rhsType = + VectorType::get({1, cSize, fSize}, rhsShapedType.getElementType()); + Value rhs = rewriter.create( + loc, rhsType, rhsShaped, ValueRange{kwVal, zero, zero}); + + for (int64_t w = 0; w < wSize; ++w) { + // Read lhs slice of size {n, 1, c} @ [0, sw * w + dw * kw, 0]. + Value lhsStridedIdx = rewriter.create( + loc, strideW * w + dilationW * kw); + VectorType lhsType = + VectorType::get({nSize, 1, cSize}, lhsShapedType.getElementType()); + Value lhs = rewriter.create( + loc, lhsType, lhsShaped, ValueRange{zero, lhsStridedIdx, zero}); + + // Read res slice: {n, 1, f} @ [0, w, 0]. + Value wVal = rewriter.create(loc, w); + VectorType resType = + VectorType::get({nSize, 1, fSize}, resShapedType.getElementType()); + Value res = rewriter.create( + loc, resType, resShaped, ValueRange{zero, wVal, zero}); + + // Compute contraction: I{n, 1, c} * F{1, c, f} -> O{n, 1, f} + StringRef par = Par().strRef, red = Red().strRef; + AffineExpr n, one, f, c; + bindDims(ctx, n, one, f, c); + // clang-format off + res = rewriter.create( + loc, lhs, rhs, res, + /*indexingMaps=*/MapList{{n, one, c}, {one, c, f}, {n, one, f}}, + /*iteratorTypes=*/ArrayRef{par, par, par, red}); + // clang-format on + + // Write back res slice: {n, 1, f} @ [0, w, 0]. + write = rewriter.create( + loc, res, resShaped, ValueRange{zero, wVal, zero}); + if (write.getNumResults() == 1) + resShaped = write->getResult(0); + } + } + + if (write.getNumResults() > 0) + rewriter.replaceOp(op, write->getResult(0)); + else + rewriter.eraseOp(op); + + return success(); + } + + /// Entry point that transposes into the common form: + /// {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}} + LogicalResult generateConv() { + AffineExpr n, w, f, kw, c; + bindDims(ctx, n, w, f, kw, c); + + if (!iters({Par(), Par(), Par(), Red(), Red()})) + return failure(); + + // No transposition needed. + if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c}, + /*rhsIndex*/ {kw, c, f}, + /*resIndex*/ {n, w, f}})) + return conv(); + return failure(); + } + +private: + bool valid; + int strideW, dilationW; + Value lhsShaped, rhsShaped, resShaped; + ShapedType lhsShapedType, rhsShapedType, resShapedType; +}; +} // namespace + +struct VectorizeConvolution + : public OpInterfaceRewritePattern { + using OpInterfaceRewritePattern::OpInterfaceRewritePattern; + + LogicalResult matchAndRewrite(ConvolutionOpInterface convOp, + PatternRewriter &rewriter) const override { + LinalgOp linalgOp = cast(convOp.getOperation()); + // TODO: these are legitimately part of ConvolutionOpInterface. + auto strides = convOp->getAttrOfType("strides"); + auto dilations = convOp->getAttrOfType("dilations"); + auto stride = strides ? *strides.getValues().begin() : 1; + auto dilation = dilations ? *dilations.getValues().begin() : 1; + Conv1D_NWC_WCF_Generator e(rewriter, linalgOp, stride, dilation); + return e.generateConv(); + } +}; + +void mlir::linalg::populateConvolutionVectorizationPatterns( + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(patterns.getContext(), benefit); +} diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir @@ -0,0 +1,108 @@ +// RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-linalg-to-vector-patterns %s | FileCheck %s + +func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<1x3x8xf32>, %output: memref<4x2x8xf32>) { + linalg.conv_1d_nwc_wcf + {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} + ins(%input, %filter : memref<4x6x3xf32>, memref<1x3x8xf32>) + outs(%output : memref<4x2x8xf32>) + return +} + +// CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3, d2)> +// CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + +// CHECK: func @conv1d_nwc_4x2x8_memref +// CHECK-SAME: (%[[INPUT:.+]]: memref<4x6x3xf32>, %[[FILTER:.+]]: memref<1x3x8xf32>, %[[OUTPUT:.+]]: memref<4x2x8xf32>) + +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index +// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 + +/// w == 0, kw == 0 +// CHECK: %[[V_FILTER:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] +// CHECK: %[[V_INPUT0:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] +// CHECK: %[[V_OUTPUT_0:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] +// CHECK: %[[CONTRACT0:.+]] = vector.contract { +// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]} +// CHECK-SAME: %[[V_INPUT0]], %[[V_FILTER]], %[[V_OUTPUT_0]] +// CHECK-SAME: : vector<4x1x3xf32>, vector<1x3x8xf32> into vector<4x1x8xf32> +// CHECK: vector.transfer_write %[[CONTRACT0]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] + +/// w == 1, kw == 0 +// CHECK: %[[V_INPUT3:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C3]], %[[C0]]], %[[F0]] +// CHECK: %[[V_OUTPUT_1:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]], %[[F0]] +// CHECK: %[[CONTRACT1:.+]] = vector.contract { +// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]} +// CHECK-SAME: %[[V_INPUT3]], %[[V_FILTER]], %[[V_OUTPUT_1]] +// CHECK-SAME: : vector<4x1x3xf32>, vector<1x3x8xf32> into vector<4x1x8xf32> +// CHECK: vector.transfer_write %[[CONTRACT1]], %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]] + +// ----- + +func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<2x3x8xf32>, %output: memref<4x2x8xf32>) { + linalg.conv_1d_nwc_wcf + {dilations = dense<2> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} + ins(%input, %filter : memref<4x6x3xf32>, memref<2x3x8xf32>) + outs(%output : memref<4x2x8xf32>) + return +} + +// CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3, d2)> +// CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + +// CHECK: func @conv1d_nwc_4x2x8_memref +// CHECK-SAME: (%[[INPUT:.+]]: memref<4x6x3xf32>, %[[FILTER:.+]]: memref<2x3x8xf32>, %[[OUTPUT:.+]]: memref<4x2x8xf32>) + +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index +// CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index +// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 + +/// w == 0, kw == 0 +// CHECK: %[[V_FILTER_A:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] +// CHECK: %[[V_INPUT0_A:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] +// CHECK: %[[V_OUTPUT_0_A:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] +// CHECK: %[[CONTRACT0_A:.+]] = vector.contract { +// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]} +// CHECK-SAME: %[[V_INPUT0_A]], %[[V_FILTER_A]], %[[V_OUTPUT_0_A]] +// CHECK-SAME: : vector<4x1x3xf32>, vector<1x3x8xf32> into vector<4x1x8xf32> +// CHECK: vector.transfer_write %[[CONTRACT0_A]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] + +/// w == 0, kw == 1 +// CHECK: %[[V_INPUT3_A:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C3]], %[[C0]]], %[[F0]] +// CHECK: %[[V_OUTPUT_1_A:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]], %[[F0]] +// CHECK: %[[CONTRACT1_A:.+]] = vector.contract { +// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]} +// CHECK-SAME: %[[V_INPUT3_A]], %[[V_FILTER_A]], %[[V_OUTPUT_1_A]] +// CHECK-SAME: : vector<4x1x3xf32>, vector<1x3x8xf32> into vector<4x1x8xf32> +// CHECK: vector.transfer_write %[[CONTRACT1_A]], %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]] + +/// w == 1, kw == 0 +// CHECK: %[[V_FILTER_B:.+]] = vector.transfer_read %[[FILTER]][%[[C1]], %[[C0]], %[[C0]]], %[[F0]] +// CHECK: %[[V_INPUT0_B:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C2]], %[[C0]]], %[[F0]] +// CHECK: %[[V_OUTPUT_0_B:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] +// CHECK: %[[CONTRACT0_B:.+]] = vector.contract { +// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]} +// CHECK-SAME: %[[V_INPUT0_B]], %[[V_FILTER_B]], %[[V_OUTPUT_0_B]] +// CHECK-SAME: : vector<4x1x3xf32>, vector<1x3x8xf32> into vector<4x1x8xf32> +// CHECK: vector.transfer_write %[[CONTRACT0_B]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] + +/// w == 1, kw == 1 +// CHECK: %[[V_INPUT3_B:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C5]], %[[C0]]], %[[F0]] +// CHECK: %[[V_OUTPUT_1_B:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]], %[[F0]] +// CHECK: %[[CONTRACT1_B:.+]] = vector.contract { +// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]} +// CHECK-SAME: %[[V_INPUT3_B]], %[[V_FILTER_B]], %[[V_OUTPUT_1_B]] +// CHECK-SAME: : vector<4x1x3xf32>, vector<1x3x8xf32> into vector<4x1x8xf32> +// CHECK: vector.transfer_write %[[CONTRACT1_B]], %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]] diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -553,6 +553,7 @@ LinalgTransformationFilter() .addOpFilter()); populatePadTensorOpVectorizationPatterns(patterns); + populateConvolutionVectorizationPatterns(patterns); (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); }