diff --git a/mlir/include/mlir/Dialect/LoopOps/Passes.h b/mlir/include/mlir/Dialect/LoopOps/Passes.h --- a/mlir/include/mlir/Dialect/LoopOps/Passes.h +++ b/mlir/include/mlir/Dialect/LoopOps/Passes.h @@ -23,6 +23,10 @@ /// Creates a loop fusion pass which fuses parallel loops. std::unique_ptr createParallelLoopFusionPass(); +/// Creates a pass that specializes parallel loop for unrolling and +/// vectorization. +std::unique_ptr createParallelLoopSpecializationPass(); + /// Creates a pass which tiles innermost parallel loops. std::unique_ptr createParallelLoopTilingPass(llvm::ArrayRef tileSize = {}); diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h --- a/mlir/include/mlir/InitAllPasses.h +++ b/mlir/include/mlir/InitAllPasses.h @@ -109,6 +109,7 @@ // LoopOps createParallelLoopFusionPass(); + createParallelLoopSpecializationPass(); createParallelLoopTilingPass(); // QuantOps diff --git a/mlir/lib/Dialect/LoopOps/Transforms/CMakeLists.txt b/mlir/lib/Dialect/LoopOps/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/LoopOps/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/LoopOps/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_llvm_library(MLIRLoopOpsTransforms ParallelLoopFusion.cpp + ParallelLoopSpecialization.cpp ParallelLoopTiling.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopSpecialization.cpp b/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopSpecialization.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopSpecialization.cpp @@ -0,0 +1,76 @@ +//===- ParallelLoopSpecialization.cpp - loop.parallel specializeation -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Specializes parallel loops for easier unrolling and vectorization. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/Dialect/LoopOps/LoopOps.h" +#include "mlir/Dialect/LoopOps/Passes.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; +using loop::ParallelOp; + +/// Rewrite a loop with bounds defined by an affine.min with a constant into 2 +/// loops after checking if the bounds are equal to that constant. This is +/// beneficial if the loop will almost always have the constant bound and that +/// version can be fully unrolled and vectorized. +static void specializeLoopForUnrolling(ParallelOp op) { + SmallVector constantIndices; + constantIndices.reserve(op.upperBound().size()); + for (auto bound : op.upperBound()) { + auto minOp = dyn_cast_or_null(bound.getDefiningOp()); + if (!minOp) + return; + int64_t minConstant = std::numeric_limits::max(); + for (auto expr : minOp.map().getResults()) { + if (auto constantIndex = expr.dyn_cast()) + minConstant = std::min(minConstant, constantIndex.getValue()); + } + if (minConstant == std::numeric_limits::max()) + return; + constantIndices.push_back(minConstant); + } + + OpBuilder b(op); + BlockAndValueMapping map; + Value cond; + for (auto bound : llvm::zip(op.upperBound(), constantIndices)) { + Value constant = b.create(op.getLoc(), std::get<1>(bound)); + Value cmp = b.create(op.getLoc(), CmpIPredicate::eq, + std::get<0>(bound), constant); + cond = cond ? b.create(op.getLoc(), cond, cmp) : cmp; + map.map(std::get<0>(bound), constant); + } + auto ifOp = b.create(op.getLoc(), cond, /*withElseRegion=*/true); + ifOp.getThenBodyBuilder().clone(*op.getOperation(), map); + ifOp.getElseBodyBuilder().clone(*op.getOperation()); + op.erase(); +} + +namespace { +struct ParallelLoopSpecialization + : public FunctionPass { + void runOnFunction() override { + getFunction().walk([](ParallelOp op) { specializeLoopForUnrolling(op); }); + } +}; +} // namespace + +std::unique_ptr mlir::createParallelLoopSpecializationPass() { + return std::make_unique(); +} + +static PassRegistration + pass("parallel-loop-specialization", + "Specialize parallel loops for vectorization."); diff --git a/mlir/test/Dialect/Loops/parallel-loop-specialization.mlir b/mlir/test/Dialect/Loops/parallel-loop-specialization.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Loops/parallel-loop-specialization.mlir @@ -0,0 +1,49 @@ +// RUN: mlir-opt %s -parallel-loop-specialization -split-input-file | FileCheck %s --dump-input-on-failure + +#map0 = affine_map<()[s0, s1] -> (1024, s0 - s1)> +#map1 = affine_map<()[s0, s1] -> (64, s0 - s1)> + +func @parallel_loop(%outer_i0: index, %outer_i1: index, %A: memref, %B: memref, + %C: memref, %result: memref) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %d0 = dim %A, 0 : memref + %d1 = dim %A, 1 : memref + %b0 = affine.min #map0()[%d0, %outer_i0] + %b1 = affine.min #map1()[%d1, %outer_i1] + loop.parallel (%i0, %i1) = (%c0, %c0) to (%b0, %b1) step (%c1, %c1) { + %B_elem = load %B[%i0, %i1] : memref + %C_elem = load %C[%i0, %i1] : memref + %sum_elem = addf %B_elem, %C_elem : f32 + store %sum_elem, %result[%i0, %i1] : memref + } + return +} + +// CHECK-LABEL: func @parallel_loop( +// CHECK-SAME: [[VAL_0:%.*]]: index, [[VAL_1:%.*]]: index, [[VAL_2:%.*]]: memref, [[VAL_3:%.*]]: memref, [[VAL_4:%.*]]: memref, [[VAL_5:%.*]]: memref) { +// CHECK: [[VAL_6:%.*]] = constant 0 : index +// CHECK: [[VAL_7:%.*]] = constant 1 : index +// CHECK: [[VAL_8:%.*]] = dim [[VAL_2]], 0 : memref +// CHECK: [[VAL_9:%.*]] = dim [[VAL_2]], 1 : memref +// CHECK: [[VAL_10:%.*]] = affine.min #map0(){{\[}}[[VAL_8]], [[VAL_0]]] +// CHECK: [[VAL_11:%.*]] = affine.min #map1(){{\[}}[[VAL_9]], [[VAL_1]]] +// CHECK: [[VAL_12:%.*]] = constant 1024 : index +// CHECK: [[VAL_13:%.*]] = cmpi "eq", [[VAL_10]], [[VAL_12]] : index +// CHECK: [[VAL_14:%.*]] = constant 64 : index +// CHECK: [[VAL_15:%.*]] = cmpi "eq", [[VAL_11]], [[VAL_14]] : index +// CHECK: [[VAL_16:%.*]] = and [[VAL_13]], [[VAL_15]] : i1 +// CHECK: loop.if [[VAL_16]] { +// CHECK: loop.parallel ([[VAL_17:%.*]], [[VAL_18:%.*]]) = ([[VAL_6]], [[VAL_6]]) to ([[VAL_12]], [[VAL_14]]) step ([[VAL_7]], [[VAL_7]]) { +// CHECK: store +// CHECK: } +// CHECK: } else { +// CHECK: loop.parallel ([[VAL_22:%.*]], [[VAL_23:%.*]]) = ([[VAL_6]], [[VAL_6]]) to ([[VAL_10]], [[VAL_11]]) step ([[VAL_7]], [[VAL_7]]) { +// CHECK: store +// CHECK: } +// CHECK: } +// CHECK: return +// CHECK: } + + +