diff --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp @@ -104,16 +104,18 @@ return success(); } -// Specialization for `linalg::GenericOp`. +/// Specialization for `linalg::GenericOp` and `linalg::IndexedGenericOp`. /// A pattern to convert Generic Linalg operations which work on tensors to /// use buffers. BufferPlacement pass should be later used to move /// Alloc operations to the correct positions and insert the missing Dealloc /// operations in the correct places. -static void finalizeBufferAllocation(ConversionPatternRewriter &rewriter, - linalg::GenericOp genericOp, - ValueRange inputs, ValueRange outputs) { +template +static void +finalizeBufferAllocationForGenericOp(ConversionPatternRewriter &rewriter, + GenericOpTy genericOp, ValueRange inputs, + ValueRange outputs) { // Generate a new linalg operation that works on buffers. - auto newGenericOp = rewriter.create( + auto newGenericOp = rewriter.create( genericOp.getLoc(), /*resultTensorTypes=*/llvm::None, /*inputs=*/inputs, @@ -147,9 +149,7 @@ rewriter.replaceOp(genericOp, outputs); } -// TODO: Specialization for `linalg::IndexedGenericOp`. - -// Specialization for all other `linalg::LinalgOp`. +/// Specialization for all other `linalg::LinalgOp`. static void finalizeBufferAllocation(ConversionPatternRewriter &rewriter, linalg::LinalgOp linalgOp, ValueRange inputs, ValueRange outputs) { @@ -207,8 +207,15 @@ // Delegate to the linalg generic pattern. if (auto genericOp = dyn_cast(op)) { - finalizeBufferAllocation(rewriter, genericOp, adaptor.inputs(), - newOutputBuffers); + finalizeBufferAllocationForGenericOp( + rewriter, genericOp, adaptor.inputs(), newOutputBuffers); + return success(); + } + + // Delegate to the linalg indexed generic pattern. + if (auto genericOp = dyn_cast(op)) { + finalizeBufferAllocationForGenericOp( + rewriter, genericOp, adaptor.inputs(), newOutputBuffers); return success(); } diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir --- a/mlir/test/Dialect/Linalg/bufferize.mlir +++ b/mlir/test/Dialect/Linalg/bufferize.mlir @@ -45,6 +45,7 @@ // CHECK: linalg.generic // CHECK-SAME: ins(%{{.*}} : memref<4xf32>) // CHECK-SAME: outs(%[[RESULT0]], %[[RESULT1]] : memref<4xf32>, memref<4xf32>) +// CHECK-NEXT: ^bb0(%{{.*}}: f32, %{{.*}}: f32, %{{.*}}: f32): func @multiple_results(%arg0: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { %0, %1 = linalg.generic { indexing_maps = [#map0, #map0, #map0], @@ -59,6 +60,31 @@ // ----- +#map0 = affine_map<(d0) -> (d0)> + +// CHECK-LABEL: func @multiple_results_indexed +// CHECK: %[[RESULT0:.*]] = alloc() : memref<4xi32> +// CHECK: %[[RESULT1:.*]] = alloc() : memref<4xi32> +// CHECK: linalg.indexed_generic +// CHECK-SAME: ins(%{{.*}} : memref<4xi32>) +// CHECK-SAME: outs(%[[RESULT0]], %[[RESULT1]] : memref<4xi32>, memref<4xi32>) +// CHECK-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: i32, %{{.*}}: i32, %{{.*}}: i32): +func @multiple_results_indexed(%arg0: tensor<4xi32>) + -> (tensor<4xi32>, tensor<4xi32>) { + %0, %1 = linalg.indexed_generic { + indexing_maps = [#map0, #map0, #map0], + iterator_types = ["parallel"] + } ins(%arg0 : tensor<4xi32>) { + ^bb0(%i: index, %gen_arg1: i32): + %i_i32 = index_cast %i : index to i32 + %tmp1 = addi %gen_arg1, %i_i32 : i32 + linalg.yield %tmp1, %tmp1 : i32, i32 + } -> tensor<4xi32>, tensor<4xi32> + return %0, %1 : tensor<4xi32>, tensor<4xi32> +} + +// ----- + #map_2d = affine_map<(d0, d1) -> (d0, d1)> #map_2d_inv = affine_map<(d0, d1) -> (d1, d0)>