This is an archive of the discontinued LLVM Phabricator instance.

Lowering for 'tosa.scatter'
ClosedPublic

Authored by rafaelubalmw on May 22 2023, 9:26 AM.

Details

Summary

This patch adds support for tosa.scatter lowering in the --tosa-to-scf pass. Here's an example for this lowering:

func.func @tosa(
                %valuesIn : tensor<3x7x5xi32>,
                %indices : tensor<3x6xi32>,
                %input : tensor<3x6x5xi32>) ->
                tensor<3x7x5xi32> {
        %0 = "tosa.scatter"(%valuesIn, %indices, %input) :
                        (tensor<3x7x5xi32>,
                        tensor<3x6xi32>,
                        tensor<3x6x5xi32>) ->
                        (tensor<3x7x5xi32>)
        return %0 : tensor<3x7x5xi32>
}

translates to

func.func @tosa(%arg0: tensor<3x7x5xi32>, %arg1: tensor<3x6xi32>, %arg2: tensor<3x6x5xi32>) -> tensor<3x7x5xi32> {
  %c0 = arith.constant 0 : index
  %c3 = arith.constant 3 : index
  %c1 = arith.constant 1 : index
  %c6 = arith.constant 6 : index
  %c2 = arith.constant 2 : index
  %c5 = arith.constant 5 : index
  %c0_0 = arith.constant 0 : index
  %c1_1 = arith.constant 1 : index
  %0 = scf.for %arg3 = %c0_0 to %c3 step %c1_1 iter_args(%arg4 = %arg0) -> (tensor<3x7x5xi32>) {
    %1 = scf.for %arg5 = %c0_0 to %c6 step %c1_1 iter_args(%arg6 = %arg4) -> (tensor<3x7x5xi32>) {
      %extracted = tensor.extract %arg1[%arg3, %arg5] : tensor<3x6xi32>
      %2 = arith.index_cast %extracted : i32 to index
      %extracted_slice = tensor.extract_slice %arg2[%arg3, %arg5, %c0_0] [%c1_1, %c1_1, %c5] [%c1_1, %c1_1, %c1_1] : tensor<3x6x5xi32> to tensor<?x?x?xi32>
      %inserted_slice = tensor.insert_slice %extracted_slice into %arg6[%arg3, %2, %c0_0] [%c1_1, %c1_1, %c5] [%c1_1, %c1_1, %c1_1] : tensor<?x?x?xi32> into tensor<3x7x5xi32>
      scf.yield %inserted_slice : tensor<3x7x5xi32>
    }
    scf.yield %1 : tensor<3x7x5xi32>
  }
  return %0 : tensor<3x7x5xi32>
}
We have attempted an alternative lowering pass that uses `tensor.scatter` as an intermediate step. However, we opted to aim straight at the `scf` dialect for the following reasons:

- The `tensor.scatter` op doesn't seem to be used anywhere. There is no available lowering pass for this op (although we have one that we'll upstream soon).
- The `tosa.scatter` and `tensor.scatter` op have different indexing semantics. The `indices` argument of `tosa.scatter` must be non-trivially modified and restructured (e.g. with a `linalg.generic` op) to adapt to the needs of `tensor.scatter`. While this overhead may be simplified and fused after a subsequent `tensor.scatter` lowering, it adds complex logic and an obscure intermediate state. Unless there is a good reason to go through the `tensor` dialect that we're missing, this additional complexity may not be justified.

Diff Detail

Event Timeline

rafaelubalmw created this revision.May 22 2023, 9:26 AM
Herald added a project: Restricted Project. · View Herald TranscriptMay 22 2023, 9:26 AM
rafaelubalmw requested review of this revision.May 22 2023, 9:26 AM
rafaelubalmw edited the summary of this revision. (Show Details)May 22 2023, 9:45 AM
rafaelubalmw added a reviewer: Restricted Project.
rafaelubalmw edited reviewers, added: silvas, mehdi_amini, rriddle, jpienaar; removed: Restricted Project.May 22 2023, 10:02 AM
eric-k256 accepted this revision.May 25 2023, 9:06 PM

Thanks for taking this on. Scatter is a complicated op, and it is good to have a working legalization to linalg.

This revision is now accepted and ready to land.May 25 2023, 9:06 PM

to have a working legalization to linalg.

Is there a path from there to linalg? It's not clear to me how it works?
And actually I'm wondering if this lowering could be expressed with a linalg.generic indeed?

to have a working legalization to linalg.

Is there a path from there to linalg? It's not clear to me how it works?
And actually I'm wondering if this lowering could be expressed with a linalg.generic indeed?

