diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Arith/Transforms/Transforms.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Arith/Transforms/Transforms.h @@ -0,0 +1,64 @@ +//===- Transforms.h - Arith Transforms --------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ARITH_TRANSFORMS_TRANSFORMS_H +#define MLIR_DIALECT_ARITH_TRANSFORMS_TRANSFORMS_H + +#include "mlir/Interfaces/ValueBoundsOpInterface.h" +#include "mlir/Support/LogicalResult.h" + +namespace mlir { +class Location; +class OpBuilder; +class OpFoldResult; +class Value; + +namespace presburger { +enum class BoundType; +} // namespace presburger + +namespace arith { + +/// Reify a bound for the given index-typed value in terms of SSA values for +/// which `stopCondition` is met. If no stop condition is specified, reify in +/// terms of the operands of the owner op. +/// +/// By default, lower/equal bounds are closed and upper bounds are open. If +/// `closedUB` is set to "true", upper bounds are also closed. +/// +/// Example: +/// %0 = arith.addi %a, %b : index +/// %1 = arith.addi %0, %c : index +/// +/// * If `stopCondition` evaluates to "true" for %0 and %c, "%0 + %c" is an EQ +/// bound for %1. +/// * If `stopCondition` evaluates to "true" for %a, %b and %c, "%a + %b + %c" +/// is an EQ bound for %1. +/// * Otherwise, if the owners of %a, %b or %c do not implement the +/// ValueBoundsOpInterface, no bound can be computed. +FailureOr reifyIndexValueBound( + OpBuilder &b, Location loc, presburger::BoundType type, Value value, + ValueBoundsConstraintSet::StopConditionFn stopCondition = nullptr, + bool closedUB = false); + +/// Reify a bound for the specified dimension of the given shaped value in terms +/// of SSA values for which `stopCondition` is met. If no stop condition is +/// specified, reify in terms of the operands of the owner op. +/// +/// By default, lower/equal bounds are closed and upper bounds are open. If +/// `closedUB` is set to "true", upper bounds are also closed. +FailureOr reifyShapedValueDimBound( + OpBuilder &b, Location loc, presburger::BoundType type, Value value, + int64_t dim, + ValueBoundsConstraintSet::StopConditionFn stopCondition = nullptr, + bool closedUB = false); + +} // namespace arith +} // namespace mlir + +#endif // MLIR_DIALECT_ARITH_TRANSFORMS_TRANSFORMS_H diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt @@ -4,6 +4,7 @@ EmulateWideInt.cpp ExpandOps.cpp IntRangeOptimizations.cpp + ReifyValueBounds.cpp UnsignedWhenEquivalent.cpp ADDITIONAL_HEADER_DIRS @@ -23,7 +24,9 @@ MLIRIR MLIRMemRefDialect MLIRPass + MLIRTensorDialect MLIRTransforms MLIRTransformUtils + MLIRValueBoundsOpInterface MLIRVectorDialect ) diff --git a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp @@ -0,0 +1,143 @@ +//===- ReifyValueBounds.cpp --- Reify value bounds with arith ops -------*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/Transforms/Transforms.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Interfaces/ValueBoundsOpInterface.h" + +using namespace mlir; +using namespace mlir::arith; + +/// Build Arith IR for the given affine map and its operands. +static Value buildArithValue(OpBuilder &b, Location loc, AffineMap map, + ValueRange operands) { + assert(map.getNumResults() == 1 && "multiple results not supported yet"); + std::function buildExpr = [&](AffineExpr e) -> Value { + switch (e.getKind()) { + case AffineExprKind::Constant: + return b.create(loc, + e.cast().getValue()); + case AffineExprKind::DimId: + return operands[e.cast().getPosition()]; + case AffineExprKind::SymbolId: + return operands[e.cast().getPosition() + + map.getNumDims()]; + case AffineExprKind::Add: { + auto binaryExpr = e.cast(); + return b.create(loc, buildExpr(binaryExpr.getLHS()), + buildExpr(binaryExpr.getRHS())); + } + case AffineExprKind::Mul: { + auto binaryExpr = e.cast(); + return b.create(loc, buildExpr(binaryExpr.getLHS()), + buildExpr(binaryExpr.getRHS())); + } + case AffineExprKind::FloorDiv: { + auto binaryExpr = e.cast(); + return b.create(loc, buildExpr(binaryExpr.getLHS()), + buildExpr(binaryExpr.getRHS())); + } + case AffineExprKind::CeilDiv: { + auto binaryExpr = e.cast(); + return b.create(loc, buildExpr(binaryExpr.getLHS()), + buildExpr(binaryExpr.getRHS())); + } + case AffineExprKind::Mod: { + auto binaryExpr = e.cast(); + return b.create(loc, buildExpr(binaryExpr.getLHS()), + buildExpr(binaryExpr.getRHS())); + } + } + }; + return buildExpr(map.getResult(0)); +} + +static FailureOr +reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type, + Value value, std::optional dim, + ValueBoundsConstraintSet::StopConditionFn stopCondition, + bool closedUB) { + // Compute bound. + AffineMap boundMap; + ValueDimList mapOperands; + if (failed(ValueBoundsConstraintSet::computeBound( + boundMap, mapOperands, type, value, dim, stopCondition, closedUB))) + return failure(); + + // Materialize tensor.dim/memref.dim ops. + SmallVector operands; + for (auto valueDim : mapOperands) { + Value value = valueDim.first; + std::optional dim = valueDim.second; + + if (!dim.has_value()) { + // This is an index-typed value. + assert(value.getType().isIndex() && "expected index type"); + operands.push_back(value); + continue; + } + + assert(value.getType().cast().isDynamicDim(*dim) && + "expected dynamic dim"); + if (value.getType().isa()) { + // A tensor dimension is used: generate a tensor.dim. + operands.push_back(b.create(loc, value, *dim)); + } else if (value.getType().isa()) { + // A memref dimension is used: generate a memref.dim. + operands.push_back(b.create(loc, value, *dim)); + } else { + llvm_unreachable("cannot generate DimOp for unsupported shaped type"); + } + } + + // Check for special cases where no arith ops are needed. + if (boundMap.isSingleConstant()) { + // Bound is a constant: return an IntegerAttr. + return static_cast( + b.getIndexAttr(boundMap.getSingleConstantResult())); + } + // No arith ops are needed if the bound is a single SSA value. + if (auto expr = boundMap.getResult(0).dyn_cast()) + return static_cast(operands[expr.getPosition()]); + if (auto expr = boundMap.getResult(0).dyn_cast()) + return static_cast( + operands[expr.getPosition() + boundMap.getNumDims()]); + // General case: build Arith ops. + return static_cast(buildArithValue(b, loc, boundMap, operands)); +} + +FailureOr mlir::arith::reifyShapedValueDimBound( + OpBuilder &b, Location loc, presburger::BoundType type, Value value, + int64_t dim, ValueBoundsConstraintSet::StopConditionFn stopCondition, + bool closedUB) { + auto reifyToOperands = [&](Value v, std::optional d) { + // We are trying to reify a bound for `value` in terms of the owning op's + // operands. Construct a stop condition that evaluates to "true" for any SSA + // value expect for `value`. I.e., the bound will be computed in terms of + // any SSA values expect for `value`. The first such values are operands of + // the owner of `value`. + return v != value; + }; + return reifyValueBound(b, loc, type, value, dim, + stopCondition ? stopCondition : reifyToOperands, + closedUB); +} + +FailureOr mlir::arith::reifyIndexValueBound( + OpBuilder &b, Location loc, presburger::BoundType type, Value value, + ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) { + auto reifyToOperands = [&](Value v, std::optional d) { + return v != value; + }; + return reifyValueBound(b, loc, type, value, /*dim=*/std::nullopt, + stopCondition ? stopCondition : reifyToOperands, + closedUB); +} diff --git a/mlir/test/Dialect/Affine/value-bounds-reification.mlir b/mlir/test/Dialect/Affine/value-bounds-reification.mlir --- a/mlir/test/Dialect/Affine/value-bounds-reification.mlir +++ b/mlir/test/Dialect/Affine/value-bounds-reification.mlir @@ -1,10 +1,18 @@ // RUN: mlir-opt %s -test-affine-reify-value-bounds="reify-to-func-args" \ // RUN: -verify-diagnostics -split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-affine-reify-value-bounds="reify-to-func-args use-arith-ops" \ +// RUN: -verify-diagnostics -split-input-file | FileCheck %s --check-prefix=CHECK-ARITH + // CHECK-LABEL: func @reify_through_chain( // CHECK-SAME: %[[sz0:.*]]: index, %[[sz2:.*]]: index // CHECK: %[[c10:.*]] = arith.constant 10 : index // CHECK: return %[[sz0]], %[[c10]], %[[sz2]] + +// CHECK-ARITH-LABEL: func @reify_through_chain( +// CHECK-ARITH-SAME: %[[sz0:.*]]: index, %[[sz2:.*]]: index +// CHECK-ARITH: %[[c10:.*]] = arith.constant 10 : index +// CHECK-ARITH: return %[[sz0]], %[[c10]], %[[sz2]] func.func @reify_through_chain(%sz0: index, %sz2: index) -> (index, index, index) { %c2 = arith.constant 2 : index %0 = tensor.empty(%sz0, %sz2) : tensor diff --git a/mlir/test/Dialect/Arith/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Arith/value-bounds-op-interface-impl.mlir --- a/mlir/test/Dialect/Arith/value-bounds-op-interface-impl.mlir +++ b/mlir/test/Dialect/Arith/value-bounds-op-interface-impl.mlir @@ -1,11 +1,23 @@ // RUN: mlir-opt %s -test-affine-reify-value-bounds -verify-diagnostics \ // RUN: -verify-diagnostics -split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-affine-reify-value-bounds="use-arith-ops" \ +// RUN: -verify-diagnostics -split-input-file | \ +// RUN: FileCheck %s --check-prefix=CHECK-ARITH + // CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 + 5)> // CHECK-LABEL: func @arith_addi( // CHECK-SAME: %[[a:.*]]: index // CHECK: %[[apply:.*]] = affine.apply #[[$map]]()[%[[a]]] // CHECK: return %[[apply]] + +// CHECK-ARITH-LABEL: func @arith_addi( +// CHECK-ARITH-SAME: %[[a:.*]]: index +// CHECK-ARITH: %[[c5:.*]] = arith.constant 5 : index +// CHECK-ARITH: %[[add:.*]] = arith.addi %[[c5]], %[[a]] +// CHECK-ARITH: %[[c5:.*]] = arith.constant 5 : index +// CHECK-ARITH: %[[add:.*]] = arith.addi %[[a]], %[[c5]] +// CHECK-ARITH: return %[[add]] func.func @arith_addi(%a: index) -> index { %0 = arith.constant 5 : index %1 = arith.addi %0, %a : index diff --git a/mlir/test/lib/Dialect/Affine/CMakeLists.txt b/mlir/test/lib/Dialect/Affine/CMakeLists.txt --- a/mlir/test/lib/Dialect/Affine/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Affine/CMakeLists.txt @@ -20,6 +20,7 @@ Core LINK_LIBS PUBLIC + MLIRArithTransforms MLIRAffineTransforms MLIRAffineUtils MLIRIR diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp --- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp +++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/Transforms/Transforms.h" +#include "mlir/Dialect/Arith/Transforms/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -45,6 +46,10 @@ Option reifyToFuncArgs{ *this, "reify-to-func-args", llvm::cl::desc("Reify in terms of function args"), llvm::cl::init(false)}; + + Option useArithOps{*this, "use-arith-ops", + llvm::cl::desc("Reify with arith dialect ops"), + llvm::cl::init(false)}; }; } // namespace @@ -62,7 +67,8 @@ /// Look for "test.reify_bound" ops in the input and replace their results with /// the reified values. static LogicalResult testReifyValueBounds(func::FuncOp funcOp, - bool reifyToFuncArgs) { + bool reifyToFuncArgs, + bool useArithOps) { IRRewriter rewriter(funcOp.getContext()); WalkResult result = funcOp.walk([&](Operation *op) { // Look for test.reify_bound ops. @@ -131,11 +137,21 @@ FailureOr(rewriter.getIndexAttr(*reifiedConst)); } else { if (dim) { - reified = reifyShapedValueDimBound(rewriter, op->getLoc(), *boundType, - value, *dim, stopCondition); + if (useArithOps) { + reified = arith::reifyShapedValueDimBound( + rewriter, op->getLoc(), *boundType, value, *dim, stopCondition); + } else { + reified = reifyShapedValueDimBound( + rewriter, op->getLoc(), *boundType, value, *dim, stopCondition); + } } else { - reified = reifyIndexValueBound(rewriter, op->getLoc(), *boundType, - value, stopCondition); + if (useArithOps) { + reified = arith::reifyIndexValueBound( + rewriter, op->getLoc(), *boundType, value, stopCondition); + } else { + reified = reifyIndexValueBound(rewriter, op->getLoc(), *boundType, + value, stopCondition); + } } } if (failed(reified)) { @@ -159,7 +175,8 @@ } void TestReifyValueBounds::runOnOperation() { - if (failed(testReifyValueBounds(getOperation(), reifyToFuncArgs))) + if (failed( + testReifyValueBounds(getOperation(), reifyToFuncArgs, useArithOps))) signalPassFailure(); } diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -10132,8 +10132,10 @@ ":MemRefDialect", ":Pass", ":Support", + ":TensorDialect", ":TransformUtils", ":Transforms", + ":ValueBoundsOpInterface", ":VectorDialect", "//llvm:Support", ], diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -552,6 +552,7 @@ "//mlir:AffineTransforms", "//mlir:AffineUtils", "//mlir:Analysis", + "//mlir:ArithTransforms", "//mlir:DialectUtils", "//mlir:FuncDialect", "//mlir:IR",