diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp @@ -39,9 +39,12 @@ // Check if the operation is a LinalgOp but not a GenericOp. if (isa(linalgOp)) return failure(); - // Check if the operation has a region builder. - if (!linalgOp.getRegionBuilder()) + // Check if the operation has exactly one region. + if (linalgOp->getNumRegions() != 1) { + assert(linalgOp->getNumRegions() == 0 && "op with multiple regions"); + // TOD: Otherwise it needs to be built explicitly from the region builder. return failure(); + } return success(); } diff --git a/mlir/test/Dialect/Linalg/transform-op-generalize.mlir b/mlir/test/Dialect/Linalg/transform-op-generalize.mlir --- a/mlir/test/Dialect/Linalg/transform-op-generalize.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-generalize.mlir @@ -10,8 +10,23 @@ return %0 : tensor } +// CHECK-LABEL: func @map_no_inputs( +func.func @map_no_inputs(%input: tensor<16x32x64xf32>, + %init: tensor<16x64xf32>) -> tensor<16x64xf32> { + // CHECK-NOT: linalg.map + // CHECK: linalg.generic + %reduce = linalg.reduce + ins(%input:tensor<16x32x64xf32>) + outs(%init:tensor<16x64xf32>) + dimensions = [1] + (%in: f32, %out: f32) { + %0 = arith.addf %out, %in: f32 + linalg.yield %0: f32 + } + func.return %reduce : tensor<16x64xf32> +} transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation): - %0 = transform.structured.match ops{["linalg.elemwise_unary"]} in %arg1 : (!pdl.operation) -> !pdl.operation + %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!pdl.operation) -> !pdl.operation %1 = transform.structured.generalize %0 }