diff --git a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h @@ -14,12 +14,21 @@ #ifndef MLIR_DIALECT_AFFINE_TRANSFORMS_TRANSFORMS_H #define MLIR_DIALECT_AFFINE_TRANSFORMS_TRANSFORMS_H +#include "mlir/Interfaces/ValueBoundsOpInterface.h" #include "mlir/Support/LogicalResult.h" namespace mlir { +class AffineApplyOp; +class Location; +class OpBuilder; +class OpFoldResult; class RewritePatternSet; class RewriterBase; -class AffineApplyOp; +class Value; + +namespace presburger { +enum class BoundType; +} // namespace presburger /// Populate patterns that expand affine index operations into more fundamental /// operations (not necessarily restricted to Affine dialect). @@ -40,6 +49,32 @@ /// maximally compose chains of AffineApplyOps. FailureOr decompose(RewriterBase &rewriter, AffineApplyOp op); +/// Reify a bound for the given index-typed value or shape dimension size in +/// terms of the owning op's operands. `dim` must be `nullopt` if and only if +/// `value` is index-typed. +FailureOr reifyValueBound(OpBuilder &b, Location loc, + presburger::BoundType type, Value value, + std::optional dim); + +/// Reify a bound for the given index-typed value or shape dimension size in +/// terms of SSA values for which `stopCondition` is met. `dim` must be +/// `nullopt` if and only if `value` is index-typed. +/// +/// Example: +/// %0 = arith.addi %a, %b : index +/// %1 = arith.addi %0, %c : index +/// +/// * If `stopCondition` evaluates to "true" for %0 and %c, "%0 + %c" is an EQ +/// bound for %1. +/// * If `stopCondition` evaluates to "true" for %a, %b and %c, "%a + %b + %c" +/// is an EQ bound for %1. +/// * Otherwise, if the owners of %a, %b or %c do not implement the +/// ValueBoundsOpInterface, no bound can be computed. +FailureOr +reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type, + Value value, std::optional dim, + ValueBoundsConstraintSet::StopConditionFn stopCondition); + } // namespace mlir #endif // MLIR_DIALECT_AFFINE_TRANSFORMS_TRANSFORMS_H diff --git a/mlir/include/mlir/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.h @@ -0,0 +1,20 @@ +//===- ValueBoundsOpInterfaceImpl.h - Impl. of ValueBoundsOpInterface -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TENSOR_IR_VALUEBOUNDSOPINTERFACEIMPL_H +#define MLIR_DIALECT_TENSOR_IR_VALUEBOUNDSOPINTERFACEIMPL_H + +namespace mlir { +class DialectRegistry; + +namespace tensor { +void registerValueBoundsOpInterfaceExternalModels(DialectRegistry ®istry); +} // namespace tensor +} // namespace mlir + +#endif // MLIR_DIALECT_TENSOR_IR_VALUEBOUNDSOPINTERFACEIMPL_H diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -65,6 +65,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h" #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h" +#include "mlir/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.h" #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" @@ -142,6 +143,7 @@ tensor::registerBufferizableOpInterfaceExternalModels(registry); tensor::registerInferTypeOpInterfaceExternalModels(registry); tensor::registerTilingInterfaceExternalModels(registry); + tensor::registerValueBoundsOpInterfaceExternalModels(registry); vector::registerBufferizableOpInterfaceExternalModels(registry); } diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt --- a/mlir/include/mlir/Interfaces/CMakeLists.txt +++ b/mlir/include/mlir/Interfaces/CMakeLists.txt @@ -12,6 +12,7 @@ add_mlir_interface(ShapedOpInterfaces) add_mlir_interface(SideEffectInterfaces) add_mlir_interface(TilingInterface) +add_mlir_interface(ValueBoundsOpInterface) add_mlir_interface(VectorInterfaces) add_mlir_interface(ViewLikeInterface) diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h @@ -0,0 +1,198 @@ +//===- ValueBoundsOpInterface.h - Value Bounds ------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_VALUEBOUNDSOPINTERFACE_H_ +#define MLIR_INTERFACES_VALUEBOUNDSOPINTERFACE_H_ + +#include "mlir/Analysis/FlatLinearValueConstraints.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Value.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" +#include "llvm/ADT/SetVector.h" + +namespace mlir { + +using ValueDimList = SmallVector>>; + +/// A helper class to be used with `ValueBoundsOpInterface`. This class stores a +/// constraint system and mapping of constrained variables to index-typed +/// values or dimension sizes of shaped values. +/// +/// Interface implementations of `ValueBoundsOpInterface` use `addBounds` to +/// insert constraints about their results and/or region block arguments into +/// the constraint set in the form of an AffineExpr. When a bound should be +/// expressed in terms of another value/dimension, `getExpr` can be used to +/// retrieve an AffineExpr that represents the specified value/dimension. +/// +/// When a value/dimension is retrieved for the first time through `getExpr`, +/// it is added to an internal worklist. See `computeBound` for more details. +/// +/// Note: Any modification of existing IR invalides the data stored in this +/// class. Adding new operations is allowed. +class ValueBoundsConstraintSet { +protected: + /// Helper class that builds a bound for a shaped value dimension or + /// index-typed value. + class BoundBuilder { + public: + /// Specify a dimension, assuming that the underlying value is a shaped + /// value. + BoundBuilder &operator[](int64_t dim); + + // These overloaded operators add lower/upper/equality bounds. + void operator<(AffineExpr expr); + void operator<=(AffineExpr expr); + void operator>(AffineExpr expr); + void operator>=(AffineExpr expr); + void operator==(AffineExpr expr); + void operator<(OpFoldResult ofr); + void operator<=(OpFoldResult ofr); + void operator>(OpFoldResult ofr); + void operator>=(OpFoldResult ofr); + void operator==(OpFoldResult ofr); + void operator<(int64_t i); + void operator<=(int64_t i); + void operator>(int64_t i); + void operator>=(int64_t i); + void operator==(int64_t i); + + protected: + friend class ValueBoundsConstraintSet; + BoundBuilder(ValueBoundsConstraintSet &cstr, Value value) + : cstr(cstr), value(value) {} + + private: + BoundBuilder(const BoundBuilder &) = delete; + BoundBuilder &operator=(const BoundBuilder &) = delete; + bool operator==(const BoundBuilder &) = delete; + bool operator!=(const BoundBuilder &) = delete; + + ValueBoundsConstraintSet &cstr; + Value value; + std::optional dim = std::nullopt; + }; + +public: + /// The stop condition when traversing the backward slice of a shaped value/ + /// index-type value. The traversal continues until the stop condition + /// evaluates to "true" for a value. + using StopConditionFn = function_ref; + + /// Compute a bound for the given index-typed value or shape dimension size. + /// The computed bound is stored in `resultMap`. The operands of the bound are + /// stored in `mapOperands`. An operand is either an index-type SSA value + /// or a shaped value and a dimension. + /// + /// `dim` must be `nullopt` if and only if `value` is index-typed. The bound + /// is computed in terms of values for which `stopCondition` evaluates to + /// "true". To that end, the backward slice (reverse use-def chain) of the + /// given value is visited in a worklist-driven manner and the constraint set + /// is populated according to `ValueBoundsOpInterface` for each visited value. + static LogicalResult computeBound(AffineMap &resultMap, + ValueDimList &mapOperands, + presburger::BoundType type, Value value, + std::optional dim, + StopConditionFn stopCondition); + + /// Add a bound for the given index-typed value or shaped value. This function + /// returns a builder that adds the bound. + BoundBuilder bound(Value value) { return BoundBuilder(*this, value); } + + /// Return an expression that represents the given index-typed value or shaped + /// value dimension. If this value/dimension was not used so far, it is added + /// to the worklist. + /// + /// `dim` must be `nullopt` if and only if the given value is of index type. + AffineExpr getExpr(Value value, std::optional dim = std::nullopt); + + /// Return an expression that represents a constant or index-typed SSA value. + /// In case of a value, if this value was not used so far, it is added to the + /// worklist. + AffineExpr getExpr(OpFoldResult ofr); + + /// Return an expression that represents a constant. + AffineExpr getExpr(int64_t constant); + +protected: + /// Dimension identifier to indicate a value is index-typed. This is used for + /// internal data structures/API only. + static constexpr int64_t kIndexValue = -1; + + /// An index-typed value or the dimension of a shaped-type value. + using ValueDim = std::pair; + + ValueBoundsConstraintSet(Value value, std::optional dim); + + /// Iteratively process all elements on the worklist until an index-typed + /// value or shaped value meets `stopCondition`. Such values are not processed + /// any further. + void processWorklist(StopConditionFn stopCondition); + + /// Bound the given column in the underlying constraint set by the given + /// expression. + void addBound(presburger::BoundType type, int64_t pos, AffineExpr expr); + + /// Return the column position of the given value/dimension. Asserts that the + /// value/dimension exists in the constraint set. + int64_t getPos(Value value, std::optional dim = std::nullopt) const; + + /// Insert a value/dimension into the constraint set. If `isSymbol` is set to + /// "false", a dimension is added. + /// + /// Note: There are certain affine restrictions wrt. dimensions. E.g., they + /// cannot be multiplied. Furthermore, bounds can only be queried for + /// dimensions but not for symbols. + int64_t insert(Value value, std::optional dim, bool isSymbol = true); + + /// Project out the given column in the constraint set. + void projectOut(int64_t pos); + + /// Project out all columns for which the condition holds. + void projectOut(function_ref condition); + + /// Mapping of columns to values/shape dimensions. + SmallVector positionToValueDim; + /// Reverse mapping of values/shape dimensions to columns. + DenseMap valueDimToPosition; + + /// Worklist of values/shape dimensions that have not been processed yet. + SetVector worklist; + + /// Constraint system of equalities and inequalities. + FlatLinearConstraints cstr; + + /// Builder for constructing affine expressions. + Builder builder; +}; + +} // namespace mlir + +#include "mlir/Interfaces/ValueBoundsOpInterface.h.inc" + +namespace mlir { + +/// Default implementation for destination style ops: Tied OpResults and +/// OpOperands have the same type. +template +struct DstValueBoundsOpInterfaceExternalModel + : public ValueBoundsOpInterface::ExternalModel< + DstValueBoundsOpInterfaceExternalModel, ConcreteOp> { + void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, + ValueBoundsConstraintSet &cstr) const { + auto dstOp = cast(op); + assert(value.getDefiningOp() == dstOp); + + Value tiedOperand = dstOp.getTiedOpOperand(value.cast())->get(); + cstr.bound(value)[dim] == cstr.getExpr(tiedOperand, dim); + } +}; + +} // namespace mlir + +#endif // MLIR_INTERFACES_VALUEBOUNDSOPINTERFACE_H_ diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.td b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.td @@ -0,0 +1,63 @@ +//===-- ValueBoundsOpInterface.td - Value Bounds -----------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef VALUEBOUNDSOPINTERFACE +#define VALUEBOUNDSOPINTERFACE + +include "mlir/IR/OpBase.td" + +def ValueBoundsOpInterface : OpInterface<"ValueBoundsOpInterface"> { + let description = [{ + This interface allows operations with index-typed and/or shaped value-typed + results/block arguments to specify range bounds. These bounds are stored in + a constraint set. The constraint set can then be queried to compute bounds + in terms of other values that are stored in the constraint set. + }]; + let cppNamespace = "::mlir"; + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Populate the constraint set with bounds for the given index-typed + value. + + Note: If `value` is a block argument, it must belong to an entry block + of a region. Unstructured control flow graphs are not supported at the + moment. + }], + /*retType=*/"void", + /*methodName=*/"populateBoundsForIndexValue", + /*args=*/(ins "::mlir::Value":$value, + "::mlir::ValueBoundsConstraintSet &":$cstr), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + llvm_unreachable("not implemented"); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Populate the constraint set with bounds for the size of the specified + dimension of the given shaped value. + + Note: If `value` is a block argument, it must belong to an entry block + of a region. Unstructured control flow graphs are not supported at the + moment. + }], + /*retType=*/"void", + /*methodName=*/"populateBoundsForShapedValueDim", + /*args=*/(ins "::mlir::Value":$value, + "int64_t":$dim, + "::mlir::ValueBoundsConstraintSet &":$cstr), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + llvm_unreachable("not implemented"); + }] + >, + ]; +} + +#endif // VALUEBOUNDSOPINTERFACE diff --git a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt @@ -12,6 +12,7 @@ LoopUnroll.cpp LoopUnrollAndJam.cpp PipelineDataTransfer.cpp + ReifyValueBounds.cpp SuperVectorize.cpp SimplifyAffineStructures.cpp @@ -33,7 +34,9 @@ MLIRPass MLIRSCFUtils MLIRSideEffectInterfaces + MLIRTensorDialect MLIRTransformUtils + MLIRValueBoundsOpInterface MLIRVectorDialect MLIRVectorUtils MLIRVectorToLLVM diff --git a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp @@ -0,0 +1,87 @@ +//===- ReifyValueBounds.cpp --- Reify value bounds with affine ops ------*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/Transforms/Transforms.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Interfaces/ValueBoundsOpInterface.h" + +using namespace mlir; + +FailureOr mlir::reifyValueBound(OpBuilder &b, Location loc, + presburger::BoundType type, + Value value, + std::optional dim) { + // We are trying to reify a bound for `value`. Construct a stop condition that + // evaluates to "true" for any SSA value expect for `value`. I.e., the bound + // will be computed in terms of any SSA values expect for `value`. The first + // such values are operands of the owner of `value`. + auto stopCondition = [&](Value v) { + // Reify in terms of SSA values that are different from `value`. + return v != value; + }; + return reifyValueBound(b, loc, type, value, dim, stopCondition); +} + +FailureOr +mlir::reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type, + Value value, std::optional dim, + function_ref stopCondition) { + // Compute bound. + AffineMap boundMap; + ValueDimList mapOperands; + if (failed(ValueBoundsConstraintSet::computeBound(boundMap, mapOperands, type, + value, dim, stopCondition))) + return failure(); + + // Materialize tensor.dim/memref.dim ops. + SmallVector operands; + for (auto valueDim : mapOperands) { + Value value = valueDim.first; + std::optional dim = valueDim.second; + + if (!dim.has_value()) { + // This is an index-typed value. + assert(value.getType().isIndex() && "expected index type"); + operands.push_back(value); + continue; + } + + assert(cast(value.getType()).isDynamicDim(*dim) && + "expected dynamic dim"); + if (isa(value.getType())) { + // A tensor dimension is used: generate a tensor.dim. + operands.push_back(b.create(loc, value, *dim)); + } else if (isa(value.getType())) { + // A memref dimension is used: generate a memref.dim. + operands.push_back(b.create(loc, value, *dim)); + } else { + llvm_unreachable("cannot generate DimOp for unsupported shaped type"); + } + } + + // Simplify and return bound. + mlir::canonicalizeMapAndOperands(&boundMap, &operands); + // Check for special cases where no affine.apply op is needed. + if (boundMap.isSingleConstant()) { + // Bound is a constant: return an IntegerAttr. + return static_cast( + b.getIndexAttr(boundMap.getSingleConstantResult())); + } + // No affine.apply op is needed if the bound is a single SSA value. + if (auto expr = boundMap.getResult(0).dyn_cast()) + return static_cast(operands[expr.getPosition()]); + if (auto expr = boundMap.getResult(0).dyn_cast()) + return static_cast( + operands[expr.getPosition() + boundMap.getNumDims()]); + // General case: build affine.apply op. + return static_cast( + b.create(loc, boundMap, operands).getResult()); +} diff --git a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt @@ -3,11 +3,13 @@ TensorInferTypeOpInterfaceImpl.cpp TensorOps.cpp TensorTilingInterfaceImpl.cpp + ValueBoundsOpInterfaceImpl.cpp ) add_mlir_dialect_library(MLIRTensorDialect TensorDialect.cpp TensorOps.cpp + ValueBoundsOpInterfaceImpl.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/mlir/Dialect/Tensor @@ -32,6 +34,7 @@ MLIRShapedOpInterfaces MLIRSideEffectInterfaces MLIRSupport + MLIRValueBoundsOpInterface MLIRViewLikeInterface ) diff --git a/mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp @@ -0,0 +1,129 @@ +//===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.h" + +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Interfaces/ValueBoundsOpInterface.h" + +using namespace mlir; + +namespace mlir { +namespace tensor { +namespace { + +struct CastOpInterface + : public ValueBoundsOpInterface::ExternalModel { + void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, + ValueBoundsConstraintSet &cstr) const { + auto castOp = cast(op); + assert(value == castOp.getResult() && "invalid value"); + + if (castOp.getResult().getType().isa() && + castOp.getSource().getType().isa()) { + cstr.bound(value)[dim] == cstr.getExpr(castOp.getSource(), dim); + } + } +}; + +struct DimOpInterface + : public ValueBoundsOpInterface::ExternalModel { + void populateBoundsForIndexValue(Operation *op, Value value, + ValueBoundsConstraintSet &cstr) const { + auto dimOp = cast(op); + assert(value == dimOp.getResult() && "invalid value"); + + auto constIndex = dimOp.getConstantIndex(); + if (!constIndex.has_value()) + return; + cstr.bound(value) == cstr.getExpr(dimOp.getSource(), *constIndex); + } +}; + +struct EmptyOpInterface + : public ValueBoundsOpInterface::ExternalModel { + void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, + ValueBoundsConstraintSet &cstr) const { + auto emptyOp = cast(op); + assert(value == emptyOp.getResult() && "invalid value"); + + cstr.bound(value)[dim] == emptyOp.getMixedSizes()[dim]; + } +}; + +struct ExtractSliceOpInterface + : public ValueBoundsOpInterface::ExternalModel { + void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, + ValueBoundsConstraintSet &cstr) const { + auto extractSliceOp = cast(op); + assert(value == extractSliceOp.getResult() && "invalid value"); + + llvm::SmallBitVector dropped = extractSliceOp.getDroppedDims(); + int64_t ctr = -1; + for (int64_t i = 0, e = extractSliceOp.getMixedSizes().size(); i < e; ++i) { + // Skip over rank-reduced dimensions. + if (!dropped.test(i)) + ++ctr; + if (ctr == dim) { + cstr.bound(value)[dim] == extractSliceOp.getMixedSizes()[i]; + return; + } + } + llvm_unreachable("could not find non-rank-reduced dim"); + } +}; + +struct PadOpInterface + : public ValueBoundsOpInterface::ExternalModel { + void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, + ValueBoundsConstraintSet &cstr) const { + auto padOp = cast(op); + assert(value == padOp.getResult() && "invalid value"); + + AffineExpr expr = cstr.getExpr(padOp.getSource(), dim) + + cstr.getExpr(padOp.getMixedLowPad()[dim]) + + cstr.getExpr(padOp.getMixedHighPad()[dim]); + cstr.bound(value)[dim] == expr; + } +}; + +struct RankOpInterface + : public ValueBoundsOpInterface::ExternalModel { + void populateBoundsForIndexValue(Operation *op, Value value, + ValueBoundsConstraintSet &cstr) const { + auto rankOp = cast(op); + assert(value == rankOp.getResult() && "invalid value"); + + auto tensorType = rankOp.getTensor().getType().dyn_cast(); + if (!tensorType) + return; + cstr.bound(value) == tensorType.getRank(); + } +}; + +} // namespace +} // namespace tensor +} // namespace mlir + +void mlir::tensor::registerValueBoundsOpInterfaceExternalModels( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) { + tensor::CastOp::attachInterface(*ctx); + tensor::DimOp::attachInterface(*ctx); + tensor::EmptyOp::attachInterface(*ctx); + tensor::ExtractSliceOp::attachInterface( + *ctx); + tensor::InsertOp::attachInterface< + DstValueBoundsOpInterfaceExternalModel>(*ctx); + tensor::InsertSliceOp::attachInterface< + DstValueBoundsOpInterfaceExternalModel>(*ctx); + tensor::PadOp::attachInterface(*ctx); + tensor::RankOp::attachInterface(*ctx); + }); +} diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt --- a/mlir/lib/Interfaces/CMakeLists.txt +++ b/mlir/lib/Interfaces/CMakeLists.txt @@ -14,6 +14,7 @@ ShapedOpInterfaces.cpp SideEffectInterfaces.cpp TilingInterface.cpp + ValueBoundsOpInterface.cpp VectorInterfaces.cpp ViewLikeInterface.cpp ) @@ -52,4 +53,18 @@ add_mlir_interface_library(VectorInterfaces) add_mlir_interface_library(ViewLikeInterface) +add_mlir_library(MLIRValueBoundsOpInterface + ValueBoundsOpInterface.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Interfaces + + DEPENDS + MLIRValueBoundsOpInterfaceIncGen + + LINK_LIBS PUBLIC + MLIRAnalysis + MLIRIR + ) + add_subdirectory(Utils) diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp @@ -0,0 +1,397 @@ +//===- ValueBoundsOpInterface.cpp - Value Bounds -------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Interfaces/ValueBoundsOpInterface.h" + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Matchers.h" +#include "llvm/ADT/APSInt.h" + +using namespace mlir; +using presburger::BoundType; +using presburger::VarKind; + +namespace mlir { +#include "mlir/Interfaces/ValueBoundsOpInterface.cpp.inc" +} // namespace mlir + +/// If ofr is a constant integer or an IntegerAttr, return the integer. +static std::optional getConstantIntValue(OpFoldResult ofr) { + // Case 1: Check for Constant integer. + if (auto val = ofr.dyn_cast()) { + APSInt intVal; + if (matchPattern(val, m_ConstantInt(&intVal))) + return intVal.getSExtValue(); + return std::nullopt; + } + // Case 2: Check for IntegerAttr. + Attribute attr = ofr.dyn_cast(); + if (auto intAttr = attr.dyn_cast_or_null()) + return intAttr.getValue().getSExtValue(); + return std::nullopt; +} + +ValueBoundsConstraintSet::ValueBoundsConstraintSet(Value value, + std::optional dim) + : builder(value.getContext()) { + insert(value, dim, /*isSymbol=*/false); +} + +#ifndef NDEBUG +static void assertValidValueDim(Value value, std::optional dim) { + if (value.getType().isIndex()) { + assert(!dim.has_value() && "invalid dim value"); + } else if (auto shapedType = dyn_cast(value.getType())) { + assert(*dim >= 0 && "invalid dim value"); + if (shapedType.hasRank()) + assert(*dim < shapedType.getRank() && "invalid dim value"); + } else { + llvm_unreachable("unsupported type"); + } +} +#endif // NDEBUG + +void ValueBoundsConstraintSet::addBound(BoundType type, int64_t pos, + AffineExpr expr) { + LogicalResult status = cstr.addBound( + type, pos, + AffineMap::get(cstr.getNumDimVars(), cstr.getNumSymbolVars(), expr)); + (void)status; + assert(succeeded(status) && "failed to add bound to constraint system"); +} + +AffineExpr ValueBoundsConstraintSet::getExpr(Value value, + std::optional dim) { +#ifndef NDEBUG + assertValidValueDim(value, dim); +#endif // NDEBUG + + auto shapedType = dyn_cast(value.getType()); + if (shapedType) { + // Static dimension: return constant directly. + if (shapedType.hasRank() && !shapedType.isDynamicDim(*dim)) + return builder.getAffineConstantExpr(shapedType.getDimSize(*dim)); + } else { + // Constant index value: return directly. + if (auto constInt = getConstantIntValue(value)) + return builder.getAffineConstantExpr(*constInt); + } + + // Dynamic value: add to constraint set. + ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue)); + if (valueDimToPosition.find(valueDim) == valueDimToPosition.end()) + (void)insert(value, dim); + int64_t pos = getPos(value, dim); + return pos < cstr.getNumDimVars() + ? builder.getAffineDimExpr(pos) + : builder.getAffineSymbolExpr(pos - cstr.getNumDimVars()); +} + +AffineExpr ValueBoundsConstraintSet::getExpr(OpFoldResult ofr) { + if (Value value = ofr.dyn_cast()) + return getExpr(value, /*dim=*/std::nullopt); + auto constInt = getConstantIntValue(ofr); + assert(constInt.has_value() && "expected Integer constant"); + return builder.getAffineConstantExpr(*constInt); +} + +AffineExpr ValueBoundsConstraintSet::getExpr(int64_t constant) { + return builder.getAffineConstantExpr(constant); +} + +int64_t ValueBoundsConstraintSet::insert(Value value, + std::optional dim, + bool isSymbol) { +#ifndef NDEBUG + assertValidValueDim(value, dim); +#endif // NDEBUG + + ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue)); + assert((valueDimToPosition.find(valueDim) == valueDimToPosition.end()) && + "already mapped"); + int64_t pos = isSymbol ? cstr.appendVar(VarKind::Symbol) + : cstr.appendVar(VarKind::SetDim); + positionToValueDim.insert(positionToValueDim.begin() + pos, valueDim); + // Update reverse mapping. + for (int64_t i = pos, e = positionToValueDim.size(); i < e; ++i) + valueDimToPosition[positionToValueDim[i]] = i; + + worklist.insert(pos); + return pos; +} + +int64_t ValueBoundsConstraintSet::getPos(Value value, + std::optional dim) const { +#ifndef NDEBUG + assertValidValueDim(value, dim); + assert((value.isa() || + value.cast().getOwner()->isEntryBlock()) && + "unstructured control flow is not supported"); +#endif // NDEBUG + + auto it = + valueDimToPosition.find(std::make_pair(value, dim.value_or(kIndexValue))); + assert(it != valueDimToPosition.end() && "expected mapped entry"); + return it->second; +} + +static Operation *getOwnerOfValue(Value value) { + if (auto bbArg = value.dyn_cast()) + return bbArg.getOwner()->getParentOp(); + return value.getDefiningOp(); +} + +void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) { + while (!worklist.empty()) { + int64_t pos = worklist.pop_back_val(); + ValueDim valueDim = positionToValueDim[pos]; + Value value = valueDim.first; + int64_t dim = valueDim.second; + + // Check for static dim size. + if (dim != kIndexValue) { + auto shapedType = cast(value.getType()); + if (shapedType.hasRank() && !shapedType.isDynamicDim(dim)) { + bound(value)[dim] == getExpr(shapedType.getDimSize(dim)); + continue; + } + } + + // Do not process any further if the stop condition is met. + if (stopCondition(value)) + continue; + + // Query `ValueBoundsOpInterface` for constraints. New items may be added to + // the worklist. + auto valueBoundsOp = + dyn_cast(getOwnerOfValue(value)); + if (!valueBoundsOp) + continue; + if (dim == kIndexValue) { + valueBoundsOp.populateBoundsForIndexValue(value, *this); + } else { + valueBoundsOp.populateBoundsForShapedValueDim(value, dim, *this); + } + } +} + +void ValueBoundsConstraintSet::projectOut(int64_t pos) { + assert(pos >= 0 && pos < static_cast(positionToValueDim.size()) && + "invalid position"); + cstr.projectOut(pos); + bool erased = valueDimToPosition.erase(positionToValueDim[pos]); + (void)erased; + assert(erased && "inconsistent reverse mapping"); + positionToValueDim.erase(positionToValueDim.begin() + pos); + // Update reverse mapping. + for (int64_t i = pos, e = positionToValueDim.size(); i < e; ++i) + valueDimToPosition[positionToValueDim[i]] = i; +} + +void ValueBoundsConstraintSet::projectOut( + function_ref condition) { + int64_t nextPos = 0; + while (nextPos < static_cast(positionToValueDim.size())) { + if (condition(positionToValueDim[nextPos])) { + projectOut(nextPos); + // The column was projected out so another column is now at that position. + // Do not increase the counter. + } else { + ++nextPos; + } + } +} + +LogicalResult ValueBoundsConstraintSet::computeBound( + AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type, + Value value, std::optional dim, StopConditionFn stopCondition) { +#ifndef NDEBUG + assertValidValueDim(value, dim); +#endif // NDEBUG + + // Only EQ bounds are supported at the moment. + assert(type == BoundType::EQ && "unsupported bound type"); + + Builder b(value.getContext()); + mapOperands.clear(); + + if (stopCondition(value)) { + // Special case: If the stop condition is satisfied for the input + // value/dimension, directly return it. + mapOperands.push_back(std::make_pair(value, dim)); + resultMap = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, + b.getAffineDimExpr(0)); + return success(); + } + + // Process the backward slice of `value` (i.e., reverse use-def chain) until + // `stopCondition` is met. + ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue)); + ValueBoundsConstraintSet cstr(value, dim); + cstr.processWorklist(stopCondition); + + // Project out all variables (apart from `valueDim`) that do not match the + // stop condition. + cstr.projectOut([&](ValueDim p) { + // Do not project out `valueDim`. + if (valueDim == p) + return false; + return !stopCondition(p.first); + }); + + // Compute lower and upper bounds for `valueDim`. + int64_t pos = cstr.getPos(value, dim); + SmallVector lb(1), ub(1); + cstr.cstr.getSliceBounds(pos, 1, value.getContext(), &lb, &ub, + /*getClosedUB=*/true); + // Note: There are TODOs in the implementation of `getSliceBounds`. In such a + // case, no lower/upper bound can be computed at the moment. + if (lb.empty() || !lb[0] || ub.empty() || !ub[0] || + lb[0].getNumResults() != 1 || ub[0].getNumResults() != 1) + return failure(); + + // Look for same lower and upper bound: EQ bound. + if (ub[0] != lb[0]) + return failure(); + + // Gather all SSA values that are used in the computed bound. + assert(cstr.cstr.getNumDimAndSymbolVars() == cstr.positionToValueDim.size() && + "inconsistent mapping state"); + SmallVector replacementDims, replacementSymbols; + int64_t numDims = 0, numSymbols = 0; + for (int64_t i = 0; i < cstr.cstr.getNumDimAndSymbolVars(); ++i) { + // Skip `value`. + if (i == pos) + continue; + // Check if the position `i` is used in the generated bound. If so, it must + // be included in the generated affine.apply op. + bool used = false; + bool isDim = i < cstr.cstr.getNumDimVars(); + if (isDim) { + if (lb[0].isFunctionOfDim(i)) + used = true; + } else { + if (lb[0].isFunctionOfSymbol(i - cstr.cstr.getNumDimVars())) + used = true; + } + + if (!used) { + // Not used: Remove dim/symbol from the result. + if (isDim) { + replacementDims.push_back(b.getAffineConstantExpr(0)); + } else { + replacementSymbols.push_back(b.getAffineConstantExpr(0)); + } + continue; + } + + if (isDim) { + replacementDims.push_back(b.getAffineDimExpr(numDims++)); + } else { + replacementSymbols.push_back(b.getAffineSymbolExpr(numSymbols++)); + } + + ValueBoundsConstraintSet::ValueDim valueDim = cstr.positionToValueDim[i]; + Value value = valueDim.first; + int64_t dim = valueDim.second; + if (dim == ValueBoundsConstraintSet::kIndexValue) { + // An index-type value is used: can be used directly in the affine.apply + // op. + assert(value.getType().isIndex() && "expected index type"); + mapOperands.push_back(std::make_pair(value, std::nullopt)); + continue; + } + + assert(cast(value.getType()).isDynamicDim(dim) && + "expected dynamic dim"); + mapOperands.push_back(std::make_pair(value, dim)); + } + + resultMap = lb[0].replaceDimsAndSymbols(replacementDims, replacementSymbols, + numDims, numSymbols); + return success(); +} + +ValueBoundsConstraintSet::BoundBuilder & +ValueBoundsConstraintSet::BoundBuilder::operator[](int64_t dim) { + assert(!this->dim.has_value() && "dim was already set"); + this->dim = dim; +#ifndef NDEBUG + assertValidValueDim(value, this->dim); +#endif // NDEBUG + return *this; +} + +void ValueBoundsConstraintSet::BoundBuilder::operator<(AffineExpr expr) { +#ifndef NDEBUG + assertValidValueDim(value, this->dim); +#endif // NDEBUG + cstr.addBound(BoundType::UB, cstr.getPos(value, this->dim), expr); +} + +void ValueBoundsConstraintSet::BoundBuilder::operator<=(AffineExpr expr) { + operator<(expr + 1); +} + +void ValueBoundsConstraintSet::BoundBuilder::operator>(AffineExpr expr) { + operator>=(expr + 1); +} + +void ValueBoundsConstraintSet::BoundBuilder::operator>=(AffineExpr expr) { +#ifndef NDEBUG + assertValidValueDim(value, this->dim); +#endif // NDEBUG + cstr.addBound(BoundType::LB, cstr.getPos(value, this->dim), expr); +} + +void ValueBoundsConstraintSet::BoundBuilder::operator==(AffineExpr expr) { +#ifndef NDEBUG + assertValidValueDim(value, this->dim); +#endif // NDEBUG + cstr.addBound(BoundType::EQ, cstr.getPos(value, this->dim), expr); +} + +void ValueBoundsConstraintSet::BoundBuilder::operator<(OpFoldResult ofr) { + operator<(cstr.getExpr(ofr)); +} + +void ValueBoundsConstraintSet::BoundBuilder::operator<=(OpFoldResult ofr) { + operator<=(cstr.getExpr(ofr)); +} + +void ValueBoundsConstraintSet::BoundBuilder::operator>(OpFoldResult ofr) { + operator>(cstr.getExpr(ofr)); +} + +void ValueBoundsConstraintSet::BoundBuilder::operator>=(OpFoldResult ofr) { + operator>=(cstr.getExpr(ofr)); +} + +void ValueBoundsConstraintSet::BoundBuilder::operator==(OpFoldResult ofr) { + operator==(cstr.getExpr(ofr)); +} + +void ValueBoundsConstraintSet::BoundBuilder::operator<(int64_t i) { + operator<(cstr.getExpr(i)); +} + +void ValueBoundsConstraintSet::BoundBuilder::operator<=(int64_t i) { + operator<=(cstr.getExpr(i)); +} + +void ValueBoundsConstraintSet::BoundBuilder::operator>(int64_t i) { + operator>(cstr.getExpr(i)); +} + +void ValueBoundsConstraintSet::BoundBuilder::operator>=(int64_t i) { + operator>=(cstr.getExpr(i)); +} + +void ValueBoundsConstraintSet::BoundBuilder::operator==(int64_t i) { + operator==(cstr.getExpr(i)); +} diff --git a/mlir/test/Dialect/Affine/value-bounds-reification.mlir b/mlir/test/Dialect/Affine/value-bounds-reification.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Affine/value-bounds-reification.mlir @@ -0,0 +1,22 @@ +// RUN: mlir-opt %s -test-affine-reify-value-bounds="reify-to-func-args" \ +// RUN: -verify-diagnostics -split-input-file | FileCheck %s + +// CHECK-LABEL: func @reify_through_chain( +// CHECK-SAME: %[[sz0:.*]]: index, %[[sz2:.*]]: index +// CHECK: %[[c10:.*]] = arith.constant 10 : index +// CHECK: return %[[sz0]], %[[c10]], %[[sz2]] +func.func @reify_through_chain(%sz0: index, %sz2: index) -> (index, index, index) { + %c2 = arith.constant 2 : index + %0 = tensor.empty(%sz0, %sz2) : tensor + %1 = tensor.cast %0 : tensor to tensor + %pos = arith.constant 0 : index + %f = arith.constant 0.0 : f32 + %2 = tensor.insert %f into %1[%pos, %pos, %pos] : tensor + %3 = tensor.dim %2, %c2 : tensor + + %4 = "test.reify_bound"(%2) {dim = 0} : (tensor) -> (index) + %5 = "test.reify_bound"(%2) {dim = 1} : (tensor) -> (index) + %6 = "test.reify_bound"(%3) : (index) -> (index) + + return %4, %5, %6 : index, index, index +} diff --git a/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir @@ -0,0 +1,137 @@ +// RUN: mlir-opt %s -test-affine-reify-value-bounds -verify-diagnostics \ +// RUN: -split-input-file | FileCheck %s + +func.func @unknown_op() -> index { + %0 = "test.foo"() : () -> (tensor) + // expected-error @below{{could not reify bound}} + %1 = "test.reify_bound"(%0) {dim = 0} : (tensor) -> (index) + return %1 : index +} + +// ----- + +// CHECK-LABEL: func @cast( +// CHECK: %[[c10:.*]] = arith.constant 10 : index +// CHECK: return %[[c10]] +func.func @cast(%t: tensor<10xf32>) -> index { + %0 = tensor.cast %t : tensor<10xf32> to tensor + %1 = "test.reify_bound"(%0) {dim = 0} : (tensor) -> (index) + return %1 : index +} + +// ----- + +func.func @cast_unranked(%t: tensor<*xf32>) -> index { + %0 = tensor.cast %t : tensor<*xf32> to tensor + // expected-error @below{{could not reify bound}} + %1 = "test.reify_bound"(%0) {dim = 0} : (tensor) -> (index) + return %1 : index +} + +// ----- + +// CHECK-LABEL: func @dim( +// CHECK-SAME: %[[t:.*]]: tensor +// CHECK: %[[dim:.*]] = tensor.dim %[[t]] +// CHECK: %[[dim:.*]] = tensor.dim %[[t]] +// CHECK: return %[[dim]] +func.func @dim(%t: tensor) -> index { + %c0 = arith.constant 0 : index + %0 = tensor.dim %t, %c0 : tensor + %1 = "test.reify_bound"(%0) : (index) -> (index) + return %1 : index +} + +// ----- + +// CHECK-LABEL: func @empty( +// CHECK-SAME: %[[sz:.*]]: index +// CHECK: %[[c6:.*]] = arith.constant 6 : index +// CHECK: return %[[c6]], %[[sz]] +func.func @empty(%sz: index) -> (index, index) { + %0 = tensor.empty(%sz) : tensor<6x?xf32> + %1 = "test.reify_bound"(%0) {dim = 0} : (tensor<6x?xf32>) -> (index) + %2 = "test.reify_bound"(%0) {dim = 1} : (tensor<6x?xf32>) -> (index) + return %1, %2 : index, index +} + +// ----- + +// CHECK-LABEL: func @extract_slice_dynamic( +// CHECK-SAME: %[[t:.*]]: tensor, %[[sz:.*]]: index +// CHECK: return %[[sz]] +func.func @extract_slice_dynamic(%t: tensor, %sz: index) -> index { + %0 = tensor.extract_slice %t[2][%sz][1] : tensor to tensor + %1 = "test.reify_bound"(%0) {dim = 0} : (tensor) -> (index) + return %1 : index +} + +// ----- + +// CHECK-LABEL: func @extract_slice_static( +// CHECK-SAME: %[[t:.*]]: tensor +// CHECK: %[[c5:.*]] = arith.constant 5 : index +// CHECK: return %[[c5]] +func.func @extract_slice_static(%t: tensor) -> index { + %0 = tensor.extract_slice %t[2][5][1] : tensor to tensor<5xf32> + %1 = "test.reify_bound"(%0) {dim = 0} : (tensor<5xf32>) -> (index) + return %1 : index +} + +// ----- + +// CHECK-LABEL: func @extract_slice_rank_reduce( +// CHECK-SAME: %[[t:.*]]: tensor, %[[sz:.*]]: index +// CHECK: return %[[sz]] +func.func @extract_slice_rank_reduce(%t: tensor, %sz: index) -> index { + %0 = tensor.extract_slice %t[0, 2][1, %sz][1, 1] : tensor to tensor + %1 = "test.reify_bound"(%0) {dim = 0} : (tensor) -> (index) + return %1 : index +} + +// ----- + +// CHECK-LABEL: func @insert( +// CHECK-SAME: %[[t:.*]]: tensor +// CHECK: %[[c0:.*]] = arith.constant 0 : index +// CHECK: %[[dim:.*]] = tensor.dim %[[t]], %[[c0]] +// CHECK: return %[[dim]] +func.func @insert(%t: tensor, %f: f32, %pos: index) -> index { + %0 = tensor.insert %f into %t[%pos] : tensor + %1 = "test.reify_bound"(%0) {dim = 0} : (tensor) -> (index) + return %1 : index +} + +// ----- + +// CHECK: #[[$map:.*]] = affine_map<()[s0, s1] -> (s0 + s1 * 2)> +// CHECK: #[[$map1:.*]] = affine_map<()[s0] -> (s0 + 12)> +// CHECK-LABEL: func @pad( +// CHECK-SAME: %[[t:.*]]: tensor, %[[a:.*]]: index, %[[b:.*]]: index +// CHECK: %[[c0:.*]] = arith.constant 0 : index +// CHECK: %[[dim0:.*]] = tensor.dim %[[t]], %[[c0]] +// CHECK: %[[bound0:.*]] = affine.apply #[[$map]]()[%[[dim0]], %[[a]]] +// CHECK: %[[bound1:.*]] = affine.apply #[[$map1]]()[%[[b]]] +// CHECK: return %[[bound0]], %[[bound1]] +func.func @pad(%t: tensor, %a: index, %b: index) -> (index, index) { + %pad = arith.constant 0.0 : f32 + %0 = tensor.pad %t low[%a, 5] high[%a, %b] { + ^bb0(%arg1: index, %arg2: index): + tensor.yield %pad : f32 + } : tensor to tensor + %1 = "test.reify_bound"(%0) {dim = 0} : (tensor) -> (index) + %2 = "test.reify_bound"(%0) {dim = 1} : (tensor) -> (index) + return %1, %2 : index, index +} + +// ----- + +// CHECK-LABEL: func @rank( +// CHECK-SAME: %[[t:.*]]: tensor<5xf32> +// CHECK: %[[c1:.*]] = arith.constant 1 : index +// CHECK: return %[[c1]] +func.func @rank(%t: tensor<5xf32>) -> index { + %0 = tensor.rank %t : tensor<5xf32> + %1 = "test.reify_bound"(%0) : (index) -> (index) + return %1 : index +} diff --git a/mlir/test/lib/Dialect/Affine/CMakeLists.txt b/mlir/test/lib/Dialect/Affine/CMakeLists.txt --- a/mlir/test/lib/Dialect/Affine/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Affine/CMakeLists.txt @@ -4,6 +4,7 @@ TestAffineLoopUnswitching.cpp TestAffineLoopParametricTiling.cpp TestDecomposeAffineOps.cpp + TestReifyValueBounds.cpp TestLoopFusion.cpp TestLoopMapping.cpp TestLoopPermutation.cpp diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp @@ -0,0 +1,120 @@ +//===- TestReifyValueBounds.cpp - Test value bounds reification -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/Transforms/Transforms.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/ValueBoundsOpInterface.h" +#include "mlir/Pass/Pass.h" + +#define PASS_NAME "test-affine-reify-value-bounds" + +using namespace mlir; + +namespace { + +/// This pass applies the permutation on the first maximal perfect nest. +struct TestReifyValueBounds + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReifyValueBounds) + + StringRef getArgument() const final { return PASS_NAME; } + StringRef getDescription() const final { + return "Tests ValueBoundsOpInterface with affine dialect reification"; + } + TestReifyValueBounds() = default; + TestReifyValueBounds(const TestReifyValueBounds &pass) : PassWrapper(pass){}; + + void runOnOperation() override; + +private: + Option reifyToFuncArgs{ + *this, "reify-to-func-args", + llvm::cl::desc("Reify in terms of function args"), llvm::cl::init(false)}; +}; + +} // namespace + +/// Look for "test.reify_bound" ops in the input and replace their results with +/// the reified values. +static LogicalResult testReifyValueBounds(func::FuncOp funcOp, + bool reifyToFuncArgs) { + IRRewriter rewriter(funcOp.getContext()); + WalkResult result = funcOp.walk([&](Operation *op) { + // Look for test.reify_bound ops. + if (op->getName().getStringRef() == "test.reify_bound") { + if (op->getNumOperands() != 1 || op->getNumResults() != 1 || + !op->getResultTypes()[0].isIndex()) { + op->emitOpError("invalid op"); + return WalkResult::skip(); + } + Value value = op->getOperand(0); + if (value.getType().isa() != + !op->hasAttrOfType("dim")) { + // Op should have "dim" attribute if and only if the operand is an + // index-typed value. + op->emitOpError("invalid op"); + return WalkResult::skip(); + } + + auto dim = value.getType().isIndex() + ? std::nullopt + : std::make_optional( + op->getAttrOfType("dim").getInt()); + + // Reify value bound. + rewriter.setInsertionPointAfter(op); + FailureOr reified; + if (!reifyToFuncArgs) { + // Reify in terms of the op's operands. + reified = reifyValueBound(rewriter, op->getLoc(), + presburger::BoundType::EQ, value, dim); + } else { + // Reify in terms of function block arguments. + auto stopCondition = [](Value v) { + auto bbArg = v.dyn_cast(); + if (!bbArg) + return false; + return isa( + bbArg.getParentBlock()->getParentOp()); + }; + reified = + reifyValueBound(rewriter, op->getLoc(), presburger::BoundType::EQ, + value, dim, stopCondition); + } + if (failed(reified)) { + op->emitOpError("could not reify bound"); + return WalkResult::interrupt(); + } + + // Replace the op with the reified bound. + if (auto val = reified->dyn_cast()) { + rewriter.replaceOp(op, val); + return WalkResult::skip(); + } + Value constOp = rewriter.create( + op->getLoc(), reified->get().cast().getInt()); + rewriter.replaceOp(op, constOp); + return WalkResult::skip(); + } + return WalkResult::advance(); + }); + return failure(result.wasInterrupted()); +} + +void TestReifyValueBounds::runOnOperation() { + if (failed(testReifyValueBounds(getOperation(), reifyToFuncArgs))) + signalPassFailure(); +} + +namespace mlir { +void registerTestAffineReifyValueBoundsPass() { + PassRegistration(); +} +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -40,6 +40,7 @@ void registerSymbolTestPasses(); void registerRegionTestPasses(); void registerTestAffineDataCopyPass(); +void registerTestAffineReifyValueBoundsPass(); void registerTestDecomposeAffineOpPass(); void registerTestAffineLoopUnswitchingPass(); void registerTestAllReduceLoweringPass(); @@ -151,6 +152,7 @@ registerSymbolTestPasses(); registerRegionTestPasses(); registerTestAffineDataCopyPass(); + registerTestAffineReifyValueBoundsPass(); registerTestDecomposeAffineOpPass(); registerTestAffineLoopUnswitchingPass(); registerTestAllReduceLoweringPass(); diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -2836,7 +2836,9 @@ ":SCFDialect", ":SCFUtils", ":Support", + ":TensorDialect", ":Transforms", + ":ValueBoundsOpInterface", ":VectorDialect", ":VectorUtils", "//llvm:Support", @@ -5653,8 +5655,12 @@ "include/mlir/Transforms/InliningUtils.h", "lib/Dialect/Tensor/IR/TensorDialect.cpp", "lib/Dialect/Tensor/IR/TensorOps.cpp", + "lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp", + ], + hdrs = [ + "include/mlir/Dialect/Tensor/IR/Tensor.h", + "include/mlir/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.h", ], - hdrs = ["include/mlir/Dialect/Tensor/IR/Tensor.h"], includes = ["include"], deps = [ ":AffineDialect", @@ -5673,6 +5679,7 @@ ":Support", ":TensorOpsIncGen", ":TilingInterface", + ":ValueBoundsOpInterface", ":ViewLikeInterface", "//llvm:Support", ], @@ -8808,6 +8815,52 @@ ], ) +td_library( + name = "ValueBoundsOpInterfaceTdFiles", + srcs = [ + "include/mlir/Interfaces/ValueBoundsOpInterface.td", + ], + includes = ["include"], + deps = [ + ":OpBaseTdFiles", + ], +) + +gentbl_cc_library( + name = "ValueBoundsOpInterfaceIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + ["-gen-op-interface-decls"], + "include/mlir/Interfaces/ValueBoundsOpInterface.h.inc", + ), + ( + ["-gen-op-interface-defs"], + "include/mlir/Interfaces/ValueBoundsOpInterface.cpp.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Interfaces/ValueBoundsOpInterface.td", + deps = [ + ":ValueBoundsOpInterfaceTdFiles", + ], +) + +cc_library( + name = "ValueBoundsOpInterface", + srcs = ["lib/Interfaces/ValueBoundsOpInterface.cpp"], + hdrs = ["include/mlir/Interfaces/ValueBoundsOpInterface.h"], + includes = ["include"], + deps = [ + ":Analysis", + ":DestinationStyleOpInterface", + ":IR", + ":Support", + ":ValueBoundsOpInterfaceIncGen", + "//llvm:Support", + ], +) + cc_library( name = "TilingInterface", srcs = ["lib/Interfaces/TilingInterface.cpp"], diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -559,6 +559,7 @@ "//mlir:SCFDialect", "//mlir:Support", "//mlir:Transforms", + "//mlir:ValueBoundsOpInterface", "//mlir:VectorDialect", "//mlir:VectorUtils", ],