diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -396,7 +396,8 @@ } // namespace template -static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp) { +static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, + ArrayRef operands) { // Fold producer-consumer reshape ops that where the operand type of the // producer is same as the return type of the consumer. This can only be // verified if the shapes in question are static. @@ -406,6 +407,10 @@ reshapeOp.getResultType().hasStaticShape() && reshapeSrcOp.getSrcType() == reshapeOp.getResultType()) return reshapeSrcOp.src(); + if (auto elements = operands.front().dyn_cast_or_null()) { + return elements.reshape( + reshapeOp.getResult().getType().template cast()); + } return nullptr; } @@ -1175,18 +1180,18 @@ // TODO: Consider making all this boilerplate easy to autogenerate // with Tablegen. This seems a desirable property in the context of OpInterfaces // where a Linalg "named" op **isa** LinalgOp. -OpFoldResult ReshapeOp::fold(ArrayRef) { +OpFoldResult ReshapeOp::fold(ArrayRef operands) { if (succeeded(foldMemRefCast(*this))) return getResult(); - return foldReshapeOp(*this); + return foldReshapeOp(*this, operands); } OpFoldResult SliceOp::fold(ArrayRef) { if (succeeded(foldMemRefCast(*this))) return getResult(); return {}; } -OpFoldResult TensorReshapeOp::fold(ArrayRef) { - return foldReshapeOp(*this); +OpFoldResult TensorReshapeOp::fold(ArrayRef operands) { + return foldReshapeOp(*this, operands); } OpFoldResult TransposeOp::fold(ArrayRef) { if (succeeded(foldMemRefCast(*this))) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -773,6 +773,9 @@ static LinalgOp fuse(TensorReshapeOp producer, LinalgOp consumer, unsigned consumerIdx, PatternRewriter &rewriter, OperationFolder *folder = nullptr) { + if (producer.src().getDefiningOp()) + return nullptr; + if (!isFusible(producer, consumer, consumerIdx)) return nullptr; @@ -826,20 +829,19 @@ /// Implementation of fusion on tensor ops when consumer is a TensorReshapeOp. struct FuseTensorReshapeOpAsConsumer { - static bool isFusible(LinalgOp producer, TensorReshapeOp consumer, - unsigned consumerIdx) { + static bool isCollapsingAndFusible(LinalgOp producer, + TensorReshapeOp consumer, + unsigned consumerIdx) { return isa(producer.getOperation()) && producer.hasTensorSemantics() && isTensorReshapeOpFusible(consumer, producer.getOutputIndexingMap(0), /*asProducer=*/false); } - static LinalgOp fuse(LinalgOp producer, TensorReshapeOp consumer, - unsigned consumerIdx, PatternRewriter &rewriter, - OperationFolder *folder = nullptr) { - if (!isFusible(producer, consumer, consumerIdx)) - return nullptr; - + static LinalgOp fuseCollapsingCase(LinalgOp producer, + TensorReshapeOp consumer, + unsigned consumerIdx, + PatternRewriter &rewriter) { // The indexing_maps for the operands of the fused operation are same as // those for the operands of the producer. SmallVector fusedIndexMaps = @@ -882,6 +884,77 @@ fusedRegion.begin()); return fusedOp; } + + static bool isExpandingAndFusible(LinalgOp producer, TensorReshapeOp consumer, + unsigned consumerIdx) { + // Is fusible only if: + // 1) The producer is a generic op. + // 2) The producer has tensor semantics. + // 3) The tensor reshape op is a expanding case. + // 4) All the shapes are the same for the generic op. + // 5) All the indexing maps in producer are identity. + // 6) All the loops in producer are parallel loops. + // 7) The producer has a single user. + auto types = producer.getInputOutputShapedTypes(); + assert(!types.empty()); + return isa(producer.getOperation()) && + producer.hasTensorSemantics() && + consumer.getSrcType().getRank() < + consumer.getResultType().getRank() && + std::equal(types.begin() + 1, types.end(), types.begin()) && + llvm::all_of(producer.getIndexingMaps(), + [](AffineMap map) { return map.isIdentity(); }) && + llvm::all_of(producer.iterator_types(), + [](Attribute attr) { + return attr.cast().getValue() == + getParallelIteratorTypeName(); + }) && + producer.getOperation()->hasOneUse(); + } + + static LinalgOp fuseExpandingCase(LinalgOp producer, TensorReshapeOp consumer, + unsigned consumerIdx, + PatternRewriter &rewriter) { + Location loc = producer.getLoc(); + auto dstShape = consumer.getResultType().cast().getShape(); + SmallVector args; + for (auto arg : producer.getOperation()->getOperands()) { + auto type = RankedTensorType::get( + dstShape, arg.getType().cast().getElementType()); + args.push_back(rewriter.createOrFold( + loc, type, arg, consumer.reassociation())); + } + + SmallVector resultTypes; + for (auto t : producer.getOutputTensorTypes()) { + Type type = RankedTensorType::get(dstShape, + t.cast().getElementType()); + resultTypes.push_back(type); + } + + int rank = dstShape.size(); + int numArgsIn = producer.getNumInputs(); + int numArgsOut = producer.getNumOutputs(); + auto genericOp = rewriter.create( + loc, resultTypes, args, numArgsIn, numArgsOut, + SmallVector(args.size() + resultTypes.size(), + rewriter.getMultiDimIdentityMap(rank)), + SmallVector(rank, getParallelIteratorTypeName())); + Region ®ion = genericOp.getRegion(); + rewriter.cloneRegionBefore(producer.getOperation()->getRegion(0), region, + region.begin()); + return cast(genericOp.getOperation()); + } + + static LinalgOp fuse(LinalgOp producer, TensorReshapeOp consumer, + unsigned consumerIdx, PatternRewriter &rewriter, + OperationFolder *folder = nullptr) { + if (isCollapsingAndFusible(producer, consumer, consumerIdx)) + return fuseCollapsingCase(producer, consumer, consumerIdx, rewriter); + if (isExpandingAndFusible(producer, consumer, consumerIdx)) + return fuseExpandingCase(producer, consumer, consumerIdx, rewriter); + return nullptr; + } }; /// Implementation of fusion on tensor ops when producer is a splat constant. diff --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir --- a/mlir/test/Dialect/Linalg/fusion-tensor.mlir +++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir @@ -222,6 +222,40 @@ // ----- +#map0 = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d2)> + +func @generic_op_reshape_consumer_expanding(%arg0: tensor<264x4xf32>) + -> tensor<8x33x4xf32> { + %cst = constant dense<2.000000e+00> : tensor<264x4xf32> + %0 = linalg.generic + {args_in = 2 : i64, args_out = 1 : i64, + indexing_maps = [#map0, #map0, #map0], + iterator_types = ["parallel", "parallel"]} + %arg0, %cst { + ^bb0(%arg1: f32, %arg2: f32): // no predecessors + %2 = mulf %arg1, %arg2 : f32 + linalg.yield %2 : f32 + }: tensor<264x4xf32>, tensor<264x4xf32> -> tensor<264x4xf32> + %1 = linalg.tensor_reshape %0 [#map1, #map2] : + tensor<264x4xf32> into tensor<8x33x4xf32> + return %1 : tensor<8x33x4xf32> +} + +// The reshape op in `%arg0` is folded into the indexing map of generic op. +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0 * 33 + d1, d2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK: func @generic_op_reshape_consumer_expanding +// CHECK-NOT: linalg.tensor_reshape +// CHECK: %[[CST:.*]] = constant {{.*}} : f32 +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] +// CHECK: tensor<264x4xf32> -> tensor<8x33x4xf32> +// CHECK-NOT: linalg.tensor_reshape + +// ----- + #map0 = affine_map<(d0, d1, d2) -> (d0)> #map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> func @generic_op_constant_fusion(%arg0 : tensor<5x?x?xf32>) -> tensor<5x?x?xf32>