This is an archive of the discontinued LLVM Phabricator instance.

[MLIR][LINALG] Add canonicalization pattern in `linalg.generic` op for static shape inference.
ClosedPublic

Authored by gprateek93 on Feb 3 2022, 11:19 AM.

Details

Summary

This commit adds canonicalization pattern in linalg.generic op
for static shape inference. If any of the inputs or outputs have
static shape or is casted from a tensor of static shape, then
shapes of all the inputs and outputs can be inferred by using the
affine map of the static shape input/output.

Signed-Off-By: Prateek Gupta <prateek@nod-labs.com>

Diff Detail

Event Timeline

gprateek93 created this revision.Feb 3 2022, 11:19 AM
gprateek93 requested review of this revision.Feb 3 2022, 11:19 AM
mravishankar requested changes to this revision.Feb 3 2022, 8:06 PM

Thanks for helping push on this. I don think we can simplify the implementation here by a lot.

@nicolasvasilache might have better ways of doing this for arbitrary affine exprs in indexing maps. In general you need to solve a set of linear equations to compute static dimensions. There might be a way to formulate that in affine map "arithmetic". That seems very heavy wieght (and I havent worked out how to express than using affine map computations). I am thinking of solving a subset of a common case when affine maps are all projected permutations.

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
861

General note. I prefer not to use lambdas this way. Its much more readable to have a static method outside of the function.

865

Do you need to pass a copy of SmallVector<OpOperand *> here? Can you use an ArrayRef?

867

If you are going to use a lambda, I would put the loop outside of the lambda and simplify the signature of the lambda.

873

I'd refactor this a bit. First get the shape value

  • If operand is result of a tensor.cast get the shape from source of the tensor.cast. If not use the shape of the operand.
  • Then iterate over the shape you get in step 1 and check if individual dimensions are static or dynamic. If static, then add it to sourceMap.
878

It is not necessary that the result be an AffineDimExpr. You need to check that the it is and then bail if it is not.

888

I think the else part can be dropped after the above suggestion.

915

Same comment here as above. Move loop out of the lambda.

This revision now requires changes to proceed.Feb 3 2022, 8:06 PM
gprateek93 updated this revision to Diff 405934.Feb 4 2022, 5:35 AM

Refactored the code.

gprateek93 marked 7 inline comments as done.Feb 4 2022, 5:37 AM
gprateek93 updated this revision to Diff 405974.Feb 4 2022, 8:07 AM

Added check for the operands to be RankedTensorType.

mravishankar requested changes to this revision.Feb 7 2022, 10:59 AM

Thanks for the updates. I have one more round of comments. I think after that this should be ready to land.

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
853

The startIdx part is confusing to me in the way this is written. There seems to be some handshake between source and operands here. Maybe this could work,

static void populateMap(GenericOp genericOp, ArrayRef<OpOperand *> operands,
                        ValueRange source,
                        llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
  for (auto it : llvm::zip(operands, source)) {
    OpOperand *opOperand = std::get<0>(it);
    Value src = std::get<1>(it);
    ...
  }
  ...
}

In any case, looking at the uses, do you even need to use source. Looks like operands is all you need.

856

Could you add an assert here to check sourceMap.isProjectedPermutation() ? I am assuming the caller verifies that the map is a projected permutation and doesnt call this function if that is the case.

869

nit: LLVM style is to not have { ..} for 1-line scopes.

876

Please do not use -1. Use ShapedType::isDynamicDim instead.

893

Same as above, you probably dont need source and startIdx. Also please move affineExprToSize "before" the parameter that is the return by reference (resultTypeVector I guess). I think there is a simpler way to do this. Ill leave some comments below where this function is called.

915

Please dont use -1 explicitly. Use ShapedType::isDynamicDim instead.

956

You can replace lines 1007 - 1024 with

if (!genericOp.hasTensorSemantics())
  return failure();
957

You also need to check that all maps of the genericOp are projected permutations. Something like

