diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -41,11 +41,15 @@ //===----------------------------------------------------------------------===// using LinalgLoops = SmallVector; -/// Populates patterns for vectorization of all ConvN-D ops. +/// [DEPRECATED] Populates patterns for vectorization of all ConvN-D ops. void populateConvVectorizationPatterns( MLIRContext *context, SmallVectorImpl &patterns, ArrayRef tileSizes); +/// Populates patterns for vectorizing convolution ops. +void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); + /// Populate patterns that convert `ElementwiseMappable` ops to linalg /// parallel loops. void populateElementwiseToLinalgConversionPatterns(RewritePatternSet &patterns); diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h --- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h @@ -25,7 +25,7 @@ namespace mlir { -class PatternRewriter; +class OpBuilder; /// Tests whether the given maps describe a row major matmul. The test is /// permutation-invariant. Note that this only checks the affine maps from an @@ -161,8 +161,8 @@ Win() : IteratorType(getWindowIteratorTypeName()) {} }; - StructuredGenerator(PatternRewriter &rewriter, StructuredOpInterface op) - : rewriter(rewriter), ctx(op.getContext()), loc(op.getLoc()), + StructuredGenerator(OpBuilder &builder, StructuredOpInterface op) + : builder(builder), ctx(op.getContext()), loc(op.getLoc()), iterators(op.iterator_types()), maps(op.getIndexingMaps()), op(op) {} bool iters(ArrayRef its) { @@ -181,7 +181,7 @@ } protected: - PatternRewriter &rewriter; + OpBuilder &builder; MLIRContext *ctx; Location loc; ArrayAttr iterators; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -20,6 +20,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/Dialect/Vector/VectorTransforms.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" @@ -44,6 +45,12 @@ #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") #define LDBG(X) LLVM_DEBUG(DBGS() << X) +// Forward declarations. +static LogicalResult vectorizeContraction(OpBuilder &b, LinalgOp linalgOp, + SmallVectorImpl &newResults); +static FailureOr +vectorizeConvolution(OpBuilder &b, ConvolutionOpInterface convOp); + /// Return the unique instance of OpType in `block` if it is indeed unique. /// Return null if none or more than 1 instances exist. template @@ -147,7 +154,7 @@ auto linalgOp = cast(outputOperand->getOwner()); unsigned outputPos = outputOperand->getOperandNumber() - linalgOp.getNumInputs(); - // Only single combiner operatios are supported for now. + // Only single combiner operations are supported for now. SmallVector combinerOps; if (!matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) || combinerOps.size() != 1) @@ -575,6 +582,11 @@ return success(); } +/// Helper function to vectorize a `linalgOp` with contraction semantics in a +/// generic fashion. +/// This helper is needed atm because the truly generic implementation requires +/// good vector.multi_reduce folding patterns that are currently NYI. +// TODO: drop reliance on a specific pattern. static LogicalResult vectorizeContraction(OpBuilder &b, LinalgOp linalgOp, SmallVectorImpl &newResults) { assert(isaContractionOpInterface(linalgOp) && @@ -664,6 +676,11 @@ return success(); if (isaContractionOpInterface(linalgOp)) return success(); + // TODO: isaConvolutionOpInterface that can also infer from generic features. + // But we will still need stride/dilation attributes that will be annoying to + // reverse-engineer... + if (isa(op)) + return success(); // TODO: the common vector shape is equal to the static loop sizes only when // all indexing maps are projected permutations. For convs and stencils the // logic will need to evolve. @@ -688,6 +705,18 @@ if (isaContractionOpInterface(linalgOp)) return vectorizeContraction(b, linalgOp, newResults); + // TODO: isaConvolutionOpInterface that can also infer from generic features. + // But we will still need stride/dilation attributes that will be annoying to + // reverse-engineer... + if (auto convOp = dyn_cast(op)) { + FailureOr resultOrFail = vectorizeConvolution(b, convOp); + if (failed(resultOrFail)) + return failure(); + Operation *newOp = *resultOrFail; + llvm::append_range(newResults, newOp->getResults()); + return success(); + } + LDBG("" << "Vectorize linalg op as a generic by broadcasting to " "maximal common shape: " @@ -1421,3 +1450,188 @@ return success(); } + +//===----------------------------------------------------------------------===// +// Convolution vectorization patterns +//===----------------------------------------------------------------------===// +namespace { +/// Generate a vector implementation for: +/// ``` +/// Op def: ( n, w, c, kw, f ) +/// Iters: ({Par(), Par(), Par(), Red(), Red()}) +/// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}} +/// ``` +/// w and kw are unrolled. +/// TODO: do not unroll w (resp. kw) when the strideW ( resp. dilationW) is > 1. +struct Conv1D_NWC_WCF_Generator : public StructuredGenerator { + Conv1D_NWC_WCF_Generator(OpBuilder &builder, LinalgOp linalgOp, int strideW, + int dilationW) + : StructuredGenerator(builder, linalgOp), valid(false), + strideW(strideW), dilationW(dilationW) { + // Determine whether `linalgOp` can be generated with this generator + if (linalgOp.getNumInputs() != 2 || linalgOp.getNumOutputs() != 1) + return; + lhsShaped = linalgOp.inputs()[0]; + rhsShaped = linalgOp.inputs()[1]; + resShaped = linalgOp.outputs()[0]; + lhsShapedType = lhsShaped.getType().dyn_cast(); + rhsShapedType = rhsShaped.getType().dyn_cast(); + resShapedType = resShaped.getType().dyn_cast(); + if (!lhsShapedType || !rhsShapedType || !resShapedType) + return; + if (lhsShapedType.getRank() != 3 || rhsShapedType.getRank() != 3 || + resShapedType.getRank() != 3) + return; + + // Check for reduction `add` preceded by `mul`. + Operation *reduceOp = matchLinalgReduction(linalgOp.getOutputOperand(0)); + if (!reduceOp) + return; + llvm::Optional maybeKind; + maybeKind = getKindForOp(reduceOp); + if (!maybeKind || *maybeKind != vector::CombiningKind::ADD) + return; + maybeKind = getKindForOp(&(linalgOp->getRegion(0).front().front())); + if (!maybeKind || *maybeKind != vector::CombiningKind::MUL) + return; + + // The op is now known to be valid. + valid = true; + } + + /// Generate a vector implementation for: + /// ``` + /// Op def: ( n, w, c, kw, f ) + /// Iters: ({Par(), Par(), Par(), Red(), Red()}) + /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}} + /// ``` + /// w and kw are unrolled. + /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is > 1. + FailureOr conv() { + if (!valid) + return failure(); + + int nSize = lhsShapedType.getShape()[0]; + int wSize = resShapedType.getShape()[1]; + int cSize = lhsShapedType.getShape()[2]; + int kwSize = rhsShapedType.getShape()[0]; + int fSize = rhsShapedType.getShape()[2]; + + vector::TransferWriteOp write; + Value zero = builder.create(loc, 0); + + // Unroll along kw and read slices of lhs and rhs. + // Alternatively we could preload both 3-d slices and extract smaller slices + // iteratively without touching memory. But this will quickly spill. + for (int64_t kw = 0; kw < kwSize; ++kw) { + // Read rhs slice of size {1, c, f} @ [kw, 0, 0]. + Value kwVal = builder.create(loc, kw); + VectorType rhsType = + VectorType::get({1, cSize, fSize}, rhsShapedType.getElementType()); + Value rhs = builder.create( + loc, rhsType, rhsShaped, ValueRange{kwVal, zero, zero}); + + for (int64_t w = 0; w < wSize; ++w) { + // Read lhs slice of size {n, 1, c} @ [0, sw * w + dw * kw, 0]. + Value lhsStridedIdx = builder.create( + loc, strideW * w + dilationW * kw); + VectorType lhsType = + VectorType::get({nSize, 1, cSize}, lhsShapedType.getElementType()); + Value lhs = builder.create( + loc, lhsType, lhsShaped, ValueRange{zero, lhsStridedIdx, zero}); + + // Read res slice: {n, 1, f} @ [0, w, 0]. + Value wVal = builder.create(loc, w); + VectorType resType = + VectorType::get({nSize, 1, fSize}, resShapedType.getElementType()); + // When operating on tensors, reading from the updated value is required + // for vector.transfer_read/write hoisting to function as expected. + Value res = builder.create( + loc, resType, resShaped, ValueRange{zero, wVal, zero}); + + // Compute contraction: I{n, 1, c} * F{1, c, f} -> O{n, 1, f} + StringRef par = Par().strRef, red = Red().strRef; + AffineExpr n, one, f, c; + bindDims(ctx, n, one, f, c); + // clang-format off + res = builder.create( + loc, lhs, rhs, res, + /*indexingMaps=*/MapList{{n, one, c}, {one, c, f}, {n, one, f}}, + /*iteratorTypes=*/ArrayRef{par, par, par, red}); + // clang-format on + + // Write back res slice: {n, 1, f} @ [0, w, 0]. + write = builder.create( + loc, res, resShaped, ValueRange{zero, wVal, zero}); + if (write.getNumResults() == 1) + resShaped = write->getResult(0); + } + } + + return write.getOperation(); + } + + /// Entry point that transposes into the common form: + /// {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}} + FailureOr generateConv() { + AffineExpr n, w, f, kw, c; + bindDims(ctx, n, w, f, kw, c); + + if (!iters({Par(), Par(), Par(), Red(), Red()})) + return failure(); + + // No transposition needed. + if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c}, + /*rhsIndex*/ {kw, c, f}, + /*resIndex*/ {n, w, f}})) + return conv(); + return failure(); + } + +private: + bool valid; + int strideW, dilationW; + Value lhsShaped, rhsShaped, resShaped; + ShapedType lhsShapedType, rhsShapedType, resShapedType; +}; +} // namespace + +/// Helper function to vectorize a `linalgOp` with convolution semantics. +// TODO: extend the generic vectorization to support windows and drop this. +static FailureOr +vectorizeConvolution(OpBuilder &b, ConvolutionOpInterface convOp) { + // TODO: these are legitimately part of ConvolutionOpInterface. + auto strides = convOp->getAttrOfType("strides"); + auto dilations = convOp->getAttrOfType("dilations"); + auto stride = strides ? *strides.getValues().begin() : 1; + auto dilation = dilations ? *dilations.getValues().begin() : 1; + LinalgOp linalgOp = cast(convOp.getOperation()); + Conv1D_NWC_WCF_Generator e(b, linalgOp, stride, dilation); + return e.generateConv(); +} + +struct VectorizeConvolution + : public OpInterfaceRewritePattern { + using OpInterfaceRewritePattern::OpInterfaceRewritePattern; + + LogicalResult matchAndRewrite(ConvolutionOpInterface convOp, + PatternRewriter &rewriter) const override { + FailureOr resultOrFail = + vectorizeConvolution(rewriter, convOp); + if (failed(resultOrFail)) + return failure(); + Operation *newOp = *resultOrFail; + if (newOp->getNumResults() == 0) { + rewriter.eraseOp(convOp.getOperation()); + return success(); + } + assert(newOp->getNumResults() == 1 && "expected single result"); + rewriter.replaceOp(convOp.getOperation(), newOp->getResult(0)); + return success(); + } +}; + +void mlir::linalg::populateConvolutionVectorizationPatterns( + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(patterns.getContext(), benefit); +} diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -1257,36 +1257,34 @@ struct UnrolledOuterProductGenerator : public StructuredGenerator { - UnrolledOuterProductGenerator(PatternRewriter &rewriter, - vector::ContractionOp op) - : StructuredGenerator(rewriter, op), + UnrolledOuterProductGenerator(OpBuilder &builder, vector::ContractionOp op) + : StructuredGenerator(builder, op), kind(op.kind()), lhs(op.lhs()), rhs(op.rhs()), res(op.acc()), lhsType(op.getLhsType()) {} Value t(Value v) { static constexpr std::array perm = {1, 0}; - return rewriter.create(loc, v, perm); + return builder.create(loc, v, perm); } - LogicalResult outer_prod(Value lhs, Value rhs, Value res, int reductionSize) { + Value outer_prod(Value lhs, Value rhs, Value res, int reductionSize) { assert(reductionSize > 0); for (int64_t k = 0; k < reductionSize; ++k) { - Value a = rewriter.create(loc, lhs, k); - Value b = rewriter.create(loc, rhs, k); - res = rewriter.create(loc, res.getType(), a, b, - res, kind); + Value a = builder.create(loc, lhs, k); + Value b = builder.create(loc, rhs, k); + res = builder.create(loc, res.getType(), a, b, + res, kind); } - rewriter.replaceOp(op, res); - return success(); + return res; } /// Two outer parallel, one inner reduction (matmat flavor). - LogicalResult matmat() { + FailureOr matmat() { if (!iters({Par(), Par(), Red()})) return failure(); // Set up the parallel/reduction structure in the right form. AffineExpr m, n, k; - bindDims(rewriter.getContext(), m, n, k); + bindDims(builder.getContext(), m, n, k); // Classical row-major matmul: Just permute the lhs. if (layout({{m, k}, {k, n}, {m, n}})) return outer_prod(t(lhs), rhs, res, lhsType.getDimSize(1)); @@ -1318,11 +1316,11 @@ } /// One outer parallel, one inner reduction (matvec flavor) - LogicalResult matvec() { + FailureOr matvec() { if (!iters({Par(), Red()})) return failure(); AffineExpr m, k; - bindDims(rewriter.getContext(), m, k); + bindDims(builder.getContext(), m, k); // Case mat-vec: transpose. if (layout({{m, k}, {k}, {m}})) @@ -1342,11 +1340,11 @@ // // One outer reduction, one inner parallel (tmatvec flavor) // - LogicalResult tmatvec() { + FailureOr tmatvec() { if (!iters({Red(), Par()})) return failure(); AffineExpr k, m; - bindDims(rewriter.getContext(), k, m); + bindDims(builder.getContext(), k, m); // Case mat-vec: transpose. if (layout({{m, k}, {k}, {m}})) @@ -1399,12 +1397,21 @@ return failure(); UnrolledOuterProductGenerator e(rewriter, op); - if (succeeded(e.matmat())) + FailureOr matmatRes = e.matmat(); + if (succeeded(matmatRes)) { + rewriter.replaceOp(op, *matmatRes); return success(); - if (succeeded(e.matvec())) + } + FailureOr matvecRes = e.matvec(); + if (succeeded(matvecRes)) { + rewriter.replaceOp(op, *matvecRes); return success(); - if (succeeded(e.tmatvec())) + } + FailureOr tmatvecRes = e.tmatvec(); + if (succeeded(tmatvecRes)) { + rewriter.replaceOp(op, *tmatvecRes); return success(); + } return failure(); } diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir @@ -0,0 +1,108 @@ +// RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-linalg-to-vector-patterns %s | FileCheck %s + +func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<1x3x8xf32>, %output: memref<4x2x8xf32>) { + linalg.conv_1d_nwc_wcf + {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} + ins(%input, %filter : memref<4x6x3xf32>, memref<1x3x8xf32>) + outs(%output : memref<4x2x8xf32>) + return +} + +// CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3, d2)> +// CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + +// CHECK: func @conv1d_nwc_4x2x8_memref +// CHECK-SAME: (%[[INPUT:.+]]: memref<4x6x3xf32>, %[[FILTER:.+]]: memref<1x3x8xf32>, %[[OUTPUT:.+]]: memref<4x2x8xf32>) + +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index +// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 + +/// w == 0, kw == 0 +// CHECK: %[[V_FILTER:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] +// CHECK: %[[V_INPUT0:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] +// CHECK: %[[V_OUTPUT_0:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] +// CHECK: %[[CONTRACT0:.+]] = vector.contract { +// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]} +// CHECK-SAME: %[[V_INPUT0]], %[[V_FILTER]], %[[V_OUTPUT_0]] +// CHECK-SAME: : vector<4x1x3xf32>, vector<1x3x8xf32> into vector<4x1x8xf32> +// CHECK: vector.transfer_write %[[CONTRACT0]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] + +/// w == 1, kw == 0 +// CHECK: %[[V_INPUT3:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C3]], %[[C0]]], %[[F0]] +// CHECK: %[[V_OUTPUT_1:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]], %[[F0]] +// CHECK: %[[CONTRACT1:.+]] = vector.contract { +// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]} +// CHECK-SAME: %[[V_INPUT3]], %[[V_FILTER]], %[[V_OUTPUT_1]] +// CHECK-SAME: : vector<4x1x3xf32>, vector<1x3x8xf32> into vector<4x1x8xf32> +// CHECK: vector.transfer_write %[[CONTRACT1]], %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]] + +// ----- + +func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<2x3x8xf32>, %output: memref<4x2x8xf32>) { + linalg.conv_1d_nwc_wcf + {dilations = dense<2> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} + ins(%input, %filter : memref<4x6x3xf32>, memref<2x3x8xf32>) + outs(%output : memref<4x2x8xf32>) + return +} + +// CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3, d2)> +// CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + +// CHECK: func @conv1d_nwc_4x2x8_memref +// CHECK-SAME: (%[[INPUT:.+]]: memref<4x6x3xf32>, %[[FILTER:.+]]: memref<2x3x8xf32>, %[[OUTPUT:.+]]: memref<4x2x8xf32>) + +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index +// CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index +// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 + +/// w == 0, kw == 0 +// CHECK: %[[V_FILTER_A:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] +// CHECK: %[[V_INPUT0_A:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] +// CHECK: %[[V_OUTPUT_0_A:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] +// CHECK: %[[CONTRACT0_A:.+]] = vector.contract { +// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]} +// CHECK-SAME: %[[V_INPUT0_A]], %[[V_FILTER_A]], %[[V_OUTPUT_0_A]] +// CHECK-SAME: : vector<4x1x3xf32>, vector<1x3x8xf32> into vector<4x1x8xf32> +// CHECK: vector.transfer_write %[[CONTRACT0_A]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] + +/// w == 0, kw == 1 +// CHECK: %[[V_INPUT3_A:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C3]], %[[C0]]], %[[F0]] +// CHECK: %[[V_OUTPUT_1_A:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]], %[[F0]] +// CHECK: %[[CONTRACT1_A:.+]] = vector.contract { +// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]} +// CHECK-SAME: %[[V_INPUT3_A]], %[[V_FILTER_A]], %[[V_OUTPUT_1_A]] +// CHECK-SAME: : vector<4x1x3xf32>, vector<1x3x8xf32> into vector<4x1x8xf32> +// CHECK: vector.transfer_write %[[CONTRACT1_A]], %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]] + +/// w == 1, kw == 0 +// CHECK: %[[V_FILTER_B:.+]] = vector.transfer_read %[[FILTER]][%[[C1]], %[[C0]], %[[C0]]], %[[F0]] +// CHECK: %[[V_INPUT0_B:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C2]], %[[C0]]], %[[F0]] +// CHECK: %[[V_OUTPUT_0_B:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] +// CHECK: %[[CONTRACT0_B:.+]] = vector.contract { +// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]} +// CHECK-SAME: %[[V_INPUT0_B]], %[[V_FILTER_B]], %[[V_OUTPUT_0_B]] +// CHECK-SAME: : vector<4x1x3xf32>, vector<1x3x8xf32> into vector<4x1x8xf32> +// CHECK: vector.transfer_write %[[CONTRACT0_B]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] + +/// w == 1, kw == 1 +// CHECK: %[[V_INPUT3_B:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C5]], %[[C0]]], %[[F0]] +// CHECK: %[[V_OUTPUT_1_B:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]], %[[F0]] +// CHECK: %[[CONTRACT1_B:.+]] = vector.contract { +// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]} +// CHECK-SAME: %[[V_INPUT3_B]], %[[V_FILTER_B]], %[[V_OUTPUT_1_B]] +// CHECK-SAME: : vector<4x1x3xf32>, vector<1x3x8xf32> into vector<4x1x8xf32> +// CHECK: vector.transfer_write %[[CONTRACT1_B]], %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]] diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -553,6 +553,7 @@ LinalgTransformationFilter() .addOpFilter()); populatePadTensorOpVectorizationPatterns(patterns); + populateConvolutionVectorizationPatterns(patterns); (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); }