This is an archive of the discontinued LLVM Phabricator instance.

[mlir][Linalg] Allow specifiying zero-rank shaped type operands to linalg.generic ops.
ClosedPublic

Authored by mravishankar on Feb 14 2020, 12:56 PM.

Details

Summary

Fixing a bug where using a zero-rank shaped type operand to
linalg.generic ops hit an unrelated assert. This also meant that
lowering the operation to loops was not supported. Adding roundtrip
tests and lowering to loops test for zero-rank shaped type operand
with fixes to make the test pass.

Diff Detail

Event Timeline

mravishankar created this revision.Feb 14 2020, 12:56 PM

Removing errant ;

nicolasvasilache accepted this revision.Feb 14 2020, 1:07 PM

Corner cases are the worst :)
Thanks a lot Mahesh, very cool!

rriddle accepted this revision.Feb 14 2020, 1:19 PM
rriddle added inline comments.
mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
242

nit: Wrap this in {} to match the else.

mlir/lib/IR/AffineMap.cpp
354

nit: Can you wrap the predicate in ()?

mlir/test/Dialect/Linalg/loops.mlir
392

nit: Add a newline here.

This revision is now accepted and ready to land.Feb 14 2020, 1:19 PM
mravishankar marked 3 inline comments as done.

Addressing comments and fixing failing test

Change to test to check for scalar load

This revision was automatically updated to reflect the committed changes.

Just few formatting nits while going through the examples.

mlir/test/Dialect/Linalg/loops.mlir
365

Add a space after comma

376

Remove the space right after %arg0.
And there are two spaces after "%arg1:", remove one.

379

s/%b :/ %b:

mlir/test/Dialect/Linalg/roundtrip.mlir
350

Add a space after comma.

361

Remove the space between "%arg0" and ":".

361

There are two spaces after "->", remove one.

hanchung added inline comments.Feb 18 2020, 6:17 PM
mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
346–351

This part should also be updated with

  Value input = indexedGenericOp.getInput(i);
  if (!input.getType().cast<ShapedType>().getRank()) {
    indexedValues[nLoops + i] = std_load(input);
  } else {
    ValueHandleArray indexing(makeCanonicalAffineApplies(
        b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs));
    indexedValues[nLoops + i] = std_load(input, indexing);
  }
}