diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -456,7 +456,48 @@ if (!srcPadOp) return failure(); - auto dstFillOp = insertOp.dest().getDefiningOp(); + if (insertOp.getType().getRank() != insertOp.getSourceType().getRank()) + return failure(); + + // Walk back the tensor.insert_slice chain and find the first destination + // value at the start of the chain. + Value firstDest = insertOp.dest(); + while (auto prevOp = firstDest.getDefiningOp()) { + if (prevOp.getType().getRank() != prevOp.getSourceType().getRank()) + return failure(); + + // Make sure the range of values accessed are disjoint. Without this, we + // cannot fold tensor.pad away. + bool disjoint = false; + for (int i = 0, e = prevOp.getType().getRank(); i < e; ++i) { + // If the dimension has dynamic offset/size, we cannot guarantee + // disjoint. So just skip it. + if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) || + insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) || + prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i)) + continue; + + // Get the range start and end, inclusively for both. + int64_t prevStart = prevOp.getStaticOffset(i); + int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) * + prevOp.getStaticStride(i); + int64_t nextStart = insertOp.getStaticOffset(i); + int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) * + insertOp.getStaticStride(i); + if (prevEnd < nextStart || nextEnd < prevStart) { + disjoint = true; + break; + } + } + + if (!disjoint) + break; + firstDest = prevOp.dest(); + } + + // Check whether the first destination is a fill op. For overlapped cases, + // this also cannot be true. + auto dstFillOp = firstDest.getDefiningOp(); if (!dstFillOp) return failure(); diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -642,3 +642,108 @@ %0 = tensor.insert_slice %pad into %fill[0, 1, 2] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> return %0: tensor<8x384x384xf32> } + +// ----- + +// CHECK-LABEL: func @multi_insert_pad_into_fill +// CHECK-SAME: (%[[INPUT:.+]]: tensor<7x123x124xf32>, %[[A:.+]]: tensor<8x128x128xf32>, %[[OFFSET:.+]]: index) +// CHECK: %[[FILL:.+]] = linalg.fill +// CHECK: %[[INSERT0:.+]] = tensor.insert_slice %[[A]] into %[[FILL]][%[[OFFSET]], 0, 0] [8, 128, 128] [1, 1, 1] +// CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[A]] into %[[INSERT0]][0, 128, %[[OFFSET]]] [8, 128, 128] [1, 1, 1] +// CHECK: tensor.insert_slice %[[INPUT]] into %[[INSERT1]][1, 2, 256] [7, 123, 124] [1, 1, 1] +func @multi_insert_pad_into_fill(%input: tensor<7x123x124xf32>, %a: tensor<8x128x128xf32>, %offset: index) -> tensor<8x384x384xf32> { + %f0 = arith.constant 0.0 : f32 + %c0 = arith.constant 0 : index + %pad = tensor.pad %input low[1, 2, 0] high[0, 3, 4] { + ^bb0(%arg3: index, %arg4: index, %arg5: index): + tensor.yield %f0 : f32 + } : tensor<7x123x124xf32> to tensor<8x128x128xf32> + %init = linalg.init_tensor [8, 384, 384] : tensor<8x384x384xf32> + %fill = linalg.fill(%f0, %init) : f32, tensor<8x384x384xf32> -> tensor<8x384x384xf32> + %0 = tensor.insert_slice %a into %fill[%offset, 0, 0] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> + %1 = tensor.insert_slice %a into %0 [0, 128, %offset][8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> + %2 = tensor.insert_slice %pad into %1 [0, 0, 256] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> + return %2: tensor<8x384x384xf32> +} + +// ----- + +// CHECK-LABEL: func @multi_insert_pad_into_fill_overlap +func @multi_insert_pad_into_fill_overlap(%input: tensor<7x123x124xf32>, %a: tensor<8x128x128xf32>, %offset: index) -> tensor<8x384x384xf32> { + %f0 = arith.constant 0.0 : f32 + %c0 = arith.constant 0 : index + // CHECK: tensor.pad + %pad = tensor.pad %input low[1, 2, 0] high[0, 3, 4] { + ^bb0(%arg3: index, %arg4: index, %arg5: index): + tensor.yield %f0 : f32 + } : tensor<7x123x124xf32> to tensor<8x128x128xf32> + %init = linalg.init_tensor [8, 384, 384] : tensor<8x384x384xf32> + %fill = linalg.fill(%f0, %init) : f32, tensor<8x384x384xf32> -> tensor<8x384x384xf32> + %0 = tensor.insert_slice %a into %fill[%offset, 0, 0] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> + %1 = tensor.insert_slice %a into %0 [0, 0, 129] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> + // Range overlap with %1 at dim#3 + %2 = tensor.insert_slice %pad into %1 [0, 0, 256] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> + return %2: tensor<8x384x384xf32> +} + +// ----- + +// CHECK-LABEL: func @multi_insert_pad_into_fill_overlap +func @multi_insert_pad_into_fill_overlap(%input: tensor<7x123x124xf32>, %a: tensor<8x128x128xf32>, %offset: index) -> tensor<8x384x384xf32> { + %f0 = arith.constant 0.0 : f32 + %c0 = arith.constant 0 : index + // CHECK: tensor.pad + %pad = tensor.pad %input low[1, 2, 0] high[0, 3, 4] { + ^bb0(%arg3: index, %arg4: index, %arg5: index): + tensor.yield %f0 : f32 + } : tensor<7x123x124xf32> to tensor<8x128x128xf32> + %init = linalg.init_tensor [8, 384, 384] : tensor<8x384x384xf32> + %fill = linalg.fill(%f0, %init) : f32, tensor<8x384x384xf32> -> tensor<8x384x384xf32> + %0 = tensor.insert_slice %a into %fill[0, 0, %offset] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> + %1 = tensor.insert_slice %a into %0 [0, 128, 255] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> + // Range overlap with %0 at dim#3 + %2 = tensor.insert_slice %pad into %1 [0, 0, 256] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> + return %2: tensor<8x384x384xf32> +} + +// ----- + +// CHECK-LABEL: func @multi_insert_pad_into_fill +func @multi_insert_pad_into_fill(%input: tensor<7x123x124xf32>, %a: tensor<8x128x128xf32>, %offset: index) -> tensor<8x384x384xf32> { + %f0 = arith.constant 0.0 : f32 + %c0 = arith.constant 0 : index + // CHECK-NOT: tensor.pad + %pad = tensor.pad %input low[1, 2, 0] high[0, 3, 4] { + ^bb0(%arg3: index, %arg4: index, %arg5: index): + tensor.yield %f0 : f32 + } : tensor<7x123x124xf32> to tensor<8x128x128xf32> + %init = linalg.init_tensor [8, 384, 384] : tensor<8x384x384xf32> + %fill = linalg.fill(%f0, %init) : f32, tensor<8x384x384xf32> -> tensor<8x384x384xf32> + // Overlap btween %0 and %1 is fine but not with %2 is fine. + // CHECK-COUNT-3: tensor.insert_slice + %0 = tensor.insert_slice %a into %fill[0, 0, %offset] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> + %1 = tensor.insert_slice %a into %0 [0, 1, %offset] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> + %2 = tensor.insert_slice %pad into %1 [0, 256, 256] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> + return %2: tensor<8x384x384xf32> +} + +// ----- + +// CHECK-LABEL: func @multi_insert_pad_into_fill_mismatch +func @multi_insert_pad_into_fill_mismatch(%input: tensor<7x123x124xf32>, %a: tensor<8x128x128xf32>, %offset: index) -> tensor<8x384x384xf32> { + %f0 = arith.constant 0.0 : f32 + %f1 = arith.constant 1.0 : f32 + %c0 = arith.constant 0 : index + // CHECK: tensor.pad + %pad = tensor.pad %input low[1, 2, 0] high[0, 3, 4] { + ^bb0(%arg3: index, %arg4: index, %arg5: index): + tensor.yield %f0 : f32 + } : tensor<7x123x124xf32> to tensor<8x128x128xf32> + %init = linalg.init_tensor [8, 384, 384] : tensor<8x384x384xf32> + // Different filling value than padding value. + %fill = linalg.fill(%f1, %init) : f32, tensor<8x384x384xf32> -> tensor<8x384x384xf32> + %0 = tensor.insert_slice %a into %fill[%offset, 0, 0] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> + %1 = tensor.insert_slice %a into %0 [0, 128, %offset][8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> + %2 = tensor.insert_slice %pad into %1 [0, 0, 256] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> + return %2: tensor<8x384x384xf32> +}