diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -157,6 +157,9 @@ This implementation is based on the algorithm described by Wegman and Zadeck in [“Constant Propagation with Conditional Branches”](https://dl.acm.org/doi/10.1145/103135.103136) (1991). + + In addition, this pass uses the IntRangeAnalysis to determine which + scalar integer values are constants. }]; let constructor = "mlir::createSCCPPass()"; } diff --git a/mlir/lib/Transforms/SCCP.cpp b/mlir/lib/Transforms/SCCP.cpp --- a/mlir/lib/Transforms/SCCP.cpp +++ b/mlir/lib/Transforms/SCCP.cpp @@ -16,6 +16,7 @@ #include "PassDetail.h" #include "mlir/Analysis/DataFlowAnalysis.h" +#include "mlir/Analysis/IntRangeAnalysis.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Dialect.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" @@ -164,6 +165,34 @@ // SCCP Rewrites //===----------------------------------------------------------------------===// +/// Replace the given value with a constant if the corresponding lattice +/// represents a constant. Returns success if the value was replaced, failure +/// otherwise. +static LogicalResult replaceWithConstant(IntRangeAnalysis &analysis, + OpBuilder &b, OperationFolder &folder, + Value value) { + Optional maybeInferredRange = analysis.getResult(value); + if (!maybeInferredRange) + return failure(); + const ConstantIntRanges &inferredRange = maybeInferredRange.getValue(); + Optional maybeConstValue = inferredRange.getConstantValue(); + if (!maybeConstValue.hasValue()) + return failure(); + + Operation *maybeDefiningOp = value.getDefiningOp(); + Dialect *valueDialect = + maybeDefiningOp ? maybeDefiningOp->getDialect() + : value.getParentRegion()->getParentOp()->getDialect(); + Attribute constAttr = b.getIntegerAttr(value.getType(), *maybeConstValue); + Value constant = folder.getOrCreateConstant(b, valueDialect, constAttr, + value.getType(), value.getLoc()); + if (!constant) + return failure(); + + value.replaceAllUsesWith(constant); + return success(); +} + /// Replace the given value with a constant if the corresponding lattice /// represents a constant. Returns success if the value was replaced, failure /// otherwise. @@ -192,7 +221,8 @@ /// Rewrite the given regions using the computing analysis. This replaces the /// uses of all values that have been computed to be constant, and erases as /// many newly dead operations. -static void rewrite(SCCPAnalysis &analysis, MLIRContext *context, +static void rewrite(IntRangeAnalysis &rangeAnalysis, SCCPAnalysis &sccpAnalysis, + MLIRContext *context, MutableArrayRef initialRegions) { SmallVector worklist; auto addToWorklist = [&](MutableArrayRef regions) { @@ -215,8 +245,10 @@ // Replace any result with constants. bool replacedAll = op.getNumResults() != 0; for (Value res : op.getResults()) - replacedAll &= - succeeded(replaceWithConstant(analysis, builder, folder, res)); + replacedAll &= (succeeded(replaceWithConstant(rangeAnalysis, builder, + folder, res)) || + succeeded(replaceWithConstant(sccpAnalysis, builder, + folder, res))); // If all of the results of the operation were replaced, try to erase // the operation completely. @@ -233,7 +265,8 @@ // Replace any block arguments with constants. builder.setInsertionPointToStart(block); for (BlockArgument arg : block->getArguments()) - (void)replaceWithConstant(analysis, builder, folder, arg); + if (failed(replaceWithConstant(rangeAnalysis, builder, folder, arg))) + (void)replaceWithConstant(sccpAnalysis, builder, folder, arg); } } @@ -250,9 +283,10 @@ void SCCP::runOnOperation() { Operation *op = getOperation(); + IntRangeAnalysis rangeAnalysis(op); SCCPAnalysis analysis(op->getContext()); analysis.run(op); - rewrite(analysis, op->getContext(), op->getRegions()); + rewrite(rangeAnalysis, analysis, op->getContext(), op->getRegions()); } std::unique_ptr mlir::createSCCPPass() { diff --git a/mlir/test/Dialect/Arithmetic/int-range-interface.mlir b/mlir/test/Dialect/Arithmetic/int-range-interface.mlir --- a/mlir/test/Dialect/Arithmetic/int-range-interface.mlir +++ b/mlir/test/Dialect/Arithmetic/int-range-interface.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -test-int-range-inference -canonicalize %s | FileCheck %s +// RUN: mlir-opt -sccp -canonicalize %s | FileCheck %s // CHECK-LABEL: func @add_min_max // CHECK: %[[c3:.*]] = arith.constant 3 : index diff --git a/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir --- a/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir +++ b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -test-int-range-inference %s | FileCheck %s +// RUN: mlir-opt -sccp %s | FileCheck %s // CHECK-LABEL: func @constant // CHECK: %[[cst:.*]] = "test.constant"() {value = 3 : index} diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -3,7 +3,6 @@ TestConstantFold.cpp TestControlFlowSink.cpp TestInlining.cpp - TestIntRangeInference.cpp EXCLUDE_FROM_LIBMLIR diff --git a/mlir/test/lib/Transforms/TestIntRangeInference.cpp b/mlir/test/lib/Transforms/TestIntRangeInference.cpp deleted file mode 100644 --- a/mlir/test/lib/Transforms/TestIntRangeInference.cpp +++ /dev/null @@ -1,115 +0,0 @@ -//===- TestIntRangeInference.cpp - Create consts from range inference ---===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// TODO: This pass is needed to test integer range inference until that -// functionality has been integrated into SCCP. -//===----------------------------------------------------------------------===// - -#include "mlir/Analysis/IntRangeAnalysis.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassRegistry.h" -#include "mlir/Support/TypeID.h" -#include "mlir/Transforms/FoldUtils.h" - -using namespace mlir; - -/// Patterned after SCCP -static LogicalResult replaceWithConstant(IntRangeAnalysis &analysis, - OpBuilder &b, OperationFolder &folder, - Value value) { - Optional maybeInferredRange = analysis.getResult(value); - if (!maybeInferredRange) - return failure(); - const ConstantIntRanges &inferredRange = maybeInferredRange.getValue(); - Optional maybeConstValue = inferredRange.getConstantValue(); - if (!maybeConstValue.hasValue()) - return failure(); - - Operation *maybeDefiningOp = value.getDefiningOp(); - Dialect *valueDialect = - maybeDefiningOp ? maybeDefiningOp->getDialect() - : value.getParentRegion()->getParentOp()->getDialect(); - Attribute constAttr = b.getIntegerAttr(value.getType(), *maybeConstValue); - Value constant = folder.getOrCreateConstant(b, valueDialect, constAttr, - value.getType(), value.getLoc()); - if (!constant) - return failure(); - - value.replaceAllUsesWith(constant); - return success(); -} - -static void rewrite(IntRangeAnalysis &analysis, MLIRContext *context, - MutableArrayRef initialRegions) { - SmallVector worklist; - auto addToWorklist = [&](MutableArrayRef regions) { - for (Region ®ion : regions) - for (Block &block : llvm::reverse(region)) - worklist.push_back(&block); - }; - - OpBuilder builder(context); - OperationFolder folder(context); - - addToWorklist(initialRegions); - while (!worklist.empty()) { - Block *block = worklist.pop_back_val(); - - for (Operation &op : llvm::make_early_inc_range(*block)) { - builder.setInsertionPoint(&op); - - // Replace any result with constants. - bool replacedAll = op.getNumResults() != 0; - for (Value res : op.getResults()) - replacedAll &= - succeeded(replaceWithConstant(analysis, builder, folder, res)); - - // If all of the results of the operation were replaced, try to erase - // the operation completely. - if (replacedAll && wouldOpBeTriviallyDead(&op)) { - assert(op.use_empty() && "expected all uses to be replaced"); - op.erase(); - continue; - } - - // Add any the regions of this operation to the worklist. - addToWorklist(op.getRegions()); - } - - // Replace any block arguments with constants. - builder.setInsertionPointToStart(block); - for (BlockArgument arg : block->getArguments()) - (void)replaceWithConstant(analysis, builder, folder, arg); - } -} - -namespace { -struct TestIntRangeInference - : PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestIntRangeInference) - - StringRef getArgument() const final { return "test-int-range-inference"; } - StringRef getDescription() const final { - return "Test integer range inference analysis"; - } - - void runOnOperation() override { - Operation *op = getOperation(); - IntRangeAnalysis analysis(op); - rewrite(analysis, op->getContext(), op->getRegions()); - } -}; -} // end anonymous namespace - -namespace mlir { -namespace test { -void registerTestIntRangeInference() { - PassRegistration(); -} -} // end namespace test -} // end namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -79,7 +79,6 @@ void registerTestExpandMathPass(); void registerTestComposeSubView(); void registerTestMultiBuffering(); -void registerTestIntRangeInference(); void registerTestIRVisitorsPass(); void registerTestGenericIRVisitorsPass(); void registerTestGenericIRVisitorsInterruptPass(); @@ -176,7 +175,6 @@ mlir::test::registerTestExpandMathPass(); mlir::test::registerTestComposeSubView(); mlir::test::registerTestMultiBuffering(); - mlir::test::registerTestIntRangeInference(); mlir::test::registerTestIRVisitorsPass(); mlir::test::registerTestGenericIRVisitorsPass(); mlir::test::registerTestInterfaces();