Page MenuHomePhabricator

[mlir][vector] Allow values outside of [0; dim-size] in create_mask
ClosedPublic

Authored by sgrechanik on Dec 20 2021, 5:59 PM.

Details

Summary

This commits explicitly states that negative values and values exceeding
vector dimensions are allowed in vector.create_mask (but not in
vector.constant_mask). These values are now truncated when
canonicalizing vector.create_mask to vector.constant_mask.

Diff Detail

Event Timeline

sgrechanik created this revision.Dec 20 2021, 5:59 PM
sgrechanik requested review of this revision.Dec 20 2021, 5:59 PM

Thanks for the contribution, Sergei! I think I don't have enough experience with this op so I'll leave this to @nicolasvasilache.
How do you end up generating values that are out of the expected bounds of the mask?
Folding invalid values into valid ones could be surprising and maybe lead to silent bugs (?) but maybe it makes sense for this op. An alternative would be generating code in your use case to handle these out-of-bounds values and truncate them accordingly before they are passed to the create_mask op. I would useful if you could share a bit more about your use case.

Thanks,
Diego

@aartbik is actually the original contributor of this abstraction and the main user at this time.
Offhand it would seem to me that we wouldn't want negative values here?
I would personally rather go for an explicit truncation, but @aartbik will know better.

Values that are greater than the vector size are already used by the vectorizer when vectorizing reductions and creating a mask. In this example the mask filters out garbage elements (with index >= 400) and is based on the value %elts_left, which is often greater than 64:

func @vecdim_reduction_masked(%arg0: memref<?xf32>, %arg1: memref<f32>) {
  %cst = arith.constant 0.000000e+00 : f32
  %cst_0 = arith.constant dense<0.000000e+00> : vector<64xf32>
  %0 = affine.for %arg2 = 0 to 400 step 64 iter_args(%arg3 = %cst_0) -> (vector<64xf32>) {
    %elts_left = affine.apply affine_map<(d0) -> (400 - d0)>(%arg2)
    %3 = vector.create_mask %elts_left : vector<64xi1>
    %4 = vector.transfer_read %arg0[%arg2], %cst : memref<?xf32>, vector<64xf32>
    %5 = arith.addf %arg3, %4 : vector<64xf32>
    %6 = select %3, %5, %arg3 : vector<64xi1>, vector<64xf32>
    affine.yield %6 : vector<64xf32>
  }
  %1 = vector.reduction "add", %0 : vector<64xf32> into f32
  affine.store %1, %arg1[] : memref<f32>
  return
}

(It then fails when peeling and unrolling multiple loop iterations, in which case create_mask ops become constant_mask ops which statically checks that the value is within bounds).

Negative values are more rare, in our use case we sometimes add empty loop iterations to avoid remainder loops when performing loop unrolling. In the code above this would change the loop bound:

...
// The upper bound is changed from 400 to 512, adding an empty iteration:
%0 = affine.for %arg2 = 0 to 512 step 64 iter_args(%arg3 = %cst_0) -> (vector<64xf32>) {
  // But we still use the constant 400 here to make sure that the last iteration is really empty:
  %elts_left = affine.apply affine_map<(d0) -> (400 - d0)>(%arg2)
   ...

When %arg2 equals 448, the value of %elts_left becomes negative, but the intention here is the same, to filter out elements with the index >= 400, so the mask must be all zeros in this case.

In my opinion, both cases are natural continuations of the create_masks's semantics. We might want to keep the checks on the constant_mask op though, to have better protection against mistakes.

Thanks, Sergei! Much clearer now. I think your proposal makes more sense to me now. The problem is very specific to the way you are generating the code but it's clearing exposing some corner cases of the create_mask operation. I'll leave it to Aart, though, since he has much more context on these ops (I'm not even sure I understand why we have the constant and non-constant variant of this op so I can't fully understand the implications of this change).

My suggestion here, though, is that you shouldn't be limited by the create_mask op. You could always create masks by using logical operations. For example, you could compute something like (assuming VF=4): ([iv, iv, iv, iv] + [0, 1, 2, 3]) < [ub, ub, ub, ub]. I'm currently working on a proposal for masking and I'm not always able to use create_mask to generate all the masks needed.

Hopefully that helps!

Thanks,
Diego

The issue here is that the index type does not have a sign, so all mask values are interpreted as >= 0. In the original implementation, the

%0 = arith.constant -2 : index

would really give a mask value of

18446744073709551614

This change gives an interpretation to the mask index.

aartbik accepted this revision.Jan 18 2022, 5:35 PM

I looked at some of the actual implementations, and calling

func @create_vector_mask_dyn(%c : index) -> vector<10xi1> {
   %0 = vector.create_mask %c : vector<10xi1>
   return %0 : vector<10xi1>
 }

with %c between -4 and 11

-4 == 18446744073709551612 ( 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
-3 == 18446744073709551613 ( 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
-2 == 18446744073709551614 ( 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
-1 == 18446744073709551615 ( 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
0 ( 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
1 ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
2 ( 1, 1, 0, 0, 0, 0, 0, 0, 0, 0 )
3 ( 1, 1, 1, 0, 0, 0, 0, 0, 0, 0 )
4 ( 1, 1, 1, 1, 0, 0, 0, 0, 0, 0 )
5 ( 1, 1, 1, 1, 1, 0, 0, 0, 0, 0 )
6 ( 1, 1, 1, 1, 1, 1, 0, 0, 0, 0 )
7 ( 1, 1, 1, 1, 1, 1, 1, 0, 0, 0 )
8 ( 1, 1, 1, 1, 1, 1, 1, 1, 0, 0 )
9 ( 1, 1, 1, 1, 1, 1, 1, 1, 1, 0 )
10 ( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 )
11 ( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 )

So this change seems in line of what the de facto implementation is (since I see signed comparison in the generated code).

This revision is now accepted and ready to land.Jan 18 2022, 5:35 PM

Thanks! I'll merge this tomorrow if there are no more comments.
(As an alternative we can merge only the half of this change related to values larger than the vector size, and keep the old behavior for negative values, but it's probably better to explicitly require that the mask index should be interpreted as a signed integer than to leave this unspecified).