The original implementation does not take reassociation indices into account,
which leads to bugs.
Details
Diff Detail
- Repository
- rG LLVM Github Monorepo
Event Timeline
mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h | ||
---|---|---|
237–239 | You need to bail on non-identity layout for MemRefType atm. | |
292 | you need to bail on non-identity layout for MemRefType atm | |
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | ||
1962–1963 | I don't this this can be correct in the absence of explicit layout map handling. | |
2003–2004 | I don't this this can be correct in the absence of explicit layout map handling. |
mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h | ||
---|---|---|
259 | The complexity of the double loop logic + case disjunction looks unnecessary to me. DenseMap<unsigned, unsigned> map1; unsigned index = 0, group = 0; for (auto &reassoc : op1.getReassociationIndices()) { for (unsigned i = 0, e = reassoc.size(); i < e; ++i) map1.insert(std::pair<unsigned, unsigned>(index++, group)); ++group; } // Same for map2, this could become a separate reshape util by itself. Then you can simply: SmallVector<ReassociationIndices> composedReassociation(targetRank); for (unsigned i = 0, e = sourceRank; i < e; ++i) composedReassociation[map2[map1[i]]].push_back(i); This should work for all 4 cases (Collapse+Collapse, Collapse+Expand, Expand+Collapse, Expand+Expand). if (rankExpanding) rewriter.replaceOpWithNewOp<ExpandOpTy>( secondOp, resultType, firstOp.src(), composedReassociation); else rewriter.replaceOpWithNewOp<CollapseOpTy>( secondOp, resultType, firstOp.src(), composedReassociation); | |
mlir/test/Dialect/MemRef/canonicalize.mlir | ||
388–389 | Please add a test case with memref + layout and make sure it does not canonicalize. |
Rewrite of the expand(collapse
mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h | ||
---|---|---|
259 | I am not sure how this approach would work for the cases, when the composition is impossible. smth like 2x5xf32 -> [0, 1] collapse to 10xf32 -> [0, 1] expand to 5x2xf32. Or ?x?xf32 -> [0], [1, 2] expand to ?x?x2xf32 -> [0, 1], [2] collapse to ?x2xf32. |
mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h | ||
---|---|---|
259 | You could factor out relevants parts of the verifier as to ensure the resulting Expand/Collapse is valid and bail on invalid rewrites. |
mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h | ||
---|---|---|
259 | Please disregard my suggestion, the intuition was wrong and too simplistic, my apologies for the noise. Still, if there are ways to reuse parts of the verifier (and rewrite where necessary if affinemap is too annoying), please consider it rather than introducing new non-trivial code. Thanks! |
mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h | ||
---|---|---|
259 | I reused parts of getReassociationIndicesForReshape. |
@nicolasvasilache I pushed the commit to fix the bugs. If you have further suggestions/comments, I ll be happy to address them.
I think there are bugs in the commit. The verifier failed with the commit. To repro:
$ mlir-opt -canonicalize a.mlir
func @foo(%0: tensor<1x1xf32>, %1: tensor<1x1xf32>, %2: tensor<1x1xf32>) -> tensor<1x1xf32> { %cst = arith.constant 0.000000e+00 : f32 %3 = linalg.init_tensor [8, 1] : tensor<8x1xf32> %4 = linalg.fill ins(%cst : f32) outs(%3 : tensor<8x1xf32>) -> tensor<8x1xf32> %5 = tensor.collapse_shape %0 [] : tensor<1x1xf32> into tensor<f32> %6 = tensor.insert_slice %5 into %4[0, 0] [1, 1] [1, 1] : tensor<f32> into tensor<8x1xf32> %7 = linalg.init_tensor [8, 1] : tensor<8x1xf32> %8 = linalg.fill ins(%cst : f32) outs(%7 : tensor<8x1xf32>) -> tensor<8x1xf32> %9 = tensor.collapse_shape %2 [] : tensor<1x1xf32> into tensor<f32> %10 = tensor.insert_slice %9 into %8[0, 0] [1, 1] [1, 1] : tensor<f32> into tensor<8x1xf32> %11 = tensor.collapse_shape %6 [[0, 1]] : tensor<8x1xf32> into tensor<8xf32> %12 = linalg.init_tensor [8] : tensor<8xf32> %13 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%11 : tensor<8xf32>) outs(%12 : tensor<8xf32>) { ^bb0(%arg3: f32, %arg4: f32): linalg.yield %arg3 : f32 } -> tensor<8xf32> %14 = tensor.expand_shape %13 [[0, 1, 2, 3]] : tensor<8xf32> into tensor<1x1x8x1xf32> %15 = tensor.collapse_shape %1 [] : tensor<1x1xf32> into tensor<f32> %16 = linalg.init_tensor [] : tensor<f32> %17 = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%15 : tensor<f32>) outs(%16 : tensor<f32>) { ^bb0(%arg3: f32, %arg4: f32): linalg.yield %arg3 : f32 } -> tensor<f32> %18 = tensor.expand_shape %17 [] : tensor<f32> into tensor<1x1x1x1xf32> %19 = tensor.collapse_shape %10 [[0, 1]] : tensor<8x1xf32> into tensor<8xf32> %20 = linalg.init_tensor [8] : tensor<8xf32> %21 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%19 : tensor<8xf32>) outs(%20 : tensor<8xf32>) { ^bb0(%arg3: f32, %arg4: f32): linalg.yield %arg3 : f32 } -> tensor<8xf32> %22 = tensor.expand_shape %21 [[0, 1, 2, 3]] : tensor<8xf32> into tensor<1x1x8x1xf32> %23 = linalg.mmt4d {comment = "f32*f32->f32, aarch64, matrix*vector"} ins(%14, %18 : tensor<1x1x8x1xf32>, tensor<1x1x1x1xf32>) outs(%22 : tensor<1x1x8x1xf32>) -> tensor<1x1x8x1xf32> %24 = tensor.collapse_shape %23 [[0, 1, 2, 3]] : tensor<1x1x8x1xf32> into tensor<8xf32> %25 = linalg.init_tensor [8] : tensor<8xf32> %26 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%24 : tensor<8xf32>) outs(%25 : tensor<8xf32>) { ^bb0(%arg3: f32, %arg4: f32): linalg.yield %arg3 : f32 } -> tensor<8xf32> %27 = tensor.expand_shape %26 [[0, 1]] : tensor<8xf32> into tensor<8x1xf32> %28 = tensor.extract_slice %27[0, 0] [1, 1] [1, 1] : tensor<8x1xf32> to tensor<f32> %29 = tensor.expand_shape %28 [] : tensor<f32> into tensor<1x1xf32> return %29 : tensor<1x1xf32> }
results in
a.mlir:24:9: error: 'tensor.expand_shape' op expected rank of the collapsed type(2) to be the number of reassociation maps(0) %18 = tensor.expand_shape %17 [] : tensor<f32> into tensor<1x1x1x1xf32> ^ a.mlir:24:9: note: see current operation: %10 = "tensor.expand_shape"(%arg1) {reassociation = []} : (tensor<1x1xf32>) -> tensor<1x1x1x1xf32>
You need to bail on non-identity layout for MemRefType atm.
You can use a simple templated hasNonIdentityLayout helper specialization to bail out early.