diff --git a/mlir/include/mlir/Dialect/LoopOps/Passes.h b/mlir/include/mlir/Dialect/LoopOps/Passes.h --- a/mlir/include/mlir/Dialect/LoopOps/Passes.h +++ b/mlir/include/mlir/Dialect/LoopOps/Passes.h @@ -23,6 +23,10 @@ /// Creates a loop fusion pass which fuses parallel loops. std::unique_ptr createParallelLoopFusionPass(); +/// Creates a pass that specializes parallel loop for unrolling and +/// vectorization. +std::unique_ptr createParallelLoopPreparationForVectorizationPass(); + /// Creates a pass which tiles innermost parallel loops. std::unique_ptr createParallelLoopTilingPass(llvm::ArrayRef tileSize = {}); diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h --- a/mlir/include/mlir/InitAllPasses.h +++ b/mlir/include/mlir/InitAllPasses.h @@ -109,6 +109,7 @@ // LoopOps createParallelLoopFusionPass(); + createParallelLoopPreparationForVectorizationPass(); createParallelLoopTilingPass(); // QuantOps diff --git a/mlir/lib/Dialect/LoopOps/Transforms/CMakeLists.txt b/mlir/lib/Dialect/LoopOps/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/LoopOps/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/LoopOps/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_llvm_library(MLIRLoopOpsTransforms ParallelLoopFusion.cpp + ParallelLoopPreparationForVectorization.cpp ParallelLoopTiling.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopPreparationForVectorization.cpp b/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopPreparationForVectorization.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopPreparationForVectorization.cpp @@ -0,0 +1,72 @@ +//===- ParallelLoopPreparationForVectorization.cpp - loop.parallel splitting =// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Specializes parallel loops for easier unrolling and vectorization. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/Dialect/LoopOps/LoopOps.h" +#include "mlir/Dialect/LoopOps/Passes.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; +using loop::ParallelOp; + +namespace { +/// Rewrite a loop with bounds defined by an affine.min with a constant into 2 +/// loops after checking if the bounds are equal to that constant. This is +/// beneficial if the loop will almost always have the constant bound and that +/// version can be fully unrolled and vectorized. +void specializeLoopForUnrolling(ParallelOp op) { + SmallVector constantIndices; + constantIndices.reserve(op.upperBound().size()); + for (auto bound : op.upperBound()) { + auto minOp = dyn_cast_or_null(bound.getDefiningOp()); + if (!minOp) + return; + auto constantIndex = + dyn_cast_or_null(minOp.getOperand(0).getDefiningOp()); + if (!constantIndex) + return; + constantIndices.push_back(constantIndex); + } + + OpBuilder b(op); + BlockAndValueMapping map; + Value cond; + for (auto bound : llvm::zip(op.upperBound(), constantIndices)) { + Value cmp = b.create(op.getLoc(), CmpIPredicate::eq, + std::get<0>(bound), std::get<1>(bound)); + cond = cond ? b.create(op.getLoc(), cond, cmp) : cmp; + map.map(std::get<0>(bound), std::get<1>(bound)); + } + auto ifOp = b.create(op.getLoc(), cond, /*withElseRegion=*/true); + ifOp.getThenBodyBuilder().clone(*op.getOperation(), map); + ifOp.getElseBodyBuilder().clone(*op.getOperation()); + op.erase(); +} + +struct ParallelLoopPreparationForVectorization + : public FunctionPass { + void runOnFunction() override { + getFunction().walk([](ParallelOp op) { specializeLoopForUnrolling(op); }); + } +}; +} // namespace + +std::unique_ptr +mlir::createParallelLoopPreparationForVectorizationPass() { + return std::make_unique(); +} + +static PassRegistration + pass("parallel-loop-prep-vec", + "Specialize parallel loops for vectorization."); diff --git a/mlir/test/Dialect/Loops/parallel-loop-preparation-for-vectorization.mlir b/mlir/test/Dialect/Loops/parallel-loop-preparation-for-vectorization.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Loops/parallel-loop-preparation-for-vectorization.mlir @@ -0,0 +1,50 @@ +// RUN: mlir-opt %s -parallel-loop-prep-vec -split-input-file | FileCheck %s --dump-input-on-failure + +#map0 = affine_map<(d0, d1, d2) -> (d0, d1 - d2)> + +func @parallel_loop(%outer_i0: index, %outer_i1: index, %A: memref, %B: memref, + %C: memref, %result: memref) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c64 = constant 64 : index + %c1024 = constant 1024 : index + %d0 = dim %A, 0 : memref + %d1 = dim %A, 1 : memref + %b0 = affine.min #map0(%c1024, %d0, %outer_i0) + %b1 = affine.min #map0(%c64, %d1, %outer_i1) + loop.parallel (%i0, %i1) = (%c0, %c0) to (%b0, %b1) step (%c1, %c1) { + %B_elem = load %B[%i0, %i1] : memref + %C_elem = load %C[%i0, %i1] : memref + %sum_elem = addf %B_elem, %C_elem : f32 + store %sum_elem, %result[%i0, %i1] : memref + } + return +} + +// CHECK-LABEL: func @parallel_loop( +// CHECK-SAME: [[VAL_0:%.*]]: index, [[VAL_1:%.*]]: index, [[VAL_2:%.*]]: memref, [[VAL_3:%.*]]: memref, [[VAL_4:%.*]]: memref, [[VAL_5:%.*]]: memref) { +// CHECK: [[VAL_6:%.*]] = constant 0 : index +// CHECK: [[VAL_7:%.*]] = constant 1 : index +// CHECK: [[VAL_8:%.*]] = constant 64 : index +// CHECK: [[VAL_9:%.*]] = constant 1024 : index +// CHECK: [[VAL_10:%.*]] = dim [[VAL_2]], 0 : memref +// CHECK: [[VAL_11:%.*]] = dim [[VAL_2]], 1 : memref +// CHECK: [[VAL_12:%.*]] = affine.min #map0([[VAL_9]], [[VAL_10]], [[VAL_0]]) +// CHECK: [[VAL_13:%.*]] = affine.min #map0([[VAL_8]], [[VAL_11]], [[VAL_1]]) +// CHECK: [[VAL_14:%.*]] = cmpi "eq", [[VAL_12]], [[VAL_9]] : index +// CHECK: [[VAL_15:%.*]] = cmpi "eq", [[VAL_13]], [[VAL_8]] : index +// CHECK: [[VAL_16:%.*]] = and [[VAL_14]], [[VAL_15]] : i1 +// CHECK: loop.if [[VAL_16]] { +// CHECK: loop.parallel ([[VAL_17:%.*]], [[VAL_18:%.*]]) = ([[VAL_6]], [[VAL_6]]) to ([[VAL_9]], [[VAL_8]]) step ([[VAL_7]], [[VAL_7]]) { +// CHECK: store +// CHECK: } +// CHECK: } else { +// CHECK: loop.parallel ([[VAL_22:%.*]], [[VAL_23:%.*]]) = ([[VAL_6]], [[VAL_6]]) to ([[VAL_12]], [[VAL_13]]) step ([[VAL_7]], [[VAL_7]]) { +// CHECK: store +// CHECK: } +// CHECK: } +// CHECK: return +// CHECK: } +// CHECK: } + +