diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -96,6 +96,18 @@ /// Returns true if the result of this operation is a symbol for all its /// uses in `region`. bool isValidSymbol(Region *region); + + /// Returns all dimension operands. + ValueRange getDimOperands() { + return OperandRange{getOperands().begin(), + getOperands().begin() + getMap().getNumDims()}; + } + + /// Returns all symbol operands. + ValueRange getSymbolOperands() { + return OperandRange{getOperands().begin() + getMap().getNumDims(), + getOperands().end()}; + } }]; let hasCanonicalizer = 1; diff --git a/mlir/include/mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h @@ -0,0 +1,20 @@ +//===- ValueBoundsOpInterfaceImpl.h - Impl. of ValueBoundsOpInterface -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_AFFINE_IR_VALUEBOUNDSOPINTERFACEIMPL_H +#define MLIR_DIALECT_AFFINE_IR_VALUEBOUNDSOPINTERFACEIMPL_H + +namespace mlir { +class DialectRegistry; + +namespace affine { +void registerValueBoundsOpInterfaceExternalModels(DialectRegistry ®istry); +} // namespace affine +} // namespace mlir + +#endif // MLIR_DIALECT_AFFINE_IR_VALUEBOUNDSOPINTERFACEIMPL_H diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -17,6 +17,7 @@ #include "mlir/Dialect/AMDGPU/AMDGPUDialect.h" #include "mlir/Dialect/AMX/AMXDialect.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h" #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h" @@ -128,6 +129,7 @@ vector::registerTransformDialectExtension(registry); // Register all external models. + affine::registerValueBoundsOpInterfaceExternalModels(registry); arith::registerBufferizableOpInterfaceExternalModels(registry); bufferization::func_ext::registerBufferizableOpInterfaceExternalModels( registry); diff --git a/mlir/lib/Dialect/Affine/IR/CMakeLists.txt b/mlir/lib/Dialect/Affine/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Affine/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Affine/IR/CMakeLists.txt @@ -2,6 +2,7 @@ AffineMemoryOpInterfaces.cpp AffineOps.cpp AffineValueMap.cpp + ValueBoundsOpInterfaceImpl.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Affine @@ -18,4 +19,5 @@ MLIRMemRefDialect MLIRShapedOpInterfaces MLIRSideEffectInterfaces + MLIRValueBoundsOpInterface ) diff --git a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp @@ -0,0 +1,50 @@ +//===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Interfaces/ValueBoundsOpInterface.h" + +using namespace mlir; +using presburger::BoundType; + +namespace mlir { +namespace { + +struct AffineApplyOpInterface + : public ValueBoundsOpInterface::ExternalModel { + void populateBoundsForIndexValue(Operation *op, Value value, + ValueBoundsConstraintSet &cstr) const { + auto applyOp = cast(op); + assert(value == applyOp.getResult() && "invalid value"); + assert(applyOp.getAffineMap().getNumResults() == 1 && + "expected single result"); + + // Align affine map result with dims/symbols in the constraint set. + AffineExpr expr = applyOp.getAffineMap().getResult(0); + SmallVector dimReplacements = llvm::to_vector(llvm::map_range( + applyOp.getDimOperands(), [&](Value v) { return cstr.getExpr(v); })); + SmallVector symReplacements = llvm::to_vector(llvm::map_range( + applyOp.getSymbolOperands(), [&](Value v) { return cstr.getExpr(v); })); + AffineExpr bound = + expr.replaceDimsAndSymbols(dimReplacements, symReplacements); + cstr.addBound(BoundType::EQ, value, bound); + }; +}; + +} // namespace +} // namespace mlir + +void mlir::affine::registerValueBoundsOpInterfaceExternalModels( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, AffineDialect *dialect) { + AffineApplyOp::attachInterface(*ctx); + }); +} diff --git a/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir @@ -0,0 +1,14 @@ +// RUN: mlir-opt %s -test-affine-reify-value-bounds -verify-diagnostics \ +// RUN: -split-input-file | FileCheck %s + +// CHECK: #[[$map:.*]] = affine_map<()[s0, s1] -> (s0 + s1)> +// CHECK-LABEL: func @affine_apply( +// CHECK-SAME: %[[a:.*]]: index, %[[b:.*]]: index +// CHECK: %[[apply:.*]] = affine.apply #[[$map]]()[%[[a]], %[[b]]] +// CHECK: %[[apply:.*]] = affine.apply #[[$map]]()[%[[a]], %[[b]]] +// CHECL: return %[[apply]] +func.func @affine_apply(%a: index, %b: index) -> index { + %0 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%a, %b] + %1 = "test.reify_bound"(%0) : (index) -> (index) + return %1 : index +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -2548,6 +2548,7 @@ ":ShapedOpInterfaces", ":SideEffectInterfaces", ":Support", + ":ValueBoundsOpInterface", "//llvm:Support", ], )