diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -901,3 +901,37 @@ // CHECK-SAME: outs(%[[OUT_CAST]] : // CHECK: %[[RESULT_CAST:.+]] = tensor.cast %[[CONV]] // CHECK: return %[[CONV]], %[[RESULT_CAST]] + +// ----- + +func @fold_multi_use_generic_op_with_consumer(%arg0 : tensor) -> (tensor, tensor<2x3x4xf32>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %d0 = tensor.dim %arg0, %c0 : tensor + %d1 = tensor.dim %arg0, %c1 : tensor + %d2 = tensor.dim %arg0, %c2 : tensor + %init1 = linalg.init_tensor [%d1, %d2, %d0] : tensor + %init2 = linalg.init_tensor [%d2, %d1, %d0] : tensor + %0:2 = linalg.generic { + iterator_types = ["parallel", "parallel", "parallel"], + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2, d0)>, + affine_map<(d0, d1, d2) -> (d2, d1, d0)>]} + ins(%arg0 : tensor) outs(%init1, %init2 : tensor, tensor) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32) : + linalg.yield %b0, %b0 : f32, f32 + } -> (tensor, tensor) + %1 = tensor.cast %0#1 : tensor to tensor<2x3x4xf32> + return %0#0, %1 : tensor, tensor<2x3x4xf32> +} +// CHECK: func @fold_multi_use_generic_op_with_consumer +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK-DAG: %[[INIT1:.+]] = linalg.init_tensor [2, 3, 4] : tensor<2x3x4xf32> +// CHECK-DAG: %[[CAST:.+]] = tensor.cast %[[ARG0]] : tensor to tensor<4x3x2xf32> +// CHECK-DAG: %[[INIT2:.+]] = linalg.init_tensor [3, 2, 4] : tensor<3x2x4xf32> +// CHECK: %[[GENERIC:.+]]:2 = linalg.generic +// CHECK-SAME: ins(%[[CAST]] : +// CHECK-SAME: outs(%[[INIT2]], %[[INIT1]] : +// CHECK: %[[RETURN_CAST:.+]] = tensor.cast %[[GENERIC]]#0 : tensor<3x2x4xf32> to tensor +// CHECK: return %[[RETURN_CAST]], %[[GENERIC]]#1