diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -30,6 +30,10 @@ SmallVector loops; }; +/// Populates patterns for vectorization of all ConvN-D ops. +void populateConvVectorizationPatterns(MLIRContext *context, + OwningRewritePatternList &patterns); + /// Performs standalone tiling of a single LinalgOp by `tileSizes`. /// and permute the loop nest according to `interchangeVector` /// The permutation is expressed as a list of integers that specify @@ -531,6 +535,53 @@ PatternRewriter &rewriter) const override; }; +/// Converts Convolution op into vector contraction. +/// +/// Conversion expects ConvOp to have dimensions marked in the *mask* as +/// false of size 1. This ensures that the ConvOp can be lowered to vector +/// contraction of dimensions marked in the *mask* as true. +/// +/// A good example is ConvNHWCOp which is 2D Conv op with channels as the last +/// dimension. For this op we contract last 3 dimensions. +/// The initial op definition looks like this: +/// ``` +/// linalg.conv_2d_nhwc %arg0, %arg1, %arg2 : +/// (memref<1x3x3x3xf32>, memref<1x3x3x3xf32>, memref) +/// ``` +/// This op can be expressed as a dot product between %arg0 (input) and +/// %arg1 (kernel) which is written into first entry of %arg2 (output). This is +/// the ConvOp this pass expects and converts into: +/// ``` +/// #map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +/// #map1 = affine_map<(d0, d1, d2) -> ()> +/// ..... +/// %0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %c0_f32 +/// : memref<1x3x3x3xf32>, vector<3x3x3xf32> +/// %1 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %c0_f32 +/// : memref<1x3x3x3xf32>, vector<3x3x3xf32> +/// %2 = vector.contract {indexing_maps = [#map0, #map0, #map1], +/// iterator_types = ["reduction", "reduction", "reduction"]} %0, %1, +/// %c0_f32 : vector<3x3x3xf32>, vector<3x3x3xf32> into f32 +/// store %2, %arg2[%c0, %c0, %c0, %c0] : memref +/// ``` +/// where first 2 operations read input and kernel memory buffers into vectors. +/// Subsequently, they are contracted together and the result is written to +/// the first entry of the output buffer. +template +struct ConvOpVectorization : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + SmallVector mask; + + ConvOpVectorization(MLIRContext *context, SmallVector msk) + : OpRewritePattern(context) { + assert(msk.size() == N && "Mask size does not match rank"); + this->mask = msk; + } + + LogicalResult matchAndRewrite(ConvOp minOp, + PatternRewriter &rewriter) const override; +}; + //===----------------------------------------------------------------------===// // Support for staged pattern application. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -367,3 +367,98 @@ return success(); } + +template +LogicalResult ConvOpVectorization::matchAndRewrite( + ConvOp op, PatternRewriter &rewriter) const { + const uint dimSize = 3; + Location loc = op.getLoc(); + MLIRContext *context = op.getContext(); + edsc::ScopedContext scope(rewriter, loc); + + ShapedType inShapeType = op.getInputShapedType(0); + ShapedType kShapeType = op.getInputShapedType(1); + + ArrayRef inShape = inShapeType.getShape(); + ArrayRef kShape = kShapeType.getShape(); + + if (!inShapeType.hasStaticShape() || !kShapeType.hasStaticShape()) + return failure(); + + SmallVector mapping; + // Fail to apply when the size of not vectorized dimension is not 1 or + // when the size of vectorized dimension is not dimSize. + for (unsigned i = 0; i < N; i++) { + if (!mask[i] && (inShape[i] != 1 || kShape[i] != 1)) + return failure(); + if (mask[i] && (inShape[i] != dimSize || kShape[i] != dimSize)) + return failure(); + + if (mask[i]) + mapping.push_back(getAffineDimExpr(i, context)); + } + + Value input = op.getInput(0); + Value kernel = op.getInput(1); + Value output = op.getOutputBuffer(0); + + uint rank = inShapeType.getRank(); + uint numDims = mapping.size(); + Type elemType = inShapeType.getElementType(); + + auto map = AffineMap::get(rank, 0, mapping, context); + SmallVector zeros(rank, std_constant_index(0)); + auto vecType = + VectorType::get(SmallVector(numDims, dimSize), elemType); + + auto inputVec = vector_transfer_read(vecType, input, zeros, map); + auto kernelVec = vector_transfer_read(vecType, kernel, zeros, map); + + auto acc = std_constant(elemType, rewriter.getZeroAttr(elemType)); + + std::array indexingMaps{ + AffineMap::getMultiDimIdentityMap(numDims, context), + AffineMap::getMultiDimIdentityMap(numDims, context), + AffineMap::get(numDims, 0, {}, context)}; + + std::vector iteratorTypes(numDims, "reduction"); + + auto result = rewriter.create( + loc, inputVec, kernelVec, acc, + rewriter.getAffineMapArrayAttr(indexingMaps), + rewriter.getStrArrayAttr(iteratorTypes)); + + rewriter.create(loc, result, output, ValueRange(zeros)); + rewriter.eraseOp(op); + return success(); +} + +void mlir::linalg::populateConvVectorizationPatterns( + MLIRContext *context, OwningRewritePatternList &patterns) { + patterns.insert>( + context, SmallVector{true}); + + patterns.insert>( + context, SmallVector{false, true, true}); + + patterns.insert>( + context, SmallVector{false, true, true}); + + patterns.insert>( + context, SmallVector{true, true}); + + patterns.insert>( + context, SmallVector{false, true, true, true}); + + patterns.insert>( + context, SmallVector{false, true, true, true}); + + patterns.insert>( + context, SmallVector{true, true, true}); + + patterns.insert>( + context, SmallVector{false, true, true, true, true}); + + patterns.insert>( + context, SmallVector{false, true, true, true, true}); +} diff --git a/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir b/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir @@ -0,0 +1,167 @@ +// RUN: mlir-opt %s -test-conv-vectorization --cse | FileCheck %s + +// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0) -> (d0)> +// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0) -> ()> +// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)> +// CHECK-DAG: #[[$map3:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[$map4:.*]] = affine_map<(d0, d1) -> ()> +// CHECK-DAG: #[[$map5:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)> +// CHECK-DAG: #[[$map6:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[$map7:.*]] = affine_map<(d0, d1, d2) -> ()> +// CHECK-DAG: #[[$map8:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d1, d2, d3, d4)> +// CHECK-DAG: #[[$map9:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-DAG: #[[$map10:.*]] = affine_map<(d0, d1, d2, d3) -> ()> + +func @conv_1d(%arg0: memref<3xf32>, %arg1: memref<3xf32>, %arg2: memref) { + linalg.conv_1d %arg0, %arg1, %arg2 : (memref<3xf32>, memref<3xf32>, memref) + return +} + +// CHECK-LABEL: @conv_1d +// CHECK-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<3xf32> +// CHECK-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<3xf32> +// CHECK-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref, vector<3xf32> +// CHECK: %[[v1:.*]] = vector.transfer_read %[[arg1]][%[[c0]]], %[[cst]] : memref<3xf32>, vector<3xf32> +// CHECK: %[[v2:.*]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map0]], #[[$map1]]], iterator_types = ["reduction"]} %[[v0]], %[[v1]], %[[cst]] : vector<3xf32>, vector<3xf32> into f32 +// CHECK: store %[[v2]], %[[arg2]][%[[c0]]] : memref +// CHECK: return + +func @conv_1d_ncw(%arg0: memref<1x3x3xf32>, %arg1: memref<1x3x3xf32>, %arg2: memref) { + linalg.conv_1d_ncw %arg0, %arg1, %arg2 : (memref<1x3x3xf32>, memref<1x3x3xf32>, memref) + return +} + +// CHECK-LABEL: @conv_1d_ncw +// CHECK-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<1x3x3xf32> +// CHECK-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<1x3x3xf32> +// CHECK-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref, vector<3x3xf32> +// CHECK: %[[v1:.*]] = vector.transfer_read %[[arg1]][%[[c0]], %[[c0]], %[[c0]]], %[[cst]] : memref<1x3x3xf32>, vector<3x3xf32> +// CHECK: %[[v2:.*]] = vector.contract {indexing_maps = [#[[$map3]], #[[$map3]], #[[$map4]]], iterator_types = ["reduction", "reduction"]} %[[v0]], %[[v1]], %[[cst]] : vector<3x3xf32>, vector<3x3xf32> into f32 +// CHECK: store %[[v2]], %[[arg2]][%[[c0]], %[[c0]], %[[c0]]] : memref +// CHECK: return + + +func @conv_1d_nwc(%arg0: memref<1x3x3xf32>, %arg1: memref<1x3x3xf32>, %arg2: memref) { + linalg.conv_1d_nwc %arg0, %arg1, %arg2 : (memref<1x3x3xf32>, memref<1x3x3xf32>, memref) + return +} + +// CHECK-LABEL: @conv_1d_nwc +// CHECK-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<1x3x3xf32> +// CHECK-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<1x3x3xf32> +// CHECK-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref, vector<3x3xf32> +// CHECK: %[[v1:.*]] = vector.transfer_read %[[arg1]][%[[c0]], %[[c0]], %[[c0]]], %[[cst]] : memref<1x3x3xf32>, vector<3x3xf32> +// CHECK: %[[v2:.*]] = vector.contract {indexing_maps = [#[[$map3]], #[[$map3]], #[[$map4]]], iterator_types = ["reduction", "reduction"]} %[[v0]], %[[v1]], %[[cst]] : vector<3x3xf32>, vector<3x3xf32> into f32 +// CHECK: store %[[v2]], %[[arg2]][%[[c0]], %[[c0]], %[[c0]]] : memref +// CHECK: return + +func @conv_2d(%arg0: memref<3x3xf32>, %arg1: memref<3x3xf32>, %arg2: memref) { + linalg.conv_2d %arg0, %arg1, %arg2 : (memref<3x3xf32>, memref<3x3xf32>, memref) + return +} + +// CHECK-LABEL: @conv_2d +// CHECK-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<3x3xf32> +// CHECK-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<3x3xf32> +// CHECK-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref, vector<3x3xf32> +// CHECK: %[[v1:.*]] = vector.transfer_read %[[arg1]][%[[c0]], %[[c0]]], %[[cst]] : memref<3x3xf32>, vector<3x3xf32> +// CHECK: %[[v2:.*]] = vector.contract {indexing_maps = [#[[$map3]], #[[$map3]], #[[$map4]]], iterator_types = ["reduction", "reduction"]} %[[v0]], %[[v1]], %[[cst]] : vector<3x3xf32>, vector<3x3xf32> into f32 +// CHECK: store %[[v2]], %[[arg2]][%[[c0]], %[[c0]]] : memref +// CHECK: return + +func @conv_2d_nchw(%arg0: memref<1x3x3x3xf32>, %arg1: memref<1x3x3x3xf32>, %arg2: memref) { + linalg.conv_2d_nchw %arg0, %arg1, %arg2 : (memref<1x3x3x3xf32>, memref<1x3x3x3xf32>, memref) + return +} + +// CHECK-LABEL: @conv_2d_nchw +// CHECK-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<1x3x3x3xf32> +// CHECK-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<1x3x3x3xf32> +// CHECK-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref, vector<3x3x3xf32> +// CHECK: %[[v1:.*]] = vector.transfer_read %[[arg1]][%[[c0]], %[[c0]], %[[c0]], %[[c0]]], %[[cst]] : memref<1x3x3x3xf32>, vector<3x3x3xf32> +// CHECK: %[[v2:.*]] = vector.contract {indexing_maps = [#[[$map6]], #[[$map6]], #[[$map7]]], iterator_types = ["reduction", "reduction", "reduction"]} %[[v0]], %[[v1]], %[[cst]] : vector<3x3x3xf32>, vector<3x3x3xf32> into f32 +// CHECK: store %[[v2]], %[[arg2]][%[[c0]], %[[c0]], %[[c0]], %[[c0]]] : memref +// CHECK: return + +func @conv_2d_nhwc(%arg0: memref<1x3x3x3xf32>, %arg1: memref<1x3x3x3xf32>, %arg2: memref) { + linalg.conv_2d_nhwc %arg0, %arg1, %arg2 : (memref<1x3x3x3xf32>, memref<1x3x3x3xf32>, memref) + return +} + +// CHECK-LABEL: @conv_2d_nhwc +// CHECK-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<1x3x3x3xf32> +// CHECK-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<1x3x3x3xf32> +// CHECK-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref, vector<3x3x3xf32> +// CHECK: %[[v1:.*]] = vector.transfer_read %[[arg1]][%[[c0]], %[[c0]], %[[c0]], %[[c0]]], %[[cst]] : memref<1x3x3x3xf32>, vector<3x3x3xf32> +// CHECK: %[[v2:.*]] = vector.contract {indexing_maps = [#[[$map6]], #[[$map6]], #[[$map7]]], iterator_types = ["reduction", "reduction", "reduction"]} %[[v0]], %[[v1]], %[[cst]] : vector<3x3x3xf32>, vector<3x3x3xf32> into f32 +// CHECK: store %[[v2]], %[[arg2]][%[[c0]], %[[c0]], %[[c0]], %[[c0]]] : memref +// CHECK: return + +func @conv_3d(%arg0: memref<3x3x3xf32>, %arg1: memref<3x3x3xf32>, %arg2: memref) { + linalg.conv_3d %arg0, %arg1, %arg2 : (memref<3x3x3xf32>, memref<3x3x3xf32>, memref) + return +} + +// CHECK-LABEL: @conv_3d +// CHECK-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<3x3x3xf32> +// CHECK-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<3x3x3xf32> +// CHECK-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref, vector<3x3x3xf32> +// CHECK: %[[v1:.*]] = vector.transfer_read %[[arg1]][%[[c0]], %[[c0]], %[[c0]]], %[[cst]] : memref<3x3x3xf32>, vector<3x3x3xf32> +// CHECK: %[[v2:.*]] = vector.contract {indexing_maps = [#[[$map6]], #[[$map6]], #[[$map7]]], iterator_types = ["reduction", "reduction", "reduction"]} %[[v0]], %[[v1]], %[[cst]] : vector<3x3x3xf32>, vector<3x3x3xf32> into f32 +// CHECK: store %[[v2]], %[[arg2]][%[[c0]], %[[c0]], %[[c0]]] : memref +// CHECK: return + +func @conv_3d_ncdhw(%arg0: memref<1x3x3x3x3xf32>, %arg1: memref<1x3x3x3x3xf32>, %arg2: memref) { + linalg.conv_3d_ncdhw %arg0, %arg1, %arg2 : (memref<1x3x3x3x3xf32>, memref<1x3x3x3x3xf32>, memref) + return +} + +// CHECK-LABEL: @conv_3d_ncdhw +// CHECK-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<1x3x3x3x3xf32> +// CHECK-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<1x3x3x3x3xf32> +// CHECK-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref, vector<3x3x3x3xf32> +// CHECK: %[[v1:.*]] = vector.transfer_read %[[arg1]][%[[c0]], %[[c0]], %[[c0]], %[[c0]], %[[c0]]], %[[cst]] : memref<1x3x3x3x3xf32>, vector<3x3x3x3xf32> +// CHECK: %[[v2:.*]] = vector.contract {indexing_maps = [#[[$map9]], #[[$map9]], #[[$map10]]], iterator_types = ["reduction", "reduction", "reduction", "reduction"]} %[[v0]], %[[v1]], %[[cst]] : vector<3x3x3x3xf32>, vector<3x3x3x3xf32> into f32 +// CHECK: store %[[v2]], %[[arg2]][%[[c0]], %[[c0]], %[[c0]], %[[c0]], %[[c0]]] : memref +// CHECK: return + +func @conv_3d_ndhwc(%arg0: memref<1x3x3x3x3xf32>, %arg1: memref<1x3x3x3x3xf32>, %arg2: memref) { + linalg.conv_3d_ndhwc %arg0, %arg1, %arg2 : (memref<1x3x3x3x3xf32>, memref<1x3x3x3x3xf32>, memref) + return +} + +// CHECK-LABEL: @conv_3d_ndhwc +// CHECK-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<1x3x3x3x3xf32> +// CHECK-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<1x3x3x3x3xf32> +// CHECK-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref, vector<3x3x3x3xf32> +// CHECK: %[[v1:.*]] = vector.transfer_read %[[arg1]][%[[c0]], %[[c0]], %[[c0]], %[[c0]], %[[c0]]], %[[cst]] : memref<1x3x3x3x3xf32>, vector<3x3x3x3xf32> +// CHECK: %[[v2:.*]] = vector.contract {indexing_maps = [#[[$map9]], #[[$map9]], #[[$map10]]], iterator_types = ["reduction", "reduction", "reduction", "reduction"]} %[[v0]], %[[v1]], %[[cst]] : vector<3x3x3x3xf32>, vector<3x3x3x3xf32> into f32 +// CHECK: store %[[v2]], %[[arg2]][%[[c0]], %[[c0]], %[[c0]], %[[c0]], %[[c0]]] : memref +// CHECK: return diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -5,6 +5,7 @@ TestExpandTanh.cpp TestCallGraph.cpp TestConstantFold.cpp + TestConvVectorization.cpp TestConvertCallOp.cpp TestConvertGPUKernelToCubin.cpp TestConvertGPUKernelToHsaco.cpp diff --git a/mlir/test/lib/Transforms/TestConvVectorization.cpp b/mlir/test/lib/Transforms/TestConvVectorization.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Transforms/TestConvVectorization.cpp @@ -0,0 +1,51 @@ +//===- TestConvVectorization.cpp - Linalg to Vector dialect conversion ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; + +namespace { +/// A pass converting MLIR Linalg ops into Vector ops. +class TestConvVectorization + : public PassWrapper> { + void runOnOperation() override; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + registry.insert(); + } +}; +} // namespace + +void TestConvVectorization::runOnOperation() { + MLIRContext *context = &getContext(); + ModuleOp module = getOperation(); + + ConversionTarget target(*context); + target.addLegalDialect(); + target.addLegalOp(); + target.addLegalOp(); + + OwningRewritePatternList patterns; + linalg::populateConvVectorizationPatterns(context, patterns); + + if (failed(applyPartialConversion(module, target, patterns))) + return signalPassFailure(); +} + +namespace mlir { +void registerTestConvVectorization() { + PassRegistration testTransformPatternsPass( + "test-conv-vectorization", "Test vectorization of convolutions"); +} +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -45,6 +45,7 @@ void registerTestBufferPlacementPreparationPass(); void registerTestCallGraphPass(); void registerTestConstantFold(); +void registerTestConvVectorization(); void registerTestConvertGPUKernelToCubinPass(); void registerTestConvertGPUKernelToHsacoPass(); void registerTestDominancePass(); @@ -93,6 +94,7 @@ registerTestAffineLoopUnswitchingPass(); registerTestLoopPermutationPass(); registerTestCallGraphPass(); + registerTestConvVectorization(); registerTestConstantFold(); #if MLIR_CUDA_CONVERSIONS_ENABLED registerTestConvertGPUKernelToCubinPass();