diff --git a/mlir/include/mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h @@ -0,0 +1,20 @@ +//===- ValueBoundsOpInterfaceImpl.h - Impl. of ValueBoundsOpInterface -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ARITH_IR_VALUEBOUNDSOPINTERFACEIMPL_H +#define MLIR_DIALECT_ARITH_IR_VALUEBOUNDSOPINTERFACEIMPL_H + +namespace mlir { +class DialectRegistry; + +namespace arith { +void registerValueBoundsOpInterfaceExternalModels(DialectRegistry ®istry); +} // namespace arith +} // namespace mlir + +#endif // MLIR_DIALECT_ARITH_IR_VALUEBOUNDSOPINTERFACEIMPL_H diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -20,6 +20,7 @@ #include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h" #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h" #include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h" @@ -132,6 +133,7 @@ // Register all external models. affine::registerValueBoundsOpInterfaceExternalModels(registry); arith::registerBufferizableOpInterfaceExternalModels(registry); + arith::registerValueBoundsOpInterfaceExternalModels(registry); bufferization::func_ext::registerBufferizableOpInterfaceExternalModels( registry); linalg::registerBufferizableOpInterfaceExternalModels(registry); diff --git a/mlir/lib/Dialect/Arith/IR/CMakeLists.txt b/mlir/lib/Dialect/Arith/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Arith/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Arith/IR/CMakeLists.txt @@ -1,3 +1,10 @@ +set(LLVM_OPTIONAL_SOURCES + ArithOps.cpp + ArithDialect.cpp + InferIntRangeInterfaceImpls.cpp + ValueBoundsOpInterfaceImpl.cpp + ) + set(LLVM_TARGET_DEFINITIONS ArithCanonicalization.td) mlir_tablegen(ArithCanonicalization.inc -gen-rewriters) add_public_tablegen_target(MLIRArithCanonicalizationIncGen) @@ -6,6 +13,7 @@ ArithOps.cpp ArithDialect.cpp InferIntRangeInterfaceImpls.cpp + ValueBoundsOpInterfaceImpl.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Arith @@ -20,4 +28,17 @@ MLIRInferIntRangeInterface MLIRInferTypeOpInterface MLIRIR + MLIRValueBoundsOpInterface + ) + +add_mlir_dialect_library(MLIRArithValueBoundsOpInterfaceImpl + ValueBoundsOpInterfaceImpl.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Arith + + LINK_LIBS PUBLIC + MLIRArithDialect + MLIRIR + MLIRValueBoundsOpInterface ) diff --git a/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp @@ -0,0 +1,85 @@ +//===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Interfaces/ValueBoundsOpInterface.h" + +using namespace mlir; +using presburger::BoundType; + +namespace mlir { +namespace arith { +namespace { + +struct AddIOpInterface + : public ValueBoundsOpInterface::ExternalModel { + void populateBoundsForIndexValue(Operation *op, Value value, + ValueBoundsConstraintSet &cstr) const { + auto addIOp = cast(op); + assert(value == addIOp.getResult() && "invalid value"); + + cstr.addBound(BoundType::EQ, value, + cstr.getExpr(addIOp.getLhs()) + + cstr.getExpr(addIOp.getRhs())); + } +}; + +struct ConstantOpInterface + : public ValueBoundsOpInterface::ExternalModel { + void populateBoundsForIndexValue(Operation *op, Value value, + ValueBoundsConstraintSet &cstr) const { + auto constantOp = cast(op); + assert(value == constantOp.getResult() && "invalid value"); + + if (auto attr = constantOp.getValue().dyn_cast()) + cstr.addBound(BoundType::EQ, value, cstr.getExpr(attr.getInt())); + } +}; + +struct SubIOpInterface + : public ValueBoundsOpInterface::ExternalModel { + void populateBoundsForIndexValue(Operation *op, Value value, + ValueBoundsConstraintSet &cstr) const { + auto subIOp = cast(op); + assert(value == subIOp.getResult() && "invalid value"); + + cstr.addBound(BoundType::EQ, value, + cstr.getExpr(subIOp.getLhs()) - + cstr.getExpr(subIOp.getRhs())); + } +}; + +struct MulIOpInterface + : public ValueBoundsOpInterface::ExternalModel { + void populateBoundsForIndexValue(Operation *op, Value value, + ValueBoundsConstraintSet &cstr) const { + auto mulIOp = cast(op); + assert(value == mulIOp.getResult() && "invalid value"); + + cstr.addBound(BoundType::EQ, value, + cstr.getExpr(mulIOp.getLhs()) * + cstr.getExpr(mulIOp.getRhs())); + } +}; + +} // namespace +} // namespace arith +} // namespace mlir + +void mlir::arith::registerValueBoundsOpInterfaceExternalModels( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, arith::ArithDialect *dialect) { + arith::AddIOp::attachInterface(*ctx); + arith::ConstantOp::attachInterface(*ctx); + arith::SubIOp::attachInterface(*ctx); + arith::MulIOp::attachInterface(*ctx); + }); +} diff --git a/mlir/test/Dialect/Arith/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Arith/value-bounds-op-interface-impl.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Arith/value-bounds-op-interface-impl.mlir @@ -0,0 +1,54 @@ +// RUN: mlir-opt %s -test-affine-reify-value-bounds -verify-diagnostics \ +// RUN: -split-input-file | FileCheck %s + +// CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 + 5)> +// CHECK-LABEL: func @arith_addi( +// CHECK-SAME: %[[a:.*]]: index +// CHECK: %[[apply:.*]] = affine.apply #[[$map]]()[%[[a]]] +// CHECK: return %[[apply]] +func.func @arith_addi(%a: index) -> index { + %0 = arith.constant 5 : index + %1 = arith.addi %0, %a : index + %2 = "test.reify_bound"(%1) : (index) -> (index) + return %2 : index +} + +// ----- + +// CHECK: #[[$map:.*]] = affine_map<()[s0] -> (-s0 + 5)> +// CHECK-LABEL: func @arith_subi( +// CHECK-SAME: %[[a:.*]]: index +// CHECK: %[[apply:.*]] = affine.apply #[[$map]]()[%[[a]]] +// CHECK: return %[[apply]] +func.func @arith_subi(%a: index) -> index { + %0 = arith.constant 5 : index + %1 = arith.subi %0, %a : index + %2 = "test.reify_bound"(%1) : (index) -> (index) + return %2 : index +} + +// ----- + +// CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 * 5)> +// CHECK-LABEL: func @arith_muli( +// CHECK-SAME: %[[a:.*]]: index +// CHECK: %[[apply:.*]] = affine.apply #[[$map]]()[%[[a]]] +// CHECK: return %[[apply]] +func.func @arith_muli(%a: index) -> index { + %0 = arith.constant 5 : index + %1 = arith.muli %0, %a : index + %2 = "test.reify_bound"(%1) : (index) -> (index) + return %2 : index +} + +// ----- + +// CHECK-LABEL: func @arith_const() +// CHECK: %[[c5:.*]] = arith.constant 5 : index +// CHECK: %[[c5:.*]] = arith.constant 5 : index +// CHECK: return %[[c5]] +func.func @arith_const() -> index { + %c5 = arith.constant 5 : index + %0 = "test.reify_bound"(%c5) : (index) -> (index) + return %0 : index +} diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp --- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp +++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp @@ -31,6 +31,10 @@ TestReifyValueBounds() = default; TestReifyValueBounds(const TestReifyValueBounds &pass) : PassWrapper(pass){}; + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnOperation() override; private: diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -6989,6 +6989,7 @@ ":ArithToLLVM", ":ArithToSPIRV", ":ArithTransforms", + ":ArithValueBoundsOpInterfaceImpl", ":ArmNeonDialect", ":ArmSVEDialect", ":ArmSVETransforms", @@ -8641,6 +8642,18 @@ ], ) +cc_library( + name = "ArithValueBoundsOpInterfaceImpl", + srcs = ["lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp"], + hdrs = ["include/mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h"], + includes = ["include"], + deps = [ + ":ArithDialect", + ":IR", + ":ValueBoundsOpInterface", + ], +) + cc_library( name = "TilingInterface", srcs = ["lib/Interfaces/TilingInterface.cpp"], @@ -9707,12 +9720,11 @@ cc_library( name = "ArithDialect", - srcs = glob( - [ - "lib/Dialect/Arith/IR/*.cpp", - "lib/Dialect/Arith/IR/*.h", - ], - ), + srcs = [ + "lib/Dialect/Arith/IR/ArithDialect.cpp", + "lib/Dialect/Arith/IR/ArithOps.cpp", + "lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp", + ], hdrs = [ "include/mlir/Dialect/Arith/IR/Arith.h", "include/mlir/Transforms/InliningUtils.h",