diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h @@ -12,10 +12,14 @@ #include "mlir/Pass/Pass.h" namespace mlir { +class DataFlowSolver; + namespace arith { #define GEN_PASS_DECL #include "mlir/Dialect/Arith/Transforms/Passes.h.inc" +#define GEN_PASS_DECL_ARITHINTRANGEOPTS +#include "mlir/Dialect/Arith/Transforms/Passes.h.inc" class WideIntEmulationConverter; @@ -44,6 +48,13 @@ /// equivalent. std::unique_ptr createArithUnsignedWhenEquivalentPass(); +/// Add patterns for int range based optimizations. +void populateIntRangeOptimizationsPatterns(RewritePatternSet &patterns, + DataFlowSolver &solver); + +/// Create a pass which do optimizations based on integer range analysis. +std::unique_ptr createIntRangeOptimizationsPass(); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td @@ -49,6 +49,15 @@ let constructor = "mlir::arith::createArithUnsignedWhenEquivalentPass()"; } +def ArithIntRangeOpts : Pass<"int-range-optimizations"> { + let summary = "Do optimizations based on integer range analysis"; + let description = [{ + This pass runs integer range analysis and apllies optimizations based on its + results. e.g. replace arith.cmpi with const if it can be inferred from + args ranges. + }]; +} + def ArithEmulateWideInt : Pass<"arith-emulate-wide-int"> { let summary = "Emulate 2*N-bit integer operations using N-bit operations"; let description = [{ diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ Bufferize.cpp EmulateWideInt.cpp ExpandOps.cpp + IntRangeOptimizations.cpp UnsignedWhenEquivalent.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp @@ -0,0 +1,183 @@ +//===- IntRangeOptimizations.cpp - Optimizations based on integer ranges --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/Transforms/Passes.h" + +#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" +#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir::arith { +#define GEN_PASS_DEF_ARITHINTRANGEOPTS +#include "mlir/Dialect/Arith/Transforms/Passes.h.inc" +} // namespace mlir::arith + +using namespace mlir; +using namespace mlir::arith; +using namespace mlir::dataflow; + +/// Returns true if 2 integer ranges have intersection. +static bool intersects(ConstantIntRanges lhs, ConstantIntRanges rhs) { + return !((lhs.smax().slt(rhs.smin()) || lhs.smin().sgt(rhs.smax())) && + (lhs.umax().ult(rhs.umin()) || lhs.umin().ugt(rhs.umax()))); +} + +static FailureOr handleEq(ConstantIntRanges lhs, ConstantIntRanges rhs) { + if (!intersects(lhs, rhs)) + return false; + + return failure(); +} + +static FailureOr handleNe(ConstantIntRanges lhs, ConstantIntRanges rhs) { + if (!intersects(lhs, rhs)) + return true; + + return failure(); +} + +static FailureOr handleSlt(ConstantIntRanges lhs, ConstantIntRanges rhs) { + if (lhs.smax().slt(rhs.smin())) + return true; + + if (lhs.smin().sge(rhs.smax())) + return false; + + return failure(); +} + +static FailureOr handleSle(ConstantIntRanges lhs, ConstantIntRanges rhs) { + if (lhs.smax().sle(rhs.smin())) + return true; + + if (lhs.smin().sgt(rhs.smax())) + return false; + + return failure(); +} + +static FailureOr handleSgt(ConstantIntRanges lhs, ConstantIntRanges rhs) { + return handleSlt(rhs, lhs); +} + +static FailureOr handleSge(ConstantIntRanges lhs, ConstantIntRanges rhs) { + return handleSle(rhs, lhs); +} + +static FailureOr handleUlt(ConstantIntRanges lhs, ConstantIntRanges rhs) { + if (lhs.umax().ult(rhs.umin())) + return true; + + if (lhs.umin().uge(rhs.umax())) + return false; + + return failure(); +} + +static FailureOr handleUle(ConstantIntRanges lhs, ConstantIntRanges rhs) { + if (lhs.umax().ule(rhs.umin())) + return true; + + if (lhs.umin().ugt(rhs.umax())) + return false; + + return failure(); +} + +static FailureOr handleUgt(ConstantIntRanges lhs, ConstantIntRanges rhs) { + return handleUlt(rhs, lhs); +} + +static FailureOr handleUge(ConstantIntRanges lhs, ConstantIntRanges rhs) { + return handleUle(rhs, lhs); +} + +namespace { +struct ConvertCmpOp : public OpRewritePattern { + + ConvertCmpOp(MLIRContext *context, DataFlowSolver &s) + : OpRewritePattern(context), solver(s) {} + + LogicalResult matchAndRewrite(arith::CmpIOp op, + PatternRewriter &rewriter) const override { + auto *lhsResult = + solver.lookupState(op.getLhs()); + if (!lhsResult || lhsResult->getValue().isUninitialized()) + return failure(); + + auto *rhsResult = + solver.lookupState(op.getRhs()); + if (!rhsResult || rhsResult->getValue().isUninitialized()) + return failure(); + + using HandlerFunc = + FailureOr (*)(ConstantIntRanges, ConstantIntRanges); + std::array + handlers{}; + using Pred = arith::CmpIPredicate; + handlers[static_cast(Pred::eq)] = &handleEq; + handlers[static_cast(Pred::ne)] = &handleNe; + handlers[static_cast(Pred::slt)] = &handleSlt; + handlers[static_cast(Pred::sle)] = &handleSle; + handlers[static_cast(Pred::sgt)] = &handleSgt; + handlers[static_cast(Pred::sge)] = &handleSge; + handlers[static_cast(Pred::ult)] = &handleUlt; + handlers[static_cast(Pred::ule)] = &handleUle; + handlers[static_cast(Pred::ugt)] = &handleUgt; + handlers[static_cast(Pred::uge)] = &handleUge; + + HandlerFunc handler = handlers[static_cast(op.getPredicate())]; + if (!handler) + return failure(); + + ConstantIntRanges lhsValue = lhsResult->getValue().getValue(); + ConstantIntRanges rhsValue = rhsResult->getValue().getValue(); + FailureOr result = handler(lhsValue, rhsValue); + + if (failed(result)) + return failure(); + + rewriter.replaceOpWithNewOp( + op, static_cast(*result), /*width*/ 1); + return success(); + } + +private: + DataFlowSolver &solver; +}; + +struct IntRangeOptimizationsPass + : public arith::impl::ArithIntRangeOptsBase { + + void runOnOperation() override { + Operation *op = getOperation(); + MLIRContext *ctx = op->getContext(); + DataFlowSolver solver; + solver.load(); + solver.load(); + if (failed(solver.initializeAndRun(op))) + return signalPassFailure(); + + RewritePatternSet patterns(ctx); + populateIntRangeOptimizationsPatterns(patterns, solver); + + if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) + signalPassFailure(); + } +}; +} // namespace + +void mlir::arith::populateIntRangeOptimizationsPatterns( + RewritePatternSet &patterns, DataFlowSolver &solver) { + patterns.add(patterns.getContext(), solver); +} + +std::unique_ptr mlir::arith::createIntRangeOptimizationsPass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/Arith/int-range-opts.mlir b/mlir/test/Dialect/Arith/int-range-opts.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Arith/int-range-opts.mlir @@ -0,0 +1,73 @@ +// RUN: mlir-opt -int-range-optimizations --split-input-file %s | FileCheck %s + +// CHECK-LABEL: func @test +// CHECK: %[[C:.*]] = arith.constant false +// CHECK: return %[[C]] +func.func @test() -> i1 { + %cst1 = arith.constant -1 : index + %0 = test.with_bounds { umin = 0 : index, umax = 0x7fffffffffffffff : index, smin = 0 : index, smax = 0x7fffffffffffffff : index } + %1 = arith.cmpi eq, %0, %cst1 : index + return %1: i1 +} + +// ----- + +// CHECK-LABEL: func @test +// CHECK: %[[C:.*]] = arith.constant true +// CHECK: return %[[C]] +func.func @test() -> i1 { + %cst1 = arith.constant -1 : index + %0 = test.with_bounds { umin = 0 : index, umax = 0x7fffffffffffffff : index, smin = 0 : index, smax = 0x7fffffffffffffff : index } + %1 = arith.cmpi ne, %0, %cst1 : index + return %1: i1 +} + +// ----- + + +// CHECK-LABEL: func @test +// CHECK: %[[C:.*]] = arith.constant true +// CHECK: return %[[C]] +func.func @test() -> i1 { + %cst = arith.constant 0 : index + %0 = test.with_bounds { umin = 0 : index, umax = 0x7fffffffffffffff : index, smin = 0 : index, smax = 0x7fffffffffffffff : index } + %1 = arith.cmpi sge, %0, %cst : index + return %1: i1 +} + +// ----- + +// CHECK-LABEL: func @test +// CHECK: %[[C:.*]] = arith.constant false +// CHECK: return %[[C]] +func.func @test() -> i1 { + %cst = arith.constant 0 : index + %0 = test.with_bounds { umin = 0 : index, umax = 0x7fffffffffffffff : index, smin = 0 : index, smax = 0x7fffffffffffffff : index } + %1 = arith.cmpi slt, %0, %cst : index + return %1: i1 +} + +// ----- + + +// CHECK-LABEL: func @test +// CHECK: %[[C:.*]] = arith.constant true +// CHECK: return %[[C]] +func.func @test() -> i1 { + %cst1 = arith.constant -1 : index + %0 = test.with_bounds { umin = 0 : index, umax = 0x7fffffffffffffff : index, smin = 0 : index, smax = 0x7fffffffffffffff : index } + %1 = arith.cmpi sgt, %0, %cst1 : index + return %1: i1 +} + +// ----- + +// CHECK-LABEL: func @test +// CHECK: %[[C:.*]] = arith.constant false +// CHECK: return %[[C]] +func.func @test() -> i1 { + %cst1 = arith.constant -1 : index + %0 = test.with_bounds { umin = 0 : index, umax = 0x7fffffffffffffff : index, smin = 0 : index, smax = 0x7fffffffffffffff : index } + %1 = arith.cmpi sle, %0, %cst1 : index + return %1: i1 +}