diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -2276,13 +2276,15 @@ SmallVector returnedArgs; for (Value yieldVal : yieldOp.values()) { auto yieldArg = yieldVal.dyn_cast(); - if (!yieldArg) + if (!yieldArg || yieldArg.getOwner() != &body) return failure(); unsigned argumentNumber = yieldArg.getArgNumber(); if (argumentNumber < numIndexArgs) return failure(); returnedArgs.push_back(op->getOperand(argumentNumber - numIndexArgs)); } + if (returnedArgs.size() != genericOp.getOperation()->getNumResults()) + return failure(); rewriter.replaceOp(genericOp, returnedArgs); return success(); } diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -615,3 +615,56 @@ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor // CHECK: return %[[ARG1]], %[[ARG0]] + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +func @keep_not_noop(%arg0 : tensor) -> tensor { + %c0 = constant 0 : index + %c1 = constant 1 : index + %cst = constant 1.000000e+00 : f32 + %0 = dim %arg0, %c0 : tensor + %1 = dim %arg0, %c1 : tensor + %2 = linalg.init_tensor [%0, %1] : tensor + br ^bb1(%cst : f32) + +^bb1(%arg1 : f32): + %3 = linalg.generic + {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} + ins(%arg0 : tensor) outs(%2 : tensor) { + ^bb0(%arg2: f32, %arg3 : f32): + linalg.yield %arg1 : f32 + } -> tensor + return %3 : tensor +} +// CHECK-LABEL: func @keep_not_noop +// CHECK: %[[RESULT:.+]] = linalg.generic +// CHECK: return %[[RESULT]] + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +func @keep_not_noop(%arg0 : tensor, %arg1 : tensor) + -> (tensor, tensor) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %cst = constant 1.000000e+00 : f32 + %0 = dim %arg0, %c0 : tensor + %1 = dim %arg0, %c1 : tensor + %2 = linalg.init_tensor [%0, %1] : tensor + br ^bb1(%cst : f32) + +^bb1(%arg2 : f32): + %3:2 = linalg.generic + {indexing_maps = [#map, #map, #map, #map], + iterator_types = ["parallel", "parallel"]} + ins(%arg0, %arg1 : tensor, tensor) + outs(%2, %2 : tensor, tensor) { + ^bb0(%arg3: f32, %arg4 : f32, %arg5 : f32, %arg6 : f32): + linalg.yield %arg2, %arg4 : f32, f32 + } -> tensor, tensor + return %3#0, %3#1 : tensor, tensor +} +// CHECK-LABEL: func @keep_not_noop +// CHECK: %[[RESULT:.+]]:2 = linalg.generic +// CHECK: return %[[RESULT]]#0, %[[RESULT]]#1