diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -1129,9 +1129,12 @@ SmallVector newOutputs; SmallVector newOutputTypes; for (auto output : op.outputs()) { + auto newOutputType = RankedTensorType::get( + reshapeFound.getSrcType().getShape(), + output.getType().template cast().getElementType()); Value newOutput = rewriter.create( - op->getLoc(), reshapeFound.getSrcType(), output, reassociation); - newOutputTypes.push_back(newOutput.getType()); + op->getLoc(), newOutputType, output, reassociation); + newOutputTypes.push_back(newOutputType); newOutputs.push_back(newOutput); } // 5. Create a new generic op with lowerer rank. diff --git a/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir b/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir --- a/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir +++ b/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir @@ -88,3 +88,40 @@ } -> tensor<112x112x16xf32> return %22 : tensor<112x112x16xf32> } + +// ----- + +func @type_correctness(%arg0 : tensor<6x5xi32>, %arg1 : tensor<5xf32>, + %arg2 : tensor<5xf32>) -> tensor<2x3x5xf32> { + %cst_6 = constant 1.000000e+00 : f32 + %cst_7 = constant 7.000000e+00 : f32 + %cst_8 = constant 1.1920929E-7 : f32 + %25 = linalg.tensor_reshape %arg0 + [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>] + : tensor<6x5xi32> into tensor<2x3x5xi32> + %26 = linalg.init_tensor [2, 3, 5] : tensor<2x3x5xf32> + %28 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2) -> (d2)>, + affine_map<(d0, d1, d2) -> (d2)>, + affine_map<(d0, d1, d2) -> (d0, d1, d2)>], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%25, %arg1, %arg2 : tensor<2x3x5xi32>, tensor<5xf32>, tensor<5xf32>) + outs(%26 : tensor<2x3x5xf32>) { + ^bb0(%arg6: i32, %arg7: f32, %arg8: f32, %arg9: f32): // no predecessors + %29 = sitofp %arg6 : i32 to f32 + %30 = addf %arg7, %cst_8 : f32 + %31 = divf %cst_7, %30 : f32 + %32 = divf %cst_6, %31 : f32 + %33 = mulf %29, %32 : f32 + %34 = addf %33, %arg8 : f32 + linalg.yield %34 : f32 + } -> tensor<2x3x5xf32> + return %28 : tensor<2x3x5xf32> +} +// CHECK-LABEL: func @type_correctness +// CHECK: %[[OP:.+]] = linalg.generic +// CHECK-SAME: ins(%{{.+}}, %{{.+}}, %{{.+}} : tensor<6x5xi32>, tensor<5xf32>, tensor<5xf32>) +// CHECK-SAME: outs(%{{.+}} : tensor<6x5xf32>) +// CHECK: linalg.tensor_reshape %[[OP]] +// CHECK-SAME: tensor<6x5xf32> into tensor<2x3x5xf32>