This is an archive of the discontinued LLVM Phabricator instance.

[mlir] Rewrite canonicalization of collapse(expand) and expand(collapse) .
ClosedPublic

Authored by pifon2a on Mar 29 2022, 10:01 AM.

Details

Summary

The original implementation does not take reassociation indices into account,
which leads to bugs.

Diff Detail

Event Timeline

pifon2a created this revision.Mar 29 2022, 10:01 AM
Herald added a project: Restricted Project. · View Herald TranscriptMar 29 2022, 10:01 AM
pifon2a requested review of this revision.Mar 29 2022, 10:01 AM
Herald added a project: Restricted Project. · View Herald TranscriptMar 29 2022, 10:01 AM
pifon2a retitled this revision from [mlir] Rewrite canonicalization of expand_shape(collapse_shape). to [mlir] Rewrite canonicalization of collapse_shape(expand_shape)..Mar 29 2022, 10:03 AM
pifon2a edited the summary of this revision. (Show Details)
mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
220

You need to bail on non-identity layout for MemRefType atm.
You can use a simple templated hasNonIdentityLayout helper specialization to bail out early.

275

you need to bail on non-identity layout for MemRefType atm

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
1711

I don't this this can be correct in the absence of explicit layout map handling.

1752

I don't this this can be correct in the absence of explicit layout map handling.

nicolasvasilache requested changes to this revision.Mar 30 2022, 12:44 AM
nicolasvasilache added inline comments.
mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
242

The complexity of the double loop logic + case disjunction looks unnecessary to me.
You could just use two maps.

DenseMap<unsigned, unsigned> map1;
unsigned index = 0, group = 0;
for (auto &reassoc : op1.getReassociationIndices()) {
  for (unsigned i = 0, e = reassoc.size(); i < e; ++i)
    map1.insert(std::pair<unsigned, unsigned>(index++, group));
  ++group;
}
// Same for map2, this could become a separate reshape util by itself.

Then you can simply:

SmallVector<ReassociationIndices> composedReassociation(targetRank);
for (unsigned i = 0, e = sourceRank; i < e; ++i)
  composedReassociation[map2[map1[i]]].push_back(i);

This should work for all 4 cases (Collapse+Collapse, Collapse+Expand, Expand+Collapse, Expand+Expand).
Then you can just:

if (rankExpanding)
    rewriter.replaceOpWithNewOp<ExpandOpTy>(
        secondOp, resultType, firstOp.src(), composedReassociation);
else
    rewriter.replaceOpWithNewOp<CollapseOpTy>(
        secondOp, resultType, firstOp.src(), composedReassociation);
mlir/test/Dialect/MemRef/canonicalize.mlir
371

Please add a test case with memref + layout and make sure it does not canonicalize.
Once they land, we can reuse layout computation utils for expand/collapse and expand (haha) the behavior.

This revision now requires changes to proceed.Mar 30 2022, 12:44 AM
pifon2a updated this revision to Diff 419144.Mar 30 2022, 8:09 AM
pifon2a marked 5 inline comments as done.

Address the comments.