Sorry, I misspoke and wrote linalg while meaning scf/tensor.
With scatter, I didn't think it would be possible to use linalg.generic as you don't have an affine map for the output because you are dependent on the values in the index tensor. Perhaps I'm missing an implementation option.

to have a working legalization to linalg.

Is there a path from there to linalg? It's not clear to me how it works?
And actually I'm wondering if this lowering could be expressed with a linalg.generic indeed?

We were unable to devise a lowering to linalg.generic directly. As @eric-k256 noted, this is due to the fact that the indexing on the output is based on the indices tensor rather than an affine map.

The IREE project created their own op in their 'LinalgExt' dialect to accommodate a more gradual lowering: https://github.com/openxla/iree/blob/97779d7f494660f88864b035475ec77a1e54c6c8/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.td#L62

The tensor.scatter operation is an obvious lowering target for tosa.scatter, but

  1. There is no lowering out of tensor.scatter at the moment (discussed a bit here: https://discourse.llvm.org/t/lowering-of-scatter-operations/70535)
  2. The tosa.scatter -> tensor.scatter lowering is not as straightforward as the names would suggest

to have a working legalization to linalg.

Is there a path from there to linalg? It's not clear to me how it works?
And actually I'm wondering if this lowering could be expressed with a linalg.generic indeed?

We were unable to devise a lowering to linalg.generic directly. As @eric-k256 noted, this is due to the fact that the indexing on the output is based on the indices tensor rather than an affine map.

The IREE project created their own op in their 'LinalgExt' dialect to accommodate a more gradual lowering: https://github.com/openxla/iree/blob/97779d7f494660f88864b035475ec77a1e54c6c8/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.td#L62

The tensor.scatter operation is an obvious lowering target for tosa.scatter, but

  1. There is no lowering out of tensor.scatter at the moment (discussed a bit here: https://discourse.llvm.org/t/lowering-of-scatter-operations/70535)
  2. The tosa.scatter -> tensor.scatter lowering is not as straightforward as the names would suggest

Just to elaborate on (2), this was our attempt to grab a tosa.scatter instance and lower it to tensor.scatter. The conversion of the indexing semantics involved introducing an intermediate linalg.generic step:

func.func @main(
		%valuesIn : tensor<3x7x5xi32>,
		%indices : tensor<3x6xi32>,
		%input : tensor<3x6x5xi32>) ->
		tensor<3x7x5xi32> {
	%0 = "tosa.scatter"(%valuesIn, %indices, %input) :
			(tensor<3x7x5xi32>,
			tensor<3x6xi32>,
			tensor<3x6x5xi32>) ->
			(tensor<3x7x5xi32>)
	return %0 : tensor<3x7x5xi32>
}

lowers to tensor.scatter as

func.func @main(
		%valuesIn : tensor<3x7x5xi32>,
		%indices : tensor<3x6xi32>,
		%input : tensor<3x6x5xi32>) ->
		tensor<3x7x5xi32> {
	%reshapedInput = tensor.reshape %input : tensor<3x6x5xi32> into tensor<3x6 x1x1x5 xi32>
	%emptyNewIndices = tensor.empty : tensor<3x6 x2 xindex>
	%newIndices = linalg.generic {
		indexing_maps [
			affine_map<(i, j, k) -> (i, j)>,
			affine_map<(i, j, k) -> (i, j, k)>
		],
		iterator_types = ["parallel", "parallel"]
	} ins(%indices: tensor<3x6xi32>) outs(%emptyNewIndices: tensor<3x6x2 xindex>) {
	^bb0(%index: i32):
		%i = linalg.index 0 : index
		%j = linalg.index 2 : index
		%zero = arith.constant 0 : index
		%isZero = arith.cmp eq %j, %zero : index
		%ret = arith.select %isZero, %i, %index
		linalg.yield %ret : index
	}
	%result = tensor.scatter %reshapedInput into %valuesIn[%newIndices]
			scatter_dims([0, 1]) unique :
			(tensor<3x6x 1x1x5 xi32>,
			tensor<3x7x5 xi32>,
			tensor<3x6 x2 xi32>) ->
			tensor<3x7x5 xi32>
	return %result : tensor<3x7x5xi32>
}

In our exploration of this path, we also created a lowering pattern for tensor.scatter -> scf (which we might upstream soon anyway). But we concluded that the additional overhead of the index conversion did not justify using tensor.scatter as an intermediate representation when lowering tosa.scatter.

Thanks for the detailed informations!

This revision was automatically updated to reflect the committed changes.