diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp @@ -36,8 +36,10 @@ using namespace mlir::linalg; static LogicalResult generalizeNamedOpPrecondition(LinalgOp linalgOp) { - // Check if the operation is a LinalgOp but not a GenericOp. - if (isa(linalgOp)) + // Bailout if `linalgOp` is already a generic or a linalg.map. We cannot + // trivially generalize a `linalg.map`, as it does not use the output as + // region arguments in the block. + if (isa(linalgOp) || isa(linalgOp)) return failure(); // Check if the operation has exactly one region. if (linalgOp->getNumRegions() != 1) { diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir --- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir +++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir @@ -272,3 +272,17 @@ // CHECK: %[[MUL:.+]] = arith.mulf %[[BBARG0]], %[[BBARG1]] : f32 // CHECK: %[[ADD:.+]] = arith.addf %[[BBARG2]], %[[MUL]] : f32 // CHECK: linalg.yield %[[ADD]] : f32 + +// ----- + +// CHECK-LABEL: generalize_linalg_map +func.func @generalize_linalg_map(%arg0: memref<1x8x8x8xf32>) { + %cst = arith.constant 0.000000e+00 : f32 + // CHECK: linalg.map + // CHECK-NOT: linalg.generic + linalg.map outs(%arg0 : memref<1x8x8x8xf32>) + () { + linalg.yield %cst : f32 + } + return +}