diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -172,6 +172,10 @@ RankedTensorType getResultType() { return result().getType().cast(); } + SmallVector getReassociationMaps() { + return llvm::to_vector<4>(llvm::map_range(reassociation(), + [](Attribute a) { return a.cast().getValue(); })); + } }]; } diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -18,8 +18,10 @@ namespace mlir { class FuncOp; +class MLIRContext; class ModuleOp; template class OperationPass; +class OwningRewritePatternList; class Pass; std::unique_ptr> createLinalgFusionPass(); @@ -48,6 +50,10 @@ /// Placeholder for now, this is NYI. std::unique_ptr> createConvertLinalgToAffineLoopsPass(); +/// Patterns for fusing linalg operation on tensors. +void populateLinalgTensorOpsFusionPatterns(MLIRContext *context, + OwningRewritePatternList &patterns); + } // namespace mlir #endif // MLIR_DIALECT_LINALG_PASSES_H_ diff --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h --- a/mlir/include/mlir/IR/StandardTypes.h +++ b/mlir/include/mlir/IR/StandardTypes.h @@ -672,9 +672,18 @@ /// varying stride is always `1`. /// /// Examples: -/// - memref<3x4x5xf32> has canonical stride expression `20*d0 + 5*d1 + d2`. -/// - memref<3x?x5xf32> has canonical stride expression `s0*d0 + 5*d1 + d2`. -/// - memref<3x4x?xf32> has canonical stride expression `s1*d0 + s0*d1 + d2`. +/// - memref<3x4x5xf32> has canonical stride expression +/// `20*exprs[0] + 5*exprs[1] + exprs[2]`. +/// - memref<3x?x5xf32> has canonical stride expression +/// `s0*exprs[0] + 5*exprs[1] + exprs[2]`. +/// - memref<3x4x?xf32> has canonical stride expression +/// `s1*exprs[0] + s0*exprs[1] + exprs[2]`. +AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef sizes, + ArrayRef exprs, + MLIRContext *context); + +/// Return the result of makeCanonicalStrudedLayoutExpr for the common case +/// where `exprs` is {d0, d1, .., d_(sizes.size()-1)} AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef sizes, MLIRContext *context); diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -530,7 +530,7 @@ unsigned currentDim = 0; for (AffineMap m : reassociation) { unsigned dim = m.getNumResults(); - auto band = shape.drop_front(currentDim).take_front(dim); + auto band = shape.slice(currentDim, dim); int64_t size = 1; if (llvm::is_contained(band, ShapedType::kDynamicSize)) size = ShapedType::kDynamicSize; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -577,6 +577,201 @@ }; } // namespace +/// Linearize the expressions in `sourceMap` based on the `reassociationMaps` +/// provided, given the shape of the source tensor that corresponds to the +/// `sourceMap`. Note that this implicitly assumes that the tensors dimensions +/// are "row-major" ordered logically. +/// +/// For example: +/// +/// %0 = op ... : tensor +/// with output index_map `affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>` + +/// and reshape: +/// %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>, +/// affine_map<(i, j, k, l) -> (j, k, l)>] : +/// tensor into tensor + +/// would be rewritten into: +/// %0 = op ... : tensor +/// with output index_map `affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 +/// + d3)>` +static AffineMap linearizeCollapsedDims(AffineMap sourceMap, + ArrayRef sourceShape, + ArrayRef reassociationMaps) { + SmallVector resultExprs; + resultExprs.reserve(reassociationMaps.size()); + ArrayRef sourceExprs = sourceMap.getResults(); + MLIRContext *context = sourceMap.getContext(); + + // Compute the result exprs based on the reassociation maps. + for (AffineMap map : reassociationMaps) { + ArrayRef collapsedDims = map.getResults(); + // Assume that they are in-order and contiguous (already checked in + // verifier). + assert(!collapsedDims.empty()); + unsigned startDim = + collapsedDims.front().cast().getPosition(); + AffineExpr linearizedExpr = makeCanonicalStridedLayoutExpr( + sourceShape.slice(startDim, collapsedDims.size()), + sourceExprs.slice(startDim, collapsedDims.size()), context); + resultExprs.push_back(linearizedExpr); + } + return AffineMap::get(sourceMap.getNumDims(), sourceMap.getNumSymbols(), + resultExprs, context); +} + +/// Checks if the `reshapeOp` can be fused with it consumer (if `asProducer` is +/// true) or its producer (if `asProducer` is false) given the indexing map at +/// its use. +static bool isTensorReshapeOpFusible(TensorReshapeOp reshapeOp, + AffineMap useIndexMap, bool asProducer) { + RankedTensorType returnType = reshapeOp.getResultType(); + RankedTensorType operandType = reshapeOp.getSrcType(); + // Reshape is fusible with its consumer (i.e. reshape as a producer) when its + // operand is of lesser rank than the result. Fusing when operand has higher + // rank will require use of mods and divs in the indexing maps of the fused op + // which would make it non-invertible. Similarly reshape is fused with its + // producer (i.e. reshape as consumer) only if the return type has lesser + // rank. + if ((asProducer && returnType.getRank() < operandType.getRank()) || + (!asProducer && operandType.getRank() < returnType.getRank())) + return false; + // For reshape to be fused, the collapsed (expanded) dimensions of the operand + // (result) must be statically shaped. + // TODO: In the future this restriction may justify extending the + // linalg.generic to semi-affine maps. + // TODO: Alternatively, fusing across a reshape and pushing the reshape + // towards the boundary of the function could help too. + ArrayRef srcShape = + (asProducer ? returnType.getShape() : operandType.getShape()); + ArrayRef reassociationMaps = reshapeOp.getReassociationMaps(); + unsigned dim = 0; + for (AffineMap map : reassociationMaps) { + ArrayRef collapsedDims = map.getResults(); + // If the linearized expr has symbols, then disable fusion since indexing + // maps cannot have symbols right now. + AffineExpr linearizedExpr = makeCanonicalStridedLayoutExpr( + srcShape.slice(dim, collapsedDims.size()), + useIndexMap.getResults().slice(dim, collapsedDims.size()), + reshapeOp.getContext()); + bool hasSymbol = false; + linearizedExpr.walk([&hasSymbol](AffineExpr d) { + if (d.isa()) + hasSymbol = true; + }); + if (hasSymbol) + return false; + dim += collapsedDims.size(); + } + // Currently only implement the case where the consumer is using an identity + // map for the indexing_map + return useIndexMap.isIdentity(); +} + +namespace { +template +struct FuseTensorReshapeOpAsProducer + : public FuseTensorOpsImpl, + TensorReshapeOp, LinalgOpTy> { + static bool isFusible(TensorReshapeOp producer, LinalgOpTy consumer, + unsigned consumerIdx) { + return isTensorReshapeOpFusible( + producer, consumer.getInputIndexingMap(consumerIdx), true); + } + + static Operation *createFusedOp(PatternRewriter &rewriter, + ArrayRef fusedOperands, + TensorReshapeOp producer, LinalgOpTy consumer, + unsigned consumerIdx, + OperationFolder *folder = nullptr) { + // Compute indexing_maps for the fused operation. The indexing_maps for the + // operands of the consumers that arent fused are the same. + SmallVector fusedIndexMaps = + llvm::to_vector<4>(llvm::map_range( + consumer.indexing_maps(), [](Attribute attr) -> AffineMap { + return attr.cast().getValue(); + })); + // Compute the indexing map to use for the operand of the producer. + fusedIndexMaps[consumerIdx] = linearizeCollapsedDims( + fusedIndexMaps[consumerIdx], producer.getResultType().getShape(), + producer.getReassociationMaps()); + + // Further check that the resulting index maps can be fused and + // inverted. Without this the resultant op is not legal. + if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) + return nullptr; + + SmallVector indexMapAttrs = llvm::to_vector<4>( + llvm::map_range(fusedIndexMaps, [](AffineMap map) -> Attribute { + return AffineMapAttr::get(map); + })); + auto fusedOp = rewriter.create( + rewriter.getUnknownLoc(), consumer.getResultTypes(), fusedOperands, + rewriter.getI64IntegerAttr(fusedOperands.size()), + rewriter.getI64IntegerAttr(consumer.getNumResults()), + rewriter.getArrayAttr(indexMapAttrs), consumer.iterator_types(), + /*doc=*/nullptr, + /*library_call=*/nullptr); + auto &fusedRegion = fusedOp.region(); + rewriter.cloneRegionBefore(consumer.region(), fusedRegion, + fusedRegion.begin()); + return fusedOp; + } +}; + +template +struct FuseTensorReshapeOpAsConsumer + : public FuseTensorOpsImpl, + LinalgOpTy, TensorReshapeOp> { + static bool isFusible(LinalgOpTy producer, TensorReshapeOp consumer, + unsigned consumerIdx) { + return isTensorReshapeOpFusible(consumer, producer.getOutputIndexingMap(0), + false); + } + + static Operation *createFusedOp(PatternRewriter &rewriter, + ArrayRef fusedOperands, + LinalgOpTy producer, TensorReshapeOp consumer, + unsigned consumerIdx, + OperationFolder *folder = nullptr) { + // The indexing_maps for the operands of the fused operation are same as + // those for the operands of the producer. + SmallVector fusedIndexMaps = + llvm::to_vector<4>(llvm::map_range( + producer.indexing_maps(), [](Attribute attr) -> AffineMap { + return attr.cast().getValue(); + })); + // Compute the indexing map to use for the operand of the producer. + fusedIndexMaps.back() = linearizeCollapsedDims( + producer.getOutputIndexingMap(0), consumer.getSrcType().getShape(), + consumer.getReassociationMaps()); + + // Further check that the resulting index maps can be fused and + // inverted. Without this the resultant op is not legal. + if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) + return nullptr; + + SmallVector indexMapAttrs = llvm::to_vector<4>( + llvm::map_range(fusedIndexMaps, [](AffineMap map) -> Attribute { + return AffineMapAttr::get(map); + })); + + auto fusedOp = rewriter.create( + rewriter.getUnknownLoc(), consumer.getResultType(), fusedOperands, + rewriter.getI64IntegerAttr(fusedOperands.size()), + rewriter.getI64IntegerAttr(1), rewriter.getArrayAttr(indexMapAttrs), + producer.iterator_types(), + /*doc=*/nullptr, + /*library_call=*/nullptr); + auto &fusedRegion = fusedOp.region(); + rewriter.cloneRegionBefore(producer.region(), fusedRegion, + fusedRegion.begin()); + return fusedOp; + } +}; +} // namespace + Operation *mlir::linalg::fuseTensorOps(PatternRewriter &rewriter, Operation *consumer, unsigned consumerIdx, @@ -587,6 +782,7 @@ if (!producer || producer->getNumResults() != 1) return nullptr; + // Fuse when consumer is GenericOp. if (GenericOp genericOp = dyn_cast(consumer)) { if (!genericOp.hasTensorSemantics()) return nullptr; @@ -594,7 +790,21 @@ if (genericOpProducer.hasTensorSemantics()) return FuseGenericOpsOnTensors::fuse(rewriter, genericOpProducer, genericOp, consumerIdx); + } else if (auto reshapeOpProducer = dyn_cast(producer)) { + return FuseTensorReshapeOpAsProducer::fuse( + rewriter, reshapeOpProducer, genericOp, consumerIdx); + } + return nullptr; + } + + // Fuse when consumer is a TensorReshapeOp. + if (TensorReshapeOp reshapeOp = dyn_cast(consumer)) { + if (auto genericOpProducer = dyn_cast(producer)) { + if (genericOpProducer.hasTensorSemantics()) + return FuseTensorReshapeOpAsConsumer::fuse( + rewriter, genericOpProducer, reshapeOp, consumerIdx); } + return nullptr; } return nullptr; } @@ -630,7 +840,7 @@ void runOnOperation() override { OwningRewritePatternList patterns; Operation *op = getOperation(); - patterns.insert>(op->getContext()); + populateLinalgTensorOpsFusionPatterns(op->getContext(), patterns); applyPatternsAndFoldGreedily(op->getRegions(), patterns); }; }; @@ -640,6 +850,12 @@ }; } // namespace +void mlir::populateLinalgTensorOpsFusionPatterns( + MLIRContext *context, OwningRewritePatternList &patterns) { + patterns.insert, FuseTensorOps>( + context); +} + std::unique_ptr> mlir::createLinalgFusionPass() { return std::make_unique(); } diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp --- a/mlir/lib/IR/StandardTypes.cpp +++ b/mlir/lib/IR/StandardTypes.cpp @@ -728,35 +728,47 @@ } AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef sizes, + ArrayRef exprs, MLIRContext *context) { AffineExpr expr; bool dynamicPoisonBit = false; + unsigned numDims = 0; unsigned nSymbols = 0; + // Compute the number of symbols and dimensions of the passed exprs. + for (AffineExpr expr : exprs) { + expr.walk([&numDims, &nSymbols](AffineExpr d) { + if (AffineDimExpr dim = d.dyn_cast()) + numDims = std::max(numDims, dim.getPosition() + 1); + else if (AffineSymbolExpr symbol = d.dyn_cast()) + nSymbols = std::max(nSymbols, symbol.getPosition() + 1); + }); + } int64_t runningSize = 1; - unsigned rank = sizes.size(); - for (auto en : llvm::enumerate(llvm::reverse(sizes))) { - auto size = en.value(); - auto position = rank - 1 - en.index(); + for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) { + int64_t size = std::get<1>(en); // Degenerate case, no size =-> no stride if (size == 0) continue; - auto d = getAffineDimExpr(position, context); - // Static case: stride = runningSize and runningSize *= size. - if (!dynamicPoisonBit) { - auto cst = getAffineConstantExpr(runningSize, context); - expr = expr ? expr + cst * d : cst * d; - if (size > 0) - runningSize *= size; - else - // From now on bail into dynamic mode. - dynamicPoisonBit = true; - continue; - } - // Dynamic case, new symbol for each new stride. - auto sym = getAffineSymbolExpr(nSymbols++, context); - expr = expr ? expr + d * sym : d * sym; + AffineExpr dimExpr = std::get<0>(en); + AffineExpr stride = dynamicPoisonBit + ? getAffineSymbolExpr(nSymbols++, context) + : getAffineConstantExpr(runningSize, context); + expr = expr ? expr + dimExpr * stride : dimExpr * stride; + if (size > 0) + runningSize *= size; + else + dynamicPoisonBit = true; } - return simplifyAffineExpr(expr, rank, nSymbols); + return simplifyAffineExpr(expr, numDims, nSymbols); +} + +AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef sizes, + MLIRContext *context) { + SmallVector exprs; + exprs.reserve(sizes.size()); + for (auto dim : llvm::seq(0, sizes.size())) + exprs.push_back(getAffineDimExpr(dim, context)); + return makeCanonicalStridedLayoutExpr(sizes, exprs, context); } /// Return true if the layout for `t` is compatible with strided semantics. diff --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir --- a/mlir/test/Dialect/Linalg/fusion-tensor.mlir +++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir @@ -129,3 +129,93 @@ return %1 : tensor } + +// ----- + +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> + +#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +func @generic_op_reshape_producer_fusion(%arg0 : tensor, + %arg1 : tensor) -> + tensor +{ + %0 = linalg.tensor_reshape %arg0 [affine_map<(i, j, k, l) -> (i)>, + affine_map<(i, j, k, l) -> (j, k)>, + affine_map<(i, j, k, l) -> (l)>] : + tensor into tensor + %1 = linalg.generic + {args_in = 2 : i64, args_out = 1 : i64, + indexing_maps = [#map0, #map0, #map0], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + %0, %arg1 { + ^bb0(%arg3: f32, %arg4: f32): // no predecessors + %1 = mulf %arg3, %arg4 : f32 + linalg.yield %1 : f32 + }: tensor, tensor -> tensor + return %1 : tensor +} + +// CHECK-LABEL: func @generic_op_reshape_producer_fusion +// CHECK: linalg.generic +// CHECK-SAME: args_in = 2 +// CHECK-SAME: args_out = 1 +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP1]]] +// CHECK-NOT: linalg.generic + +// ----- + +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)> + +#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +func @generic_op_reshape_consumer_fusion(%arg0 : tensor, + %arg1 : tensor) -> + tensor +{ + %0 = linalg.generic + {args_in = 2 : i64, args_out = 1 : i64, + indexing_maps = [#map0, #map0, #map0], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + %arg0, %arg1 { + ^bb0(%arg3: f32, %arg4: f32): // no predecessors + %1 = mulf %arg3, %arg4 : f32 + linalg.yield %1 : f32 + }: tensor, tensor -> tensor + %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>, + affine_map<(i, j, k, l) -> (j, k, l)>] : + tensor into tensor + return %1 : tensor +} + +// CHECK-LABEL: func @generic_op_reshape_consumer_fusion +// CHECK: linalg.generic +// CHECK-SAME: args_in = 2 +// CHECK-SAME: args_out = 1 +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]], #[[MAP1]]] +// CHECK-NOT: linalg.generic + +// ----- + +#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +func @generic_op_reshape_consumer_nofusion(%arg0 : tensor, + %arg1 : tensor) -> + tensor +{ + %0 = linalg.generic + {args_in = 2 : i64, args_out = 1 : i64, + indexing_maps = [#map0, #map0, #map0], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + %arg0, %arg1 { + ^bb0(%arg3: f32, %arg4: f32): // no predecessors + %1 = mulf %arg3, %arg4 : f32 + linalg.yield %1 : f32 + }: tensor, tensor -> tensor + %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>, + affine_map<(i, j, k, l) -> (j, k, l)>] : + tensor into tensor + return %1 : tensor +} + +// CHECK-LABEL: func @generic_op_reshape_consumer_nofusion +// CHECK: linalg.tensor_reshape