diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -17,13 +17,9 @@ #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" - #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SetVector.h" -using mlir::edsc::intrinsics::AffineIndexedValue; -using mlir::edsc::intrinsics::MemRefIndexedValue; - namespace mlir { class AffineExpr; class AffineForOp; @@ -34,33 +30,32 @@ namespace linalg { class LinalgDependenceGraph; -/// A struct containing the Linalg producer before and after fusion. -/// When operating on tensors, `fusedProducer` may feed into a `tensor.cast` op -/// before the consumer Linalg op, until enough canonicalizations have applied. -struct FusionInfo { - LinalgOp originalProducer; - LinalgOp fusedProducer; -}; +//===----------------------------------------------------------------------===// +// General utilities +//===----------------------------------------------------------------------===// -/// A struct containing common matchers over linalg op's region. -struct RegionMatcher { - enum class BinaryOpKind { - IAdd, - }; +/// Apply the permutation defined by `permutation` to `inVec`. +/// Element `i` in `inVec` is mapped to location `j = permutation[i]`. +/// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation vector +/// `permutation = [2, 0, 1]`, this function leaves `inVec = ['c', 'a', 'b']`. +template +void applyPermutationToVector(SmallVector &inVec, + ArrayRef permutation) { + SmallVector auxVec(inVec.size()); + for (unsigned i = 0; i < permutation.size(); ++i) + auxVec[i] = inVec[permutation[i]]; + inVec = auxVec; +} - /// Matches the given linalg op if its body is performing binary operation on - /// int or float scalar values and returns the binary op kind. - /// - /// The linalg op's region is expected to be - /// ``` - /// { - /// ^bb(%a: , %b: ): - /// %0 = %a, %b: - /// linalg.yield %0: - /// } - /// ``` - static Optional matchAsScalarBinaryOp(GenericOp op); -}; +/// If `size` comes from an AffineMinOp and one of the values of AffineMinOp +/// is a constant then return a new value set to the smallest such constant. +/// If `size` comes from a ConstantOp, return the constant. +/// Otherwise return nullptr. +IntegerAttr getSmallestBoundingIndex(Value size); + +//===----------------------------------------------------------------------===// +// Iterator type utilities +//===----------------------------------------------------------------------===// /// Checks if an iterator_type attribute is parallel. bool isParallelIteratorType(Attribute attr); @@ -71,6 +66,10 @@ /// Checks if an iterator_type attribute is parallel. bool isWindowIteratorType(Attribute attr); +//===----------------------------------------------------------------------===// +// Fusion utilities +//===----------------------------------------------------------------------===// + /// Checks whether the specific `producer` is the last write to exactly the /// whole `consumedView`. This checks structural dominance, that the dependence /// is a RAW without any interleaved write to any piece of `consumedView`. @@ -84,6 +83,21 @@ bool isFusableInto(const LinalgDependenceGraph &graph, LinalgOp consumer, Value consumedView, LinalgOp producer); +/// Creates subtensor/subview ops for all `tiledOperands` of the given +/// `linalgOp` with `builder`, assuming `linalgOp` is being fused into a loop +/// nest for tiling with the given induction variables `ivs` and tile sizes +/// `tileSizes`. `sizeBounds` are the iteration space bounds for *all* the +/// implicit loops in `linalgOp`. +/// +/// Note that a constant zero in `tileSizes` means no tiling at that implicit +/// loop. The number of non-zero values in `tileSizes` should be equal to the +/// number of values in `ivs`. +SmallVector makeTiledShapes(OpBuilder &builder, Location loc, + LinalgOp linalgOp, + ArrayRef tiledOperands, + ValueRange ivs, ValueRange tileSizes, + ArrayRef sizeBounds); + using FusableOpDependencesTy = llvm::MapVector< Operation *, SmallVector>; @@ -91,6 +105,14 @@ findAllFusableDependences(ArrayRef ops, const LinalgDependenceGraph &dependenceGraph); +/// A struct containing the Linalg producer before and after fusion. +/// When operating on tensors, `fusedProducer` may feed into a `tensor.cast` op +/// before the consumer Linalg op, until enough canonicalizations have applied. +struct FusionInfo { + LinalgOp originalProducer; + LinalgOp fusedProducer; +}; + /// Fuses producer into consumer if the producer is structurally feasible and /// the fusion would not violate dependencies. /// Implements the fusion part of the "tileAndFuse on buffers" transformation @@ -119,24 +141,9 @@ Optional> fuseTensorOps(PatternRewriter &rewriter, OpOperand &consumerOpOperand); -/// Apply the permutation defined by `permutation` to `inVec`. -/// Element `i` in `inVec` is mapped to location `j = permutation[i]`. -/// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation vector -/// `permutation = [2, 0, 1]`, this function leaves `inVec = ['c', 'a', 'b']`. -template -void applyPermutationToVector(SmallVector &inVec, - ArrayRef permutation) { - SmallVector auxVec(inVec.size()); - for (unsigned i = 0; i < permutation.size(); ++i) - auxVec[i] = inVec[permutation[i]]; - inVec = auxVec; -} - -/// If `size` comes from an AffineMinOp and one of the values of AffineMinOp -/// is a constant then return a new value set to the smallest such constant. -/// If `size` comes from a ConstantOp, return the constant. -/// Otherwise return nullptr. -IntegerAttr getSmallestBoundingIndex(Value size); +//===----------------------------------------------------------------------===// +// Distribution utilities +//===----------------------------------------------------------------------===// /// Scheme used to distribute loops to processors. enum class DistributionMethod { @@ -206,6 +213,34 @@ SmallVector distributionMethod = {}; }; +//===----------------------------------------------------------------------===// +// Generic op region utilities +//===----------------------------------------------------------------------===// + +/// A struct containing common matchers over linalg op's region. +struct RegionMatcher { + enum class BinaryOpKind { + IAdd, + }; + + /// Matches the given linalg op if its body is performing binary operation on + /// int or float scalar values and returns the binary op kind. + /// + /// The linalg op's region is expected to be + /// ``` + /// { + /// ^bb(%a: , %b: ): + /// %0 = %a, %b: + /// linalg.yield %0: + /// } + /// ``` + static Optional matchAsScalarBinaryOp(GenericOp op); +}; + +//===----------------------------------------------------------------------===// +// Loop nest utilities +//===----------------------------------------------------------------------===// + /// Utility class used to generate nested loops with ranges described by /// `loopRanges` and loop type described by the `iteratorTypes`. `bodyBuilderFn` /// is used to generate the body of the innermost loop. It is passed a range @@ -214,7 +249,8 @@ struct GenerateLoopNest { using IndexedValueTy = typename std::conditional::value, - AffineIndexedValue, MemRefIndexedValue>::type; + edsc::intrinsics::AffineIndexedValue, + edsc::intrinsics::MemRefIndexedValue>::type; static void doit(ArrayRef loopRanges, ValueRange iterArgInitValues, diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -23,7 +23,6 @@ #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/AffineExpr.h" -#include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/AffineMap.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -82,34 +81,6 @@ Range{std_constant_index(0), shapeSizes[idx], tileSizes[idx]}); return std::make_tuple(res, loopIndexToRangeIndex); } -namespace { - -// Helper visitor to determine whether an AffineExpr is tiled. -// This is achieved by traversing every AffineDimExpr with position `pos` and -// checking whether the corresponding `tileSizes[pos]` is non-zero. -// This also enforces only positive coefficients occur in multiplications. -// -// Example: -// `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0] -// -struct TileCheck : public AffineExprVisitor { - TileCheck(ValueRange tileSizes) : isTiled(false), tileSizes(tileSizes) {} - - void visitDimExpr(AffineDimExpr expr) { - isTiled |= !isZero(tileSizes[expr.getPosition()]); - } - void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { - visit(expr.getLHS()); - visit(expr.getRHS()); - if (expr.getKind() == mlir::AffineExprKind::Mul) - assert(expr.getRHS().cast().getValue() > 0 && - "nonpositive multiplying coefficient"); - } - bool isTiled; - ValueRange tileSizes; -}; - -} // namespace // IndexedGenericOp explicitly uses induction variables in the loop body. The // values of the indices that are used in the loop body for any given access of @@ -201,117 +172,6 @@ } } -static bool isTiled(AffineExpr expr, ValueRange tileSizes) { - if (!expr) - return false; - TileCheck t(tileSizes); - t.visit(expr); - return t.isTiled; -} - -// Checks whether the `map varies with respect to a non-zero `tileSize`. -static bool isTiled(AffineMap map, ValueRange tileSizes) { - if (!map) - return false; - for (unsigned r = 0; r < map.getNumResults(); ++r) - if (isTiled(map.getResult(r), tileSizes)) - return true; - return false; -} - -static SmallVector -makeTiledShapes(OpBuilder &b, Location loc, LinalgOp linalgOp, - ArrayRef tiledOperands, AffineMap map, ValueRange ivs, - ValueRange tileSizes, ValueRange allShapeSizes) { - assert(ivs.size() == static_cast(llvm::count_if( - llvm::make_range(tileSizes.begin(), tileSizes.end()), - [](Value v) { return !isZero(v); })) && - "expected as many ivs as non-zero sizes"); - - using namespace edsc::op; - - auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes); - // Construct (potentially temporary) mins and maxes on which to apply maps - // that define tile subshapes. - SmallVector lbs, subShapeSizes; - for (unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) { - bool isTiled = !isZero(tileSizes[idx]); - lbs.push_back(isTiled ? ivs[idxIvs++] : (Value)std_constant_index(0)); - // Before composing, we need to make range a closed interval. - Value size = isTiled ? tileSizes[idx] : shapeSizes[idx]; - subShapeSizes.push_back(size - std_constant_index(1)); - } - - SmallVector res; - res.reserve(tiledOperands.size()); - for (auto en : llvm::enumerate(tiledOperands)) { - Value shapedOp = en.value(); - ShapedType shapedType = shapedOp.getType().cast(); - unsigned rank = shapedType.getRank(); - AffineMap map = linalgOp.getIndexingMap(en.index()); - // If the shape is not tiled, we can use it as is. - if (!isTiled(map, tileSizes)) { - res.push_back(shapedOp); - continue; - } - - // Construct a new subview / subtensor for the tile. - SmallVector offsets, sizes, strides; - offsets.reserve(rank); - sizes.reserve(rank); - strides.reserve(rank); - for (unsigned r = 0; r < rank; ++r) { - if (!isTiled(map.getSubMap({r}), tileSizes)) { - offsets.push_back(b.getIndexAttr(0)); - sizes.push_back(memref_dim(shapedOp, r).value); - strides.push_back(b.getIndexAttr(1)); - continue; - } - - // Tiling creates a new slice at the proper index, the slice step is 1 - // (i.e. the op does not subsample, stepping occurs in the loop). - auto m = map.getSubMap({r}); - auto offset = applyMapToValues(b, loc, m, lbs).front(); - offsets.push_back(offset); - auto closedIntSize = applyMapToValues(b, loc, m, subShapeSizes).front(); - // Resulting size needs to be made half open interval again. - auto size = closedIntSize + std_constant_index(1); - - // The size of the subview / subtensor should be trimmed to avoid - // out-of-bounds accesses, unless we statically know the subshape size - // divides the shape size evenly. - int64_t shapeSize = shapedType.getDimSize(r); - auto sizeCst = size.getDefiningOp(); - if (ShapedType::isDynamic(shapeSize) || !sizeCst || - (shapeSize % sizeCst.getValue()) != 0) { - // Compute min(size, dim - offset) to avoid out-of-bounds accesses. - auto minMap = AffineMap::get( - /*dimCount=*/3, /*symbolCount=*/0, - {getAffineDimExpr(/*position=*/0, b.getContext()), - getAffineDimExpr(/*position=*/1, b.getContext()) - - getAffineDimExpr(/*position=*/2, b.getContext())}, - b.getContext()); - Value d = memref_dim(shapedOp, r); - SmallVector operands{size, d, offset}; - fullyComposeAffineMapAndOperands(&minMap, &operands); - size = affine_min(b.getIndexType(), minMap, operands); - } - - sizes.push_back(size); - strides.push_back(b.getIndexAttr(1)); - } - - if (shapedType.isa()) - res.push_back( - b.create(loc, shapedOp, offsets, sizes, strides)); - else - res.push_back( - b.create(loc, shapedOp, offsets, sizes, strides)); - } - - return res; -} - template static Optional tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes, @@ -401,9 +261,10 @@ assert(outputBuffers.empty() || iterArgs.empty()); operands.append(outputBuffers.begin(), outputBuffers.end()); operands.append(iterArgs.begin(), iterArgs.end()); - SmallVector tiledOperands = - makeTiledShapes(b, loc, op, operands, shapeSizesToLoopsMap, - interchangedIvs, tileSizes, allShapeSizes); + auto sizeBounds = + applyMapToValues(b, loc, shapeSizesToLoopsMap, allShapeSizes); + SmallVector tiledOperands = makeTiledShapes( + b, loc, op, operands, interchangedIvs, tileSizes, sizeBounds); auto nonShapedOperands = op.getAssumedNonShapedOperands(); tiledOperands.append(nonShapedOperands.begin(), nonShapedOperands.end()); diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -18,8 +18,10 @@ #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/SCF/EDSC/Builders.h" #include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" @@ -27,9 +29,64 @@ #include "mlir/Transforms/LoopUtils.h" using namespace mlir; +using namespace mlir::edsc; +using namespace mlir::edsc::intrinsics; using namespace mlir::linalg; using namespace mlir::scf; +static bool isZero(Value v) { + if (auto cst = v.getDefiningOp()) + return cst.getValue() == 0; + return false; +} + +namespace { + +// Helper visitor to determine whether an AffineExpr is tiled. +// This is achieved by traversing every AffineDimExpr with position `pos` and +// checking whether the corresponding `tileSizes[pos]` is non-zero. +// This also enforces only positive coefficients occur in multiplications. +// +// Example: +// `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0] +// +struct TileCheck : public AffineExprVisitor { + TileCheck(ValueRange tileSizes) : isTiled(false), tileSizes(tileSizes) {} + + void visitDimExpr(AffineDimExpr expr) { + isTiled |= !isZero(tileSizes[expr.getPosition()]); + } + void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { + visit(expr.getLHS()); + visit(expr.getRHS()); + if (expr.getKind() == mlir::AffineExprKind::Mul) + assert(expr.getRHS().cast().getValue() > 0 && + "nonpositive multiplying coefficient"); + } + bool isTiled; + ValueRange tileSizes; +}; + +} // namespace + +static bool isTiled(AffineExpr expr, ValueRange tileSizes) { + if (!expr) + return false; + TileCheck t(tileSizes); + t.visit(expr); + return t.isTiled; +} + +// Checks whether the `map varies with respect to a non-zero `tileSize`. +static bool isTiled(AffineMap map, ValueRange tileSizes) { + if (!map) + return false; + for (unsigned r = 0; r < map.getNumResults(); ++r) + if (isTiled(map.getResult(r), tileSizes)) + return true; + return false; +} + Optional RegionMatcher::matchAsScalarBinaryOp(GenericOp op) { auto ®ion = op.region(); @@ -374,5 +431,98 @@ assert(ivs.size() == iteratorTypes.size() && "did not generate enough loops"); } +SmallVector makeTiledShapes(OpBuilder &builder, Location loc, + LinalgOp linalgOp, + ArrayRef tiledOperands, + ValueRange ivs, ValueRange tileSizes, + ArrayRef sizeBounds) { + assert(ivs.size() == static_cast(llvm::count_if( + llvm::make_range(tileSizes.begin(), tileSizes.end()), + [](Value v) { return !isZero(v); })) && + "expected as many ivs as non-zero sizes"); + + using namespace edsc::op; + + // Construct (potentially temporary) mins and maxes on which to apply maps + // that define tile subshapes. + SmallVector lbs, subShapeSizes; + for (unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) { + bool isTiled = !isZero(tileSizes[idx]); + lbs.push_back(isTiled ? ivs[idxIvs++] : (Value)std_constant_index(0)); + // Before composing, we need to make range a closed interval. + Value size = isTiled ? tileSizes[idx] : sizeBounds[idx]; + subShapeSizes.push_back(size - std_constant_index(1)); + } + + MLIRContext *context = builder.getContext(); + SmallVector tiledShapes; + tiledShapes.reserve(tiledOperands.size()); + for (auto en : llvm::enumerate(tiledOperands)) { + Value shapedOp = en.value(); + ShapedType shapedType = shapedOp.getType().cast(); + unsigned rank = shapedType.getRank(); + AffineMap map = linalgOp.getIndexingMap(en.index()); + // If the shape is not tiled, we can use it as is. + if (!isTiled(map, tileSizes)) { + tiledShapes.push_back(shapedOp); + continue; + } + + // Construct a new subview / subtensor for the tile. + SmallVector offsets, sizes, strides; + offsets.reserve(rank); + sizes.reserve(rank); + strides.reserve(rank); + for (unsigned r = 0; r < rank; ++r) { + if (!isTiled(map.getSubMap({r}), tileSizes)) { + offsets.push_back(builder.getIndexAttr(0)); + sizes.push_back(memref_dim(shapedOp, r).value); + strides.push_back(builder.getIndexAttr(1)); + continue; + } + + // Tiling creates a new slice at the proper index, the slice step is 1 + // (i.e. the op does not subsample, stepping occurs in the loop). + auto m = map.getSubMap({r}); + auto offset = applyMapToValues(builder, loc, m, lbs).front(); + offsets.push_back(offset); + auto closedIntSize = + applyMapToValues(builder, loc, m, subShapeSizes).front(); + // Resulting size needs to be made half open interval again. + auto size = closedIntSize + std_constant_index(1); + + // The size of the subview / subtensor should be trimmed to avoid + // out-of-bounds accesses, unless we statically know the subshape size + // divides the shape size evenly. + int64_t shapeSize = shapedType.getDimSize(r); + auto sizeCst = size.getDefiningOp(); + if (ShapedType::isDynamic(shapeSize) || !sizeCst || + (shapeSize % sizeCst.getValue()) != 0) { + AffineExpr dim0, dim1, dim2; + bindDims(context, dim0, dim1, dim2); + // Compute min(size, dim - offset) to avoid out-of-bounds accesses. + auto minMap = AffineMap::get( + /*dimCount=*/3, /*symbolCount=*/0, {dim0, dim1 - dim2}, context); + Value d = memref_dim(shapedOp, r); + SmallVector operands{size, d, offset}; + fullyComposeAffineMapAndOperands(&minMap, &operands); + size = affine_min(builder.getIndexType(), minMap, operands); + } + + sizes.push_back(size); + strides.push_back(builder.getIndexAttr(1)); + } + + if (shapedType.isa()) + tiledShapes.push_back(builder.create( + loc, shapedOp, offsets, sizes, strides)); + else + tiledShapes.push_back( + builder.create(loc, shapedOp, offsets, sizes, strides)); + } + + return tiledShapes; +} + } // namespace linalg } // namespace mlir