Rewrite of the expand(collapse

mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
242

I am not sure how this approach would work for the cases, when the composition is impossible.

smth like 2x5xf32 -> [0, 1] collapse to 10xf32 -> [0, 1] expand to 5x2xf32.

Or ?x?xf32 -> [0], [1, 2] expand to ?x?x2xf32 -> [0, 1], [2] collapse to ?x2xf32.

mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
242

You could factor out relevants parts of the verifier as to ensure the resulting Expand/Collapse is valid and bail on invalid rewrites.
We do similar type of probing in other places already (at least around cast ops areCastCompatible and some others I can't remember offhand).

mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
242

Please disregard my suggestion, the intuition was wrong and too simplistic, my apologies for the noise.

Still, if there are ways to reuse parts of the verifier (and rewrite where necessary if affinemap is too annoying), please consider it rather than introducing new non-trivial code.

Thanks!

pifon2a updated this revision to Diff 419757.Apr 1 2022, 7:53 AM

Update the pattern for expand(collapse).

pifon2a retitled this revision from [mlir] Rewrite canonicalization of collapse_shape(expand_shape). to [mlir] Rewrite canonicalization of collapse(expand) and expand(collapse) ..Apr 1 2022, 7:54 AM
pifon2a edited the summary of this revision. (Show Details)
pifon2a added inline comments.
mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
242

I reused parts of getReassociationIndicesForReshape.

pifon2a updated this revision to Diff 419767.Apr 1 2022, 8:16 AM
pifon2a edited the summary of this revision. (Show Details)

Update the strided example.

pifon2a updated this revision to Diff 419768.Apr 1 2022, 8:24 AM

Fix compiler warnings.

bkramer accepted this revision.Apr 4 2022, 11:10 AM

Looks good, let's get this rolling to unblock stuff.

This revision was not accepted when it landed; it landed in state Needs Review.Apr 5 2022, 1:10 AM
This revision was automatically updated to reflect the committed changes.

@nicolasvasilache I pushed the commit to fix the bugs. If you have further suggestions/comments, I ll be happy to address them.

I think there are bugs in the commit. The verifier failed with the commit. To repro:

$ mlir-opt -canonicalize a.mlir

func @foo(%0: tensor<1x1xf32>, %1: tensor<1x1xf32>, %2: tensor<1x1xf32>) -> tensor<1x1xf32> {
  %cst = arith.constant 0.000000e+00 : f32
  %3 = linalg.init_tensor [8, 1] : tensor<8x1xf32>
  %4 = linalg.fill ins(%cst : f32) outs(%3 : tensor<8x1xf32>) -> tensor<8x1xf32>
  %5 = tensor.collapse_shape %0 [] : tensor<1x1xf32> into tensor<f32>
  %6 = tensor.insert_slice %5 into %4[0, 0] [1, 1] [1, 1] : tensor<f32> into tensor<8x1xf32>
  %7 = linalg.init_tensor [8, 1] : tensor<8x1xf32>
  %8 = linalg.fill ins(%cst : f32) outs(%7 : tensor<8x1xf32>) -> tensor<8x1xf32>
  %9 = tensor.collapse_shape %2 [] : tensor<1x1xf32> into tensor<f32>
  %10 = tensor.insert_slice %9 into %8[0, 0] [1, 1] [1, 1] : tensor<f32> into tensor<8x1xf32>
  %11 = tensor.collapse_shape %6 [[0, 1]] : tensor<8x1xf32> into tensor<8xf32>
  %12 = linalg.init_tensor [8] : tensor<8xf32>
  %13 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%11 : tensor<8xf32>) outs(%12 : tensor<8xf32>) {
  ^bb0(%arg3: f32, %arg4: f32):
    linalg.yield %arg3 : f32
  } -> tensor<8xf32>
  %14 = tensor.expand_shape %13 [[0, 1, 2, 3]] : tensor<8xf32> into tensor<1x1x8x1xf32>
  %15 = tensor.collapse_shape %1 [] : tensor<1x1xf32> into tensor<f32>
  %16 = linalg.init_tensor [] : tensor<f32>
  %17 = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%15 : tensor<f32>) outs(%16 : tensor<f32>) {
  ^bb0(%arg3: f32, %arg4: f32):
    linalg.yield %arg3 : f32
  } -> tensor<f32>
  %18 = tensor.expand_shape %17 [] : tensor<f32> into tensor<1x1x1x1xf32>
  %19 = tensor.collapse_shape %10 [[0, 1]] : tensor<8x1xf32> into tensor<8xf32>
  %20 = linalg.init_tensor [8] : tensor<8xf32>
  %21 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%19 : tensor<8xf32>) outs(%20 : tensor<8xf32>) {
  ^bb0(%arg3: f32, %arg4: f32):
    linalg.yield %arg3 : f32
  } -> tensor<8xf32>
  %22 = tensor.expand_shape %21 [[0, 1, 2, 3]] : tensor<8xf32> into tensor<1x1x8x1xf32>
  %23 = linalg.mmt4d {comment = "f32*f32->f32, aarch64, matrix*vector"} ins(%14, %18 : tensor<1x1x8x1xf32>, tensor<1x1x1x1xf32>) outs(%22 : tensor<1x1x8x1xf32>) -> tensor<1x1x8x1xf32>
  %24 = tensor.collapse_shape %23 [[0, 1, 2, 3]] : tensor<1x1x8x1xf32> into tensor<8xf32>
  %25 = linalg.init_tensor [8] : tensor<8xf32>
  %26 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%24 : tensor<8xf32>) outs(%25 : tensor<8xf32>) {
  ^bb0(%arg3: f32, %arg4: f32):
    linalg.yield %arg3 : f32
  } -> tensor<8xf32>
  %27 = tensor.expand_shape %26 [[0, 1]] : tensor<8xf32> into tensor<8x1xf32>
  %28 = tensor.extract_slice %27[0, 0] [1, 1] [1, 1] : tensor<8x1xf32> to tensor<f32>
  %29 = tensor.expand_shape %28 [] : tensor<f32> into tensor<1x1xf32>
  return %29 : tensor<1x1xf32>
}

results in

a.mlir:24:9: error: 'tensor.expand_shape' op expected rank of the collapsed type(2) to be the number of reassociation maps(0)
  %18 = tensor.expand_shape %17 [] : tensor<f32> into tensor<1x1x1x1xf32>                                                                                                           ^
a.mlir:24:9: note: see current operation: %10 = "tensor.expand_shape"(%arg1) {reassociation = []} : (tensor<1x1xf32>) -> tensor<1x1x1x1xf32>