diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -116,19 +116,26 @@ currOpOperand = linalgOp.getDpsInitOperand(result.getResultNumber()); } - // Fail if `currOpOperand` is not defined by an ExtractSliceOp. - auto sliceOp = currOpOperand->get().getDefiningOp(); - if (!sliceOp) + // Fail if `currOpOperand` is not defined by an ExtractSliceOp or EmptyOp. + llvm::SmallBitVector droppedDims; + SmallVector mixedSizes; + if (auto sliceOp = + currOpOperand->get().getDefiningOp()) { + // Compute the dropped dimensions if `sliceOp` is ranke-reducing. + droppedDims = sliceOp.getDroppedDims(); + mixedSizes = sliceOp.getMixedSizes(); + } else if (auto emptyOp = + currOpOperand->get().getDefiningOp()) { + mixedSizes = emptyOp.getMixedSizes(); + droppedDims.resize(mixedSizes.size()); + } else { return failure(); + } - // Compute the dropped dimensions if `sliceOp` is ranke-reducing. - llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims(); - OffsetSizeAndStrideOpInterface shapedOp = sliceOp; - - // Upper bound the `sliceOp` sizes to obtain a static bounding box. + // Upper bound the sizes to obtain a static bounding box. SmallVector paddedShape(shape.begin(), shape.end()); int64_t shapeIdx = 0; - for (const auto &en : enumerate(shapedOp.getMixedSizes())) { + for (const auto &en : enumerate(mixedSizes)) { // Skip dropped dimensions. if (droppedDims.test(en.index())) continue; diff --git a/mlir/test/Dialect/Linalg/transform-op-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-pad.mlir --- a/mlir/test/Dialect/Linalg/transform-op-pad.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-pad.mlir @@ -39,6 +39,45 @@ // ----- +#map = affine_map<()[s0] -> (-s0 + 12, 7)> + +// CHECK-LABEL: @static_sizes_output_divisible_on_empty_op +func.func @static_sizes_output_divisible_on_empty_op(%arg0: tensor<24x12xf32>, + %arg1: tensor<12x25xf32>, + %arg2: tensor<24x25xf32>, + %iv0 : index, %iv1 : index, %iv2 : index) -> tensor<24x25xf32> { + %0 = affine.min #map()[%iv2] + + // CHECK: %[[T0:.*]] = tensor.empty + // CHECK: %[[T1:.*]] = tensor.empty + // CHECK: %[[T2:.*]] = tensor.empty + %1 = tensor.empty(%0) : tensor<4x?xf32> + %2 = tensor.empty(%0) : tensor + %3 = tensor.empty() : tensor<4x5xf32> + + // CHECK-DAG: %[[CST:.*]] = arith.constant 0. + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + + // CHECK: %[[T3:.*]] = tensor.pad %[[T0]] nofold + // CHECK: tensor.yield %[[CST]] + // CHECK: %[[T4:.*]] = tensor.pad %[[T1]] nofold + + // CHECK: %[[T5:.*]] = linalg.matmul + // CHECK-SAME: ins(%[[T3]], %[[T4]] : tensor<4x7xf32>, tensor<7x5xf32>) + // CHECK-SAME: outs(%[[T2]] : tensor<4x5xf32>) + %4 = linalg.matmul ins(%1, %2 : tensor<4x?xf32>, tensor) outs(%3 : tensor<4x5xf32>) -> tensor<4x5xf32> + %5 = tensor.insert_slice %4 into %arg2[%iv0, %iv1] [4, 5] [1, 1] : tensor<4x5xf32> into tensor<24x25xf32> + func.return %5 : tensor<24x25xf32> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 + %1 = transform.structured.pad %0 {padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32], padding_dimensions=[0, 1, 2], pack_paddings=[1, 1, 0]} +} + +// ----- + func.func @pad(%arg0: tensor<24x12xf32>, %arg1: tensor<12x25xf32>, %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> {