diff --git a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp @@ -183,8 +183,12 @@ Block *newBlock = rewriter.createBlock(&newRegion, newRegion.begin(), oldBlock->getArgumentTypes()); - // Add the result arguments to the new block. - for (Value v : newOutputBuffers) + // Add the result arguments that do not come from init_tensors to the new + // block. + // TODO: update this assumption because the reality is more complex under + // linalg on tensor based transformations. + for (Value v : + ValueRange(newOutputBuffers).drop_front(adaptor.init_tensors().size())) newBlock->addArgument(v.getType().cast().getElementType()); // Clone the body of the old block to the new block. diff --git a/mlir/test/Dialect/Linalg/tensors-to-buffers.mlir b/mlir/test/Dialect/Linalg/tensors-to-buffers.mlir --- a/mlir/test/Dialect/Linalg/tensors-to-buffers.mlir +++ b/mlir/test/Dialect/Linalg/tensors-to-buffers.mlir @@ -173,3 +173,142 @@ // know that things will play nicely at the C ABI boundary). func @print_memref_f32(%ptr : tensor<*xf32>) // CHECK-LABEL: func @print_memref_f32(memref<*xf32>) + +// ----- + +#accesses = [ + affine_map<(i, j, k) -> (j, i, k)>, + affine_map<(i, j, k) -> (i, j)> +] + +#trait = { + indexing_maps = #accesses, + iterator_types = ["parallel", "parallel", "reduction"] +} + +func @generic_with_init_tensor(%arg0: tensor<2x3x4xvector<3x4xi4>>, + %arg1: tensor<3x2xf32>) -> (tensor<3x2xf32>) { + + %0 = linalg.generic #trait + ins(%arg0 : tensor<2x3x4xvector<3x4xi4>>) + init(%arg1 : tensor<3x2xf32>) { + ^bb(%v0: vector<3x4xi4>, %v1: f32) : + %f0 = constant 0.0 : f32 + linalg.yield %f0 : f32 + } -> tensor<3x2xf32> + + return %0 : tensor<3x2xf32> +} +// CHECK-LABEL: func @generic_with_init_tensor +// CHECK-SAME: (%[[ARG0:.*]]: memref<2x3x4xvector<3x4xi4>>, %[[ARG1:.*]]: memref<3x2xf32>, %[[RESULT0:.*]]: memref<3x2xf32>) { +// CHECK-NEXT: linalg.generic +// CHECK: linalg.copy(%[[ARG1]], %[[RESULT0]]) +// CHECK-NEXT: return +// CHECK-NOT: % + +// ----- + +#accesses = [ + affine_map<(i, j, k) -> (j, i, k)>, + affine_map<(i, j, k) -> (i, j)> +] + +#trait = { + indexing_maps = #accesses, + iterator_types = ["parallel", "parallel", "reduction"] +} + +func @init_tensor_with_2_uses(%arg0: tensor<2x3x4xvector<3x4xi4>>, + %arg1: tensor<3x2xf32>) -> (tensor<3x2xf32>, tensor<3x2xf32>) { + + %0 = linalg.generic #trait + ins(%arg0 : tensor<2x3x4xvector<3x4xi4>>) + init(%arg1 : tensor<3x2xf32>) { + ^bb(%v0: vector<3x4xi4>, %v1: f32) : + %f0 = constant 0.0 : f32 + linalg.yield %f0 : f32 + } -> tensor<3x2xf32> + + %1 = linalg.generic #trait + ins(%arg0 : tensor<2x3x4xvector<3x4xi4>>) + init(%arg1 : tensor<3x2xf32>) { + ^bb(%v0: vector<3x4xi4>, %v1: f32) : + %f0 = constant 0.0 : f32 + linalg.yield %f0 : f32 + } -> tensor<3x2xf32> + + return %0, %1 : tensor<3x2xf32>, tensor<3x2xf32> +} +// CHECK-LABEL: func @init_tensor_with_2_uses +// CHECK-SAME: (%[[ARG0:.*]]: memref<2x3x4xvector<3x4xi4>>, %[[ARG1:.*]]: memref<3x2xf32>, %[[RESULT0:.*]]: memref<3x2xf32>, %[[RESULT1:.*]]: memref<3x2xf32>) { +// CHECK-NEXT: %[[ALLOC0:.*]] = alloc +// CHECK-NEXT: linalg.copy(%[[ARG1]], %[[ALLOC0]]) +// CHECK-NEXT: linalg.generic +// CHECK-SAME: outs(%[[ALLOC0]] +// CHECK-NEXT: ^bb +// CHECK-NEXT: constant +// CHECK-NEXT: yield +// CHECK-NEXT: } +// CHECK-NEXT: %[[ALLOC1:.*]] = alloc +// CHECK-NEXT: linalg.copy(%[[ARG1]], %[[ALLOC1]]) +// CHECK-NEXT: linalg.generic +// CHECK-SAME: outs(%[[ALLOC1]] +// CHECK-NEXT: ^bb +// CHECK-NEXT: constant +// CHECK-NEXT: yield +// CHECK-NEXT: } +// CHECK-NEXT: linalg.copy(%[[ALLOC0]], %[[RESULT0]]) +// CHECK-NEXT: dealloc +// CHECK-NEXT: linalg.copy(%[[ALLOC1]], %[[RESULT1]]) +// CHECK-NEXT: dealloc +// CHECK-NEXT: return +// CHECK-NOT: % + +// ----- + +#accesses = [ + affine_map<(i, j, k) -> (j, i, k)>, + affine_map<(i, j, k) -> (i, j)> +] + +#trait = { + indexing_maps = #accesses, + iterator_types = ["parallel", "parallel", "reduction"] +} + +func @init_tensor_with_1_use_def_chain(%arg0: tensor<2x3x4xvector<3x4xi4>>, + %arg1: tensor<3x2xf32>) -> (tensor<3x2xf32>) { + + %0 = linalg.generic #trait + ins(%arg0 : tensor<2x3x4xvector<3x4xi4>>) + init(%arg1 : tensor<3x2xf32>) { + ^bb(%v0: vector<3x4xi4>, %v1: f32) : + %f0 = constant 0.0 : f32 + linalg.yield %f0 : f32 + } -> tensor<3x2xf32> + + %1 = linalg.generic #trait + ins(%arg0 : tensor<2x3x4xvector<3x4xi4>>) + init(%0 : tensor<3x2xf32>) { + ^bb(%v0: vector<3x4xi4>, %v1: f32) : + %f0 = constant 0.0 : f32 + linalg.yield %f0 : f32 + } -> tensor<3x2xf32> + + return %1 : tensor<3x2xf32> +} +// CHECK-LABEL: func @init_tensor_with_1_use_def_chain +// CHECK-SAME: (%[[ARG0:.*]]: memref<2x3x4xvector<3x4xi4>>, %[[ARG1:.*]]: memref<3x2xf32>, %[[RESULT0:.*]]: memref<3x2xf32>) { +// CHECK-NEXT: linalg.generic +// CHECK-NEXT: ^bb +// CHECK-NEXT: constant +// CHECK-NEXT: yield +// CHECK-NEXT: } +// CHECK-NEXT: linalg.generic +// CHECK-NEXT: ^bb +// CHECK-NEXT: constant +// CHECK-NEXT: yield +// CHECK-NEXT: } +// CHECK-NEXT: linalg.copy(%[[ARG1]], %[[RESULT0]]) +// CHECK-NEXT: return +// CHECK-NOT: % diff --git a/mlir/test/Transforms/buffer-placement-preparation.mlir b/mlir/test/Transforms/buffer-placement-preparation.mlir --- a/mlir/test/Transforms/buffer-placement-preparation.mlir +++ b/mlir/test/Transforms/buffer-placement-preparation.mlir @@ -382,141 +382,3 @@ // CHECK-NEXT: linalg.copy(%[[SECOND_TUPLE_SECOND_ELEM]], %[[RESULT0]]) // CHECK-NEXT: linalg.copy(%[[ARG2]], %[[RESULT1]]) // CHECK-NEXT: return %[[SECOND_TUPLE_FIRST_ELEM]], %[[FIRST_TUPLE_FIRST_ELEM]], %[[FIRST_TUPLE_SECOND_ELEM]] - -// ----- - -#accesses = [ - affine_map<(i, j, k) -> (j, i, k)>, - affine_map<(i, j, k) -> (i, j)> -] - -#trait = { - indexing_maps = #accesses, - iterator_types = ["parallel", "parallel", "reduction"] -} - -func @generic_with_init_tensor( - %arg0: tensor<2x3x4xvector<3x4xi4>>, %arg1: tensor<3x2xf32>) -> (tensor<3x2xf32>) { - - %0 = linalg.generic #trait - ins(%arg0 : tensor<2x3x4xvector<3x4xi4>>) - init(%arg1 : tensor<3x2xf32>) { - ^bb(%v0: vector<3x4xi4>, %v1: f32) : - %f0 = constant 0.0 : f32 - linalg.yield %f0 : f32 - } -> tensor<3x2xf32> - - return %0 : tensor<3x2xf32> -} -// CHECK-LABEL: func @generic_with_init_tensor -// CHECK-SAME: (%[[ARG0:.*]]: memref<2x3x4xvector<3x4xi4>>, %[[ARG1:.*]]: memref<3x2xf32>, %[[RESULT0:.*]]: memref<3x2xf32>) { -// CHECK-NEXT: linalg.generic -// CHECK: linalg.copy(%[[ARG1]], %[[RESULT0]]) -// CHECK-NEXT: return -// CHECK-NOT: % - -// ----- - -#accesses = [ - affine_map<(i, j, k) -> (j, i, k)>, - affine_map<(i, j, k) -> (i, j)> -] - -#trait = { - indexing_maps = #accesses, - iterator_types = ["parallel", "parallel", "reduction"] -} - -func @init_tensor_with_2_uses( - %arg0: tensor<2x3x4xvector<3x4xi4>>, %arg1: tensor<3x2xf32>) -> (tensor<3x2xf32>, tensor<3x2xf32>) { - - %0 = linalg.generic #trait - ins(%arg0 : tensor<2x3x4xvector<3x4xi4>>) - init(%arg1 : tensor<3x2xf32>) { - ^bb(%v0: vector<3x4xi4>, %v1: f32) : - %f0 = constant 0.0 : f32 - linalg.yield %f0 : f32 - } -> tensor<3x2xf32> - - %1 = linalg.generic #trait - ins(%arg0 : tensor<2x3x4xvector<3x4xi4>>) - init(%arg1 : tensor<3x2xf32>) { - ^bb(%v0: vector<3x4xi4>, %v1: f32) : - %f0 = constant 0.0 : f32 - linalg.yield %f0 : f32 - } -> tensor<3x2xf32> - - return %0, %1 : tensor<3x2xf32>, tensor<3x2xf32> -} -// CHECK-LABEL: func @init_tensor_with_2_uses -// CHECK-SAME: (%[[ARG0:.*]]: memref<2x3x4xvector<3x4xi4>>, %[[ARG1:.*]]: memref<3x2xf32>, %[[RESULT0:.*]]: memref<3x2xf32>, %[[RESULT1:.*]]: memref<3x2xf32>) { -// CHECK-NEXT: %[[ALLOC0:.*]] = alloc -// CHECK-NEXT: linalg.copy(%[[ARG1]], %[[ALLOC0]]) -// CHECK-NEXT: linalg.generic -// CHECK-SAME: outs(%[[ALLOC0]] -// CHECK-NEXT: ^bb -// CHECK-NEXT: constant -// CHECK-NEXT: yield -// CHECK-NEXT: } -// CHECK-NEXT: %[[ALLOC1:.*]] = alloc -// CHECK-NEXT: linalg.copy(%[[ARG1]], %[[ALLOC1]]) -// CHECK-NEXT: linalg.generic -// CHECK-SAME: outs(%[[ALLOC1]] -// CHECK-NEXT: ^bb -// CHECK-NEXT: constant -// CHECK-NEXT: yield -// CHECK-NEXT: } -// CHECK-NEXT: linalg.copy(%[[ALLOC0]], %[[RESULT0]]) -// CHECK-NEXT: linalg.copy(%[[ALLOC1]], %[[RESULT1]]) -// CHECK-NEXT: return -// CHECK-NOT: % - -// ----- - -#accesses = [ - affine_map<(i, j, k) -> (j, i, k)>, - affine_map<(i, j, k) -> (i, j)> -] - -#trait = { - indexing_maps = #accesses, - iterator_types = ["parallel", "parallel", "reduction"] -} - -func @init_tensor_with_1_use_def_chain( - %arg0: tensor<2x3x4xvector<3x4xi4>>, %arg1: tensor<3x2xf32>) -> (tensor<3x2xf32>) { - - %0 = linalg.generic #trait - ins(%arg0 : tensor<2x3x4xvector<3x4xi4>>) - init(%arg1 : tensor<3x2xf32>) { - ^bb(%v0: vector<3x4xi4>, %v1: f32) : - %f0 = constant 0.0 : f32 - linalg.yield %f0 : f32 - } -> tensor<3x2xf32> - - %1 = linalg.generic #trait - ins(%arg0 : tensor<2x3x4xvector<3x4xi4>>) - init(%0 : tensor<3x2xf32>) { - ^bb(%v0: vector<3x4xi4>, %v1: f32) : - %f0 = constant 0.0 : f32 - linalg.yield %f0 : f32 - } -> tensor<3x2xf32> - - return %1 : tensor<3x2xf32> -} -// CHECK-LABEL: func @init_tensor_with_1_use_def_chain -// CHECK-SAME: (%[[ARG0:.*]]: memref<2x3x4xvector<3x4xi4>>, %[[ARG1:.*]]: memref<3x2xf32>, %[[RESULT0:.*]]: memref<3x2xf32>) { -// CHECK-NEXT: linalg.generic -// CHECK-NEXT: ^bb -// CHECK-NEXT: constant -// CHECK-NEXT: yield -// CHECK-NEXT: } -// CHECK-NEXT: linalg.generic -// CHECK-NEXT: ^bb -// CHECK-NEXT: constant -// CHECK-NEXT: yield -// CHECK-NEXT: } -// CHECK-NEXT: linalg.copy(%[[ARG1]], %[[RESULT0]]) -// CHECK-NEXT: return -// CHECK-NOT: % -