This reverts commit 64f659bee67b5a024defeb3cd2ecf65e1ad8c0a7.
An invalid tensor.expand_shape op is generated with the commit. To repro:
$ mlir-opt -canonicalize a.mlir
func @foo(%0: tensor<1x1xf32>, %1: tensor<1x1xf32>, %2: tensor<1x1xf32>) -> tensor<1x1xf32> { %cst = arith.constant 0.000000e+00 : f32 %3 = linalg.init_tensor [8, 1] : tensor<8x1xf32> %4 = linalg.fill ins(%cst : f32) outs(%3 : tensor<8x1xf32>) -> tensor<8x1xf32> %5 = tensor.collapse_shape %0 [] : tensor<1x1xf32> into tensor<f32> %6 = tensor.insert_slice %5 into %4[0, 0] [1, 1] [1, 1] : tensor<f32> into tensor<8x1xf32> %7 = linalg.init_tensor [8, 1] : tensor<8x1xf32> %8 = linalg.fill ins(%cst : f32) outs(%7 : tensor<8x1xf32>) -> tensor<8x1xf32> %9 = tensor.collapse_shape %2 [] : tensor<1x1xf32> into tensor<f32> %10 = tensor.insert_slice %9 into %8[0, 0] [1, 1] [1, 1] : tensor<f32> into tensor<8x1xf32> %11 = tensor.collapse_shape %6 [[0, 1]] : tensor<8x1xf32> into tensor<8xf32> %12 = linalg.init_tensor [8] : tensor<8xf32> %13 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%11 : tensor<8xf32>) outs(%12 : tensor<8xf32>) { ^bb0(%arg3: f32, %arg4: f32): linalg.yield %arg3 : f32 } -> tensor<8xf32> %14 = tensor.expand_shape %13 [[0, 1, 2, 3]] : tensor<8xf32> into tensor<1x1x8x1xf32> %15 = tensor.collapse_shape %1 [] : tensor<1x1xf32> into tensor<f32> %16 = linalg.init_tensor [] : tensor<f32> %17 = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%15 : tensor<f32>) outs(%16 : tensor<f32>) { ^bb0(%arg3: f32, %arg4: f32): linalg.yield %arg3 : f32 } -> tensor<f32> %18 = tensor.expand_shape %17 [] : tensor<f32> into tensor<1x1x1x1xf32> %19 = tensor.collapse_shape %10 [[0, 1]] : tensor<8x1xf32> into tensor<8xf32> %20 = linalg.init_tensor [8] : tensor<8xf32> %21 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%19 : tensor<8xf32>) outs(%20 : tensor<8xf32>) { ^bb0(%arg3: f32, %arg4: f32): linalg.yield %arg3 : f32 } -> tensor<8xf32> %22 = tensor.expand_shape %21 [[0, 1, 2, 3]] : tensor<8xf32> into tensor<1x1x8x1xf32> %23 = linalg.mmt4d {comment = "f32*f32->f32, aarch64, matrix*vector"} ins(%14, %18 : tensor<1x1x8x1xf32>, tensor<1x1x1x1xf32>) outs(%22 : tensor<1x1x8x1xf32>) -> tensor<1x1x8x1xf32> %24 = tensor.collapse_shape %23 [[0, 1, 2, 3]] : tensor<1x1x8x1xf32> into tensor<8xf32> %25 = linalg.init_tensor [8] : tensor<8xf32> %26 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%24 : tensor<8xf32>) outs(%25 : tensor<8xf32>) { ^bb0(%arg3: f32, %arg4: f32): linalg.yield %arg3 : f32 } -> tensor<8xf32> %27 = tensor.expand_shape %26 [[0, 1]] : tensor<8xf32> into tensor<8x1xf32> %28 = tensor.extract_slice %27[0, 0] [1, 1] [1, 1] : tensor<8x1xf32> to tensor<f32> %29 = tensor.expand_shape %28 [] : tensor<f32> into tensor<1x1xf32> return %29 : tensor<1x1xf32> }