diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -172,6 +172,10 @@ RankedTensorType getResultType() { return result().getType().cast(); } + SmallVector getReassociationMaps() { + return llvm::to_vector<4>(llvm::map_range(reassociation(), + [](Attribute a) { return a.cast().getValue(); })); + } }]; } diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -18,8 +18,10 @@ namespace mlir { class FuncOp; +class MLIRContext; class ModuleOp; template class OperationPass; +class OwningRewritePatternList; class Pass; std::unique_ptr> createLinalgFusionPass(); @@ -48,6 +50,10 @@ /// Placeholder for now, this is NYI. std::unique_ptr> createConvertLinalgToAffineLoopsPass(); +/// Patterns for fusing linalg operation on tensors. +void populateLinalgTensorOpsFusionPatterns(MLIRContext *context, + OwningRewritePatternList &patterns); + } // namespace mlir #endif // MLIR_DIALECT_LINALG_PASSES_H_ diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -577,6 +577,166 @@ }; } // namespace +/// Linearize the expressions in `sourceMap` based on the `reassociationMaps` +/// provided, given the shape of the source tensor that corresponds to the +/// `sourceMap`. Note that this implicitly assumes that the tensors dimensions +/// are "row-major" ordered logically. +static AffineMap linearizeCollapsedDims(AffineMap sourceMap, + ArrayRef sourceShape, + ArrayRef reassociationMaps) { + SmallVector resultExprs; + resultExprs.reserve(reassociationMaps.size()); + ArrayRef sourceExprs = sourceMap.getResults(); + MLIRContext *context = sourceMap.getContext(); + + // Generic function to linearize the collapsed dimensions of the source map. + auto linearizeFn = [&context](ArrayRef exprs, + ArrayRef shape) -> AffineExpr { + assert(!exprs.empty() && exprs.size() == shape.size()); + AffineExpr resultExpr = nullptr; + int64_t stride = 1; + for (auto it : llvm::zip(reverse(exprs), reverse(shape))) { + resultExpr = + (resultExpr + ? std::get<0>(it) * getAffineConstantExpr(stride, context) + + resultExpr + : std::get<0>(it)); + stride *= std::get<1>(it); + } + return resultExpr; + }; + + // Compute the result exprs based on the reassociation maps. + for (AffineMap map : reassociationMaps) { + ArrayRef collapsedDims = map.getResults(); + // Assume that they are in-order and contiguous (already checked in + // verifier). + assert(!collapsedDims.empty()); + unsigned startDim = + collapsedDims.front().cast().getPosition(); + unsigned endDim = collapsedDims.back().cast().getPosition(); + AffineExpr linearizedExpr = + linearizeFn(sourceExprs.slice(startDim, endDim - startDim + 1), + sourceShape.slice(startDim, endDim - startDim + 1)); + resultExprs.push_back(linearizedExpr); + } + return AffineMap::get(sourceMap.getNumDims(), sourceMap.getNumSymbols(), + resultExprs, context); +} + +static bool isTensorReshapeOpFusible(TensorReshapeOp reshapeOp, + AffineMap useIndexMap, bool asProducer) { + RankedTensorType returnType = reshapeOp.getResultType(); + RankedTensorType operandType = reshapeOp.getSrcType(); + return + // Fuse only if the operands are static shaped. + returnType.hasStaticShape() && operandType.hasStaticShape() && + // Reshape is fusible with its consumer (i.e. reshape as a producer) + // when its operand is of lesser rank than the result. Fusing when + // operand has higher rank will require use of mods and divs in the + // indexing maps of the fused op which would make it + // non-invertible. Similarly reshape is fused with its producer + // (i.e. reshape as consumer) only if the return type has lesser rank. + (asProducer ? returnType.getRank() >= operandType.getRank() + : operandType.getRank() >= returnType.getRank()) && + // Currently only implement the case where the consumer is using an + // identity map for the indexing_map + useIndexMap.isIdentity(); +} + +namespace { +template +struct FuseTensorReshapeOpAsProducer + : public FuseTensorOpsImpl, + TensorReshapeOp, LinalgOpTy> { + static bool isFusible(TensorReshapeOp producer, LinalgOpTy consumer, + unsigned consumerIdx) { + return isTensorReshapeOpFusible( + producer, consumer.getInputIndexingMap(consumerIdx), true); + } + + static Operation *createFusedOp(PatternRewriter &rewriter, + ArrayRef fusedOperands, + TensorReshapeOp producer, LinalgOpTy consumer, + unsigned consumerIdx, + OperationFolder *folder = nullptr) { + // Compute indexing_maps for the fused operation. The indexing_maps for the + // operands of the consumers that arent fused are the same. + auto consumerIndexMaps = consumer.indexing_maps(); + SmallVector fusedIndexMaps; + fusedIndexMaps.reserve(consumerIndexMaps.size()); + fusedIndexMaps.append(consumerIndexMaps.begin(), + std::next(consumerIndexMaps.begin(), consumerIdx)); + + // Compute the indexing map to use for the operand of the producer. + AffineMap producerOperandIndexingMap = linearizeCollapsedDims( + consumer.getInputIndexingMap(consumerIdx), + producer.getResultType().getShape(), producer.getReassociationMaps()); + fusedIndexMaps.push_back(AffineMapAttr::get(producerOperandIndexingMap)); + + // The rest of the indexing_maps need to be obtained from the remaining + // unfused consumer operands. + fusedIndexMaps.append(std::next(consumerIndexMaps.begin(), consumerIdx + 1), + consumerIndexMaps.end()); + + auto fusedOp = rewriter.create( + rewriter.getUnknownLoc(), consumer.getResultTypes(), fusedOperands, + rewriter.getI64IntegerAttr(fusedOperands.size()), + rewriter.getI64IntegerAttr(consumer.getNumResults()), + rewriter.getArrayAttr(fusedIndexMaps), consumer.iterator_types(), + /*doc=*/nullptr, + /*library_call=*/nullptr); + auto &fusedRegion = fusedOp.region(); + rewriter.cloneRegionBefore(consumer.region(), fusedRegion, + fusedRegion.begin()); + return fusedOp; + } +}; + +template +struct FuseTensorReshapeOpAsConsumer + : public FuseTensorOpsImpl, + LinalgOpTy, TensorReshapeOp> { + static bool isFusible(LinalgOpTy producer, TensorReshapeOp consumer, + unsigned consumerIdx) { + return isTensorReshapeOpFusible(consumer, producer.getOutputIndexingMap(0), + false); + } + + static Operation *createFusedOp(PatternRewriter &rewriter, + ArrayRef fusedOperands, + LinalgOpTy producer, TensorReshapeOp consumer, + unsigned consumerIdx, + OperationFolder *folder = nullptr) { + // Thee indexing_maps for the operands that were originally operands from + // the producers are the same as before. + auto producerIndexMaps = producer.indexing_maps(); + SmallVector fusedIndexMaps; + fusedIndexMaps.reserve(producerIndexMaps.size()); + fusedIndexMaps.append(producerIndexMaps.begin(), + std::prev(producerIndexMaps.end())); + + // Compute the indexing map to use for the operand of the producer. + AffineMap producerResultIndexingMap = linearizeCollapsedDims( + producer.getOutputIndexingMap(0), consumer.getSrcType().getShape(), + consumer.getReassociationMaps()); + fusedIndexMaps.push_back(AffineMapAttr::get(producerResultIndexingMap)); + + auto fusedOp = rewriter.create( + rewriter.getUnknownLoc(), consumer.getResultType(), fusedOperands, + rewriter.getI64IntegerAttr(fusedOperands.size()), + rewriter.getI64IntegerAttr(1), rewriter.getArrayAttr(fusedIndexMaps), + producer.iterator_types(), + /*doc=*/nullptr, + /*library_call=*/nullptr); + auto &fusedRegion = fusedOp.region(); + rewriter.cloneRegionBefore(producer.region(), fusedRegion, + fusedRegion.begin()); + return fusedOp; + } +}; +} // namespace + template <> Operation *mlir::linalg::fuseTensorOps(PatternRewriter &rewriter, GenericOp consumer, @@ -591,6 +751,24 @@ if (genericOp.hasTensorSemantics()) return FuseGenericOpsOnTensors::fuse(rewriter, genericOp, consumer, consumerIdx); + } else if (auto reshapeOp = dyn_cast(producer)) { + return FuseTensorReshapeOpAsProducer::fuse( + rewriter, reshapeOp, consumer, consumerIdx); + } + return nullptr; +} + +template <> +Operation *mlir::linalg::fuseTensorOps( + PatternRewriter &rewriter, TensorReshapeOp consumer, unsigned consumerIdx, + OperationFolder *folder) { + Operation *producer = consumer.getOperand().getDefiningOp(); + if (!producer || producer->getNumResults() != 1) + return nullptr; + if (auto genericOp = dyn_cast(producer)) { + if (genericOp.hasTensorSemantics()) + return FuseTensorReshapeOpAsConsumer::fuse( + rewriter, genericOp, consumer, consumerIdx); } return nullptr; } @@ -626,7 +804,7 @@ void runOnOperation() override { OwningRewritePatternList patterns; Operation *op = getOperation(); - patterns.insert>(op->getContext()); + populateLinalgTensorOpsFusionPatterns(op->getContext(), patterns); applyPatternsAndFoldGreedily(op->getRegions(), patterns); }; }; @@ -636,6 +814,12 @@ }; } // namespace +void mlir::populateLinalgTensorOpsFusionPatterns( + MLIRContext *context, OwningRewritePatternList &patterns) { + patterns.insert, FuseTensorOps>( + context); +} + std::unique_ptr> mlir::createLinalgFusionPass() { return std::make_unique(); } diff --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir --- a/mlir/test/Dialect/Linalg/fusion-tensor.mlir +++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir @@ -129,3 +129,68 @@ return %1 : tensor } + +// ----- + +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> + +#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +func @generic_op_reshape_producer_fusion(%arg0 : tensor<2x12x5xf32>, + %arg1 : tensor<2x3x4x5xf32>) -> + tensor<2x3x4x5xf32> +{ + %0 = linalg.tensor_reshape %arg0 [affine_map<(i, j, k, l) -> (i)>, + affine_map<(i, j, k, l) -> (j, k)>, + affine_map<(i, j, k, l) -> (l)>] : + tensor<2x12x5xf32> into tensor<2x3x4x5xf32> + %1 = linalg.generic + {args_in = 2 : i64, args_out = 1 : i64, + indexing_maps = [#map0, #map0, #map0], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + %0, %arg1 { + ^bb0(%arg3: f32, %arg4: f32): // no predecessors + %1 = mulf %arg3, %arg4 : f32 + linalg.yield %1 : f32 + }: tensor<2x3x4x5xf32>, tensor<2x3x4x5xf32> -> tensor<2x3x4x5xf32> + return %1 : tensor<2x3x4x5xf32> +} + +// CHECK-LABEL: func @generic_op_reshape_producer_fusion +// CHECK: linalg.generic +// CHECK-SAME: args_in = 2 +// CHECK-SAME: args_out = 1 +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP1]]] +// CHECK-NOT: linalg.generic + +// ----- + +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)> + +#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +func @generic_op_reshape_consumer_fusion(%arg0 : tensor<2x3x4x5xf32>, + %arg1 : tensor<2x3x4x5xf32>) -> + tensor<2x60xf32> +{ + %0 = linalg.generic + {args_in = 2 : i64, args_out = 1 : i64, + indexing_maps = [#map0, #map0, #map0], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + %arg0, %arg1 { + ^bb0(%arg3: f32, %arg4: f32): // no predecessors + %1 = mulf %arg3, %arg4 : f32 + linalg.yield %1 : f32 + }: tensor<2x3x4x5xf32>, tensor<2x3x4x5xf32> -> tensor<2x3x4x5xf32> + %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>, + affine_map<(i, j, k, l) -> (j, k, l)>] : + tensor<2x3x4x5xf32> into tensor<2x60xf32> + return %1 : tensor<2x60xf32> +} + +// CHECK-LABEL: func @generic_op_reshape_consumer_fusion +// CHECK: linalg.generic +// CHECK-SAME: args_in = 2 +// CHECK-SAME: args_out = 1 +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]], #[[MAP1]]] +// CHECK-NOT: linalg.generic