diff --git a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h @@ -17,9 +17,17 @@ #include "mlir/Support/LogicalResult.h" namespace mlir { +class AffineApplyOp; +class Location; +class OpBuilder; +class OpFoldResult; class RewritePatternSet; class RewriterBase; -class AffineApplyOp; +class Value; + +namespace presburger { +enum class BoundType; +} // namespace presburger /// Populate patterns that expand affine index operations into more fundamental /// operations (not necessarily restricted to Affine dialect). @@ -40,6 +48,21 @@ /// maximally compose chains of AffineApplyOps. FailureOr decompose(RewriterBase &rewriter, AffineApplyOp op); +/// Reify a bound for the given index-typed value or shape dimension size in +/// terms of the owning op's operands. `dim` must be `nullopt` if and only if +/// `value` is index-typed. +FailureOr reifyValueBound(OpBuilder &b, Location loc, + presburger::BoundType type, Value value, + std::optional dim); + +/// Reify a bound for the given index-typed value or shape dimension size in +/// terms of SSA values for which `stopCondition` is met. `dim` must be +/// `nullopt` if and only if `value` is index-typed. +FailureOr +reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type, + Value value, std::optional dim, + function_ref stopCondition); + } // namespace mlir #endif // MLIR_DIALECT_AFFINE_TRANSFORMS_TRANSFORMS_H diff --git a/mlir/include/mlir/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.h @@ -0,0 +1,20 @@ +//===- ValueBoundsOpInterfaceImpl.h - Impl. of ValueBoundsOpInterface -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TENSOR_IR_VALUEBOUNDSOPINTERFACEIMPL_H +#define MLIR_DIALECT_TENSOR_IR_VALUEBOUNDSOPINTERFACEIMPL_H + +namespace mlir { +class DialectRegistry; + +namespace tensor { +void registerValueBoundsOpInterfaceExternalModels(DialectRegistry ®istry); +} // namespace tensor +} // namespace mlir + +#endif // MLIR_DIALECT_TENSOR_IR_VALUEBOUNDSOPINTERFACEIMPL_H diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -64,6 +64,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h" #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h" +#include "mlir/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.h" #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" @@ -140,6 +141,7 @@ tensor::registerBufferizableOpInterfaceExternalModels(registry); tensor::registerInferTypeOpInterfaceExternalModels(registry); tensor::registerTilingInterfaceExternalModels(registry); + tensor::registerValueBoundsOpInterfaceExternalModels(registry); vector::registerBufferizableOpInterfaceExternalModels(registry); } diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt --- a/mlir/include/mlir/Interfaces/CMakeLists.txt +++ b/mlir/include/mlir/Interfaces/CMakeLists.txt @@ -12,6 +12,7 @@ add_mlir_interface(ShapedOpInterfaces) add_mlir_interface(SideEffectInterfaces) add_mlir_interface(TilingInterface) +add_mlir_interface(ValueBoundsOpInterface) add_mlir_interface(VectorInterfaces) add_mlir_interface(ViewLikeInterface) diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h @@ -0,0 +1,158 @@ +//===- ValueBoundsOpInterface.h - Value Bounds ------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_VALUEBOUNDSOPINTERFACE_H_ +#define MLIR_INTERFACES_VALUEBOUNDSOPINTERFACE_H_ + +#include "mlir/Analysis/FlatValueConstraints.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Value.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" +#include "llvm/ADT/SetVector.h" + +namespace mlir { + +using ValueDimList = SmallVector>>; + +/// A helper class to be used with `ValueBoundsOpInterface`. This class stores a +/// constraint system and mapping of constrained variables to index-typed +/// values or dimension sizes of shaped values. +/// +/// Interface implementations of `ValueBoundsOpInterface` use `addBounds` to +/// insert constraints about their results and/or region block arguments into +/// the constraint set in the form of an AffineExpr. When a bound should be +/// expressed in terms of another value/dimension, `getExpr` can be used to +/// retrieve an AffineExpr that represents the specified value/dimension. +/// +/// When a value/dimension is retrieved for the first time through `getExpr`, +/// it is added to an internal worklist. See `computeBound` for more details. +/// +/// Note: Any modification of the IR invalides the data stored in this class. +class ValueBoundsConstraintSet { +public: + /// Compute a bound for the given index-typed value or shape dimension size. + /// The computed bound is stored in `resultMap`. The operands of the bound are + /// stored in `mapOperands`. An operand is either an index-type SSA value + /// or a shaped value and a dimension. + /// + /// `dim` must be `nullopt` if and only if `value` is index-typed. The bound + /// is computed in terms of values for which `stopCondition` evaluates to + /// "true". To that end, the backward slice (reverse use-def chain) of the + /// given value is visited in a worklist-driven manner and the constraint set + /// is populated according to `ValueBoundsOpInterface` for each visited value. + static LogicalResult computeBound(AffineMap &resultMap, + ValueDimList &mapOperands, Location loc, + presburger::BoundType type, Value value, + std::optional dim, + function_ref stopCondition); + + /// Bound the given index-typed value by the given expression. + void addBound(presburger::BoundType type, Value value, AffineExpr expr); + + /// Bound the given shaped value dimension by the given expression. + void addBound(presburger::BoundType type, Value value, int64_t dim, + AffineExpr expr); + + /// Bound the given index-typed value by a constant or SSA value. + void addBound(presburger::BoundType type, Value value, OpFoldResult ofr); + + /// Bound the given shaped value dimension by a constant or SSA value. + void addBound(presburger::BoundType type, Value value, int64_t dim, + OpFoldResult ofr); + + /// Return an expression that represents the given index-typed value or shaped + /// value dimension. If this value/dimension was not used so far, it is added + /// to the worklist. + /// + /// `dim` must be `nullopt` if and only if the given value is of index type. + AffineExpr getExpr(Value value, std::optional dim = std::nullopt); + + /// Return an expression that represents a constant or index-typed SSA value. + /// In case of a value, if this value was not used so far, it is added to the + /// worklist. + AffineExpr getExpr(OpFoldResult ofr); + + /// Return an expression that represents a constant. + AffineExpr getExpr(int64_t constant); + +protected: + /// Dimension identifier to indicate a value is index-typed. + static constexpr int64_t kIndexValue = -1; + + using ValueDim = std::pair; + ValueBoundsConstraintSet(ValueDim valueDim); + + /// Iteratively process all elements on the worklist until an index-typed + /// value or shaped value meets `stopCondition`. Such values are not processed + /// any further. + void processWorklist(function_ref stopCondition); + + /// Bound the given column in the underlying constraint set by the given + /// expression. + void addBound(presburger::BoundType type, int64_t pos, AffineExpr expr); + + /// Return the column position of the given value/dimension. Asserts that the + /// value/dimension exists in the constraint set. + int64_t getPos(Value value, int64_t dim = kIndexValue) const; + + /// Insert a value/dimension into the constraint set. If `isSymbol` is set to + /// "false", a dimension is added. + /// + /// Note: There are certain affine restrictions wrt. dimensions. E.g., they + /// cannot be multiplied. Furthermore, bounds can only be queried for + /// dimensions but not for symbols. + int64_t insert(ValueDim valueDim, bool isSymbol = true); + + /// Project out the given column in the constraint set. + void projectOut(int64_t pos); + + /// Project out all columns for which the condition holds. + void projectOut(function_ref condition); + + /// Mapping of columns to values/shape dimensions. + SmallVector positionToValueDim; + /// Reverse mapping of values/shape dimensions to columns. + DenseMap valueDimToPosition; + + /// Worklist of values/shape dimensions that have not been processed yet. + SetVector worklist; + + /// Constraint system of equalities and inequalities. + FlatConstraints cstr; + + /// Builder for constructing affine expressions. + Builder builder; +}; + +} // namespace mlir + +#include "mlir/Interfaces/ValueBoundsOpInterface.h.inc" + +namespace mlir { + +/// Default implementation for destination style ops: Tied OpResults and +/// OpOperands have the same type. +template +struct DstValueBoundsOpInterfaceExternalModel + : public ValueBoundsOpInterface::ExternalModel< + DstValueBoundsOpInterfaceExternalModel, ConcreteOp> { + void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, + ValueBoundsConstraintSet &cstr) const { + auto dstOp = cast(op); + assert(value.getDefiningOp() == dstOp); + + Value tiedOperand = dstOp.getTiedOpOperand(value.cast())->get(); + cstr.addBound(presburger::BoundType::EQ, value, dim, + cstr.getExpr(tiedOperand, dim)); + } +}; + +} // namespace mlir + +#endif // MLIR_INTERFACES_VALUEBOUNDSOPINTERFACE_H_ diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.td b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.td @@ -0,0 +1,63 @@ +//===-- ValueBoundsOpInterface.td - Value Bounds -----------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef VALUEBOUNDSOPINTERFACE +#define VALUEBOUNDSOPINTERFACE + +include "mlir/IR/OpBase.td" + +def ValueBoundsOpInterface : OpInterface<"ValueBoundsOpInterface"> { + let description = [{ + This interface allows operations with index-typed and/or shaped value-typed + results/block arguments to specify range bounds. These bounds are stored in + a constraint set. The constraint set can then be queried to compute bounds + in terms of other values that are stored in the constraint set. + }]; + let cppNamespace = "::mlir"; + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Populate the constraint set with bounds for the given index-typed + value. + + Note: If `value` is a block argument, it must belong to an entry block + of a region. Unstructured control flow graphs are not supported at the + moment. + }], + /*retType=*/"void", + /*methodName=*/"populateBoundsForIndexValue", + /*args=*/(ins "::mlir::Value":$value, + "::mlir::ValueBoundsConstraintSet &":$cstr), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + llvm_unreachable("not implemented"); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Populate the constraint set with bounds for the size of the specified + dimension of the given shaped value. + + Note: If `value` is a block argument, it must belong to an entry block + of a region. Unstructured control flow graphs are not supported at the + moment. + }], + /*retType=*/"void", + /*methodName=*/"populateBoundsForShapedValueDim", + /*args=*/(ins "::mlir::Value":$value, + "int64_t":$dim, + "::mlir::ValueBoundsConstraintSet &":$cstr), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + llvm_unreachable("not implemented"); + }] + >, + ]; +} + +#endif // VALUEBOUNDSOPINTERFACE diff --git a/mlir/lib/Analysis/FlatValueConstraints.cpp b/mlir/lib/Analysis/FlatValueConstraints.cpp --- a/mlir/lib/Analysis/FlatValueConstraints.cpp +++ b/mlir/lib/Analysis/FlatValueConstraints.cpp @@ -457,7 +457,7 @@ SmallVectorImpl *lbMaps, SmallVectorImpl *ubMaps, bool closedUB) { - assert(num < getNumDimVars() && "invalid range"); + assert(offset + num <= getNumDimVars() && "invalid range"); // Basic simplification. normalizeConstraintsByGCD(); diff --git a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt @@ -12,6 +12,7 @@ LoopUnroll.cpp LoopUnrollAndJam.cpp PipelineDataTransfer.cpp + ReifyValueBounds.cpp SuperVectorize.cpp SimplifyAffineStructures.cpp @@ -33,7 +34,9 @@ MLIRPass MLIRSCFUtils MLIRSideEffectInterfaces + MLIRTensorDialect MLIRTransformUtils + MLIRValueBoundsOpInterface MLIRVectorDialect MLIRVectorUtils MLIRVectorToLLVM diff --git a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp @@ -0,0 +1,83 @@ +//===- ReifyValueBounds.cpp --- Reify value bounds with affine ops ------*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/Transforms/Transforms.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Interfaces/ValueBoundsOpInterface.h" + +using namespace mlir; + +FailureOr mlir::reifyValueBound(OpBuilder &b, Location loc, + presburger::BoundType type, + Value value, + std::optional dim) { + auto stopCondition = [&](Value v) { + // Reify in terms of SSA values that are different from `value`. + return v != value; + }; + return reifyValueBound(b, loc, type, value, dim, stopCondition); +} + +FailureOr +mlir::reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type, + Value value, std::optional dim, + function_ref stopCondition) { + // Compute bound. + AffineMap boundMap; + ValueDimList mapOperands; + if (failed(ValueBoundsConstraintSet::computeBound( + boundMap, mapOperands, loc, type, value, dim, stopCondition))) + return failure(); + + // Materialize tensor.dim/memref.dim ops. + SmallVector operands; + for (auto valueDim : mapOperands) { + Value value = valueDim.first; + std::optional dim = valueDim.second; + + if (!dim.has_value()) { + // This is an index-typed value. + assert(value.getType().isIndex() && "expected index type"); + operands.push_back(value); + continue; + } + + assert(value.getType().cast().isDynamicDim(*dim) && + "expected dynamic dim"); + if (value.getType().isa()) { + // A tensor dimension is used: generate a tensor.dim. + operands.push_back(b.create(loc, value, *dim)); + } else if (value.getType().isa()) { + // A memref dimension is used: generate a memref.dim. + operands.push_back(b.create(loc, value, *dim)); + } else { + llvm_unreachable("cannot generate DimOp for unsupported shaped type"); + } + } + + // Simplify and return bound. + mlir::canonicalizeMapAndOperands(&boundMap, &operands); + // Check for special cases where no affine.apply op is needed. + if (boundMap.isSingleConstant()) { + // Bound is a constant: return an IntegerAttr. + return static_cast( + b.getIndexAttr(boundMap.getSingleConstantResult())); + } + // No affine.apply op is needed if the bound is a single SSA value. + if (auto expr = boundMap.getResult(0).dyn_cast()) + return static_cast(operands[expr.getPosition()]); + if (auto expr = boundMap.getResult(0).dyn_cast()) + return static_cast( + operands[expr.getPosition() + boundMap.getNumDims()]); + // General case: build affine.apply op. + return static_cast( + b.create(loc, boundMap, operands).getResult()); +} diff --git a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt @@ -3,11 +3,13 @@ TensorInferTypeOpInterfaceImpl.cpp TensorOps.cpp TensorTilingInterfaceImpl.cpp + ValueBoundsOpInterfaceImpl.cpp ) add_mlir_dialect_library(MLIRTensorDialect TensorDialect.cpp TensorOps.cpp + ValueBoundsOpInterfaceImpl.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/mlir/Dialect/Tensor @@ -32,6 +34,7 @@ MLIRShapedOpInterfaces MLIRSideEffectInterfaces MLIRSupport + MLIRValueBoundsOpInterface MLIRViewLikeInterface ) diff --git a/mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp @@ -0,0 +1,134 @@ +//===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.h" + +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Interfaces/ValueBoundsOpInterface.h" + +using namespace mlir; +using presburger::BoundType; + +namespace mlir { +namespace tensor { +namespace { + +struct CastOpInterface + : public ValueBoundsOpInterface::ExternalModel { + void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, + ValueBoundsConstraintSet &cstr) const { + auto castOp = cast(op); + assert(value == castOp.getResult() && "invalid value"); + + if (castOp.getResult().getType().isa() && + castOp.getSource().getType().isa()) { + cstr.addBound(BoundType::EQ, value, dim, + cstr.getExpr(castOp.getSource(), dim)); + } + } +}; + +struct DimOpInterface + : public ValueBoundsOpInterface::ExternalModel { + void populateBoundsForIndexValue(Operation *op, Value value, + ValueBoundsConstraintSet &cstr) const { + auto dimOp = cast(op); + assert(value == dimOp.getResult() && "invalid value"); + + auto constIndex = dimOp.getConstantIndex(); + if (!constIndex.has_value()) + return; + cstr.addBound(BoundType::EQ, value, + cstr.getExpr(dimOp.getSource(), *constIndex)); + } +}; + +struct EmptyOpInterface + : public ValueBoundsOpInterface::ExternalModel { + void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, + ValueBoundsConstraintSet &cstr) const { + auto emptyOp = cast(op); + assert(value == emptyOp.getResult() && "invalid value"); + + cstr.addBound(BoundType::EQ, value, dim, + cstr.getExpr(emptyOp.getMixedSizes()[dim])); + } +}; + +struct ExtractSliceOpInterface + : public ValueBoundsOpInterface::ExternalModel { + void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, + ValueBoundsConstraintSet &cstr) const { + auto extractSliceOp = cast(op); + assert(value == extractSliceOp.getResult() && "invalid value"); + + llvm::SmallBitVector dropped = extractSliceOp.getDroppedDims(); + int64_t ctr = -1; + for (int64_t i = 0, e = extractSliceOp.getMixedSizes().size(); i < e; ++i) { + // Skip over rank-reduced dimensions. + if (!dropped.test(i)) + ++ctr; + if (ctr == dim) { + cstr.addBound(BoundType::EQ, value, dim, + extractSliceOp.getMixedSizes()[i]); + return; + } + } + llvm_unreachable("could not find non-rank-reduced dim"); + } +}; + +struct PadOpInterface + : public ValueBoundsOpInterface::ExternalModel { + void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, + ValueBoundsConstraintSet &cstr) const { + auto padOp = cast(op); + assert(value == padOp.getResult() && "invalid value"); + + AffineExpr expr = cstr.getExpr(padOp.getSource(), dim) + + cstr.getExpr(padOp.getMixedLowPad()[dim]) + + cstr.getExpr(padOp.getMixedHighPad()[dim]); + cstr.addBound(BoundType::EQ, value, dim, expr); + } +}; + +struct RankOpInterface + : public ValueBoundsOpInterface::ExternalModel { + void populateBoundsForIndexValue(Operation *op, Value value, + ValueBoundsConstraintSet &cstr) const { + auto rankOp = cast(op); + assert(value == rankOp.getResult() && "invalid value"); + + auto tensorType = rankOp.getTensor().getType().dyn_cast(); + if (!tensorType) + return; + cstr.addBound(BoundType::EQ, value, cstr.getExpr(tensorType.getRank())); + } +}; + +} // namespace +} // namespace tensor +} // namespace mlir + +void mlir::tensor::registerValueBoundsOpInterfaceExternalModels( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) { + tensor::CastOp::attachInterface(*ctx); + tensor::DimOp::attachInterface(*ctx); + tensor::EmptyOp::attachInterface(*ctx); + tensor::ExtractSliceOp::attachInterface( + *ctx); + tensor::InsertOp::attachInterface< + DstValueBoundsOpInterfaceExternalModel>(*ctx); + tensor::InsertSliceOp::attachInterface< + DstValueBoundsOpInterfaceExternalModel>(*ctx); + tensor::PadOp::attachInterface(*ctx); + tensor::RankOp::attachInterface(*ctx); + }); +} diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt --- a/mlir/lib/Interfaces/CMakeLists.txt +++ b/mlir/lib/Interfaces/CMakeLists.txt @@ -14,6 +14,7 @@ ShapedOpInterfaces.cpp SideEffectInterfaces.cpp TilingInterface.cpp + ValueBoundsOpInterface.cpp VectorInterfaces.cpp ViewLikeInterface.cpp ) @@ -49,6 +50,7 @@ add_mlir_interface_library(ShapedOpInterfaces) add_mlir_interface_library(SideEffectInterfaces) add_mlir_interface_library(TilingInterface) +add_mlir_interface_library(ValueBoundsOpInterface) add_mlir_interface_library(VectorInterfaces) add_mlir_interface_library(ViewLikeInterface) diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp @@ -0,0 +1,331 @@ +//===- ValueBoundsOpInterface.cpp - Value Bounds -------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Interfaces/ValueBoundsOpInterface.h" + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Matchers.h" +#include "llvm/ADT/APSInt.h" + +using namespace mlir; +using presburger::BoundType; +using presburger::VarKind; + +namespace mlir { +#include "mlir/Interfaces/ValueBoundsOpInterface.cpp.inc" +} // namespace mlir + +/// If ofr is a constant integer or an IntegerAttr, return the integer. +static std::optional getConstantIntValue(OpFoldResult ofr) { + // Case 1: Check for Constant integer. + if (auto val = ofr.dyn_cast()) { + APSInt intVal; + if (matchPattern(val, m_ConstantInt(&intVal))) + return intVal.getSExtValue(); + return std::nullopt; + } + // Case 2: Check for IntegerAttr. + Attribute attr = ofr.dyn_cast(); + if (auto intAttr = attr.dyn_cast_or_null()) + return intAttr.getValue().getSExtValue(); + return std::nullopt; +} + +ValueBoundsConstraintSet::ValueBoundsConstraintSet(ValueDim valueDim) + : builder(valueDim.first.getContext()) { + insert(valueDim, /*isSymbol=*/false); +} + +#ifndef NDEBUG +static void assertValidValueDim(Value value, std::optional dim) { + if (value.getType().isIndex()) { + assert(!dim.has_value() && "invalid dim value"); + } else if (auto shapedType = value.getType().dyn_cast()) { + assert(*dim >= 0 && "invalid dim value"); + if (shapedType.hasRank()) + assert(*dim < shapedType.getRank() && "invalid dim value"); + } else { + llvm_unreachable("unsupported type"); + } +} +#endif // NDEBUG + +void ValueBoundsConstraintSet::addBound(BoundType type, int64_t pos, + AffineExpr expr) { + LogicalResult status = cstr.addBound( + type, pos, + AffineMap::get(cstr.getNumDimVars(), cstr.getNumSymbolVars(), expr)); + (void)status; + assert(succeeded(status) && "failed to add bound to constraint system"); +} + +void ValueBoundsConstraintSet::addBound(BoundType type, Value value, + AffineExpr expr) { + assert(value.getType().isIndex() && "expected index type"); + assert((value.isa() || + value.cast().getOwner()->isEntryBlock()) && + "unstructured control flow is not supported"); + + addBound(type, getPos(value), expr); +} + +void ValueBoundsConstraintSet::addBound(BoundType type, Value value, + int64_t dim, AffineExpr expr) { +#ifndef NDEBUG + assertValidValueDim(value, dim); + assert((value.isa() || + value.cast().getOwner()->isEntryBlock()) && + "unstructured control flow is not supported"); +#endif // NDEBUG + + addBound(type, getPos(value, dim), expr); +} + +void ValueBoundsConstraintSet::addBound(BoundType type, Value value, + OpFoldResult ofr) { + assert(value.getType().isIndex() && "expected index type"); + addBound(type, getPos(value), getExpr(ofr)); +} + +void ValueBoundsConstraintSet::addBound(BoundType type, Value value, + int64_t dim, OpFoldResult ofr) { +#ifndef NDEBUG + assertValidValueDim(value, dim); +#endif // NDEBUG + addBound(type, getPos(value, dim), getExpr(ofr)); +} + +AffineExpr ValueBoundsConstraintSet::getExpr(Value value, + std::optional dim) { +#ifndef NDEBUG + assertValidValueDim(value, dim); +#endif // NDEBUG + + auto shapedType = value.getType().dyn_cast(); + if (shapedType) { + // Static dimension: return constant directly. + if (shapedType.hasRank() && !shapedType.isDynamicDim(*dim)) + return builder.getAffineConstantExpr(shapedType.getDimSize(*dim)); + } else { + // Constant index value: return directly. + if (auto constInt = getConstantIntValue(value)) + return builder.getAffineConstantExpr(*constInt); + } + + // Dynamic value: add to constraint set. + ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue)); + if (valueDimToPosition.find(valueDim) == valueDimToPosition.end()) + (void)insert(valueDim); + int64_t pos = getPos(value, dim.value_or(kIndexValue)); + return pos < cstr.getNumDimVars() + ? builder.getAffineDimExpr(pos) + : builder.getAffineSymbolExpr(pos - cstr.getNumDimVars()); +} + +AffineExpr ValueBoundsConstraintSet::getExpr(OpFoldResult ofr) { + if (Value value = ofr.dyn_cast()) + return getExpr(value, /*dim=*/std::nullopt); + auto constInt = getConstantIntValue(ofr); + assert(constInt.has_value() && "expected Integer constant"); + return builder.getAffineConstantExpr(*constInt); +} + +AffineExpr ValueBoundsConstraintSet::getExpr(int64_t constant) { + return builder.getAffineConstantExpr(constant); +} + +int64_t ValueBoundsConstraintSet::insert(ValueDim valueDim, bool isSymbol) { + assert((valueDimToPosition.find(valueDim) == valueDimToPosition.end()) && + "already mapped"); + int64_t pos = isSymbol ? cstr.appendVar(VarKind::Symbol) + : cstr.appendVar(VarKind::SetDim); + positionToValueDim.insert(positionToValueDim.begin() + pos, valueDim); + // Update reverse mapping. + for (int64_t i = pos; i < positionToValueDim.size(); ++i) + valueDimToPosition[positionToValueDim[i]] = i; + + worklist.insert(pos); + return pos; +} + +int64_t ValueBoundsConstraintSet::getPos(Value value, int64_t dim) const { +#ifndef NDEBUG + assertValidValueDim(value, dim == kIndexValue + ? std::nullopt + : std::make_optional(dim)); +#endif // NDEBUG + + auto it = valueDimToPosition.find(std::make_pair(value, dim)); + assert(it != valueDimToPosition.end() && "expected mapped entry"); + return it->second; +} + +static Operation *getOwnerOfValue(Value value) { + if (auto bbArg = value.dyn_cast()) + return bbArg.getOwner()->getParentOp(); + return value.getDefiningOp(); +} + +void ValueBoundsConstraintSet::processWorklist( + function_ref stopCondition) { + while (!worklist.empty()) { + int64_t pos = worklist.pop_back_val(); + ValueDim valueDim = positionToValueDim[pos]; + Value value = valueDim.first; + int64_t dim = valueDim.second; + + // Check for static dim size. + if (dim != kIndexValue) { + auto shapedType = value.getType().cast(); + if (shapedType.hasRank() && !shapedType.isDynamicDim(dim)) { + addBound(BoundType::EQ, value, dim, + builder.getAffineConstantExpr(shapedType.getDimSize(dim))); + continue; + } + } + + // Do not process any further if the stop condition is met. + if (stopCondition(value)) + continue; + + // Query `ValueBoundsOpInterface` for constraints. New items may be added to + // the worklist. + auto valueBoundsOp = + dyn_cast(getOwnerOfValue(value)); + if (!valueBoundsOp) + continue; + if (dim == kIndexValue) { + valueBoundsOp.populateBoundsForIndexValue(value, *this); + } else { + valueBoundsOp.populateBoundsForShapedValueDim(value, dim, *this); + } + } +} + +void ValueBoundsConstraintSet::projectOut(int64_t pos) { + assert(pos >= 0 && pos < positionToValueDim.size() && "invalid position"); + cstr.projectOut(pos); + bool erased = valueDimToPosition.erase(positionToValueDim[pos]); + (void)erased; + assert(erased && "inconsistent reverse mapping"); + positionToValueDim.erase(positionToValueDim.begin() + pos); + // Update reverse mapping. + for (int64_t i = pos; i < positionToValueDim.size(); ++i) + valueDimToPosition[positionToValueDim[i]] = i; +} + +void ValueBoundsConstraintSet::projectOut( + function_ref condition) { + int64_t nextPos = 0; + while (nextPos < positionToValueDim.size()) { + if (condition(positionToValueDim[nextPos])) { + projectOut(nextPos); + // The column was projected out so another column is now at that position. + // Do not increase the counter. + } else { + ++nextPos; + } + } +} + +LogicalResult ValueBoundsConstraintSet::computeBound( + AffineMap &resultMap, ValueDimList &mapOperands, Location loc, + presburger::BoundType type, Value value, std::optional dim, + function_ref stopCondition) { + // Only EQ bounds are supported at the moment. + assert(type == BoundType::EQ && "unsupported bound type"); + + // Process the backward slice of `value` (i.e., reverse use-def chain) until + // `stopCondition` is met. + ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue)); + ValueBoundsConstraintSet cstr(valueDim); + cstr.processWorklist(stopCondition); + + // Project out all variables (apart from `valueDim`) that do not match the + // stop condition. + cstr.projectOut([&](ValueDim p) { + // Do not project out `valueDim`. + if (valueDim == p) + return false; + return !stopCondition(p.first); + }); + + // Compute lower and upper bounds for `valueDim`. + int64_t pos = cstr.getPos(value, dim.value_or(kIndexValue)); + SmallVector lb(1), ub(1); + cstr.cstr.getSliceBounds(pos, 1, value.getContext(), &lb, &ub, + /*getClosedUB=*/true); + // Note: There are TODOs in the implementation of `getSliceBounds`. In such a + // case, no lower/upper bound can be computed at the moment. + if (lb.empty() || !lb[0] || ub.empty() || !ub[0] || + lb[0].getNumResults() != 1 || ub[0].getNumResults() != 1) + return failure(); + + // Look for same lower and upper bound: EQ bound. + if (ub[0] != lb[0]) + return failure(); + + // Gather all SSA values that are used in the computed bound. + mapOperands.clear(); + assert(cstr.cstr.getNumDimAndSymbolVars() == cstr.positionToValueDim.size() && + "inconsistent mapping state"); + Builder b(value.getContext()); + SmallVector replacementDims, replacementSymbols; + int64_t numDims = 0, numSymbols = 0; + for (int64_t i = 0; i < cstr.cstr.getNumDimAndSymbolVars(); ++i) { + // Skip `value`. + if (i == pos) + continue; + // Check if the position `i` is used in the generated bound. If so, it must + // be included in the generated affine.apply op. + bool used = false; + bool isDim = i < cstr.cstr.getNumDimVars(); + if (isDim) { + if (lb[0].isFunctionOfDim(i)) + used = true; + } else { + if (lb[0].isFunctionOfSymbol(i - cstr.cstr.getNumDimVars())) + used = true; + } + + if (!used) { + // Not used: Remove dim/symbol from the result. + if (isDim) { + replacementDims.push_back(b.getAffineConstantExpr(0)); + } else { + replacementSymbols.push_back(b.getAffineConstantExpr(0)); + } + continue; + } + + if (isDim) { + replacementDims.push_back(b.getAffineDimExpr(numDims++)); + } else { + replacementSymbols.push_back(b.getAffineSymbolExpr(numSymbols++)); + } + + ValueBoundsConstraintSet::ValueDim valueDim = cstr.positionToValueDim[i]; + Value value = valueDim.first; + int64_t dim = valueDim.second; + if (dim == ValueBoundsConstraintSet::kIndexValue) { + // An index-type value is used: can be used directly in the affine.apply + // op. + assert(value.getType().isIndex() && "expected index type"); + mapOperands.push_back(std::make_pair(value, std::nullopt)); + continue; + } + + assert(value.getType().cast().isDynamicDim(dim) && + "expected dynamic dim"); + mapOperands.push_back(std::make_pair(value, dim)); + } + + resultMap = lb[0].replaceDimsAndSymbols(replacementDims, replacementSymbols, + numDims, numSymbols); + return success(); +} diff --git a/mlir/test/Dialect/Affine/value-bounds-reification.mlir b/mlir/test/Dialect/Affine/value-bounds-reification.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Affine/value-bounds-reification.mlir @@ -0,0 +1,22 @@ +// RUN: mlir-opt %s -test-affine-reify-value-bounds="reify-to-func-args" \ +// RUN: -verify-diagnostics -split-input-file | FileCheck %s + +// CHECK-LABEL: func @reify_through_chain( +// CHECK-SAME: %[[sz0:.*]]: index, %[[sz2:.*]]: index +// CHECK: %[[c10:.*]] = arith.constant 10 : index +// CHECK: return %[[sz0]], %[[c10]], %[[sz2]] +func.func @reify_through_chain(%sz0: index, %sz2: index) -> (index, index, index) { + %c2 = arith.constant 2 : index + %0 = tensor.empty(%sz0, %sz2) : tensor + %1 = tensor.cast %0 : tensor to tensor + %pos = arith.constant 0 : index + %f = arith.constant 0.0 : f32 + %2 = tensor.insert %f into %1[%pos, %pos, %pos] : tensor + %3 = tensor.dim %2, %c2 : tensor + + %4 = "test.reify_bound"(%2) {dim = 0} : (tensor) -> (index) + %5 = "test.reify_bound"(%2) {dim = 1} : (tensor) -> (index) + %6 = "test.reify_bound"(%3) : (index) -> (index) + + return %4, %5, %6 : index, index, index +} diff --git a/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir @@ -0,0 +1,137 @@ +// RUN: mlir-opt %s -test-affine-reify-value-bounds -verify-diagnostics \ +// RUN: -split-input-file | FileCheck %s + +func.func @unknown_op() -> index { + %0 = "test.foo"() : () -> (tensor) + // expected-error @below{{could not reify bound}} + %1 = "test.reify_bound"(%0) {dim = 0} : (tensor) -> (index) + return %1 : index +} + +// ----- + +// CHECK-LABEL: func @cast( +// CHECK: %[[c10:.*]] = arith.constant 10 : index +// CHECK: return %[[c10]] +func.func @cast(%t: tensor<10xf32>) -> index { + %0 = tensor.cast %t : tensor<10xf32> to tensor + %1 = "test.reify_bound"(%0) {dim = 0} : (tensor) -> (index) + return %1 : index +} + +// ----- + +func.func @cast_unranked(%t: tensor<*xf32>) -> index { + %0 = tensor.cast %t : tensor<*xf32> to tensor + // expected-error @below{{could not reify bound}} + %1 = "test.reify_bound"(%0) {dim = 0} : (tensor) -> (index) + return %1 : index +} + +// ----- + +// CHECK-LABEL: func @dim( +// CHECK-SAME: %[[t:.*]]: tensor +// CHECK: %[[dim:.*]] = tensor.dim %[[t]] +// CHECK: %[[dim:.*]] = tensor.dim %[[t]] +// CHECK: return %[[dim]] +func.func @dim(%t: tensor) -> index { + %c0 = arith.constant 0 : index + %0 = tensor.dim %t, %c0 : tensor + %1 = "test.reify_bound"(%0) : (index) -> (index) + return %1 : index +} + +// ----- + +// CHECK-LABEL: func @empty( +// CHECK-SAME: %[[sz:.*]]: index +// CHECK: %[[c6:.*]] = arith.constant 6 : index +// CHECK: return %[[c6]], %[[sz]] +func.func @empty(%sz: index) -> (index, index) { + %0 = tensor.empty(%sz) : tensor<6x?xf32> + %1 = "test.reify_bound"(%0) {dim = 0} : (tensor<6x?xf32>) -> (index) + %2 = "test.reify_bound"(%0) {dim = 1} : (tensor<6x?xf32>) -> (index) + return %1, %2 : index, index +} + +// ----- + +// CHECK-LABEL: func @extract_slice_dynamic( +// CHECK-SAME: %[[t:.*]]: tensor, %[[sz:.*]]: index +// CHECK: return %[[sz]] +func.func @extract_slice_dynamic(%t: tensor, %sz: index) -> index { + %0 = tensor.extract_slice %t[2][%sz][1] : tensor to tensor + %1 = "test.reify_bound"(%0) {dim = 0} : (tensor) -> (index) + return %1 : index +} + +// ----- + +// CHECK-LABEL: func @extract_slice_static( +// CHECK-SAME: %[[t:.*]]: tensor +// CHECK: %[[c5:.*]] = arith.constant 5 : index +// CHECK: return %[[c5]] +func.func @extract_slice_static(%t: tensor) -> index { + %0 = tensor.extract_slice %t[2][5][1] : tensor to tensor<5xf32> + %1 = "test.reify_bound"(%0) {dim = 0} : (tensor<5xf32>) -> (index) + return %1 : index +} + +// ----- + +// CHECK-LABEL: func @extract_slice_rank_reduce( +// CHECK-SAME: %[[t:.*]]: tensor, %[[sz:.*]]: index +// CHECK: return %[[sz]] +func.func @extract_slice_rank_reduce(%t: tensor, %sz: index) -> index { + %0 = tensor.extract_slice %t[0, 2][1, %sz][1, 1] : tensor to tensor + %1 = "test.reify_bound"(%0) {dim = 0} : (tensor) -> (index) + return %1 : index +} + +// ----- + +// CHECK-LABEL: func @insert( +// CHECK-SAME: %[[t:.*]]: tensor +// CHECK: %[[c0:.*]] = arith.constant 0 : index +// CHECK: %[[dim:.*]] = tensor.dim %[[t]], %[[c0]] +// CHECK: return %[[dim]] +func.func @insert(%t: tensor, %f: f32, %pos: index) -> index { + %0 = tensor.insert %f into %t[%pos] : tensor + %1 = "test.reify_bound"(%0) {dim = 0} : (tensor) -> (index) + return %1 : index +} + +// ----- + +// CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 + 12)> +// CHECK: #[[$map1:.*]] = affine_map<()[s0, s1] -> (s0 + s1 * 2)> +// CHECK-LABEL: func @pad( +// CHECK-SAME: %[[t:.*]]: tensor, %[[a:.*]]: index, %[[b:.*]]: index +// CHECK: %[[bound1:.*]] = affine.apply #[[$map]]()[%[[b]]] +// CHECK: %[[c0:.*]] = arith.constant 0 : index +// CHECK: %[[dim0:.*]] = tensor.dim %[[t]], %[[c0]] +// CHECK: %[[bound0:.*]] = affine.apply #[[$map1]]()[%[[dim0]], %[[a]]] +// CHECK: return %[[bound0]], %[[bound1]] +func.func @pad(%t: tensor, %a: index, %b: index) -> (index, index) { + %pad = arith.constant 0.0 : f32 + %0 = tensor.pad %t low[%a, 5] high[%a, %b] { + ^bb0(%arg1: index, %arg2: index): + tensor.yield %pad : f32 + } : tensor to tensor + %1 = "test.reify_bound"(%0) {dim = 0} : (tensor) -> (index) + %2 = "test.reify_bound"(%0) {dim = 1} : (tensor) -> (index) + return %1, %2 : index, index +} + +// ----- + +// CHECK-LABEL: func @rank( +// CHECK-SAME: %[[t:.*]]: tensor<5xf32> +// CHECK: %[[c1:.*]] = arith.constant 1 : index +// CHECK: return %[[c1]] +func.func @rank(%t: tensor<5xf32>) -> index { + %0 = tensor.rank %t : tensor<5xf32> + %1 = "test.reify_bound"(%0) : (index) -> (index) + return %1 : index +} diff --git a/mlir/test/lib/Dialect/Affine/CMakeLists.txt b/mlir/test/lib/Dialect/Affine/CMakeLists.txt --- a/mlir/test/lib/Dialect/Affine/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Affine/CMakeLists.txt @@ -4,6 +4,7 @@ TestAffineLoopUnswitching.cpp TestAffineLoopParametricTiling.cpp TestDecomposeAffineOps.cpp + TestReifyValueBounds.cpp TestLoopFusion.cpp TestLoopMapping.cpp TestLoopPermutation.cpp diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp @@ -0,0 +1,120 @@ +//===- TestReifyValueBounds.cpp - Test value bounds reification -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/Transforms/Transforms.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/ValueBoundsOpInterface.h" +#include "mlir/Pass/Pass.h" + +#define PASS_NAME "test-affine-reify-value-bounds" + +using namespace mlir; + +namespace { + +/// This pass applies the permutation on the first maximal perfect nest. +struct TestReifyValueBounds + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReifyValueBounds) + + StringRef getArgument() const final { return PASS_NAME; } + StringRef getDescription() const final { + return "Tests affine loop permutation utility"; + } + TestReifyValueBounds() = default; + TestReifyValueBounds(const TestReifyValueBounds &pass) : PassWrapper(pass){}; + + void runOnOperation() override; + +private: + Option reifyToFuncArgs{ + *this, "reify-to-func-args", + llvm::cl::desc("Reify in terms of function args"), llvm::cl::init(false)}; +}; + +} // namespace + +/// Look for "test.reify_bound" ops in the input and replace their results with +/// the reified values. +static LogicalResult testReifyValueBounds(func::FuncOp funcOp, + bool reifyToFuncArgs) { + IRRewriter rewriter(funcOp.getContext()); + WalkResult result = funcOp.walk([&](Operation *op) { + // Look for test.reify_bound ops. + if (op->getName().getStringRef() == "test.reify_bound") { + if (op->getNumOperands() != 1 || op->getNumResults() != 1 || + !op->getResultTypes()[0].isIndex()) { + op->emitOpError("invalid op"); + return WalkResult::skip(); + } + Value value = op->getOperand(0); + if (value.getType().isa() != + !op->hasAttrOfType("dim")) { + // Op should have "dim" attribute if and only if the operand is an + // index-typed value. + op->emitOpError("invalid op"); + return WalkResult::skip(); + } + + auto dim = value.getType().isIndex() + ? std::nullopt + : std::make_optional( + op->getAttrOfType("dim").getInt()); + + // Reify value bound. + rewriter.setInsertionPointAfterValue(value); + FailureOr reified; + if (!reifyToFuncArgs) { + // Reify in terms of the op's operands. + reified = reifyValueBound(rewriter, op->getLoc(), + presburger::BoundType::EQ, value, dim); + } else { + // Reify in terms of function block arguments. + auto stopCondition = [](Value v) { + auto bbArg = v.dyn_cast(); + if (!bbArg) + return false; + return isa( + bbArg.getParentBlock()->getParentOp()); + }; + reified = + reifyValueBound(rewriter, op->getLoc(), presburger::BoundType::EQ, + value, dim, stopCondition); + } + if (failed(reified)) { + op->emitOpError("could not reify bound"); + return WalkResult::interrupt(); + } + + // Replace the op with the reified bound. + if (auto val = reified->dyn_cast()) { + rewriter.replaceOp(op, val); + return WalkResult::skip(); + } + Value constOp = rewriter.create( + op->getLoc(), reified->get().cast().getInt()); + rewriter.replaceOp(op, constOp); + return WalkResult::skip(); + } + return WalkResult::advance(); + }); + return failure(result.wasInterrupted()); +} + +void TestReifyValueBounds::runOnOperation() { + if (failed(testReifyValueBounds(getOperation(), reifyToFuncArgs))) + signalPassFailure(); +} + +namespace mlir { +void registerTestAffineReifyValueBoundsPass() { + PassRegistration(); +} +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -40,6 +40,7 @@ void registerSymbolTestPasses(); void registerRegionTestPasses(); void registerTestAffineDataCopyPass(); +void registerTestAffineReifyValueBoundsPass(); void registerTestDecomposeAffineOpPass(); void registerTestAffineLoopUnswitchingPass(); void registerTestAllReduceLoweringPass(); @@ -149,6 +150,7 @@ registerSymbolTestPasses(); registerRegionTestPasses(); registerTestAffineDataCopyPass(); + registerTestAffineReifyValueBoundsPass(); registerTestDecomposeAffineOpPass(); registerTestAffineLoopUnswitchingPass(); registerTestAllReduceLoweringPass(); diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -2716,7 +2716,9 @@ ":SCFDialect", ":SCFUtils", ":Support", + ":TensorDialect", ":Transforms", + ":ValueBoundsOpInterface", ":VectorDialect", ":VectorUtils", "//llvm:Support", @@ -5458,8 +5460,12 @@ "include/mlir/Transforms/InliningUtils.h", "lib/Dialect/Tensor/IR/TensorDialect.cpp", "lib/Dialect/Tensor/IR/TensorOps.cpp", + "lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp", + ], + hdrs = [ + "include/mlir/Dialect/Tensor/IR/Tensor.h", + "include/mlir/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.h", ], - hdrs = ["include/mlir/Dialect/Tensor/IR/Tensor.h"], includes = ["include"], deps = [ ":AffineDialect", @@ -5478,6 +5484,7 @@ ":Support", ":TensorOpsIncGen", ":TilingInterface", + ":ValueBoundsOpInterface", ":ViewLikeInterface", "//llvm:Support", ], @@ -8534,6 +8541,52 @@ ], ) +td_library( + name = "ValueBoundsOpInterfaceTdFiles", + srcs = [ + "include/mlir/Interfaces/ValueBoundsOpInterface.td", + ], + includes = ["include"], + deps = [ + ":OpBaseTdFiles", + ], +) + +gentbl_cc_library( + name = "ValueBoundsOpInterfaceIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + ["-gen-op-interface-decls"], + "include/mlir/Interfaces/ValueBoundsOpInterface.h.inc", + ), + ( + ["-gen-op-interface-defs"], + "include/mlir/Interfaces/ValueBoundsOpInterface.cpp.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Interfaces/ValueBoundsOpInterface.td", + deps = [ + ":ValueBoundsOpInterfaceTdFiles", + ], +) + +cc_library( + name = "ValueBoundsOpInterface", + srcs = ["lib/Interfaces/ValueBoundsOpInterface.cpp"], + hdrs = ["include/mlir/Interfaces/ValueBoundsOpInterface.h"], + includes = ["include"], + deps = [ + ":Analysis", + ":DestinationStyleOpInterface", + ":IR", + ":Support", + ":ValueBoundsOpInterfaceIncGen", + "//llvm:Support", + ], +) + cc_library( name = "TilingInterface", srcs = ["lib/Interfaces/TilingInterface.cpp"], diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -541,6 +541,7 @@ "//mlir:SCFDialect", "//mlir:Support", "//mlir:Transforms", + "//mlir:ValueBoundsOpInterface", "//mlir:VectorDialect", "//mlir:VectorUtils", ],