if (llvm::any_of(genericOp.getIndexingMaps(), [](AffineMap map) { return !map.isProjectedPermutation(); }))
  return failure();
981

I think both of these calls can be combined into a single call to populateMap.

991

I think this split of updateOperands is unnecessary and confusing.

For example for the inputs newResultTypes is unnecessary. We can have separation of concerns and do it in two steps.

SmallVector<Value> newOperands;
SmallVector<Type> resultTypes;
bool noChangeNeeded = true;
newOperands.reserve(genericOp.getNumInputsAndOutputs());
resultTypes.reserve(genericOp.getNumOutputs());
for (OpOperand *operand : genericOp.getInputAndOutputOpOperands()) {
  // Check for type of operand changing and insert casts if needed
  noChangeNeeded = ... // set to true if a change is needed.
  newOperands.push_back(newOperand);
  if (genericOp->isOutputTensor(arg)) {
    resultType.push_back(newOperand.getType());
  }
}
if (noChangeNeeded) return failure();
auto *newOp = ... // Create new generic op.
SmallVector<Value> replacements;
replacements.reserve(newOp->getNumResults());
for (auto it : llvm::zip(genericOp->getResults(), newOp->getResults())) {
   if (std::get<0>(it).getType() != std::get<1>(it).getType()) {
     // Insert cast.
     replacements.push_back(castResult);
   } else {
     replacements.push_back(std::get<1>(it));
   }
}
rewriter.replaceOp(genericOp, replacements);
1003
This revision now requires changes to proceed.Feb 7 2022, 10:59 AM

Refactored the code.

gprateek93 marked 10 inline comments as done.Feb 10 2022, 8:07 AM
mravishankar requested changes to this revision.Feb 10 2022, 12:52 PM

Sorry, one more use of -1 and also a if-then-else style nit. Otherwise looks good!

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
924

Nit: I think LLVM style is to do

for (...) {
  if (...) {
     ...
     continue;
  }

  // previously else-if part.
  if (...) {
    ....
    continue;
  }

  // previously else part
  ...
}
934

Please dont use the -1 directly. Used ShapedType::isDynamicSize or equivalent. In any case, I dont think you need this check here.

This revision now requires changes to proceed.Feb 10 2022, 12:52 PM

Refactored the code and updated one of the reshape-fusion test case.

gprateek93 marked 2 inline comments as done.Feb 11 2022, 1:45 AM
mravishankar accepted this revision.Feb 12 2022, 8:10 PM
This revision is now accepted and ready to land.Feb 12 2022, 8:10 PM
bondhugula requested changes to this revision.Feb 12 2022, 8:37 PM
bondhugula added a subscriber: bondhugula.
bondhugula added inline comments.
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
909–962

This whole logic is too long (about 100 lines), with too much nesting, appears messy with generic boolean variables like changeNeeded. A couple of suggestions:

  1. document changeNeeded above,
  2. refactor it (for eg. the part iterating over operands) into an appropriately named method.
934

Please use is_contained instead of comparing with end().

mlir/test/Dialect/Linalg/canonicalize.mlir
655

Nit: Add a comment on what these are testing -- the function names themselves aren't descriptive enough.

This revision now requires changes to proceed.Feb 12 2022, 8:37 PM

Refactored code.

gprateek93 marked 3 inline comments as done.Feb 16 2022, 11:09 PM
gprateek93 added inline comments.
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
934

I am not sure how is_contained can be used with map iterators. Can you please elaborate? Thanks!

gprateek93 marked an inline comment as done.Feb 16 2022, 11:09 PM
gprateek93 marked an inline comment as not done.
bondhugula added inline comments.
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
934

I see - find() is fine.

This revision is now accepted and ready to land.Feb 20 2022, 7:33 PM

Rebasing into main.

This revision was landed with ongoing or failed builds.Feb 20 2022, 11:51 PM
This revision was automatically updated to reflect the committed changes.