diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-fuse.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-fuse.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/transform-tile-and-fuse.mlir @@ -0,0 +1,58 @@ +// RUN: mlir-opt %s --test-transform-dialect-interpreter -canonicalize | FileCheck %s + +// This is a simple tile-and-fuse example with a single fusion group. + +module { + // CHECK: func @foo + // CHECK: scf.foreach_thread {{.*}} { + // CHECK: linalg.fill + // CHECK: linalg.matmul + // CHECK: linalg.generic + // CHECK: } + func.func @foo(%A: tensor, %B: tensor, %C: tensor, + %D: tensor, %sz0: index, %sz1: index) + -> tensor + { + %cst = arith.constant 0.000000e+00 : f32 + %5 = linalg.fill + {__producer__} + ins(%cst : f32) + outs(%D : tensor) -> tensor + %6 = linalg.matmul + {__producer__} + ins(%A, %B : tensor, tensor) + outs(%5 : tensor) -> tensor + %7 = linalg.generic + {__root__, + indexing_maps = [affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } + ins(%C, %6 : tensor, tensor) + outs(%D : tensor) { + ^bb0(%arg2: f32, %arg3: f32, %arg4: f32): + %16 = arith.maxf %arg3, %cst : f32 + %17 = arith.cmpf ogt, %arg2, %cst : f32 + %18 = arith.select %17, %cst, %16 : f32 + linalg.yield %18 : f32 + } -> tensor + return %7 : tensor + } + + transform.with_pdl_patterns { + ^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + // Find the root and all producers. + %root = transform.structured.match attribute{"__root__"} in %arg1 + %producers = transform.structured.match attribute{"__producer__"} in %arg1 + + // Tile the root. + %foreach_thread_op, %tiled_op = transform.structured.tile_to_foreach_thread_op %root num_threads [10, 20] + + // Fuse all producers. + transform.structured.fuse_into_containing_op %producers into %foreach_thread_op + } + } +}