diff --git a/mlir/include/mlir/Dialect/Linalg/CMakeLists.txt b/mlir/include/mlir/Dialect/Linalg/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Linalg/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Linalg/CMakeLists.txt @@ -1,5 +1,6 @@ add_subdirectory(IR) add_subdirectory(TransformOps) +add_subdirectory(Transforms) set(LLVM_TARGET_DEFINITIONS Passes.td) mlir_tablegen(Passes.h.inc -gen-pass-decls -name Linalg) diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Linalg/Transforms/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CMakeLists.txt @@ -0,0 +1 @@ +add_mlir_interface(ValueBoundsOpInterface) diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/ValueBoundsOpInterface.h b/mlir/include/mlir/Dialect/Linalg/Transforms/ValueBoundsOpInterface.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/ValueBoundsOpInterface.h @@ -0,0 +1,147 @@ +//===- ValueBoundsOpInterface.h - Value Bounds ------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LINALG_TRANSFORMS_VALUEBOUNDSOPINTERFACE_H_ +#define MLIR_DIALECT_LINALG_TRANSFORMS_VALUEBOUNDSOPINTERFACE_H_ + +#include "mlir/Dialect/Affine/Analysis/AffineStructures.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Value.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" +#include "llvm/ADT/SetVector.h" + +namespace mlir { +namespace linalg { + +/// A helper class to be used with `ValueBoundsOpInterface`. This class stores a +/// constraint system and mapping of columns to values/shape dimensions. +/// +/// Note: This class maintains its own mapping to SSA values; no SSA values are +/// mapped in the underlying `FlatAffineValueConstraints`. This is because not +/// only SSA values but also shape dimensions of SSA values must be mapped. +class ValueBoundsConstraintSet { +public: + /// Reify a bound for the given index-typed value or shape dimension size in + /// terms of the owning op's operands. + static FailureOr + reifyBound(OpBuilder &b, Location loc, + presburger::IntegerPolyhedron::BoundType type, Value value, + int64_t dim = kIndexValue); + + /// Reify a bound for the given index-typed value or shape dimension size in + /// terms of SSA values for which `stopCondition` is met. + static FailureOr + reifyBound(OpBuilder &b, Location loc, + presburger::IntegerPolyhedron::BoundType type, Value value, + int64_t dim, function_ref stopCondition); + + /// Dimension indentifier to indicate a value is index-typed. + static const int64_t kIndexValue = -1; + + /// Bound the given index-typed value by the given expression. + void addBound(presburger::IntegerPolyhedron::BoundType type, Value value, + AffineExpr expr); + + /// Bound the the given shaped value dimension by the given expression. + void addBound(presburger::IntegerPolyhedron::BoundType type, Value value, + int64_t dim, AffineExpr expr); + + /// Bound the given index-typed value by a constant or SSA value. + void addBound(presburger::IntegerPolyhedron::BoundType type, Value value, + OpFoldResult ofr); + + /// Bound the the given shaped value dimension by a constant or SSA value. + void addBound(presburger::IntegerPolyhedron::BoundType type, Value value, + int64_t dim, OpFoldResult ofr); + + /// Return an expression that represents the given index-typed value or shaped + /// value dimension. If this value/dimension was not used so far, it is added + /// to the worklist. + AffineExpr getExpr(Value value, int64_t dim = kIndexValue); + + /// Return an expression that represents a constant or index-typed SSA value. + /// In case of a value, if this value was not used so far, it is added to the + /// worklist. + AffineExpr getExpr(OpFoldResult ofr); + + /// Return an expression that represents a constant. + AffineExpr getExpr(int64_t val); + +private: + using ValueDim = std::pair; + ValueBoundsConstraintSet(ValueDim valueDim); + + /// Iteratively process all elements on the worklist until an index-typed + /// value or shaped value meets `stopCondition`. Such values are not processed + /// any further. + void processWorklist(function_ref stopCondition); + + /// Bound the given column in the underlying constraint set by the given + /// expression. + void addBound(presburger::IntegerPolyhedron::BoundType type, int64_t pos, + AffineExpr expr); + + /// Return the column position of the given value/dimension. Asserts that the + /// value/dimension exists in the constraint set. + int64_t getPos(Value value, int64_t dim = kIndexValue) const; + + /// Insert a value/dimension into the constraint set. If `isSymbol` is set to + /// "false", a dimension is added. + /// + /// Note: There are certain affine restrictions wrt. dimensions. E.g., they + /// cannot be multiplied. Furthermore, bounds can only be queried for + /// dimensions but not for symbols. + int64_t insert(ValueDim valueDim, bool isSymbol = true); + + /// Project out the given column in the constraint set. + void projectOut(int64_t pos); + + /// Mapping of columns to values/shape dimensions. + SmallVector positionToValueDim; + /// Reverse mapping of values/shape dimensions to columns. + DenseMap valueDimToPosition; + + /// Worklist of values/shape dimensions that have not been processed yet. + SetVector worklist; + + /// Constraint system of equalities and inequalities. + FlatAffineValueConstraints cstr; + + /// Builder for constructing affine expressions. + Builder builder; +}; +} // namespace linalg +} // namespace mlir + +#include "mlir/Dialect/Linalg/Transforms/ValueBoundsOpInterface.h.inc" + +namespace mlir { +namespace linalg { + +/// Default implementation for destination style ops: Tied OpResults and +/// OpOperands have the same type. +template +struct DstValueBoundsOpInterfaceExternalModel + : public ValueBoundsOpInterface::ExternalModel< + DstValueBoundsOpInterfaceExternalModel, ConcreteOp> { + void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, + ValueBoundsConstraintSet &cstr) const { + auto dstOp = cast(op); + assert(value.getDefiningOp() == dstOp); + + Value tiedOperand = dstOp.getTiedOpOperand(value.cast())->get(); + cstr.addBound(presburger::IntegerPolyhedron::BoundType::EQ, value, dim, + cstr.getExpr(tiedOperand, dim)); + } +}; + +} // namespace linalg +} // namespace mlir + +#endif // MLIR_DIALECT_LINALG_TRANSFORMS_VALUEBOUNDSOPINTERFACE_H_ diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/ValueBoundsOpInterface.td b/mlir/include/mlir/Dialect/Linalg/Transforms/ValueBoundsOpInterface.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/ValueBoundsOpInterface.td @@ -0,0 +1,55 @@ +//===-- ValueBoundsOpInterface.td - Value Bounds -----------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef VALUEBOUNDSOPINTERFACE +#define VALUEBOUNDSOPINTERFACE + +include "mlir/IR/OpBase.td" + +def ValueBoundsOpInterface : OpInterface<"ValueBoundsOpInterface"> { + let description = [{ + This interface allows operations with index-typed and/or shaped value-typed + results/block arguments to specify range bounds. These bounds are stored in + a constraint set. The constraint set can then be queried to compute bounds + in terms of other values that are stored in the constraint set. + }]; + let cppNamespace = "::mlir::linalg"; + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Populate the constraint set with bounds for the given index-typed + value. + }], + /*retType=*/"void", + /*methodName=*/"populateBoundsForIndexValue", + /*args=*/(ins "::mlir::Value":$value, + "::mlir::linalg::ValueBoundsConstraintSet &":$cstr), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + llvm_unreachable("not implemented"); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Populate the constraint set with bounds for the size of the specified + dimension of the given shaped value. + }], + /*retType=*/"void", + /*methodName=*/"populateBoundsForShapedValueDim", + /*args=*/(ins "::mlir::Value":$value, + "int64_t":$dim, + "::mlir::linalg::ValueBoundsConstraintSet &":$cstr), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + llvm_unreachable("not implemented"); + }] + >, + ]; +} + +#endif // VALUEBOUNDSOPINTERFACE diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/ValueBoundsOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/Transforms/ValueBoundsOpInterfaceImpl.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/ValueBoundsOpInterfaceImpl.h @@ -0,0 +1,20 @@ +//===- ValueBoundsOpInterfaceImpl.h - Impl. of ValueBoundsOpInterface -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LINALG_TRANSFORMS_VALUEBOUNDSOPINTERFACEIMPL_H +#define MLIR_DIALECT_LINALG_TRANSFORMS_VALUEBOUNDSOPINTERFACEIMPL_H + +namespace mlir { +class DialectRegistry; + +namespace linalg { +void registerValueBoundsOpInterfaceExternalModels(DialectRegistry ®istry); +} // namespace linalg +} // namespace mlir + +#endif // MLIR_DIALECT_LINALG_TRANSFORMS_VALUEBOUNDSOPINTERFACEIMPL_H diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -41,6 +41,7 @@ #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" #include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" +#include "mlir/Dialect/Linalg/Transforms/ValueBoundsOpInterfaceImpl.h" #include "mlir/Dialect/MLProgram/IR/MLProgram.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" @@ -132,6 +133,7 @@ registry); linalg::registerBufferizableOpInterfaceExternalModels(registry); linalg::registerTilingInterfaceExternalModels(registry); + linalg::registerValueBoundsOpInterfaceExternalModels(registry); memref::registerBufferizableOpInterfaceExternalModels(registry); memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry); scf::registerBufferizableOpInterfaceExternalModels(registry); diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp --- a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp @@ -1010,7 +1010,7 @@ unsigned offset, unsigned num, MLIRContext *context, SmallVectorImpl *lbMaps, SmallVectorImpl *ubMaps, bool getClosedUB) { - assert(num < getNumDimVars() && "invalid range"); + assert(offset + num <= getNumDimVars() && "invalid range"); // Basic simplification. normalizeConstraintsByGCD(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -30,6 +30,8 @@ Tiling.cpp TilingInterfaceImpl.cpp Transforms.cpp + ValueBoundsOpInterface.cpp + ValueBoundsOpInterfaceImpl.cpp Vectorization.cpp ADDITIONAL_HEADER_DIRS @@ -39,6 +41,7 @@ MLIRLinalgPassIncGen LINK_LIBS PUBLIC + MLIRAffineAnalysis MLIRAffineDialect MLIRAffineUtils MLIRAnalysis diff --git a/mlir/lib/Dialect/Linalg/Transforms/ValueBoundsOpInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/ValueBoundsOpInterface.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/ValueBoundsOpInterface.cpp @@ -0,0 +1,324 @@ +//===- ValueBoundsOpInterface.cpp - Value Bounds -------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/Transforms/ValueBoundsOpInterface.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/BuiltinTypes.h" + +using namespace mlir; +using namespace mlir::linalg; +using presburger::IntegerPolyhedron; + +namespace mlir { +#include "mlir/Dialect/Linalg/Transforms/ValueBoundsOpInterface.cpp.inc" +} // namespace mlir + +ValueBoundsConstraintSet::ValueBoundsConstraintSet(ValueDim valueDim) + : builder(valueDim.first.getContext()) { + insert(valueDim, /*isSymbol=*/false); +} + +#ifndef NDEBUG +static void assertValidValueDim(Value value, int64_t dim) { + if (value.getType().isIndex()) { + assert(dim == ValueBoundsConstraintSet::kIndexValue && "invalid dim value"); + } else if (auto shapedType = value.getType().dyn_cast()) { + assert(dim >= 0 && "invalid dim value"); + if (shapedType.hasRank()) + assert(dim < shapedType.getRank() && "invalid dim value"); + } else { + llvm_unreachable("unsupported type"); + } +} +#endif // NDEBUG + +void ValueBoundsConstraintSet::addBound( + presburger::IntegerPolyhedron::BoundType type, int64_t pos, + AffineExpr expr) { + LogicalResult status = cstr.addBound( + type, pos, + AffineMap::get(cstr.getNumDimVars(), cstr.getNumSymbolVars(), expr)); + (void)status; + assert(succeeded(status) && "failed to add bound to constraint system"); +} + +void ValueBoundsConstraintSet::addBound( + presburger::IntegerPolyhedron::BoundType type, Value value, + AffineExpr expr) { + assert(value.getType().isIndex() && "expected index type"); + addBound(type, getPos(value), expr); +} + +void ValueBoundsConstraintSet::addBound( + presburger::IntegerPolyhedron::BoundType type, Value value, int64_t dim, + AffineExpr expr) { +#ifndef NDEBUG + assertValidValueDim(value, dim); +#endif // NDEBUG + addBound(type, getPos(value, dim), expr); +} + +void ValueBoundsConstraintSet::addBound( + presburger::IntegerPolyhedron::BoundType type, Value value, + OpFoldResult ofr) { + assert(value.getType().isIndex() && "expected index type"); + addBound(type, getPos(value), getExpr(ofr)); +} + +void ValueBoundsConstraintSet::addBound( + presburger::IntegerPolyhedron::BoundType type, Value value, int64_t dim, + OpFoldResult ofr) { +#ifndef NDEBUG + assertValidValueDim(value, dim); +#endif // NDEBUG + addBound(type, getPos(value, dim), getExpr(ofr)); +} + +AffineExpr ValueBoundsConstraintSet::getExpr(Value value, int64_t dim) { +#ifndef NDEBUG + assertValidValueDim(value, dim); +#endif // NDEBUG + + auto shapedType = value.getType().dyn_cast(); + if (shapedType) { + // Static dimension: return constant directly. + if (shapedType.hasRank() && !shapedType.isDynamicDim(dim)) + return builder.getAffineConstantExpr(shapedType.getDimSize(dim)); + } else { + // Constant index value: return directly. + if (auto constInt = getConstantIntValue(value)) + return builder.getAffineConstantExpr(*constInt); + } + + // Dynamic value: add to constraint set. + ValueDim valueDim = std::make_pair(value, dim); + if (valueDimToPosition.find(valueDim) == valueDimToPosition.end()) + (void)insert(valueDim); + int64_t pos = getPos(value, dim); + return pos < cstr.getNumDimVars() + ? builder.getAffineDimExpr(pos) + : builder.getAffineSymbolExpr(pos - cstr.getNumDimVars()); +} + +AffineExpr ValueBoundsConstraintSet::getExpr(OpFoldResult ofr) { + if (Value value = ofr.dyn_cast()) + return getExpr(value, /*dim=*/kIndexValue); + auto constInt = getConstantIntValue(ofr); + assert(constInt.has_value() && "expected Integer constant"); + return builder.getAffineConstantExpr(*constInt); +} + +AffineExpr ValueBoundsConstraintSet::getExpr(int64_t val) { + return builder.getAffineConstantExpr(val); +} + +int64_t ValueBoundsConstraintSet::insert(ValueDim valueDim, bool isSymbol) { + assert((valueDimToPosition.find(valueDim) == valueDimToPosition.end()) && + "already mapped"); + int64_t pos = isSymbol ? cstr.appendSymbolVar() : cstr.appendDimVar(); + positionToValueDim.insert(positionToValueDim.begin() + pos, valueDim); + // Update reverse mapping. + for (int64_t i = pos; i < positionToValueDim.size(); ++i) + valueDimToPosition[positionToValueDim[i]] = i; + + worklist.insert(pos); + return pos; +} + +int64_t ValueBoundsConstraintSet::getPos(Value value, int64_t dim) const { +#ifndef NDEBUG + assertValidValueDim(value, dim); +#endif // NDEBUG + + auto it = valueDimToPosition.find(std::make_pair(value, dim)); + assert(it != valueDimToPosition.end() && "expected mapped entry"); + return it->second; +} + +static Operation *getOwnerOfValue(Value value) { + if (auto bbArg = value.dyn_cast()) + return bbArg.getOwner()->getParentOp(); + return value.getDefiningOp(); +} + +void ValueBoundsConstraintSet::processWorklist( + function_ref stopCondition) { + while (!worklist.empty()) { + int64_t pos = worklist.pop_back_val(); + ValueDim valueDim = positionToValueDim[pos]; + Value value = valueDim.first; + int64_t dim = valueDim.second; + + // Check for static dim size. + if (dim != kIndexValue) { + auto shapedType = value.getType().cast(); + if (shapedType.hasRank() && !shapedType.isDynamicDim(dim)) { + addBound(IntegerPolyhedron::BoundType::EQ, value, dim, + builder.getAffineConstantExpr(shapedType.getDimSize(dim))); + continue; + } + } + + // Do not process any further if the stop condition is met. + if (stopCondition(value)) + continue; + + // Query `ValueBoundsOpInterface` for constraints. New items may be added to + // the worklist. + auto valueBoundsOp = + dyn_cast(getOwnerOfValue(value)); + if (!valueBoundsOp) + continue; + if (dim == kIndexValue) { + valueBoundsOp.populateBoundsForIndexValue(value, *this); + } else { + valueBoundsOp.populateBoundsForShapedValueDim(value, dim, *this); + } + } +} + +void ValueBoundsConstraintSet::projectOut(int64_t pos) { + cstr.projectOut(pos); + bool erased = valueDimToPosition.erase(positionToValueDim[pos]); + assert(erased && "inconsistent reverse mapping"); + positionToValueDim.erase(positionToValueDim.begin() + pos); + // Update reverse mapping. + for (int64_t i = pos; i < positionToValueDim.size(); ++i) + valueDimToPosition[positionToValueDim[i]] = i; +} + +FailureOr ValueBoundsConstraintSet::reifyBound( + OpBuilder &b, Location loc, presburger::IntegerPolyhedron::BoundType type, + Value value, int64_t dim) { + auto stopCondition = [&](Value v) { + // Reify in terms of SSA values that are different from `value`. + return v != value; + }; + return ValueBoundsConstraintSet::reifyBound(b, loc, type, value, dim, + stopCondition); +} + +FailureOr ValueBoundsConstraintSet::reifyBound( + OpBuilder &b, Location loc, presburger::IntegerPolyhedron::BoundType type, + Value value, int64_t dim, function_ref stopCondition) { +#ifndef NDEBUG + assertValidValueDim(value, dim); +#endif // NDEBUG + + // Only EQ bounds are supported at the moment. + assert(type == presburger::IntegerPolyhedron::BoundType::EQ && + "unsupported bound type"); + + // Process the backward slice of `value` (i.e., reverse use-def chain) until + // `stopCondition` is met. + ValueBoundsConstraintSet cstr(std::make_pair(value, dim)); + int64_t pos = cstr.getPos(value, dim); + cstr.processWorklist(stopCondition); + + // Project out all positions (apart from `pos`) that do not match the stop + // condition. + int64_t nextPos = 0; + while (nextPos < cstr.positionToValueDim.size()) { + if (nextPos == pos) { + ++nextPos; + continue; + } + + if (!stopCondition(cstr.positionToValueDim[nextPos].first)) { + cstr.projectOut(nextPos); + // The column was projected out so another column is now at that position. + // Do not increase the counter. + } else { + ++nextPos; + } + } + + // Compute lower and upper bounds for `value`. + SmallVector lb(1), ub(1); + cstr.cstr.getSliceBounds(pos, 1, b.getContext(), &lb, &ub, + /*getClosedUB=*/true); + // Note: There are TODOs in the implementation of `getSliceBounds`. In such a + // case, no lower/upper bound can be computed at the moment. + if (lb.empty() || !lb[0] || ub.empty() || !ub[0] || + lb[0].getNumResults() != 1 || ub[0].getNumResults() != 1) + return failure(); + + // Look for same lower and upper bound: EQ bound. + if (ub[0] != lb[0]) + return failure(); + + // Gather all SSA values that are used in the computed bound. + SmallVector operands; + assert(cstr.cstr.getNumDimAndSymbolVars() == cstr.positionToValueDim.size() && + "inconsistent mapping state"); + for (int64_t i = 0; i < cstr.cstr.getNumDimAndSymbolVars(); ++i) { + // Skip `value`. + if (i == pos) + continue; + // Check if the position `i` is used in the generated bound. If so, it must + // be included in the generated affine.apply op. + bool used = false; + if (i < cstr.cstr.getNumDimVars()) { + if (lb[0].isFunctionOfDim(i)) + used = true; + } else { + if (lb[0].isFunctionOfSymbol(i - cstr.cstr.getNumDimVars())) + used = true; + } + + if (!used) { + // Not used: Put an empty Value (will canonicalize away). + operands.push_back(Value()); + continue; + } + + ValueBoundsConstraintSet::ValueDim valueDim = cstr.positionToValueDim[i]; + Value value = valueDim.first; + int64_t dim = valueDim.second; + if (dim == ValueBoundsConstraintSet::kIndexValue) { + // An index-type value is used: can be used directly in the affine.apply + // op. + assert(value.getType().isIndex() && "expected index type"); + operands.push_back(value); + continue; + } + + assert(value.getType().cast().isDynamicDim(dim) && + "expected dynamic dim"); + if (value.getType().isa()) { + // A tensor dimension is used: generate a tensor.dim. + operands.push_back(b.create(loc, value, dim)); + } else if (value.getType().isa()) { + // A memref dimension is used: generate a memref.dim. + operands.push_back(b.create(loc, value, dim)); + } else { + llvm_unreachable("cannot generate DimOp for unsupported shaped type"); + } + } + + mlir::canonicalizeMapAndOperands(&lb[0], &operands); + // Check for special cases where no affine.apply op is needed. + if (lb[0].isSingleConstant()) { + // Bound is a constant: return an IntegerAttr. + return static_cast( + b.getIndexAttr(lb[0].getSingleConstantResult())); + } + // No affine.apply op is needed if the bound is a single SSA value. + if (auto expr = lb[0].getResult(0).dyn_cast()) + return static_cast(operands[expr.getPosition()]); + if (auto expr = lb[0].getResult(0).dyn_cast()) + return static_cast( + operands[expr.getPosition() + cstr.cstr.getNumDimVars() - 1]); + // General case: build affine.apply op. + return static_cast( + b.create(loc, lb[0], operands).getResult()); +} diff --git a/mlir/lib/Dialect/Linalg/Transforms/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/ValueBoundsOpInterfaceImpl.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/ValueBoundsOpInterfaceImpl.cpp @@ -0,0 +1,137 @@ +//===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/Transforms/ValueBoundsOpInterfaceImpl.h" + +#include "mlir/Dialect/Affine/Analysis/AffineStructures.h" +#include "mlir/Dialect/Linalg/Transforms/ValueBoundsOpInterface.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" + +using namespace mlir; +using namespace mlir::linalg; +using presburger::IntegerPolyhedron; + +namespace mlir { +namespace tensor { +namespace { + +struct CastOpInterface + : public ValueBoundsOpInterface::ExternalModel { + void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, + ValueBoundsConstraintSet &cstr) const { + auto castOp = cast(op); + assert(value == castOp.getResult() && "invalid value"); + + if (castOp.getResult().getType().isa() && + castOp.getSource().getType().isa()) { + cstr.addBound(IntegerPolyhedron::BoundType::EQ, value, dim, + cstr.getExpr(castOp.getSource(), dim)); + } + } +}; + +struct DimOpInterface + : public ValueBoundsOpInterface::ExternalModel { + void populateBoundsForIndexValue(Operation *op, Value value, + ValueBoundsConstraintSet &cstr) const { + auto dimOp = cast(op); + assert(value == dimOp.getResult() && "invalid value"); + + auto constIndex = dimOp.getConstantIndex(); + if (!constIndex.has_value()) + return; + cstr.addBound(IntegerPolyhedron::BoundType::EQ, value, + cstr.getExpr(dimOp.getSource(), *constIndex)); + } +}; + +struct EmptyOpInterface + : public ValueBoundsOpInterface::ExternalModel { + void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, + ValueBoundsConstraintSet &cstr) const { + auto emptyOp = cast(op); + assert(value == emptyOp.getResult() && "invalid value"); + + cstr.addBound(IntegerPolyhedron::BoundType::EQ, value, dim, + cstr.getExpr(emptyOp.getMixedSizes()[dim])); + } +}; + +struct ExtractSliceOpInterface + : public ValueBoundsOpInterface::ExternalModel { + void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, + ValueBoundsConstraintSet &cstr) const { + auto extractSliceOp = cast(op); + assert(value == extractSliceOp.getResult() && "invalid value"); + + llvm::SmallBitVector dropped = extractSliceOp.getDroppedDims(); + int64_t ctr = -1; + for (int64_t i = 0, e = extractSliceOp.getMixedSizes().size(); i < e; ++i) { + // Skip over rank-reduced dimensions. + if (!dropped.test(i)) + ++ctr; + if (ctr == dim) { + cstr.addBound(IntegerPolyhedron::BoundType::EQ, value, dim, + extractSliceOp.getMixedSizes()[i]); + return; + } + } + llvm_unreachable("could not find non-rank-reduced dim"); + } +}; + +struct PadOpInterface + : public ValueBoundsOpInterface::ExternalModel { + void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, + ValueBoundsConstraintSet &cstr) const { + auto padOp = cast(op); + assert(value == padOp.getResult() && "invalid value"); + + AffineExpr expr = cstr.getExpr(padOp.getSource(), dim) + + cstr.getExpr(padOp.getMixedLowPad()[dim]) + + cstr.getExpr(padOp.getMixedHighPad()[dim]); + cstr.addBound(IntegerPolyhedron::BoundType::EQ, value, dim, expr); + } +}; + +struct RankOpInterface + : public ValueBoundsOpInterface::ExternalModel { + void populateBoundsForIndexValue(Operation *op, Value value, + ValueBoundsConstraintSet &cstr) const { + auto rankOp = cast(op); + assert(value == rankOp.getResult() && "invalid value"); + + auto tensorType = rankOp.getTensor().getType().dyn_cast(); + if (!tensorType) + return; + cstr.addBound(IntegerPolyhedron::BoundType::EQ, value, + cstr.getExpr(tensorType.getRank())); + } +}; + +} // namespace +} // namespace tensor +} // namespace mlir + +void mlir::linalg::registerValueBoundsOpInterfaceExternalModels( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) { + tensor::CastOp::attachInterface(*ctx); + tensor::DimOp::attachInterface(*ctx); + tensor::EmptyOp::attachInterface(*ctx); + tensor::ExtractSliceOp::attachInterface( + *ctx); + tensor::InsertOp::attachInterface< + DstValueBoundsOpInterfaceExternalModel>(*ctx); + tensor::InsertSliceOp::attachInterface< + DstValueBoundsOpInterfaceExternalModel>(*ctx); + tensor::PadOp::attachInterface(*ctx); + tensor::RankOp::attachInterface(*ctx); + }); +} diff --git a/mlir/test/Dialect/Linalg/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Linalg/value-bounds-op-interface-impl.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/value-bounds-op-interface-impl.mlir @@ -0,0 +1,137 @@ +// RUN: mlir-opt %s -test-linalg-transform-patterns=test-refiy-shape-dims \ +// RUN: -verify-diagnostics -split-input-file -allow-unregistered-dialect | FileCheck %s + +func.func @unknown_op() -> index { + %0 = "test.foo"() : () -> (tensor) + // expected-error @below{{could not reify bound}} + %1 = "test.reify_bound"(%0) {dim = 0} : (tensor) -> (index) + return %1 : index +} + +// ----- + +// CHECK-LABEL: func @cast( +// CHECK: %[[c10:.*]] = arith.constant 10 : index +// CHECK: return %[[c10]] +func.func @cast(%t: tensor<10xf32>) -> index { + %0 = tensor.cast %t : tensor<10xf32> to tensor + %1 = "test.reify_bound"(%0) {dim = 0} : (tensor) -> (index) + return %1 : index +} + +// ----- + +func.func @cast_unranked(%t: tensor<*xf32>) -> index { + %0 = tensor.cast %t : tensor<*xf32> to tensor + // expected-error @below{{could not reify bound}} + %1 = "test.reify_bound"(%0) {dim = 0} : (tensor) -> (index) + return %1 : index +} + +// ----- + +// CHECK-LABEL: func @dim( +// CHECK-SAME: %[[t:.*]]: tensor +// CHECK: %[[dim:.*]] = tensor.dim %[[t]] +// CHECK: %[[dim:.*]] = tensor.dim %[[t]] +// CHECK: return %[[dim]] +func.func @dim(%t: tensor) -> index { + %c0 = arith.constant 0 : index + %0 = tensor.dim %t, %c0 : tensor + %1 = "test.reify_bound"(%0) : (index) -> (index) + return %1 : index +} + +// ----- + +// CHECK-LABEL: func @empty( +// CHECK-SAME: %[[sz:.*]]: index +// CHECK: %[[c6:.*]] = arith.constant 6 : index +// CHECK: return %[[c6]], %[[sz]] +func.func @empty(%sz: index) -> (index, index) { + %0 = tensor.empty(%sz) : tensor<6x?xf32> + %1 = "test.reify_bound"(%0) {dim = 0} : (tensor<6x?xf32>) -> (index) + %2 = "test.reify_bound"(%0) {dim = 1} : (tensor<6x?xf32>) -> (index) + return %1, %2 : index, index +} + +// ----- + +// CHECK-LABEL: func @extract_slice_dynamic( +// CHECK-SAME: %[[t:.*]]: tensor, %[[sz:.*]]: index +// CHECK: return %[[sz]] +func.func @extract_slice_dynamic(%t: tensor, %sz: index) -> index { + %0 = tensor.extract_slice %t[2][%sz][1] : tensor to tensor + %1 = "test.reify_bound"(%0) {dim = 0} : (tensor) -> (index) + return %1 : index +} + +// ----- + +// CHECK-LABEL: func @extract_slice_static( +// CHECK-SAME: %[[t:.*]]: tensor +// CHECK: %[[c5:.*]] = arith.constant 5 : index +// CHECK: return %[[c5]] +func.func @extract_slice_static(%t: tensor) -> index { + %0 = tensor.extract_slice %t[2][5][1] : tensor to tensor<5xf32> + %1 = "test.reify_bound"(%0) {dim = 0} : (tensor<5xf32>) -> (index) + return %1 : index +} + +// ----- + +// CHECK-LABEL: func @extract_slice_rank_reduce( +// CHECK-SAME: %[[t:.*]]: tensor, %[[sz:.*]]: index +// CHECK: return %[[sz]] +func.func @extract_slice_rank_reduce(%t: tensor, %sz: index) -> index { + %0 = tensor.extract_slice %t[0, 2][1, %sz][1, 1] : tensor to tensor + %1 = "test.reify_bound"(%0) {dim = 0} : (tensor) -> (index) + return %1 : index +} + +// ----- + +// CHECK-LABEL: func @insert( +// CHECK-SAME: %[[t:.*]]: tensor +// CHECK: %[[c0:.*]] = arith.constant 0 : index +// CHECK: %[[dim:.*]] = tensor.dim %[[t]], %[[c0]] +// CHECK: return %[[dim]] +func.func @insert(%t: tensor, %f: f32, %pos: index) -> index { + %0 = tensor.insert %f into %t[%pos] : tensor + %1 = "test.reify_bound"(%0) {dim = 0} : (tensor) -> (index) + return %1 : index +} + +// ----- + +// CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 + 12)> +// CHECK: #[[$map1:.*]] = affine_map<()[s0, s1] -> (s0 + s1 * 2)> +// CHECK-LABEL: func @pad( +// CHECK-SAME: %[[t:.*]]: tensor, %[[a:.*]]: index, %[[b:.*]]: index +// CHECK: %[[bound1:.*]] = affine.apply #[[$map]]()[%[[b]]] +// CHECK: %[[c0:.*]] = arith.constant 0 : index +// CHECK: %[[dim0:.*]] = tensor.dim %[[t]], %[[c0]] +// CHECK: %[[bound0:.*]] = affine.apply #[[$map1]]()[%[[dim0]], %[[a]]] +// CHECK: return %[[bound0]], %[[bound1]] +func.func @pad(%t: tensor, %a: index, %b: index) -> (index, index) { + %pad = arith.constant 0.0 : f32 + %0 = tensor.pad %t low[%a, 5] high[%a, %b] { + ^bb0(%arg1: index, %arg2: index): + tensor.yield %pad : f32 + } : tensor to tensor + %1 = "test.reify_bound"(%0) {dim = 0} : (tensor) -> (index) + %2 = "test.reify_bound"(%0) {dim = 1} : (tensor) -> (index) + return %1, %2 : index, index +} + +// ----- + +// CHECK-LABEL: func @rank( +// CHECK-SAME: %[[t:.*]]: tensor<5xf32> +// CHECK: %[[c1:.*]] = arith.constant 1 : index +// CHECK: return %[[c1]] +func.func @rank(%t: tensor<5xf32>) -> index { + %0 = tensor.rank %t : tensor<5xf32> + %1 = "test.reify_bound"(%0) : (index) -> (index) + return %1 : index +} diff --git a/mlir/test/Dialect/Linalg/value-bounds-reification.mlir b/mlir/test/Dialect/Linalg/value-bounds-reification.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/value-bounds-reification.mlir @@ -0,0 +1,22 @@ +// RUN: mlir-opt %s -test-linalg-transform-patterns="test-refiy-shape-dims reify-to-func-args" \ +// RUN: -verify-diagnostics -split-input-file -allow-unregistered-dialect | FileCheck %s + +// CHECK-LABEL: func @reify_through_chain( +// CHECK-SAME: %[[sz0:.*]]: index, %[[sz2:.*]]: index +// CHECK: %[[c10:.*]] = arith.constant 10 : index +// CHECK: return %[[sz0]], %[[c10]], %[[sz2]] +func.func @reify_through_chain(%sz0: index, %sz2: index) -> (index, index, index) { + %c2 = arith.constant 2 : index + %0 = tensor.empty(%sz0, %sz2) : tensor + %1 = tensor.cast %0 : tensor to tensor + %pos = arith.constant 0 : index + %f = arith.constant 0.0 : f32 + %2 = tensor.insert %f into %1[%pos, %pos, %pos] : tensor + %3 = tensor.dim %2, %c2 : tensor + + %4 = "test.reify_bound"(%2) {dim = 0} : (tensor) -> (index) + %5 = "test.reify_bound"(%2) {dim = 1} : (tensor) -> (index) + %6 = "test.reify_bound"(%3) : (index) -> (index) + + return %4, %5, %6 : index, index, index +} diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -19,6 +19,7 @@ #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/Transforms/ValueBoundsOpInterface.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Pass/PassManager.h" @@ -127,6 +128,12 @@ *this, "test-erase-unnecessary-inputs", llvm::cl::desc("Test patterns to erase unnecessary inputs"), llvm::cl::init(false)}; + Option testReifyBounds{*this, "test-refiy-shape-dims", + llvm::cl::desc("Test value bounds reification"), + llvm::cl::init(false)}; + Option reifyToFuncArgs{ + *this, "reify-to-func-args", + llvm::cl::desc("Reify in terms of function args"), llvm::cl::init(false)}; }; } // namespace @@ -217,6 +224,65 @@ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } +static LogicalResult reifyValueBounds(func::FuncOp funcOp, + bool reifyToFuncArgs) { + IRRewriter rewriter(funcOp.getContext()); + WalkResult result = funcOp.walk([&](Operation *op) { + if (op->getName().getStringRef() == "test.reify_bound") { + if (op->getNumOperands() != 1 || op->getNumResults() != 1 || + !op->getResultTypes()[0].isIndex()) { + op->emitOpError("invalid op"); + return WalkResult::skip(); + } + Value value = op->getOperand(0); + if (value.getType().isa() != + !op->hasAttrOfType("dim")) { + // Op should have "dim" attribute if and only if the operand is an + // index-typed value. + op->emitOpError("invalid op"); + return WalkResult::skip(); + } + int64_t dim = value.getType().isIndex() + ? ValueBoundsConstraintSet::kIndexValue + : op->getAttrOfType("dim").getInt(); + + rewriter.setInsertionPointAfterValue(value); + FailureOr reified; + if (!reifyToFuncArgs) { + reified = ValueBoundsConstraintSet::reifyBound( + rewriter, op->getLoc(), + presburger::IntegerPolyhedron::BoundType::EQ, value, dim); + } else { + auto stopCondition = [](Value v) { + auto bbArg = v.dyn_cast(); + if (!bbArg) + return false; + return isa( + bbArg.getParentBlock()->getParentOp()); + }; + reified = ValueBoundsConstraintSet::reifyBound( + rewriter, op->getLoc(), + presburger::IntegerPolyhedron::BoundType::EQ, value, dim, + stopCondition); + } + if (failed(reified)) { + op->emitOpError("could not reify bound"); + return WalkResult::interrupt(); + } + if (auto val = reified->dyn_cast()) { + rewriter.replaceOp(op, val); + return WalkResult::skip(); + } + Value constOp = rewriter.create( + op->getLoc(), reified->get().cast().getInt()); + rewriter.replaceOp(op, constOp); + return WalkResult::skip(); + } + return WalkResult::advance(); + }); + return failure(result.wasInterrupted()); +} + /// Apply transformations specified as patterns. void TestLinalgTransforms::runOnOperation() { if (testPatterns) @@ -243,6 +309,9 @@ return applyEraseUnusedOperandsAndResultsPatterns(getOperation()); if (testEraseUnnecessaryInputs) return applyEraseUnnecessaryInputs(getOperation()); + if (testReifyBounds) + if (failed(reifyValueBounds(getOperation(), reifyToFuncArgs))) + return signalPassFailure(); } namespace mlir { diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -8503,6 +8503,7 @@ ":TilingInterface", ":TransformUtils", ":Transforms", + ":ValueBoundsOpInterfaceIncGen", ":VectorDialect", ":VectorToSCF", ":VectorTransforms", @@ -8512,6 +8513,37 @@ ], ) +td_library( + name = "ValueBoundsOpInterfaceTdFiles", + srcs = [ + "include/mlir/Dialect/Linalg/Transforms/ValueBoundsOpInterface.td", + ], + includes = ["include"], + deps = [ + ":OpBaseTdFiles", + ], +) + +gentbl_cc_library( + name = "ValueBoundsOpInterfaceIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + ["-gen-op-interface-decls"], + "include/mlir/Dialect/Linalg/Transforms/ValueBoundsOpInterface.h.inc", + ), + ( + ["-gen-op-interface-defs"], + "include/mlir/Dialect/Linalg/Transforms/ValueBoundsOpInterface.cpp.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/Linalg/Transforms/ValueBoundsOpInterface.td", + deps = [ + ":ValueBoundsOpInterfaceTdFiles", + ], +) + cc_library( name = "TilingInterface", srcs = ["lib/Interfaces/TilingInterface.cpp"],