diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -855,12 +855,19 @@ // Get the argument number of the returned values. That is the operand // number to use for replacing uses of this operation. SmallVector returnedArgs; - for (Value yieldVal : yieldOp.values()) { - auto yieldArg = yieldVal.dyn_cast(); + for (const auto &yieldVal : llvm::enumerate(yieldOp.values())) { + auto yieldArg = yieldVal.value().dyn_cast(); if (!yieldArg || yieldArg.getOwner() != &body) return failure(); unsigned argumentNumber = yieldArg.getArgNumber(); - returnedArgs.push_back(genericOp->getOperand(argumentNumber)); + Value returnedArg = genericOp->getOperand(argumentNumber); + Type resultType = genericOp->getResult(yieldVal.index()).getType(); + // The input can have a different type than the result, e.g. a dynamic + // input dimension can be turned into a static output dimension. + if (returnedArg.getType() != resultType) + returnedArg = rewriter.create(genericOp.getLoc(), + resultType, returnedArg); + returnedArgs.push_back(returnedArg); } if (returnedArgs.size() != genericOp->getNumResults()) return failure(); diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -179,6 +179,27 @@ // ----- +#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +func @remove_no_op_mismatched_types(%arg0 : tensor) + -> tensor<1x2x3xf32> { + %out = linalg.init_tensor [1, 2, 3] : tensor<1x2x3xf32> + %g = linalg.generic { + indexing_maps = [#map, #map], + iterator_types = ["parallel", "parallel", "parallel"] + } ins(%arg0 : tensor) + outs(%out : tensor<1x2x3xf32>) { + ^bb0(%arg2 : f32, %arg3 : f32): + linalg.yield %arg2 : f32 + } -> (tensor<1x2x3xf32>) + return %g : tensor<1x2x3xf32> +} +// CHECK-LABEL: func @remove_no_op_mismatched_types +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor +// CHECK: %[[CAST:.*]] = tensor.cast %[[ARG0]] : tensor to tensor<1x2x3xf32> +// CHECK: return %[[CAST]] + +// ----- + #map = affine_map<(d0, d1) -> (d0, d1)> func @keep_not_noop(%arg0 : tensor) -> tensor { %c0 = arith.constant 0 : index