diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -834,11 +834,77 @@ /*asProducer=*/false); } + static LinalgOp propagateReshapeOpToOperands(LinalgOp producer, + TensorReshapeOp consumer, + unsigned consumerIdx, + PatternRewriter &rewriter) { + if (isa(producer.getOperation())) + return nullptr; + + // All the shapes are the same for the generic op. + auto types = producer.getInputOutputShapedTypes(); + if (!types.empty() && + !std::equal(types.begin() + 1, types.end(), types.begin())) { + return nullptr; + } + + for (auto indexingMap : producer.getIndexingMaps()) { + if (!indexingMap.isIdentity()) + return nullptr; + } + + // All loops are parallel loops. + for (auto iteratorType : producer.iterator_types()) { + if (iteratorType.cast().getValue() != + getParallelIteratorTypeName()) + return nullptr; + } + + Location loc = producer.getLoc(); + auto dstShape = consumer.getResultType().cast().getShape(); + SmallVector args; + for (auto arg : producer.getOperation()->getOperands()) { + auto type = RankedTensorType::get( + dstShape, arg.getType().cast().getElementType()); + if (auto cstOp = arg.getDefiningOp()) { + args.push_back(rewriter.create( + loc, cstOp.value().cast().reshape(type))); + } else { + args.push_back(rewriter.create( + loc, type, arg, consumer.reassociation())); + } + } + + SmallVector resultTypes; + for (auto t : producer.getOutputTensorTypes()) { + Type type = RankedTensorType::get(dstShape, + t.cast().getElementType()); + resultTypes.push_back(type); + } + + int rank = dstShape.size(); + int numArgsIn = producer.getNumInputs(); + int numArgsOut = producer.getNumOutputs(); + auto genericOp = rewriter.create( + loc, resultTypes, args, numArgsIn, numArgsOut, + SmallVector(args.size() + resultTypes.size(), + rewriter.getMultiDimIdentityMap(rank)), + SmallVector(rank, getParallelIteratorTypeName())); + Region ®ion = genericOp.getRegion(); + rewriter.cloneRegionBefore(producer.getOperation()->getRegion(0), region, + region.begin()); + return cast(genericOp.getOperation()); + } + static LinalgOp fuse(LinalgOp producer, TensorReshapeOp consumer, unsigned consumerIdx, PatternRewriter &rewriter, OperationFolder *folder = nullptr) { - if (!isFusible(producer, consumer, consumerIdx)) - return nullptr; + if (!isFusible(producer, consumer, consumerIdx)) { + // If the reshape op can not fuse to producer directly, try to make the + // reshape op to be as producer case. + return propagateReshapeOpToOperands(producer, consumer, consumerIdx, + rewriter); + } // The indexing_maps for the operands of the fused operation are same as // those for the operands of the producer. diff --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir --- a/mlir/test/Dialect/Linalg/fusion-tensor.mlir +++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir @@ -222,6 +222,37 @@ // ----- +#map0 = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d2)> + +func generic_op_reshape_consumer_expanding(%arg0: tensor<264x4xf32>) + -> tensor<8x33x4xf32> { + %cst = constant dense<2.000000e+00> : tensor<264x4xf32> + %0 = linalg.generic + {args_in = 2 : i64, args_out = 1 : i64, + indexing_maps = [#map0, #map0, #map0], + iterator_types = ["parallel", "parallel"]} + %arg0, %cst { + ^bb0(%arg1: f32, %arg2: f32): // no predecessors + %2 = mulf %arg1, %arg2 : f32 + linalg.yield %2 : f32 + }: tensor<264x4xf32>, tensor<264x4xf32> -> tensor<264x4xf32> + %1 = linalg.tensor_reshape %0 [#map1, #map2] : + tensor<264x4xf32> into tensor<8x33x4xf32> + return %1 : tensor<8x33x4xf32> + } +} + +// The reshape op in `%arg0` is folded into the indexing map of generic op. +// CHECK-LABEL: func generic_op_reshape_consumer_expanding +// CHECK-NOT: linalg.tensor_reshape +// CHECK: %[[CST:.*]] = constant {{.*}} : f32 +// CHECK: linalg.generic +// CHECK-NOT: linalg.tensor_reshape + +// ----- + #map0 = affine_map<(d0, d1, d2) -> (d0)> #map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> func @generic_op_constant_fusion(%arg0 : tensor<5x?x?xf32>) -> tensor<5x?x?xf32>