diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Arith/Transforms/Transforms.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Arith/Transforms/Transforms.h @@ -0,0 +1,52 @@ +//===- Transforms.h - Arith Transforms --------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ARITH_TRANSFORMS_TRANSFORMS_H +#define MLIR_DIALECT_ARITH_TRANSFORMS_TRANSFORMS_H + +#include "mlir/Support/LogicalResult.h" + +namespace mlir { +class Location; +class OpBuilder; +class OpFoldResult; +class Value; + +namespace presburger { +enum class BoundType; +} // namespace presburger + +namespace arith { + +/// Reify a bound for the given index-typed value or shape dimension size in +/// terms of the owning op's operands. `dim` must be `nullopt` if and only if +/// `value` is index-typed. +/// +/// By default, lower/equal bounds are closed and upper bounds are open. If +/// `closedUB` is set to "true", upper bounds are also closed. +FailureOr reifyValueBound(OpBuilder &b, Location loc, + presburger::BoundType type, Value value, + std::optional dim, + bool closedUB = false); + +/// Reify a bound for the given index-typed value or shape dimension size in +/// terms of SSA values for which `stopCondition` is met. `dim` must be +/// `nullopt` if and only if `value` is index-typed. +/// +/// By default, lower/equal bounds are closed and upper bounds are open. If +/// `closedUB` is set to "true", upper bounds are also closed. +FailureOr +reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type, + Value value, std::optional dim, + function_ref)> stopCondition, + bool closedUB = false); + +} // namespace arith +} // namespace mlir + +#endif // MLIR_DIALECT_ARITH_TRANSFORMS_TRANSFORMS_H diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt @@ -4,6 +4,7 @@ EmulateWideInt.cpp ExpandOps.cpp IntRangeOptimizations.cpp + ReifyValueBounds.cpp UnsignedWhenEquivalent.cpp ADDITIONAL_HEADER_DIRS @@ -23,7 +24,9 @@ MLIRIR MLIRMemRefDialect MLIRPass + MLIRTensorDialect MLIRTransforms MLIRTransformUtils + MLIRValueBoundsOpInterface MLIRVectorDialect ) diff --git a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp @@ -0,0 +1,129 @@ +//===- ReifyValueBounds.cpp --- Reify value bounds with arith ops -------*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/Transforms/Transforms.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Interfaces/ValueBoundsOpInterface.h" + +using namespace mlir; +using namespace mlir::arith; + +/// Build Arith IR for the given affine map and its operands. +static Value buildArithValue(OpBuilder &b, Location loc, AffineMap map, + ValueRange operands) { + assert(map.getNumResults() == 1 && "multiple results not supported yet"); + std::function buildExpr = [&](AffineExpr e) -> Value { + switch (e.getKind()) { + case AffineExprKind::Constant: + return b.create(loc, + e.cast().getValue()); + case AffineExprKind::DimId: + return operands[e.cast().getPosition()]; + case AffineExprKind::SymbolId: + return operands[e.cast().getPosition() + + map.getNumDims()]; + case AffineExprKind::Add: { + auto binaryExpr = e.cast(); + return b.create(loc, buildExpr(binaryExpr.getLHS()), + buildExpr(binaryExpr.getRHS())); + } + case AffineExprKind::Mul: { + auto binaryExpr = e.cast(); + return b.create(loc, buildExpr(binaryExpr.getLHS()), + buildExpr(binaryExpr.getRHS())); + } + case AffineExprKind::FloorDiv: { + auto binaryExpr = e.cast(); + return b.create(loc, buildExpr(binaryExpr.getLHS()), + buildExpr(binaryExpr.getRHS())); + } + case AffineExprKind::CeilDiv: { + auto binaryExpr = e.cast(); + return b.create(loc, buildExpr(binaryExpr.getLHS()), + buildExpr(binaryExpr.getRHS())); + } + case AffineExprKind::Mod: { + auto binaryExpr = e.cast(); + return b.create(loc, buildExpr(binaryExpr.getLHS()), + buildExpr(binaryExpr.getRHS())); + } + default: + llvm_unreachable("unknown AffineExpr"); + } + }; + return buildExpr(map.getResult(0)); +} + +FailureOr mlir::arith::reifyValueBound(OpBuilder &b, Location loc, + presburger::BoundType type, + Value value, + std::optional dim, + bool closedUB) { + auto stopCondition = [&](Value v, std::optional d) { + // Reify in terms of SSA values that are different from `value`. + return v != value; + }; + return reifyValueBound(b, loc, type, value, dim, stopCondition, closedUB); +} + +FailureOr mlir::arith::reifyValueBound( + OpBuilder &b, Location loc, presburger::BoundType type, Value value, + std::optional dim, + function_ref)> stopCondition, + bool closedUB) { + // Compute bound. + AffineMap boundMap; + ValueDimList mapOperands; + if (failed(ValueBoundsConstraintSet::computeBound( + boundMap, mapOperands, type, value, dim, stopCondition, closedUB))) + return failure(); + + // Materialize tensor.dim/memref.dim ops. + SmallVector operands; + for (auto valueDim : mapOperands) { + Value value = valueDim.first; + std::optional dim = valueDim.second; + + if (!dim.has_value()) { + // This is an index-typed value. + assert(value.getType().isIndex() && "expected index type"); + operands.push_back(value); + continue; + } + + assert(value.getType().cast().isDynamicDim(*dim) && + "expected dynamic dim"); + if (value.getType().isa()) { + // A tensor dimension is used: generate a tensor.dim. + operands.push_back(b.create(loc, value, *dim)); + } else if (value.getType().isa()) { + // A memref dimension is used: generate a memref.dim. + operands.push_back(b.create(loc, value, *dim)); + } else { + llvm_unreachable("cannot generate DimOp for unsupported shaped type"); + } + } + + // Check for special cases where no arith ops are needed. + if (boundMap.isSingleConstant()) { + // Bound is a constant: return an IntegerAttr. + return static_cast( + b.getIndexAttr(boundMap.getSingleConstantResult())); + } + // No arith ops are needed if the bound is a single SSA value. + if (auto expr = boundMap.getResult(0).dyn_cast()) + return static_cast(operands[expr.getPosition()]); + if (auto expr = boundMap.getResult(0).dyn_cast()) + return static_cast( + operands[expr.getPosition() + boundMap.getNumDims()]); + // General case: build Arith ops. + return static_cast(buildArithValue(b, loc, boundMap, operands)); +} diff --git a/mlir/test/Dialect/Affine/value-bounds-reification.mlir b/mlir/test/Dialect/Affine/value-bounds-reification.mlir --- a/mlir/test/Dialect/Affine/value-bounds-reification.mlir +++ b/mlir/test/Dialect/Affine/value-bounds-reification.mlir @@ -1,10 +1,18 @@ // RUN: mlir-opt %s -test-affine-reify-value-bounds="reify-to-func-args" \ // RUN: -verify-diagnostics -split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-affine-reify-value-bounds="reify-to-func-args use-arith-ops" \ +// RUN: -verify-diagnostics -split-input-file | FileCheck %s --check-prefix=CHECK-ARITH + // CHECK-LABEL: func @reify_through_chain( // CHECK-SAME: %[[sz0:.*]]: index, %[[sz2:.*]]: index // CHECK: %[[c10:.*]] = arith.constant 10 : index // CHECK: return %[[sz0]], %[[c10]], %[[sz2]] + +// CHECK-ARITH-LABEL: func @reify_through_chain( +// CHECK-ARITH-SAME: %[[sz0:.*]]: index, %[[sz2:.*]]: index +// CHECK-ARITH: %[[c10:.*]] = arith.constant 10 : index +// CHECK-ARITH: return %[[sz0]], %[[c10]], %[[sz2]] func.func @reify_through_chain(%sz0: index, %sz2: index) -> (index, index, index) { %c2 = arith.constant 2 : index %0 = tensor.empty(%sz0, %sz2) : tensor diff --git a/mlir/test/Dialect/Arith/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Arith/value-bounds-op-interface-impl.mlir --- a/mlir/test/Dialect/Arith/value-bounds-op-interface-impl.mlir +++ b/mlir/test/Dialect/Arith/value-bounds-op-interface-impl.mlir @@ -1,11 +1,23 @@ // RUN: mlir-opt %s -test-affine-reify-value-bounds -verify-diagnostics \ // RUN: -split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-affine-reify-value-bounds="use-arith-ops" \ +// RUN: -verify-diagnostics -split-input-file | \ +// RUN: FileCheck %s --check-prefix=CHECK-ARITH + // CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 + 5)> // CHECK-LABEL: func @arith_addi( // CHECK-SAME: %[[a:.*]]: index // CHECK: %[[apply:.*]] = affine.apply #[[$map]]()[%[[a]]] // CHECK: return %[[apply]] + +// CHECK-ARITH-LABEL: func @arith_addi( +// CHECK-ARITH-SAME: %[[a:.*]]: index +// CHECK-ARITH: %[[c5:.*]] = arith.constant 5 : index +// CHECK-ARITH: %[[add:.*]] = arith.addi %[[c5]], %[[a]] +// CHECK-ARITH: %[[c5:.*]] = arith.constant 5 : index +// CHECK-ARITH: %[[add:.*]] = arith.addi %[[a]], %[[c5]] +// CHECK-ARITH: return %[[add]] func.func @arith_addi(%a: index) -> index { %0 = arith.constant 5 : index %1 = arith.addi %0, %a : index diff --git a/mlir/test/lib/Dialect/Affine/CMakeLists.txt b/mlir/test/lib/Dialect/Affine/CMakeLists.txt --- a/mlir/test/lib/Dialect/Affine/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Affine/CMakeLists.txt @@ -20,6 +20,7 @@ Core LINK_LIBS PUBLIC + MLIRArithTransforms MLIRAffineTransforms MLIRAffineUtils MLIRIR diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp --- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp +++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/Transforms/Transforms.h" +#include "mlir/Dialect/Arith/Transforms/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -45,6 +46,10 @@ Option reifyToFuncArgs{ *this, "reify-to-func-args", llvm::cl::desc("Reify in terms of function args"), llvm::cl::init(false)}; + + Option useArithOps{*this, "use-arith-ops", + llvm::cl::desc("Reify with arith dialect ops"), + llvm::cl::init(false)}; }; } // namespace @@ -62,7 +67,8 @@ /// Look for "test.reify_bound" ops in the input and replace their results with /// the reified values. static LogicalResult testReifyValueBounds(func::FuncOp funcOp, - bool reifyToFuncArgs) { + bool reifyToFuncArgs, + bool useArithOps) { IRRewriter rewriter(funcOp.getContext()); WalkResult result = funcOp.walk([&](Operation *op) { // Look for test.reify_bound ops. @@ -128,8 +134,13 @@ reified = FailureOr(rewriter.getIndexAttr(*reifiedConst)); } else { - reified = reifyValueBound(rewriter, op->getLoc(), *boundType, value, - dim, stopCondition); + if (useArithOps) { + reified = arith::reifyValueBound(rewriter, op->getLoc(), *boundType, + value, dim, stopCondition); + } else { + reified = reifyValueBound(rewriter, op->getLoc(), *boundType, value, + dim, stopCondition); + } } if (failed(reified)) { op->emitOpError("could not reify bound"); @@ -152,7 +163,8 @@ } void TestReifyValueBounds::runOnOperation() { - if (failed(testReifyValueBounds(getOperation(), reifyToFuncArgs))) + if (failed( + testReifyValueBounds(getOperation(), reifyToFuncArgs, useArithOps))) signalPassFailure(); } diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -9792,8 +9792,10 @@ ":MemRefDialect", ":Pass", ":Support", + ":TensorDialect", ":TransformUtils", ":Transforms", + ":ValueBoundsOpInterface", ":VectorDialect", "//llvm:Support", ], diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -533,6 +533,7 @@ "//mlir:AffineTransforms", "//mlir:AffineUtils", "//mlir:Analysis", + "//mlir:ArithTransforms", "//mlir:DialectUtils", "//mlir:FuncDialect", "//mlir:IR",