diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -46,11 +46,6 @@ //===----------------------------------------------------------------------===// using LinalgLoops = SmallVector; -/// [DEPRECATED] Populate patterns for vectorization of all ConvN-D ops. -void populateConvVectorizationPatterns( - MLIRContext *context, SmallVectorImpl &patterns, - ArrayRef tileSizes); - /// Populate patterns for vectorizing low-D convolution ops. This is a step in /// progressive lowering for convolution ops, it assume high-D convolution ops /// were decomposed previously. diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -43,8 +43,9 @@ #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") #define LDBG(X) LLVM_DEBUG(DBGS() << X) -static FailureOr -vectorizeConvolution(OpBuilder &b, ConvolutionOpInterface convOp); +/// Try to vectorize `convOp` as a convolution. +static FailureOr vectorizeConvolution(OpBuilder &b, + LinalgOp convOp); /// Return the unique instance of OpType in `block` if it is indeed unique. /// Return null if none or more than 1 instances exist. @@ -636,13 +637,12 @@ SmallVector results; // TODO: isaConvolutionOpInterface that can also infer from generic // features. Will require stride/dilation attributes inference. - if (auto convOp = dyn_cast(linalgOp.getOperation())) { - LDBG("Vectorize as a conv: " << linalgOp); - FailureOr convOr = vectorizeConvolution(rewriter, convOp); - if (failed(convOr)) - return failure(); + FailureOr convOr = vectorizeConvolution(rewriter, linalgOp); + if (succeeded(convOr)) { llvm::append_range(results, (*convOr)->getResults()); } else { + if (failed(vectorizeLinalgOpPrecondition(linalgOp))) + return failure(); LDBG("Vectorize generic by broadcasting to a common shape: " << linalgOp); if (failed(vectorizeAsLinalgGeneric(rewriter, linalgOp, results))) return failure(); @@ -1098,134 +1098,6 @@ patterns.getContext(), baseBenefit.getBenefit() + 1); } -// TODO: cleanup all the convolution vectorization patterns. -template -LogicalResult ConvOpVectorization::matchAndRewrite( - ConvOp op, PatternRewriter &rewriter) const { - Location loc = op.getLoc(); - MLIRContext *context = op.getContext(); - - OpOperand *input = op.getInputOperand(0); - OpOperand *kernel = op.getInputOperand(1); - OpOperand *output = op.getOutputOperand(0); - ArrayRef inShape = op.getShape(input); - ArrayRef kShape = op.getShape(kernel); - - if (llvm::any_of(inShape, ShapedType::isDynamic) || - llvm::any_of(kShape, ShapedType::isDynamic)) - return failure(); - - SmallVector mapping; - SmallVector vectorDims; - // Fail to apply when the size of not vectorized dimension is not 1. - for (unsigned i = 0; i < N; i++) { - if (!mask[i] && (inShape[i] != 1 || kShape[i] != 1)) - return failure(); - - if (mask[i] && inShape[i] != kShape[i]) - return failure(); - - if (mask[i]) { - mapping.push_back(getAffineDimExpr(i, context)); - vectorDims.push_back(inShape[i]); - } - } - - int64_t rank = op.getRank(input); - int64_t numDims = mapping.size(); - Type elemType = getElementTypeOrSelf(input->get()); - - auto map = AffineMap::get(rank, 0, mapping, context); - SmallVector zeros(rank, - rewriter.create(loc, 0)); - auto vecType = VectorType::get(vectorDims, elemType); - - auto inputVec = rewriter.create( - loc, vecType, input->get(), zeros, map); - auto kernelVec = rewriter.create( - loc, vecType, kernel->get(), zeros, map); - - auto acc = rewriter.create(loc, elemType, - rewriter.getZeroAttr(elemType)); - - std::array indexingMaps{ - AffineMap::getMultiDimIdentityMap(numDims, context), - AffineMap::getMultiDimIdentityMap(numDims, context), - AffineMap::get(numDims, 0, {}, context)}; - - std::vector iteratorTypes(numDims, "reduction"); - - auto result = rewriter.create( - loc, inputVec, kernelVec, acc, - rewriter.getAffineMapArrayAttr(indexingMaps), - rewriter.getStrArrayAttr(iteratorTypes)); - - rewriter.create(loc, result, output->get(), - ValueRange(zeros)); - rewriter.eraseOp(op); - return success(); -} - -/// Inserts tiling, promotion and vectorization pattern for ConvOp -/// conversion into corresponding pattern lists. -template -static void populateVectorizationPatterns( - RewritePatternSet &tilingPatterns, RewritePatternSet &promotionPatterns, - RewritePatternSet &vectorizationPatterns, ArrayRef tileSizes) { - auto *context = tilingPatterns.getContext(); - if (tileSizes.size() < N) - return; - - constexpr static StringRef kTiledMarker = "TILED"; - constexpr static StringRef kPromotedMarker = "PROMOTED"; - tilingPatterns.add( - ConvOp::getOperationName(), context, - LinalgTilingOptions().setTileSizes(tileSizes), - LinalgTransformationFilter(ArrayRef{}, - StringAttr::get(context, kTiledMarker))); - - promotionPatterns.add>( - context, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), - LinalgTransformationFilter(StringAttr::get(context, kTiledMarker), - StringAttr::get(context, kPromotedMarker))); - - SmallVector mask(N); - int offset = tileSizes.size() - N; - std::transform(tileSizes.begin() + offset, tileSizes.end(), mask.begin(), - [](int64_t i) -> bool { return i > 1; }); - - vectorizationPatterns.add>(context, mask); -} - -void mlir::linalg::populateConvVectorizationPatterns( - MLIRContext *context, SmallVectorImpl &patterns, - ArrayRef tileSizes) { - RewritePatternSet tiling(context); - RewritePatternSet promotion(context); - RewritePatternSet vectorization(context); - populateVectorizationPatterns(tiling, promotion, vectorization, - tileSizes); - - populateVectorizationPatterns(tiling, promotion, vectorization, - tileSizes); - - populateVectorizationPatterns(tiling, promotion, vectorization, - tileSizes); - - populateVectorizationPatterns(tiling, promotion, - vectorization, tileSizes); - - populateVectorizationPatterns(tiling, promotion, - vectorization, tileSizes); - - populateVectorizationPatterns( - tiling, promotion, vectorization, tileSizes); - - patterns.push_back(std::move(tiling)); - patterns.push_back(std::move(promotion)); - patterns.push_back(std::move(vectorization)); -} - //----------------------------------------------------------------------------// // Forwarding patterns //----------------------------------------------------------------------------// @@ -1754,40 +1626,39 @@ }; } // namespace -/// Helper function to vectorize a `linalgOp` with convolution semantics. +/// Helper function to vectorize a LinalgOp with convolution semantics. // TODO: extend the generic vectorization to support windows and drop this. -static FailureOr -vectorizeConvolution(OpBuilder &b, ConvolutionOpInterface convOp) { - // TODO: these are legitimately part of ConvolutionOpInterface. - auto strides = convOp->getAttrOfType("strides"); - auto dilations = convOp->getAttrOfType("dilations"); +static FailureOr vectorizeConvolution(OpBuilder &b, LinalgOp op) { + // The ConvolutionOpInterface gives us guarantees of existence for + // strides/dilations. However, we do not need to rely on those, we can simply + // use them if present, otherwise use the default and let the generic conv. + // matcher in the ConvGenerator succeed or fail. + auto strides = op->getAttrOfType("strides"); + auto dilations = op->getAttrOfType("dilations"); auto stride = strides ? *strides.getValues().begin() : 1; auto dilation = dilations ? *dilations.getValues().begin() : 1; - LinalgOp linalgOp = cast(convOp.getOperation()); - Conv1DNwcGenerator e(b, linalgOp, stride, dilation); + Conv1DNwcGenerator e(b, op, stride, dilation); auto res = e.generateConv(); if (succeeded(res)) return res; return e.generateDilatedConv(); } -struct VectorizeConvolution - : public OpInterfaceRewritePattern { +struct VectorizeConvolution : public OpInterfaceRewritePattern { using OpInterfaceRewritePattern::OpInterfaceRewritePattern; - LogicalResult matchAndRewrite(ConvolutionOpInterface convOp, + LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override { - FailureOr resultOrFail = - vectorizeConvolution(rewriter, convOp); + FailureOr resultOrFail = vectorizeConvolution(rewriter, op); if (failed(resultOrFail)) return failure(); Operation *newOp = *resultOrFail; if (newOp->getNumResults() == 0) { - rewriter.eraseOp(convOp.getOperation()); + rewriter.eraseOp(op.getOperation()); return success(); } assert(newOp->getNumResults() == 1 && "expected single result"); - rewriter.replaceOp(convOp.getOperation(), newOp->getResult(0)); + rewriter.replaceOp(op.getOperation(), newOp->getResult(0)); return success(); } }; diff --git a/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir b/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir deleted file mode 100644 --- a/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir +++ /dev/null @@ -1,53 +0,0 @@ -// RUN: mlir-opt %s -test-conv-vectorization="tile-sizes=1,3" --cse -split-input-file -// | FileCheck %s - -// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0)[s0] -> (1, -d0 + s0)> -// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0)[s0] -> (d0 + s0)> -// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1) -> (d0 + d1)> -// CHECK-DAG: #[[$map3:.*]] = affine_map<(d0, d1)[s0] -> (3, -d0 - d1 + s0)> -// CHECK-DAG: #[[$map4:.*]] = affine_map<(d0)[s0] -> (3, -d0 + s0)> - -// CHECK-LABEL: @conv_1d -// CHECK-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref -// CHECK-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref -// CHECK-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref, %arg1: memref, %arg2: memref) { -// CHECK-DAG: %[[c12:.*]] = arith.constant 12 : index -// CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index -// CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK-DAG: %[[c3:.*]] = arith.constant 3 : index -// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index -// CHECK: %[[v0:.*]] = memref.dim %[[arg1]], %[[c0]] : memref -// CHECK: %[[v1:.*]] = memref.dim %[[arg2]], %[[c0]] : memref -// CHECK: %[[v2:.*]] = memref.dim %[[arg0]], %[[c0]] : memref -// CHECK: %[[v3:.*]] = memref.alloc(%[[c12]]) : memref -// CHECK: %[[v4:.*]] = memref.alloc(%[[c12]]) : memref -// CHECK: %[[v5:.*]] = memref.alloc(%[[c4]]) : memref -// CHECK: %[[v6:.*]] = memref.view %[[v3]][%[[c0]]][] : memref to memref<3xf32> -// CHECK: %[[v7:.*]] = memref.view %[[v4]][%[[c0]]][] : memref to memref<3xf32> -// CHECK: %[[v8:.*]] = memref.view %[[v5]][%[[c0]]][] : memref to memref<1xf32> -// CHECK: scf.for %[[arg3:.*]] = %[[c0]] to %[[v1]] step %[[c1]] { -// CHECK: %[[v9:.*]] = affine.min #[[$map0]](%[[arg3]])[%[[v1]]] -// CHECK: %[[v10:.*]] = subview %[[arg2]][%[[arg3]]] [%[[v9]]] [1] : memref to memref -// CHECK: %[[v11:.*]] = subview %[[v8]][0] [%[[v9]]] [1] : memref<1xf32> to memref -// CHECK: scf.for %[[arg4:.*]] = %[[c0]] to %[[v0]] step %[[c3]] { -// CHECK: %[[v12:.*]] = affine.apply #[[$map2]](%[[arg3]], %[[arg4]]) -// CHECK: %[[v13:.*]] = affine.min #[[$map3]](%[[arg3]], %[[arg4]])[%[[v2]]] -// CHECK: %[[v14:.*]] = subview %arg0[%12] [%13] [1] : memref to memref -// CHECK: %[[v15:.*]] = affine.min #[[$map4]](%arg4)[%0] -// CHECK: %[[v16:.*]] = subview %[[arg1]][%[[arg4]]] [%[[v15]]] [1] : memref to memref -// CHECK: %[[v17:.*]] = subview %[[v6]][0] [%[[v13]]] [1] : memref<3xf32> to memref -// CHECK: %[[v19:.*]] = vector.transfer_read %[[v6]][%[[c0]]], %[[cst]] {in_bounds = [true]} : memref<3xf32>, vector<3xf32> -// CHECK: %[[v20:.*]] = vector.transfer_read %[[v7]][%[[c0]]], %[[cst]] {in_bounds = [true]} : memref<3xf32>, vector<3xf32> -// CHECK: %[[v21:.*]] = arith.mulf %[[v19]], %[[v20]] : vector<3xf32> -// CHECK: %[[v22:.*]] = vector.reduction "add", %[[v21]], %[[cst]] : vector<3xf32> into f32 -// CHECK: store %[[v22]], %[[v8]][%[[c0]]] : memref<1xf32> -// CHECK: scf.for %[[arg5:.*]] = %[[c0]] to %[[v9]] step %[[c1]] { -// CHECK: %[[v23:.*]] = load %[[v11]][%[[arg5]]] : memref -// CHECK: store %[[v23]], %[[v10]][%[[arg5]]] : memref - linalg.conv_1d ins(%arg0, %arg1 : memref, memref) - outs(%arg2 : memref) - return -} - diff --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt --- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt @@ -1,7 +1,6 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRLinalgTestPasses TestComprehensiveBufferize.cpp - TestConvVectorization.cpp TestLinalgCodegenStrategy.cpp TestLinalgDistribution.cpp TestLinalgElementwiseFusion.cpp diff --git a/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp b/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp deleted file mode 100644 --- a/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp +++ /dev/null @@ -1,143 +0,0 @@ -//===- TestConvVectorization.cpp - Vectorization of Conv ops --------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/Conversion/VectorToSCF/VectorToSCF.h" -#include "mlir/Dialect/Linalg/Passes.h" -#include "mlir/Dialect/Linalg/Transforms/Hoisting.h" -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Dialect/SCF/Transforms.h" -#include "mlir/Dialect/Vector/VectorTransforms.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassManager.h" -#include "mlir/Transforms/DialectConversion.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "mlir/Transforms/LoopUtils.h" -#include "mlir/Transforms/Passes.h" - -using namespace mlir; -using namespace vector; - -namespace { -/// A pass converting MLIR Linalg ops into Vector ops. -class TestConvVectorization - : public PassWrapper> { -public: - StringRef getArgument() const final { return "test-conv-vectorization"; } - StringRef getDescription() const final { - return "Test vectorization of convolutions"; - } - TestConvVectorization() = default; - TestConvVectorization(const TestConvVectorization &) {} - explicit TestConvVectorization(ArrayRef tileSizesParam) { - tileSizes = tileSizesParam; - } - - void runOnOperation() override; - - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); - } - - ListOption tileSizes{ - *this, "tile-sizes", llvm::cl::desc("Vectorization sizes."), - llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; -}; -} // namespace - -void TestConvVectorization::runOnOperation() { - MLIRContext *context = &getContext(); - ModuleOp module = getOperation(); - - ConversionTarget target(*context); - target.addLegalDialect(); - target.addLegalOp(); - target.addLegalOp(); - - SmallVector stage1Patterns; - linalg::populateConvVectorizationPatterns(context, stage1Patterns, tileSizes); - SmallVector frozenStage1Patterns; - llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns)); - - RewritePatternSet stage2Patterns = - linalg::getLinalgTilingCanonicalizationPatterns(context); - scf::populateSCFForLoopCanonicalizationPatterns(stage2Patterns); - - auto stage3Transforms = [](Operation *op) { - PassManager pm(op->getContext()); - pm.addPass(createLoopInvariantCodeMotionPass()); - if (failed(pm.run(cast(op)))) - llvm_unreachable("Unexpected failure in cleanup pass pipeline."); - op->walk([](FuncOp func) { - promoteSingleIterationLoops(func); - linalg::hoistRedundantVectorTransfers(func); - }); - return success(); - }; - - (void)linalg::applyStagedPatterns(module, frozenStage1Patterns, - std::move(stage2Patterns), - stage3Transforms); - - //===--------------------------------------------------------------------===// - // Post staged patterns transforms - //===--------------------------------------------------------------------===// - - VectorTransformsOptions vectorTransformOptions{ - VectorContractLowering::Dot, VectorMultiReductionLowering::InnerParallel, - VectorTransposeLowering::EltWise}; - - RewritePatternSet vectorTransferPatterns(context); - // Pattern is not applied: rank-reducing vector transfer is not yet supported - // (see: splitFullAndPartialTransferPrecondition in VectorTransforms.cpp). - vectorTransferPatterns.add( - context, vectorTransformOptions); - (void)applyPatternsAndFoldGreedily(module, std::move(vectorTransferPatterns)); - - // Programmatic controlled lowering of linalg.copy and linalg.fill. - PassManager pm(context); - pm.addNestedPass(createConvertLinalgToLoopsPass()); - if (failed(pm.run(module))) - llvm_unreachable("Unexpected failure in linalg to loops pass."); - - // Programmatic controlled lowering of vector.contract only. - RewritePatternSet vectorContractLoweringPatterns(context); - populateVectorBroadcastLoweringPatterns(vectorContractLoweringPatterns); - populateVectorContractLoweringPatterns(vectorContractLoweringPatterns, - vectorTransformOptions); - populateVectorMaskOpLoweringPatterns(vectorContractLoweringPatterns); - populateVectorShapeCastLoweringPatterns(vectorContractLoweringPatterns); - populateVectorTransposeLoweringPatterns(vectorContractLoweringPatterns, - vectorTransformOptions); - (void)applyPatternsAndFoldGreedily(module, - std::move(vectorContractLoweringPatterns)); - - // Programmatic controlled lowering of vector.transfer only. - RewritePatternSet vectorToLoopsPatterns(context); - populateVectorToSCFConversionPatterns(vectorToLoopsPatterns, - VectorTransferToSCFOptions()); - (void)applyPatternsAndFoldGreedily(module, std::move(vectorToLoopsPatterns)); - - // Ensure we drop the marker in the end. - module.walk([](linalg::LinalgOp op) { - op->removeAttr(linalg::LinalgTransforms::kLinalgTransformMarker); - }); -} - -namespace mlir { -namespace test { -void registerTestConvVectorization() { - PassRegistration(); -} -} // namespace test -} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -66,7 +66,6 @@ void registerTestCallGraphPass(); void registerTestComprehensiveFunctionBufferize(); void registerTestConstantFold(); -void registerTestConvVectorization(); void registerTestGpuSerializeToCubinPass(); void registerTestGpuSerializeToHsacoPass(); void registerTestDataLayoutQuery(); @@ -162,7 +161,6 @@ mlir::test::registerTestGpuSerializeToHsacoPass(); #endif mlir::test::registerTestComprehensiveFunctionBufferize(); - mlir::test::registerTestConvVectorization(); mlir::test::registerTestDecomposeCallGraphTypes(); mlir::test::registerTestDataLayoutQuery(); mlir::test::registerTestDominancePass();