diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -1275,6 +1275,11 @@ %4 = vector.transfer_read %arg1[%c3, %c3], %vf0 {permutation_map = (d0, d1)->(d0, d1)} : tensor>, vector<1x1x4x3xf32> + + // Special encoding for 0-d transfer with 0-d tensor/memref, vector shape + // {1} and permutation_map () -> (0). + %0 = vector.transfer_read %arg0[], %f0 {permutation_map = affine_map<()->(0)>} : + tensor, vector<1xf32> ``` }]; @@ -1402,6 +1407,11 @@ %5 = vector.transfer_write %4, %arg1[%c3, %c3] {permutation_map = (d0, d1)->(d0, d1)} : vector<1x1x4x3xf32>, tensor> + + // Special encoding for 0-d transfer with 0-d tensor/memref, vector shape + // {1} and permutation_map () -> (0). + %1 = vector.transfer_write %0, %arg0[] {permutation_map = affine_map<()->(0)>} : + vector<1xf32>, tensor ``` }]; diff --git a/mlir/include/mlir/Interfaces/VectorInterfaces.td b/mlir/include/mlir/Interfaces/VectorInterfaces.td --- a/mlir/include/mlir/Interfaces/VectorInterfaces.td +++ b/mlir/include/mlir/Interfaces/VectorInterfaces.td @@ -114,6 +114,29 @@ /*methodBody=*/"return $_op.permutation_map();" /*defaultImplementation=*/ >, + InterfaceMethod< + /*desc=*/[{ + Returns true if op involves a 0-d tensor/memref and a vector + of shape {1}. This is temporary until we have 0-d vectors. + // TODO: turn this into 0-d vectors + empty permutation_map. + }], + /*retTy=*/"bool", + /*methodName=*/"isZeroD", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + if (getShapedType().getRank() > 0) + return false; + if (getVectorType().getShape() != ArrayRef{1}) + return false; + AffineMap map = AffineMap::get( + /*numDims=*/0, /*numSymbols=*/0, + getAffineConstantExpr(0, $_op->getContext())); + if ($_op.permutation_map() != map) + return false; + return true; + }] + >, InterfaceMethod< /*desc=*/[{ Returns true if the specified dimension is a broadcast. }], /*retTy=*/"bool", @@ -134,6 +157,9 @@ /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ + // 0-d transfers are not considered broadcasts but they need to be + // represented with a vector<1xt> until we have 0-d vectors. + if ($_op.isZeroD()) return false; for (unsigned i = 0; i < $_op.permutation_map().getNumResults(); ++i) { if ($_op.isBroadcastDim(i)) return true; diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -2292,11 +2292,14 @@ return success(); } -static LogicalResult verifyTransferOp(Operation *op, ShapedType shapedType, - VectorType vectorType, - VectorType maskType, - AffineMap permutationMap, - ArrayAttr inBounds) { +static LogicalResult +verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, + VectorType vectorType, VectorType maskType, + AffineMap permutationMap, ArrayAttr inBounds) { + if (shapedType.getRank() == 0 && !op.isZeroD()) + return op->emitOpError("0-d transfer requires vector<1xt> shape and () -> " + "(0) permutation_map"); + if (op->hasAttr("masked")) { return op->emitOpError("masked attribute has been removed. " "Use in_bounds instead."); @@ -2358,7 +2361,8 @@ if (permutationMap.getNumSymbols() != 0) return op->emitOpError("requires permutation_map without symbols"); - if (permutationMap.getNumInputs() != shapedType.getRank()) + // TODO: implement 0-d vector corner cases. + if (!op.isZeroD() && permutationMap.getNumInputs() != shapedType.getRank()) return op->emitOpError("requires a permutation_map with input dims of the " "same rank as the source type"); @@ -2534,9 +2538,10 @@ if (static_cast(op.indices().size()) != shapedType.getRank()) return op.emitOpError("requires ") << shapedType.getRank() << " indices"; - if (failed(verifyTransferOp(op.getOperation(), shapedType, vectorType, - maskType, permutationMap, - op.in_bounds() ? *op.in_bounds() : ArrayAttr()))) + if (failed( + verifyTransferOp(cast(op.getOperation()), + shapedType, vectorType, maskType, permutationMap, + op.in_bounds() ? *op.in_bounds() : ArrayAttr()))) return failure(); if (auto sourceVectorElementType = sourceElementType.dyn_cast()) { @@ -2609,6 +2614,9 @@ template static LogicalResult foldTransferInBoundsAttribute(TransferOp op) { + // TODO: Be less conservative once we have 0-d vectors. + if (op.isZeroD()) + return failure(); AffineMap permutationMap = op.permutation_map(); bool changed = false; SmallVector newInBounds; @@ -2885,9 +2893,10 @@ if (op.hasBroadcastDim()) return op.emitOpError("should not have broadcast dimensions"); - if (failed(verifyTransferOp(op.getOperation(), shapedType, vectorType, - maskType, permutationMap, - op.in_bounds() ? *op.in_bounds() : ArrayAttr()))) + if (failed( + verifyTransferOp(cast(op.getOperation()), + shapedType, vectorType, maskType, permutationMap, + op.in_bounds() ? *op.in_bounds() : ArrayAttr()))) return failure(); return verifyPermutationMap(permutationMap, diff --git a/mlir/lib/Dialect/Vector/VectorUtils.cpp b/mlir/lib/Dialect/Vector/VectorUtils.cpp --- a/mlir/lib/Dialect/Vector/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/VectorUtils.cpp @@ -239,6 +239,13 @@ shapedType.getElementType().dyn_cast(); if (elementVectorType) elementVectorRank += elementVectorType.getRank(); + // 0-d transfers are to/from tensor/memref and vector<1xt>. + // TODO: replace once we have 0-d vectors. + if (shapedType.getRank() == 0 && + vectorType.getShape() == ArrayRef{1}) + return AffineMap::get( + /*numDims=*/0, /*numSymbols=*/0, + getAffineConstantExpr(0, shapedType.getContext())); return AffineMap::getMinorIdentityMap( shapedType.getRank(), vectorType.getRank() - elementVectorRank, shapedType.getContext()); diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1394,3 +1394,16 @@ // expected-error@+1 {{'vector.insert_map' op expected number of ids must match the number of dimensions distributed}} %0 = vector.insert_map %v, %v1[%id] : vector<2x1xf32> into vector<4x32xf32> } + +// ----- + +func @vector_transfer_ops_0d(%arg0: tensor) + -> tensor { + %f0 = constant 0.0 : f32 + // expected-error@+1 {{0-d transfer requires vector<1xt> shape and () -> (0) permutation_map}} + %0 = vector.transfer_read %arg0[], %f0 {permutation_map = affine_map<(d0)->(d0)>} : + tensor, vector<1xf32> + %1 = vector.transfer_write %0, %arg0[] {permutation_map = affine_map<()->(0)>} : + vector<1xf32>, tensor + return %1: tensor +} diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -1,5 +1,20 @@ // RUN: mlir-opt %s | mlir-opt | FileCheck %s +// CHECK-LABEL: func @vector_transfer_ops_0d( +func @vector_transfer_ops_0d(%arg0: tensor, %arg1: memref) + -> tensor { + %f0 = constant 0.0 : f32 + %0 = vector.transfer_read %arg0[], %f0 {permutation_map = affine_map<()->(0)>} : + tensor, vector<1xf32> + %1 = vector.transfer_write %0, %arg0[] {permutation_map = affine_map<()->(0)>} : + vector<1xf32>, tensor + %2 = vector.transfer_read %arg1[], %f0 {permutation_map = affine_map<()->(0)>} : + memref, vector<1xf32> + vector.transfer_write %2, %arg1[] {permutation_map = affine_map<()->(0)>} : + vector<1xf32>, memref + return %1: tensor +} + // CHECK-LABEL: func @vector_transfer_ops( func @vector_transfer_ops(%arg0: memref, %arg1 : memref>,