diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -162,13 +162,24 @@ // guarantees at least one such dimension is found. If multiple candidates exist // they must agree by construction (i.e. have the same size) and we just return // the first one. -static ShapeDimension getShapeDefiningLoopRange(LinalgOp op, - unsigned loopDepth) { +static ShapeDimension +getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth, + bool fromSubViewOpOnly = false) { auto maps = op.indexing_maps(); // Iterate over the inputs and outputs in order. // Extract the subranges from the linearized ranges. SmallVector ios(op.getInputsAndOutputBuffers()); for (auto en : llvm::enumerate(ios)) { + // The method `getRangeFromOperandShape` requires using SubViewOp or + // SubTensorOps. If the value isnt defined from there continue. + // todo: The method should be adapted to get the values from + // `ViewInterface`. The interface needs a `getOrCreateRanges` method which + // currently returns a `linalg.range`. The fix here is to move this op to + // `std` dialect and add the method to `ViewInterface`. + if (fromSubViewOpOnly && + !isa_and_nonnull(en.value().getDefiningOp())) + continue; + unsigned idx = en.index(); auto map = maps[idx].cast().getValue(); LLVM_DEBUG(llvm::dbgs() @@ -821,7 +832,7 @@ builder.setInsertionPoint(tiledOp); DenseMap fusedLoopsAndRanges; for (unsigned loop : fusedLoops) { - ShapeDimension shapeDim = getShapeDefiningLoopRange(tiledOp, loop); + ShapeDimension shapeDim = getShapeDefiningLoopRange(tiledOp, loop, true); fusedLoopsAndRanges[loop] = getRangeFromOperandShape( builder, tiledOp.getLoc(), shapeDim.shape, shapeDim.dimension); } diff --git a/mlir/test/Dialect/Linalg/fusion-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-pattern.mlir --- a/mlir/test/Dialect/Linalg/fusion-pattern.mlir +++ b/mlir/test/Dialect/Linalg/fusion-pattern.mlir @@ -426,3 +426,30 @@ return } } + +// ----- + +module { + func @basic_conv_fusion(%arg0: memref, %arg1: memref, + %arg2: memref) { + %cst = constant 0.000000e+00 : f32 + linalg.fill(%arg2, %cst) : memref, f32 + linalg.conv(%arg0, %arg1, %arg2) { + dilations = [1, 1], strides = [1, 1], + __internal_linalg_transform__ = "basic_fusion"} : + memref, memref, memref + return + } +} +// CHECK: func @basic_conv_fusion +// CHECK: linalg.fill +// CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion_original" +// CHECK: scf.parallel (%{{.+}}, %{{.+}}, %{{.+}}) +// CHECK-SAME: { +// CHECK: linalg.fill +// CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion_producer" +// CHECK: linalg.conv +// CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion" +// CHECK: } +// CHECK: linalg.conv +// CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion_original" diff --git a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp --- a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp @@ -38,7 +38,8 @@ static void fillFusionPatterns(MLIRContext *context, const LinalgDependenceGraph &dependenceGraph, OwningRewritePatternList &patterns) { - patterns.insert>( + patterns.insert, + LinalgTileAndFusePattern>( context, dependenceGraph, LinalgTilingOptions() .setTileSizes({32, 64, 16})