Page MenuHomePhabricator

[MLIR] Introduce the type inference pass.
Needs ReviewPublic

Authored by arames on Sep 29 2021, 1:17 PM.

Details

Summary

Type Inference uses the existing forward dataflow analysis mechanisms, with a
lattice for the types, and the join relationship.

It is designed to be safe to run on existing IR: it will not modify types
wherever uninstructed.
Interfaces AllowsInputTypesRefinementInterface and
AllowsOuptputTypesRefinementInterface allow fine-grained control over what
input/output types can be specialized. The type inference pass will insert
type_specialization and type_relaxation ops to satisfy the specified
constraints.

Notes:

  • Due to potentially arbitrary constraints on block argument types coming from parent blocks' terminators, block argument types are never specialized. Instead, type_specialization ops are inserted when appropriate.
  • Support for input types refinement is fully controlled by AllowsInputTypesRefinementInterface.
  • Support for ouptut types refinement is explicit if AllowsOuptputTypesRefinementInterface is implemented. If not, it is implicit if InferTypeOpInterface is implemented.

Diff Detail

Event Timeline

There are a very large number of changes, so older changes are hidden. Show Older Changes
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
2107 ↗(On Diff #376020)

I'd prefer an active void here. "relax_types"

2144–2145 ↗(On Diff #376020)

specialize_types?

2152–2153 ↗(On Diff #376020)

This seems a little awkward to me. 'join_types' and 'type_specialization' seem to work together to implement a 'cast' with a dummy %join argument to pass the type information? is it useful to have them separate? (perhaps to force multiple casts to be performed the same way?)

mlir/include/mlir/Interfaces/InferTypeOpInterface.td
34–40

Should probably be documented that this needs to be a monotonic function on the type lattice?

181–182

"specialization" is used elsewhere. It would be good to consistently use one term or the other.

181–218

These interfaces seem to be able to take into account the current type of the input or output. However, they probably should be monotonic functions representing type constraints on the type lattice. Would it be better to have a less stateful interface that passed the current type and was required to be a pure function?

mlir/include/mlir/Transforms/Passes.td
706 ↗(On Diff #376020)

Output?

707 ↗(On Diff #376020)

What is implicit? That input and output types can both be refined?

mlir/lib/Dialect/StandardOps/IR/Ops.cpp
2820 ↗(On Diff #376020)

isLessSpecializedOrSame should probably be allowed and then canonicalized out?

mlir/lib/Transforms/TypeInference.cpp
2

InferTypes.cpp?

15

Can you add a picture of the type lattice?

107

Why is this a special case and not just another op that implements the right interfaces?

108

specialize

166

Perhaps this should just imply a default implementation of inferTypeOpInterface, to avoid a special case here?

182

Perhaps this should just imply a default implementation of inferTypeOpInterface, to avoid a special case here?

mlir/test/Transforms/type-inference.mlir
33–34

This seems weird because the type information is being propagated through join_types without its output type changing?

47–50

This also seems weird. why doesn't join_types output the join of its input types?

rriddle requested changes to this revision.Sep 29 2021, 2:48 PM

Will try to take a look in the next few days.

mlir/lib/Transforms/TypeInference.cpp
20

I'd suggest avoiding any kind of dialect dependency, because it prevents this from being a "general" transformation. I'd drive the creation of cast operations through a dialect interface specific to this transformation, and allow for dialects to decide how to do things.

This revision now requires changes to proceed.Sep 29 2021, 2:48 PM

Have you tried this on any graph regions? I guess maybe the same question goes for ForwardDataflowAnalysis?

mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
2107 ↗(On Diff #376020)

I'd prefer an active void here. "relax_types"

active voice :)

mlir/lib/Transforms/TypeInference.cpp
10–15

It seems that there is a critical assumption here: that static type information represents an abstract interpretation of the behavior of an op. Meaning that given the universe of values a particular op will always map input in the set to an output in the set regardless of the type. In this case specializing a type only represents the presence of more information about the op (i.e. a smaller set of possible values) and cannot change the behavior of an op. Certain kinds of ops are dangerous/would be affected by this transformation, like an op that outputs an integer given the number of input dimensions with unknown size.

arames marked 10 inline comments as done.Sep 30 2021, 12:52 PM

I went through the first round of comments.
I addressed the minor ones, and am working on the other open:

  • Consider a different mechanism than specialize_type and relax_type ops.
  • No special case for relax_type (if it is kept).
  • Maybe SameOperandsAndResultType (and similar traits) imply a default implementation of InferTypeOpInterface.
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
2152–2153 ↗(On Diff #376020)

I tried to use a short example taken from the tests, but it must be confusing and so not useful.
I both the "specialize" and "relax" examples.

mlir/include/mlir/Interfaces/InferTypeOpInterface.td
181–182

I updated the patch to use "specialize" and "relax" everywhere.

181–218

I was only aware that some exotic may need some control over this, but not to what extent. So I set it up this way to allow maximum control on the op side.
I am not sure looking at the current type would be enough here. For example, I can imagine the decision may depend on the input/output index, or operation attributes.

I think @silvas may want to weigh on this. I personally don't have a use for this, so would be fine restricting it.

mlir/include/mlir/Transforms/Passes.td
707 ↗(On Diff #376020)

"support for output types refinement" is implicit.

I rewrote to clarify.

mlir/lib/Dialect/StandardOps/IR/Ops.cpp
2820 ↗(On Diff #376020)

Sounds fair.
It should not happen with the current usage, but I can see it being useful in the future. (For example if we somehow support updating block argument types separately.)

mlir/lib/Transforms/TypeInference.cpp
2

Looking at other filenames in that directory, there seem to be a mix of active and passive voice names, with a majority of passive.
For example:

  • BufferDeallocation.cpp
  • Canonicalizer.cpp
  • LoopInvariantCodeMotion.cpp

But:

  • Bufferize.cpp
  • NormalizeMemRefs.cpp

Has a preference been established ? I'm happy to update to InferTypes.cpp and InferTypesPass if that's the established preference.

10–15

I believe the introduced the AllowsInputTypesSpecializationInterface and AllowsOutputTypesSpecializationInterface interfaces specifically address this.
By default, ops do NOT allow input or output type specialization (see details in the description for AllowsOutputTypesSpecializationInterface).

Could you check that addresses your concerns ?

20

Sounds interesting. I'll look into it.

107

I missed this as I was thinking of it as special, but it probably can simply use InferReturnTypes. Looking into it.

166

I'll look into it.
I suppose it would require the presence of the trait to imply implementation of the interface.
It should be relatively easy to do via tblgen, but is that something we already do ? or are fine to do ?

mlir/test/Transforms/type-inference.mlir
33–34

Yes. The behavior is intentional.

This follows the mechanisms to control type specialization for input types, that might be required for exotic ops.

For example I think the relax_type op itself (if kept) should be updated to use this instead of the special case:

  • infer the result type to be the same as the input type
  • but never specialize the output type

I expect most ops will simply rely on implicit support for output types specialization via InferTypeOpInterface, or maybe conditional via AllowsOutputTypesSpecializationInterface for specific outputs.

47–50

test.ti.join_types is a test op that "processes" its inputs, and somehow returns a result with a type that is the join of the input types.
Here we simply check that type-inference correctly propagates type information through type_relaxation (relax_type).
I agree the test as-is may be confusing. I reworked it to be clearer.

arames updated this revision to Diff 376332.Sep 30 2021, 12:53 PM
arames marked 3 inline comments as done.

Address minor comments. Still some work to do.

silvas added inline comments.Sep 30 2021, 12:58 PM
mlir/include/mlir/Interfaces/InferTypeOpInterface.td
181–218

These interfaces only determine whether an output or input is allowed to be refined. They don't actually mutate anything or calculate a type to refine to, so I don't think monotonicity concerns intersect with these two specific interfaces.

mlir/lib/Transforms/TypeInference.cpp
36

In torch-mlir, we found that DataFlowAnalysis.h miscompiles because it does not "meet" at control flow merge points: https://bugs.llvm.org/show_bug.cgi?id=51636

I suspect that this patch will be affected as well -- I don't think the analysis we want here can be implemented correctly without a "meet" callback on the lattice value as well.

CC @rriddle too

silvas added inline comments.Sep 30 2021, 1:02 PM
mlir/lib/Transforms/TypeInference.cpp
2

For me personally (take it with a grain of salt), I have been recently always using active voice (inferTypes.cpp) because it reads nicer on the mlir-opt command line (like a sequence of actions). I don't mind either way though.

arames added inline comments.Sep 30 2021, 1:26 PM
mlir/lib/Transforms/TypeInference.cpp
2

I was already using the term infer-types for the mlir-opt command.
I'll rename the file, pass, etc. to match this in the next update of the diff.

arames added inline comments.Sep 30 2021, 2:27 PM
mlir/lib/Transforms/TypeInference.cpp
166

I see this is already being done in tblgen. (For example code added in 31f40f603d0c0.).
So I'll extend that mechanism in a separate patch.

arames added inline comments.Oct 6 2021, 11:05 AM
mlir/lib/Transforms/TypeInference.cpp
36

Looking at the bug report, I am not understanding what is stated.
For type inference, it seems to me the dataflow analysis should always join. And the join of 3x? and ?x4 should be ?x?, not 3x4.
The refinement part becomes allowed because value types start with uninitialized, not their original type.
I test this for example in test_regionbranchopinterface in this patch.
I'll look at the bug. But I'd also be interested in chatting more.

arames updated this revision to Diff 377613.Oct 6 2021, 11:14 AM

Address first round of comments.

arames added inline comments.Oct 6 2021, 11:59 AM
mlir/lib/Transforms/TypeInference.cpp
2

Looking at this again, it seems to me filenames and classes use nouns more often than "active voice" names.
I do use -infer-types for the pass argument though.
I'm open to change if there is more support for the other approach.

36

Here is an example. Imagine test.ti.join_types are data-processing ops that happen to have the same return type as the input.

func @test_branch_join(%cond: i1, %x : tensor<1x2xi32>, %y : tensor<1x9xi32>) {
  cond_br %cond, ^bb1, ^bb2
^bb1:
  %val_true = "test.ti.join_types"(%x) : (tensor<1x2xi32>) -> tensor<*xi32>
  br ^bb3(%val_true : tensor<*xi32>)
^bb2:
  %val_false = "test.ti.join_types"(%y) : (tensor<1x9xi32>) -> tensor<*xi32>
  br ^bb3(%val_false : tensor<*xi32>)
^bb3(%val: tensor<*xi32>):
  "test.ti.allow_input_type_specialization"(%val) : (tensor<*xi32>) -> ()
  "test.ti.disallow_input_type_specialization"(%val) : (tensor<*xi32>) -> ()
  return
}
% /bin/mlir-opt -infer-types ~/work/tmp/test.mlir
ninja: no work to do.
module  {
  func @test_branch_join(%arg0: i1, %arg1: tensor<1x2xi32>, %arg2: tensor<1x9xi32>) {
    cond_br %arg0, ^bb1, ^bb2
  ^bb1:  // pred: ^bb0
    %0 = "test.ti.join_types"(%arg1) : (tensor<1x2xi32>) -> tensor<1x2xi32>
    %1 = relax_type %0 : tensor<1x2xi32> to tensor<*xi32>
    br ^bb3(%1 : tensor<*xi32>)
  ^bb2:  // pred: ^bb0
    %2 = "test.ti.join_types"(%arg2) : (tensor<1x9xi32>) -> tensor<1x9xi32>
    %3 = relax_type %2 : tensor<1x9xi32> to tensor<*xi32>
    br ^bb3(%3 : tensor<*xi32>)
  ^bb3(%4: tensor<*xi32>):  // 2 preds: ^bb1, ^bb2
    %5 = specialize_type %4 : tensor<*xi32> to tensor<1x?xi32>
    "test.ti.allow_input_type_specialization"(%5) : (tensor<1x?xi32>) -> ()
    "test.ti.disallow_input_type_specialization"(%4) : (tensor<*xi32>) -> ()
    return
  }
}

We correctly see the join of tensor<1x2xi32> and tensor<1x9xi32> to tensor<1x?xi32>.
The operation required is a join, not a meet.

silvas added inline comments.Oct 6 2021, 12:13 PM
mlir/lib/Transforms/TypeInference.cpp
36

Yes, for control flow confluence what you want is to join (which can only derefine).

What I'm confused about is how your updateLatticeElement function end up refining. Using your example, I'd like to understand better how %val_true is treated.

Eventually it runs:

// latticeElement has initial value of tensor<*xi32>
latticeElement = getLatticeElement(%val_true);
// then you run:
latticeElement.join(tensor<1x2xi32>)
// Now latticeElement should still be tensor<*xi32> per join semantics?
// How does latticeElement get refined to the final value of tensor<1x2xi32>?
arames added inline comments.Oct 6 2021, 1:34 PM
mlir/lib/Transforms/TypeInference.cpp
36

I think I see what you missed.

// latticeElement has initial value of tensor<*xi32>
latticeElement = getLatticeElement(%val_true);

No it does not !
The current type of a value is only used (in the analysis) for ops that do not support InferTypeOpInterface. This is the point that allows refining types.

During the analysis:
All lattice elements start in the bottom state.
Then, there are two cases:

  • an op does not support InferTypeOpInterface. In that case, we do join with the current type: join(uninitialized, current_type) -> current_type
  • an op does support InferTypesOpInterface. Then we do not use the current type. Instead join with the result of the inferReturnTypes function.

When using the results of the analysis, we verify that each inferred type is more specialized than the current type.

arames added inline comments.Oct 6 2021, 1:57 PM
mlir/lib/Transforms/TypeInference.cpp
36

To go in the implementation details, after

// latticeElement has initial value of tensor<*xi32>
latticeElement = getLatticeElement(%val_true);

,
latticeElement.isUninitialized() is true. So in

latticeElement.join(value with type tensor<1x2xi32>)

LatticeElement::join will select the tensor type for optimisticValue.

I'll read the code again.
Here is a dump with the debug info. Updates to lattices elements for the control-flow are not visible because they happen in the dataflow analysis code.

% ./bin/mlir-opt -debug-only=infer-types -infer-types ~/work/tmp/test.mlir
update lattice element for :    %1 = "test.ti.join_types"(%arg2) : (tensor<1x9xi32>) -> tensor<*xi32>
        original type:  uninitialized
        joining with:   tensor<1x9xi32>
        yielding:       tensor<1x9xi32>
update lattice element for :    %0 = "test.ti.join_types"(%arg1) : (tensor<1x2xi32>) -> tensor<*xi32>
        original type:  uninitialized
        joining with:   tensor<1x2xi32>
        yielding:       tensor<1x2xi32>
updated type from tensor<*xi32> to tensor<1x2xi32> (with 0/1 specialized uses) for %0 = "test.ti.join_types"(%arg1) : (tensor<1x2xi32>) -> tensor<1x2xi32>
updated type from tensor<*xi32> to tensor<1x9xi32> (with 0/1 specialized uses) for %2 = "test.ti.join_types"(%arg2) : (tensor<1x9xi32>) -> tensor<1x9xi32>
updated type from tensor<*xi32> to tensor<1x?xi32> (with 1/1 specialized uses) for <block argument> of type 'tensor<*xi32>' at index: 0
module  {
  func @test_branch_join(%arg0: i1, %arg1: tensor<1x2xi32>, %arg2: tensor<1x9xi32>) {
    cond_br %arg0, ^bb1, ^bb2
  ^bb1:  // pred: ^bb0
    %0 = "test.ti.join_types"(%arg1) : (tensor<1x2xi32>) -> tensor<1x2xi32>
    %1 = relax_type %0 : tensor<1x2xi32> to tensor<*xi32>
    br ^bb3(%1 : tensor<*xi32>)
  ^bb2:  // pred: ^bb0
    %2 = "test.ti.join_types"(%arg2) : (tensor<1x9xi32>) -> tensor<1x9xi32>
    %3 = relax_type %2 : tensor<1x9xi32> to tensor<*xi32>
    br ^bb3(%3 : tensor<*xi32>)
  ^bb3(%4: tensor<*xi32>):  // 2 preds: ^bb1, ^bb2
    %5 = specialize_type %4 : tensor<*xi32> to tensor<1x?xi32>
    "test.ti.allow_input_type_specialization"(%5) : (tensor<1x?xi32>) -> ()
    return
  }
}
arames updated this revision to Diff 377681.Oct 6 2021, 2:04 PM

Include changes to remove explicit support for Same* traits.
They will be handled in a separate patch by providing a default implementation of InferTypeOpInterface.

arames updated this revision to Diff 377682.Oct 6 2021, 2:08 PM
arames marked 2 inline comments as done.

NFC refactor.

arames added inline comments.Oct 6 2021, 2:09 PM
mlir/lib/Transforms/TypeInference.cpp
166

I removed the code to handle the different interfaces. It now only supports InferTypeOpInterface (or a default is not supported).
I'll add support for other interfaces in a separate patch, as you suggested, by supplying a default implementation of InferTypeOpInterface.

silvas added inline comments.Oct 6 2021, 3:27 PM
mlir/lib/Transforms/TypeInference.cpp
36

To go in the implementation details, after

// latticeElement has initial value of tensor<*xi32>
latticeElement = getLatticeElement(%val_true);
,
latticeElement.isUninitialized() is true. So in

What if the fixed-point is not reached in the first iteration? What if it takes multiple iterations to reach fixed-point? Won't isUninitialized return false and no further refinement will happen after the first iteration, so the true fixed-point won't be reached?

Or to put it another way, I view this as a pessimistic dataflow problem where you start with the current result type, and then work your way down the lattice to a more refined type.

From what you describe, the code currently starts at the bottom of the lattice, and moves up, which means that the first inferred type (which might not be the most refined possible) gets "stuck".

Oh, actually I may have answered my own question. The SCCP-like framework uses optimistic values for backedges, so any later visitations can only make the type less refined. I see.

Carry on.... sorry for the noise.

arames added inline comments.Oct 6 2021, 3:45 PM
mlir/lib/Transforms/TypeInference.cpp
36

From what you describe, the code currently starts at the bottom of the lattice, and moves up, which means that the first inferred type (which might not be the most refined possible) gets "stuck".

True.
I need to think more about cases that could hit this.

Or to put it another way, I view this as a pessimistic dataflow problem where you start with the current result type, and then work your way down the lattice to a more refined type.

Not sure how much of this is obvious already: I don't think you can walk down the lattice to a more refined type. Walking down the lattice means using the meet operation.
Let's imagine for a given value will see types t1, t2, and t3.
The inferred type we need is the most specialized type that is less specialized than either of them. It is by definition the join(t1, t2, t3).
We cannot refine the initial result type, because any meet operation could be incompatible with other observed types.

silvas added inline comments.Oct 6 2021, 4:50 PM
mlir/lib/Transforms/TypeInference.cpp
36

Not sure how much of this is obvious already: I don't think you can walk down the lattice to a more refined type. Walking down the lattice means using the meet operation.
Let's imagine for a given value will see types t1, t2, and t3.
The inferred type we need is the most specialized type that is less specialized than either of them. It is by definition the join(t1, t2, t3).
We cannot refine the initial result type, because any meet operation could be incompatible with other observed types.

Not sure I follow. It is totally possible for a result type to be statically written as tensor<*xf32> and then be refined to tensor<?x?xf32> and then be refined to tensor<2x2xf32> as we analyze the program further. These could all come from the same transfer function, but as the operand lattice values are refined, the corresponding result lattice values will get refined. This does not happen in the optimistic formulation we are using here though (since it only moves up the lattice, and uses its special handling of backedges optimistically to make it all work out).

True.
I need to think more about cases that could hit this.

Because of the optimistic backedge handling, the "stuck" value is guaranteed to be the most refined with the current framework. (assuming visitOperation only updates lattice values for results, and not for other unrelated values in the program)

arames updated this revision to Diff 382784.Oct 27 2021, 1:58 PM

Rebase on top of tree.
Slightly factor out the creation of type specialization and relaxation ops in TI.

arames added inline comments.Oct 28 2021, 10:11 AM
mlir/lib/Transforms/TypeInference.cpp
176

Working on this right now.

arames updated this revision to Diff 383421.Oct 29 2021, 10:33 AM

Introduce mechanism to customize explicit type specialization/relaxation.

(nice, just couple of fly by comments)

mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
1094 ↗(On Diff #383421)

+1 (on my wishlist,, I told Alex I'd get to this too as it makes Python side gen needlessly complicated)

1103 ↗(On Diff #383421)

Nit: Lets start with what the op does and then where introduced.

mlir/include/mlir/Interfaces/InferTypeOpInterface.td
192

And this is on the "flattened" form? E.g.., is variadics as in ODS considered

arames updated this revision to Diff 383435.Oct 29 2021, 11:01 AM
arames marked an inline comment as done.

Address nit for op descriptions.

arames updated this revision to Diff 383869.Nov 1 2021, 1:27 PM

Reworked TI to be a utility class instead of a pass, removing dependency on standard ops.

arames updated this revision to Diff 383871.Nov 1 2021, 1:34 PM

Upload the right patch.