diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -41,11 +41,15 @@ //===----------------------------------------------------------------------===// using LinalgLoops = SmallVector; -/// Populates patterns for vectorization of all ConvN-D ops. +/// [DEPRECATED] Populates patterns for vectorization of all ConvN-D ops. void populateConvVectorizationPatterns( MLIRContext *context, SmallVectorImpl &patterns, ArrayRef tileSizes); +/// Populates patterns for vectorizing convolution ops. +void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); + /// Populate patterns that convert `ElementwiseMappable` ops to linalg /// parallel loops. void populateElementwiseToLinalgConversionPatterns(RewritePatternSet &patterns); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -20,19 +20,25 @@ #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" + +#include #include +#define DEBUG_TYPE "linalg-vectorization" + using namespace mlir; using namespace mlir::linalg; @@ -1379,3 +1385,246 @@ return success(); } + +//===----------------------------------------------------------------------===// +// Convolution vectorization patterns +//===----------------------------------------------------------------------===// + +struct VectorizeConvolution + : public OpInterfaceRewritePattern { + using OpInterfaceRewritePattern::OpInterfaceRewritePattern; + + LogicalResult matchAndRewrite(ConvolutionOpInterface convOp, + PatternRewriter &rewriter) const override { + LLVM_DEBUG(dbgs() << "try to vectorize conv op " << convOp << "\n"); + + // Don't support depthwise convolution for now. + if (convOp.getFilterParallelIteratorIndex().hasValue()) { + LLVM_DEBUG(dbgs() << "failed: depthwise conv op unsupported right now\n"); + return failure(); + } + + Value input = convOp.image(); + Value filter = convOp.filter(); + Value output = convOp.output(); + + auto inputType = input.getType().cast(); + auto filterType = filter.getType().cast(); + auto outputType = output.getType().cast(); + + // Make sure we have static shapes. Other patterns can tile and pack to + // expose static subcases if the original workload is dynamic. + if (!inputType.hasStaticShape() || !filterType.hasStaticShape() || + !outputType.hasStaticShape()) { + LLVM_DEBUG(dbgs() << "failed: dynamic shape unsupported right now\n"); + return failure(); + } + + Optional inputChannelDim = convOp.getInputRecutionIteratorIndex(); + SmallVector outputWindowDims = + convOp.getOutputWindowIteratorIndices(); + SmallVector filterWindowDims = + convOp.getFilterWindowIteratorIndices(); + + LinalgOp linalgOp = cast(convOp.getOperation()); + SmallVector loopUpperBounds = linalgOp.computeStaticLoopSizes(); + + // The overall idea is to vectorize convolution ops of a specific form: the + // filter window dimensions should all be of size one. We will unroll all + // output window dimensions in this pattern too. Together we can make sure + // also having size-one input window dimensions. That reduces the problem to + // a matmul-like contraction. + + // Require filter window dimensions to be all size one. Other patterns can + // tile and materialize loops for filter window dimensions. + for (unsigned dim : filterWindowDims) { + if (loopUpperBounds[dim] != 1) { + LLVM_DEBUG(dbgs() << "failed: iterator#" << dim + << " indexing into filter window must be 1\n"); + return failure(); + } + } + + // Don't support dilation for now. + if (llvm::any_of(convOp.dilations().getValues(), + [](uint64_t dilation) { return dilation != 1; })) { + LLVM_DEBUG(dbgs() << "failed: reqires dilation to be all 1 for now\n"); + return failure(); + } + + auto strides = llvm::to_vector<4>(convOp.strides().getValues()); + + // Print out important configurations for debugging. + LLVM_DEBUG({ + if (inputChannelDim) + dbgs() << "input channel: iterator#" << *inputChannelDim << " size " + << loopUpperBounds[*inputChannelDim] << "\n"; + + dbgs() << "filter window:\n"; + for (unsigned dim : filterWindowDims) { + dbgs() << " iterator#" << dim << " size " << loopUpperBounds[dim] + << "\n"; + } + + dbgs() << "output window:\n"; + for (unsigned dim : outputWindowDims) { + dbgs() << " iterator#" << dim << " size " << loopUpperBounds[dim] + << "\n"; + } + + dbgs() << "strides: ["; + llvm::interleaveComma(strides, dbgs()); + dbgs() << "]\n"; + }); + + MLIRContext *context = convOp.getContext(); + Location loc = convOp.getLoc(); + Value zero = rewriter.createOrFold(loc, 0); + Type elementType = outputType.getElementType(); + + // We required the filter window dimensions to be all size one in the above. + // Next we will unroll the output window dimensions so we can reduce the + // innner convolution to be a contraction with permutated affine maps. + // We need to build the vector types, indexing maps, and iterators for the + // contraction. + + SmallVector indexingMaps = linalgOp.getIndexingMaps(); + unsigned numLoops = linalgOp.getNumLoops(); + + // For input, all window dimensions should have size 1 and only access + // index 0, because of all size-1 filter/output window dimensions. + auto inputMapResults = llvm::to_vector<6>(indexingMaps[0].getResults()); + auto inputVectorShape = llvm::to_vector<6>(inputType.getShape()); + SmallVector isInputWindowDim(inputType.getRank(), false); + for (int i = 0, e = inputType.getRank(), winDim = 0; i < e; ++i) { + if (!indexingMaps[0].getResults()[i].isa()) { + // Use output window dimensions to replace the ones from convolution op. + inputMapResults[i] = + getAffineDimExpr(outputWindowDims[winDim++], context); + inputVectorShape[i] = 1; + isInputWindowDim[i] = true; + } + } + auto inputVectorMap = AffineMap::get(numLoops, 0, inputMapResults, context); + auto inputVectorType = VectorType::get(inputVectorShape, elementType); + LLVM_DEBUG(dbgs() << "input vector type: " << inputVectorType << "\n"); + + auto filterVectorType = VectorType::get(filterType.getShape(), elementType); + LLVM_DEBUG(dbgs() << "filter vector type: " << filterVectorType << "\n"); + + auto outputMapResults = indexingMaps[2].getResults(); + auto outputVectorShape = llvm::to_vector<6>(outputType.getShape()); + DenseMap outputWindowShapeDimToLoopDim; + for (int i = 0, e = outputType.getRank(); i < e; ++i) { + if (auto dimExpr = outputMapResults[i].dyn_cast()) { + unsigned dim = dimExpr.getPosition(); + if (llvm::is_contained(outputWindowDims, dim)) { + outputVectorShape[i] = 1; + outputWindowShapeDimToLoopDim[i] = dim; + } + } + } + auto outputVectorType = VectorType::get(outputVectorShape, elementType); + LLVM_DEBUG(dbgs() << "output vector type: " << outputVectorType << "\n"); + LLVM_DEBUG({ + dbgs() << "output window shape dim to loop dim:\n"; + for (const auto &pair : outputWindowShapeDimToLoopDim) + dbgs() << " " << pair.first << " -> " << pair.second << "\n"; + }); + + // For filter and output, the affine map is the same as the convolution op. + ArrayAttr indexingMapArray = rewriter.getAffineMapArrayAttr( + {inputVectorMap, indexingMaps[1], indexingMaps[2]}); + + // Build iterator types for the vector contraction op. + SmallVector iteratorStrs(numLoops, + getParallelIteratorTypeName()); + // Only the input channel dimension is reduction. + if (inputChannelDim) + iteratorStrs[*inputChannelDim] = getReductionIteratorTypeName(); + ArrayAttr iterators = rewriter.getStrArrayAttr(iteratorStrs); + + LLVM_DEBUG({ + dbgs() << "vector contraction affine maps: ["; + llvm::interleaveComma(indexingMapArray.getValue(), dbgs()); + dbgs() << "]\n"; + + dbgs() << "vector contraction iterator types: ["; + llvm::interleaveComma(iterators.getValue(), dbgs()); + dbgs() << "]\n"; + }); + + // Transfer read the entire filter at the beginning. + SmallVector filterIndices(filterType.getRank(), zero); + Value filterVector = rewriter.create( + loc, filterVectorType, filter, filterIndices); + + // Loop over all output elements. Because we don't know how many window + // dimensions we have, so cannot use static C++ for loops here. Instead, + // loop over the total count. + unsigned outputElementCount = 1; + for (unsigned dim : outputWindowDims) + outputElementCount *= loopUpperBounds[dim]; + + LLVM_DEBUG({ + dbgs() << "output window element count: " << outputElementCount << "\n"; + }); + + // Return constant SSA values for the given list of integers. + auto getValues = [&](ArrayRef ints) { + SmallVector values; + values.reserve(ints.size()); + for (int64_t i : ints) + values.push_back(rewriter.createOrFold(loc, i)); + return values; + }; + + // If this is working on tensors, we need to keep track of the update chain + // to outputs. + bool hasTensorSemantics = linalgOp.hasTensorSemantics(); + Value outputSource = output; + + for (unsigned i = 0; i < outputElementCount; ++i) { + // Delinearize the total element indiex into window indices. + SmallVector outputIndices(outputType.getRank(), 0); + SmallVector inputIndices(inputType.getRank(), 0); + int index = i, strideIndex = strides.size() - 1; + for (int dim = inputType.getRank() - 1; dim >= 0; --dim) { + if (isInputWindowDim[dim]) { + unsigned size = loopUpperBounds[outputWindowShapeDimToLoopDim[dim]]; + outputIndices[dim] = index % size; + index /= size; + inputIndices[dim] = outputIndices[dim] * strides[strideIndex--]; + } + } + + Value inputVector = rewriter.create( + loc, inputVectorType, input, getValues(inputIndices)); + + // Read in the initial value for this output vector. + Value outputVector = rewriter.create( + loc, outputVectorType, output, getValues(outputIndices)); + // Perform contraction. + outputVector = rewriter.create( + loc, inputVector, filterVector, outputVector, indexingMapArray, + iterators); + // Write out the output vector. + auto writeOp = rewriter.create( + loc, outputVector, outputSource, getValues(outputIndices)); + if (hasTensorSemantics) + outputSource = writeOp.getResult(0); + } + + if (hasTensorSemantics) + rewriter.replaceOp(convOp, outputSource); + else + rewriter.eraseOp(convOp); + + return success(); + } +}; + +void mlir::linalg::populateConvolutionVectorizationPatterns( + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(patterns.getContext(), benefit); +} diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir @@ -0,0 +1,157 @@ +// RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-linalg-to-vector-patterns %s | FileCheck %s + +func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<1x3x8xf32>, %output: memref<4x2x8xf32>) { + linalg.conv_1d_nwc_wcf + {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} + ins(%input, %filter : memref<4x6x3xf32>, memref<1x3x8xf32>) + outs(%output : memref<4x2x8xf32>) + return +} + +// CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> +// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4, d2)> +// CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> + +// CHECK: func @conv1d_nwc_4x2x8_memref +// CHECK-SAME: (%[[INPUT:.+]]: memref<4x6x3xf32>, %[[FILTER:.+]]: memref<1x3x8xf32>, %[[OUTPUT:.+]]: memref<4x2x8xf32>) + +// CHECK-DAG: %[[C0:.+]] = constant 0 : index +// CHECK-DAG: %[[C1:.+]] = constant 1 : index +// CHECK-DAG: %[[C3:.+]] = constant 3 : index +// CHECK-DAG: %[[F0:.+]] = constant 0.000000e+00 : f32 + +// CHECK: %[[V_FILTER:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] {in_bounds = [true, true, true]} : memref<1x3x8xf32>, vector<1x3x8xf32> + +// CHECK: %[[V_INPUT0:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] {in_bounds = [true, true, true]} : memref<4x6x3xf32>, vector<4x1x3xf32> +// CHECK: %[[V_OUTPUT_0:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] {in_bounds = [true, true, true]} : memref<4x2x8xf32>, vector<4x1x8xf32> +// CHECK: %[[CONTRACT0:.+]] = vector.contract {indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"], kind = #vector.kind} %[[V_INPUT0]], %[[V_FILTER]], %[[V_OUTPUT_0]] : vector<4x1x3xf32>, vector<1x3x8xf32> into vector<4x1x8xf32> +// CHECK: vector.transfer_write %[[CONTRACT0]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<4x1x8xf32>, memref<4x2x8xf32> + +// CHECK: %[[V_INPUT3:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C3]], %[[C0]]], %[[F0]] {in_bounds = [true, true, true]} : memref<4x6x3xf32>, vector<4x1x3xf32> +// CHECK: %[[V_OUTPUT_1:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]], %[[F0]] {in_bounds = [true, true, true]} : memref<4x2x8xf32>, vector<4x1x8xf32> +// CHECK: %[[CONTRACT1:.+]] = vector.contract {indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"], kind = #vector.kind} %[[V_INPUT3]], %[[V_FILTER]], %[[V_OUTPUT_1]] : vector<4x1x3xf32>, vector<1x3x8xf32> into vector<4x1x8xf32> +// CHECK: vector.transfer_write %[[CONTRACT1]], %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]] {in_bounds = [true, true, true]} : vector<4x1x8xf32>, memref<4x2x8xf32> + +// ----- + +func @conv2d_nhwc_4x2x2x8_memref(%input: memref<4x3x6x3xf32>, %filter: memref<1x1x3x8xf32>, %output: memref<4x2x2x8xf32>) { + linalg.conv_2d_nhwc_hwcf + {dilations = dense<1> : tensor<2xi64>, strides = dense<[2,3]> : tensor<2xi64>} + ins(%input, %filter : memref<4x3x6x3xf32>, memref<1x1x3x8xf32>) + outs(%output : memref<4x2x2x8xf32>) + return +} + +// CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d6)> +// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)> +// CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> + +// CHECK: func @conv2d_nhwc_4x2x2x8_memref +// CHECK-SAME: (%[[INPUT:.+]]: memref<4x3x6x3xf32>, %[[FILTER:.+]]: memref<1x1x3x8xf32>, %[[OUTPUT:.+]]: memref<4x2x2x8xf32>) + +// CHECK-DAG: %[[C0:.+]] = constant 0 : index +// CHECK-DAG: %[[C1:.+]] = constant 1 : index +// CHECK-DAG: %[[C2:.+]] = constant 2 : index +// CHECK-DAG: %[[C3:.+]] = constant 3 : index +// CHECK-DAG: %[[F0:.+]] = constant 0.000000e+00 : f32 + +// CHECK: %[[V_FILTER:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[F0]] {in_bounds = [true, true, true, true]} : memref<1x1x3x8xf32>, vector<1x1x3x8xf32> + +// CHECK: %[[V_INPUT_00:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[F0]] {in_bounds = [true, true, true, true]} : memref<4x3x6x3xf32>, vector<4x1x1x3xf32> +// CHECK: %[[V_OUTPUT_00:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[F0]] {in_bounds = [true, true, true, true]} : memref<4x2x2x8xf32>, vector<4x1x1x8xf32> +// CHECK: %[[CONTRACT0:.+]] = vector.contract { +// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "reduction"], +// CHECK-SAME: kind = #vector.kind +// CHECK-SAME: } %[[V_INPUT_00]], %[[V_FILTER]], %[[V_OUTPUT_00]] : vector<4x1x1x3xf32>, vector<1x1x3x8xf32> into vector<4x1x1x8xf32> +// CHECK: vector.transfer_write %[[CONTRACT0]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true, true]} : vector<4x1x1x8xf32>, memref<4x2x2x8xf32> + +// CHECK: %[[V_INPUT_03:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C3]], %[[C0]]], %[[F0]] {in_bounds = [true, true, true, true]} : memref<4x3x6x3xf32>, vector<4x1x1x3xf32> +// CHECK: %[[V_OUTPUT_01:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C1]], %[[C0]]], %[[F0]] {in_bounds = [true, true, true, true]} : memref<4x2x2x8xf32>, vector<4x1x1x8xf32> +// CHECK: %[[CONTRACT1:.+]] = vector.contract { +// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "reduction"], +// CHECK-SAME: kind = #vector.kind +// CHECK-SAME: } %[[V_INPUT_03]], %[[V_FILTER]], %[[V_OUTPUT_01]] : vector<4x1x1x3xf32>, vector<1x1x3x8xf32> into vector<4x1x1x8xf32> +// CHECK: vector.transfer_write %[[CONTRACT1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C1]], %[[C0]]] {in_bounds = [true, true, true, true]} : vector<4x1x1x8xf32>, memref<4x2x2x8xf32> + +// CHECK: %[[V_INPUT_20:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C2]], %[[C0]], %[[C0]]], %[[F0]] {in_bounds = [true, true, true, true]} : memref<4x3x6x3xf32>, vector<4x1x1x3xf32> +// CHECK: %[[V_OUTPUT_10:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]], %[[C0]]], %[[F0]] {in_bounds = [true, true, true, true]} : memref<4x2x2x8xf32>, vector<4x1x1x8xf32> +// CHECK: %[[CONTRACT2:.+]] = vector.contract { +// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "reduction"], +// CHECK-SAME: kind = #vector.kind +// CHECK-SAME: } %[[V_INPUT_20]], %[[V_FILTER]], %[[V_OUTPUT_10]] : vector<4x1x1x3xf32>, vector<1x1x3x8xf32> into vector<4x1x1x8xf32> +// CHECK: vector.transfer_write %[[CONTRACT2]], %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true, true]} : vector<4x1x1x8xf32>, memref<4x2x2x8xf32> + +// CHECK: %[[V_INPUT_23:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C2]], %[[C3]], %[[C0]]], %[[F0]] {in_bounds = [true, true, true, true]} : memref<4x3x6x3xf32>, vector<4x1x1x3xf32> +// CHECK: %[[V_OUTPUT_11:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C1]], %[[C1]], %[[C0]]], %[[F0]] {in_bounds = [true, true, true, true]} : memref<4x2x2x8xf32>, vector<4x1x1x8xf32> +// CHECK: %[[CONTRACT3:.+]] = vector.contract { +// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "reduction"], +// CHECK-SAME: kind = #vector.kind +// CHECK-SAME: } %[[V_INPUT_23]], %[[V_FILTER]], %[[V_OUTPUT_11]] : vector<4x1x1x3xf32>, vector<1x1x3x8xf32> into vector<4x1x1x8xf32> +// CHECK: vector.transfer_write %[[CONTRACT3]], %[[OUTPUT]][%[[C0]], %[[C1]], %[[C1]], %[[C0]]] {in_bounds = [true, true, true, true]} : vector<4x1x1x8xf32>, memref<4x2x2x8xf32> + +// ----- + +func @conv2d_hchw_4x8x2x2_tensor(%input: tensor<4x3x3x6xf32>, %filter: tensor<8x3x1x1xf32>, %init: tensor<4x8x2x2xf32>) -> tensor<4x8x2x2xf32> { + %0 = linalg.conv_2d_nchw_fchw + {dilations = dense<1> : tensor<2xi64>, strides = dense<[2,3]> : tensor<2xi64>} + ins(%input, %filter : tensor<4x3x3x6xf32>, tensor<8x3x1x1xf32>) + outs(%init : tensor<4x8x2x2xf32>) -> tensor<4x8x2x2xf32> + return %0: tensor<4x8x2x2xf32> +} + +// CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2, d3)> +// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)> +// CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> + +// CHECK: func @conv2d_hchw_4x8x2x2_tensor +// CHECK-SAME: (%[[INPUT:.+]]: tensor<4x3x3x6xf32>, %[[FILTER:.+]]: tensor<8x3x1x1xf32>, %[[OUTPUT:.+]]: tensor<4x8x2x2xf32>) + +// CHECK-DAG: %[[C0:.+]] = constant 0 : index +// CHECK-DAG: %[[C1:.+]] = constant 1 : index +// CHECK-DAG: %[[C2:.+]] = constant 2 : index +// CHECK-DAG: %[[C3:.+]] = constant 3 : index +// CHECK-DAG: %[[F0:.+]] = constant 0.000000e+00 : f32 + +// CHECK: %[[V_FILTER:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[F0]] {in_bounds = [true, true, true, true]} : tensor<8x3x1x1xf32>, vector<8x3x1x1xf32> + +// CHECK: %[[V_INPUT_00:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[F0]] {in_bounds = [true, true, true, true]} : tensor<4x3x3x6xf32>, vector<4x3x1x1xf32> +// CHECK: %[[V_OUTPUT_00:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[F0]] {in_bounds = [true, true, true, true]} : tensor<4x8x2x2xf32>, vector<4x8x1x1xf32> +// CHECK: %[[CONTRACT0:.+]] = vector.contract { +// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "parallel", "parallel"], +// CHECK-SAME: kind = #vector.kind +// CHECK-SAME: } %[[V_INPUT_00]], %[[V_FILTER]], %[[V_OUTPUT_00]] : vector<4x3x1x1xf32>, vector<8x3x1x1xf32> into vector<4x8x1x1xf32> +// CHECK: %[[V_OUTPUT_00:.+]] = vector.transfer_write %[[CONTRACT0]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true, true]} : vector<4x8x1x1xf32>, tensor<4x8x2x2xf32> + +// CHECK: %[[V_INPUT_03:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]], %[[C3]]], %[[F0]] {in_bounds = [true, true, true, true]} : tensor<4x3x3x6xf32>, vector<4x3x1x1xf32> +// CHECK: %[[V_OUTPUT_01:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]], %[[C1]]], %[[F0]] {in_bounds = [true, true, true, true]} : tensor<4x8x2x2xf32>, vector<4x8x1x1xf32> +// CHECK: %[[CONTRACT1:.+]] = vector.contract { +// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "parallel", "parallel"], +// CHECK-SAME: kind = #vector.kind +// CHECK-SAME: } %[[V_INPUT_03]], %[[V_FILTER]], %[[V_OUTPUT_01]] : vector<4x3x1x1xf32>, vector<8x3x1x1xf32> into vector<4x8x1x1xf32> +// CHECK: %[[V_OUTPUT_01:.+]] = vector.transfer_write %[[CONTRACT1]], %[[V_OUTPUT_00]][%[[C0]], %[[C0]], %[[C0]], %[[C1]]] {in_bounds = [true, true, true, true]} : vector<4x8x1x1xf32>, tensor<4x8x2x2xf32> + +// CHECK: %[[V_INPUT_20:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C2]], %[[C0]]], %[[F0]] {in_bounds = [true, true, true, true]} : tensor<4x3x3x6xf32>, vector<4x3x1x1xf32> +// CHECK: %[[V_OUTPUT_10:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C1]], %[[C0]]], %[[F0]] {in_bounds = [true, true, true, true]} : tensor<4x8x2x2xf32>, vector<4x8x1x1xf32> +// CHECK: %[[CONTRACT2:.+]] = vector.contract { +// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "parallel", "parallel"], +// CHECK-SAME: kind = #vector.kind +// CHECK-SAME: } %[[V_INPUT_20]], %[[V_FILTER]], %[[V_OUTPUT_10]] : vector<4x3x1x1xf32>, vector<8x3x1x1xf32> into vector<4x8x1x1xf32> +// CHECK: %[[V_OUTPUT_10:.+]] = vector.transfer_write %[[CONTRACT2]], %[[V_OUTPUT_01]][%[[C0]], %[[C0]], %[[C1]], %[[C0]]] {in_bounds = [true, true, true, true]} : vector<4x8x1x1xf32>, tensor<4x8x2x2xf32> + +// CHECK: %[[V_INPUT_23:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C2]], %[[C3]]], %[[F0]] {in_bounds = [true, true, true, true]} : tensor<4x3x3x6xf32>, vector<4x3x1x1xf32> +// CHECK: %[[V_OUTPUT_11:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C1]], %[[C1]]], %[[F0]] {in_bounds = [true, true, true, true]} : tensor<4x8x2x2xf32>, vector<4x8x1x1xf32> +// CHECK: %[[CONTRACT3:.+]] = vector.contract { +// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "parallel", "parallel"], +// CHECK-SAME: kind = #vector.kind +// CHECK-SAME: } %[[V_INPUT_23]], %[[V_FILTER]], %[[V_OUTPUT_11]] : vector<4x3x1x1xf32>, vector<8x3x1x1xf32> into vector<4x8x1x1xf32> +// CHECK: %[[V_OUTPUT_11:.+]] = vector.transfer_write %[[CONTRACT3]], %[[V_OUTPUT_10]][%[[C0]], %[[C0]], %[[C1]], %[[C1]]] {in_bounds = [true, true, true, true]} : vector<4x8x1x1xf32>, tensor<4x8x2x2xf32> + +// CHECK: return %[[V_OUTPUT_11]] diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -548,6 +548,7 @@ LinalgTransformationFilter() .addOpFilter()); populatePadTensorOpVectorizationPatterns(patterns); + populateConvolutionVectorizationPatterns(patterns); (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); }