diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1134,11 +1134,81 @@ } }; +/// Rewrite a PadTensorOp on a tensor of size zero into a GenerateOp. E.g.: +/// ``` +/// %0 = linalg.pad_tensor %src low[2, 3] high[4, 5] +/// : tensor<0x0xf32> to tensor<6x8xf32> +/// ``` +/// is rewritten into: +/// ``` +/// %0 = tensor.generate : tensor<6x8xf32> +/// ``` +/// Note: Basic blocks are omitted in the above example. Dynamic indices and +/// mixed static/dynamic indices are supported. +struct PadZeroSizeTensor : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(PadTensorOp padOp, + PatternRewriter &rewriter) const override { + auto srcType = padOp.getSourceType(); + auto resType = padOp.getResultType(); + auto loc = padOp.getLoc(); + + // Pattern applies if at least one dimension has static size 0. + if (!llvm::any_of(srcType.getShape(), [](int64_t dim) { return dim == 0; })) + return failure(); + + // Helper function for converting an OpFoldResult to a Value. + auto toValue = [&](OpFoldResult ofr) -> Value { + if (auto val = ofr.dyn_cast()) + return val; + auto intVal = getConstantIntValue(ofr); + assert(intVal && "expected Value or IntegerAttr"); + return rewriter.create(loc, *intVal); + }; + + // Compute dynamic dimensions of the new GenerateOp. + SmallVector dynDims; + for (int64_t i = 0; i < resType.getRank(); ++i) { + if (resType.isDynamicDim(i)) { + auto dimSize = + rewriter.createOrFold(loc, padOp.source(), i); + auto s1 = rewriter.create(loc, dimSize, + toValue(padOp.getMixedLowPad()[i])); + auto s2 = rewriter.create(loc, s1, + toValue(padOp.getMixedHighPad()[i])); + dynDims.push_back(s2); + } + } + + // Emit GenerateOp. + auto generateOp = rewriter.create( + loc, padOp.getResultType(), dynDims); + // Copy region to new op. + BlockAndValueMapping bvm; + padOp.region().cloneInto(&generateOp.getRegion(), bvm); + // Rewrite linalg::YieldOp to tensor::YieldOp. + { + OpBuilder::InsertionGuard guard(rewriter); + auto yieldOp = dyn_cast( + generateOp.getRegion().front().getTerminator()); + assert(yieldOp && "malformed PadTensorOp: expected YieldOp terminator"); + assert(yieldOp.values().size() == 1); + rewriter.setInsertionPoint(yieldOp); + rewriter.replaceOpWithNewOp(yieldOp, + yieldOp.values()[0]); + } + // Replace PadTensorOp with new GenerateOp. + rewriter.replaceOp(padOp, generateOp.getResult()); + + return success(); + } +}; } // namespace void PadTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } /// Return the padding value of the PadTensorOp if it constant. In this context, diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -1148,3 +1148,27 @@ // CHECK-LABEL: @tensor_pad_cast // CHECK-SAME: %[[ARG0:.+]]: tensor<4x4xf32> // CHECK: return %[[ARG0]] + +// ----- + +// CHECK-LABEL: func @pad_zero_tensor_to_generate( +// CHECK-SAME: %[[ARG0:.*]]: tensor, %[[L1:.*]]: index, %[[PAD:.*]]: f32 +// CHECK-NOT: linalg.pad_tensor +// CHECK-DAG: %[[C0:.*]] = constant 0 : index +// CHECK-DAG: %[[C2:.*]] = constant 2 : index +// CHECK-DAG: %[[C5:.*]] = constant 5 : index +// CHECK: %[[DIM:.*]] = memref.dim %[[ARG0]], %[[C0]] +// CHECK: %[[S1:.*]] = addi %[[DIM]], %[[C5]] +// CHECK: %[[S2:.*]] = addi %[[L1]], %[[C2]] +// CHECK: %[[RESULT:.*]] = tensor.generate %[[S1]], %[[S2]] { +// CHECK: tensor.yield %[[PAD]] : f32 +// CHECK: } : tensor +// CHECK: return %[[RESULT]] +func @pad_zero_tensor_to_generate( + %arg0: tensor, %l1: index, %pad: f32) -> tensor { + %0 = linalg.pad_tensor %arg0 low[1, %l1, 3] high[4, 2, 1] { + ^bb0(%arg1: index, %arg2: index, %arg3: index): + linalg.yield %pad : f32 + } : tensor to tensor + return %0 : tensor +}