diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -689,10 +689,6 @@ padOp.getLoc(), getIntFromAttr(ofr.get())).getResult(); }; - // Pad value must be a constant. - auto padValue = padOp.getConstantPaddingValue(); - if (!padValue) return failure(); - auto resultType = padOp.getResultType(); // Compute size of InitTensorOp. Any combination of static/dynamic is // supported. @@ -712,20 +708,20 @@ staticSizes.push_back(resultType.getDimSize(dim)); } + // Init tensor and fill it with padding. Value init = rewriter.create( padOp.getLoc(), dynSizes, staticSizes, resultType.getElementType()); - Value fill = - rewriter.create(padOp.getLoc(), init, padValue).result(); - - auto sourceType = padOp.getSourceType(); + Value fill = tryVectorizeFill(rewriter, padOp, init, dynSizes); // Try vectorizing the copy of source. - if (tryVectorizeCopy(rewriter, padOp, padValue, fill).succeeded()) + if (tryVectorizeCopy(rewriter, padOp, fill).succeeded()) return success(); // Neither source type nor PadTensorOp result type have static shape. Such - // PadTensorOps cannot be vectorized. Generate a SubTensorInsertOp instead. + // PadTensorOps cannot be vectorized. Generate a SubTensorInsertOp instead + // for copying the PadOp source. + auto sourceType = padOp.getSourceType(); // Compute size of source of PadTensorOp. SmallVector srcSizes; for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) { @@ -745,14 +741,54 @@ return success(); } + /// Vectorize the filling of `dest`. This is possible if the padOp is padding + /// with a constant value. Otherwise, generate a tensor::GenerateOp. + Value tryVectorizeFill(PatternRewriter &rewriter, PadTensorOp padOp, + Value dest, const SmallVector &dynSizes) const { + // Fill can be vectorized if padValue is a constant. (If there is enough + // static type information, the FillOp will be vectorized by another + // pattern.) + auto padValue = padOp.getConstantPaddingValue(); + if (padValue) + return rewriter.create(padOp.getLoc(), dest, padValue).result(); + + // Fill could not be vectorized: Lower to tensor::GenerateOp with region. + auto generateOp = rewriter.create( + padOp.getLoc(), padOp.getResultType(), dynSizes); + // Copy region to new op. + BlockAndValueMapping bvm; + padOp.region().cloneInto(&generateOp.getRegion(), bvm); + // Rewrite linalg::YieldOp to tensor::YieldOp. + OpBuilder::InsertionGuard guard(rewriter); + auto yieldOp = dyn_cast( + generateOp.getRegion().front().getTerminator()); + assert(yieldOp && "malformed PadTensorOp: expected YieldOp terminator"); + assert(yieldOp.values().size() == 1); + rewriter.setInsertionPoint(yieldOp); + rewriter.replaceOpWithNewOp(yieldOp, yieldOp.values()[0]); + return generateOp; + } + /// Vectorize the copying of a PadTensorOp's source. This is possible if each /// dimension size is statically know in the source type or the result type /// (or both). LogicalResult tryVectorizeCopy(PatternRewriter &rewriter, PadTensorOp padOp, - Value padValue, Value dest) const { + Value dest) const { auto sourceType = padOp.getSourceType(); auto resultType = padOp.getResultType(); + // Copy cannot be vectorized if pad value is non-constant and source shape + // is dynamic. In case of a dynamic source shape, padding must be appended + // by TransferReadOp, but TransferReadOp supports only constant padding. + auto padValue = padOp.getConstantPaddingValue(); + if (!padValue) { + if (!sourceType.hasStaticShape()) return failure(); + // Create dummy padding value. + auto elemType = sourceType.getElementType(); + padValue = rewriter.create(padOp.getLoc(), elemType, + rewriter.getZeroAttr(elemType)); + } + SmallVector vecShape; SmallVector readInBounds; SmallVector writeInBounds; diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -674,6 +674,35 @@ // ----- +// CHECK-LABEL: func @pad_tensor_non_const_pad_value +// CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32> +// CHECK-NOT: linalg.pad_tensor +// CHECK-DAG: %[[C0:.*]] = constant 0 : index +// CHECK-DAG: %[[C3:.*]] = constant 3 : index +// CHECK-DAG: %[[C4:.*]] = constant 4 : index +// CHECK: %[[FILL:.*]] = tensor.generate +// CHECK: %[[RES:.*]] = mulf +// CHECK: tensor.yield %[[RES]] : f32 +// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], %{{.*}} {in_bounds = [true, true]} : tensor<5x6xf32>, vector<5x6xf32> +// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[READ]], %[[FILL]][%[[C3]], %[[C4]]] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<12x13xf32> +// CHECK: return %[[WRITE]] +func @pad_tensor_non_const_pad_value(%arg0: tensor<5x6xf32>) -> tensor<12x13xf32> { + %c0 = constant 0 : index + %c5 = constant 5.0 : f32 + %0 = linalg.pad_tensor %arg0 low[3, 4] high[4, 3] { + ^bb0(%arg1: index, %arg2: index): + %i1 = index_cast %arg1 : index to i32 + %i2 = index_cast %arg2 : index to i32 + %f1 = sitofp %i1 : i32 to f32 + %f2 = sitofp %i2 : i32 to f32 + %m = mulf %f1, %f2 : f32 + linalg.yield %m : f32 + } : tensor<5x6xf32> to tensor<12x13xf32> + return %0 : tensor<12x13xf32> +} + +// ----- + // CHECK-DAG: #[[$M0:.*]] = affine_map<(d0, d1) -> (d0, d1, 0)> // CHECK-LABEL: func @sum_exp