diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -1203,9 +1203,11 @@ provided to specify a fallback value in the case of out-of-bounds accesses and/or masking. - An optional SSA value `mask` of the same shape as the vector type may be - specified to mask out elements. Such elements will be replaces with - `padding`. Elements whose corresponding mask element is `0` are masked out. + An optional SSA value `mask` may be specified to mask out elements read from + the MemRef/Tensor. The `mask` type is an `i1` vector with a shape that + matches how elements are read from the MemRef/Tensor, *before* any + permutation or broadcasting. Elements whose corresponding mask element is + `0` are masked out and replaced with `padding`. An optional boolean array attribute `in_bounds` specifies for every vector dimension if the transfer is guaranteed to be within the source bounds. @@ -1415,6 +1417,12 @@ The size of the slice is specified by the size of the vector. + An optional SSA value `mask` may be specified to mask out elements written + to the MemRef/Tensor. The `mask` type is an `i1` vector with a shape that + matches how elements are written into the MemRef/Tensor, *after* applying + any permutation. Elements whose corresponding mask element is `0` are + masked out. + An optional SSA value `mask` of the same shape as the vector type may be specified to mask out elements. Elements whose corresponding mask element is `0` are masked out. diff --git a/mlir/include/mlir/Interfaces/VectorInterfaces.h b/mlir/include/mlir/Interfaces/VectorInterfaces.h --- a/mlir/include/mlir/Interfaces/VectorInterfaces.h +++ b/mlir/include/mlir/Interfaces/VectorInterfaces.h @@ -17,18 +17,6 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" -namespace mlir { -namespace vector { -namespace detail { - -/// Given the vector type and the permutation map of a vector transfer op, -/// compute the expected mask type. -VectorType transferMaskType(VectorType vecType, AffineMap map); - -} // namespace detail -} // namespace vector -} // namespace mlir - /// Include the generated interface declarations. #include "mlir/Interfaces/VectorInterfaces.h.inc" diff --git a/mlir/include/mlir/Interfaces/VectorInterfaces.td b/mlir/include/mlir/Interfaces/VectorInterfaces.td --- a/mlir/include/mlir/Interfaces/VectorInterfaces.td +++ b/mlir/include/mlir/Interfaces/VectorInterfaces.td @@ -169,16 +169,25 @@ }] >, InterfaceMethod< - /*desc=*/"Return the mask type if the op has a mask.", + /*desc=*/"Return the mask operand if the op has a mask. Otherwise, " + "return a empty value.", + /*retTy=*/"Value", + /*methodName=*/"getMask", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return $_op.getMask(); + }] + >, + InterfaceMethod< + /*desc=*/"Return the mask type if the op has a mask. Otherwise, return " + "an empty VectorType.", /*retTy=*/"::mlir::VectorType", /*methodName=*/"getMaskType", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return $_op.getMask() - ? ::mlir::vector::detail::transferMaskType( - $_op.getVectorType(), $_op.getPermutationMap()) - : ::mlir::VectorType(); + return $_op.getMask() ? $_op.getMask().getType() : ::mlir::VectorType(); }] >, InterfaceMethod< diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2873,7 +2873,8 @@ static LogicalResult verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, VectorType vectorType, VectorType maskType, - AffineMap permutationMap, ArrayAttr inBounds) { + VectorType inferredMaskType, AffineMap permutationMap, + ArrayAttr inBounds) { if (op->hasAttr("masked")) { return op->emitOpError("masked attribute has been removed. " "Use in_bounds instead."); @@ -2926,13 +2927,6 @@ if (permutationMap.getNumResults() != vectorType.getRank()) return op->emitOpError("requires a permutation_map with result dims of " "the same rank as the vector type"); - - VectorType expectedMaskType = - vector::detail::transferMaskType(vectorType, permutationMap); - if (maskType && expectedMaskType != maskType) - return op->emitOpError("expects mask type consistent with permutation " - "map: ") - << maskType; } if (permutationMap.getNumSymbols() != 0) @@ -2942,6 +2936,11 @@ return op->emitOpError("requires a permutation_map with input dims of the " "same rank as the source type"); + if (maskType && maskType != inferredMaskType) + return op->emitOpError("inferred mask type (") + << inferredMaskType << ") and mask operand type (" << maskType + << ") don't match"; + if (inBounds) { if (permutationMap.getNumResults() != static_cast(inBounds.size())) return op->emitOpError("expects the optional in_bounds attr of same rank " @@ -2984,6 +2983,19 @@ p << " : " << getShapedType() << ", " << getVectorType(); } +/// Infers the mask type for a transfer read given its vector type and +/// permutation map. The mask in a transfer read operation applies to the +/// tensor/buffer reading part of it and its type should match the shape read +/// *before* any permutation or broadcasting. +static VectorType inferTransferReadMaskType(VectorType vecType, + AffineMap permMap) { + auto i1Type = IntegerType::get(permMap.getContext(), 1); + AffineMap invPermMap = inversePermutation(compressUnusedDims(permMap)); + assert(invPermMap && "Inversed permutation map couldn't be computed"); + SmallVector maskShape = invPermMap.compose(vecType.getShape()); + return VectorType::get(maskShape, i1Type); +} + ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) { auto &builder = parser.getBuilder(); SMLoc typesLoc; @@ -3014,13 +3026,14 @@ VectorType vectorType = types[1].dyn_cast(); if (!vectorType) return parser.emitError(typesLoc, "requires vector type"); - auto permutationAttrName = TransferReadOp::getPermutationMapAttrStrName(); - Attribute mapAttr = result.attributes.get(permutationAttrName); - if (!mapAttr) { - auto permMap = getTransferMinorIdentityMap(shapedType, vectorType); - // Update `mapAttr` that is used later to determine mask type. - mapAttr = AffineMapAttr::get(permMap); - result.attributes.set(permutationAttrName, mapAttr); + auto permMapAttrName = TransferReadOp::getPermutationMapAttrStrName(); + Attribute permMapAttr = result.attributes.get(permMapAttrName); + AffineMap permMap; + if (!permMapAttr) { + permMap = getTransferMinorIdentityMap(shapedType, vectorType); + result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap)); + } else { + permMap = permMapAttr.cast().getValue(); } if (parser.resolveOperand(sourceInfo, shapedType, result.operands) || parser.resolveOperands(indexInfo, indexType, result.operands) || @@ -3031,10 +3044,9 @@ if (shapedType.getElementType().dyn_cast()) return parser.emitError( maskInfo.location, "does not support masks with vector element type"); - auto map = mapAttr.dyn_cast().getValue(); // Instead of adding the mask type as an op type, compute it based on the // vector type and the permutation map (to keep the type signature small). - auto maskType = mlir::vector::detail::transferMaskType(vectorType, map); + auto maskType = inferTransferReadMaskType(vectorType, permMap); if (parser.resolveOperand(maskInfo, maskType, result.operands)) return failure(); } @@ -3052,13 +3064,17 @@ VectorType maskType = getMaskType(); auto paddingType = getPadding().getType(); auto permutationMap = getPermutationMap(); + VectorType inferredMaskType = + maskType ? inferTransferReadMaskType(vectorType, permutationMap) + : VectorType(); auto sourceElementType = shapedType.getElementType(); if (static_cast(getIndices().size()) != shapedType.getRank()) return emitOpError("requires ") << shapedType.getRank() << " indices"; if (failed(verifyTransferOp(cast(getOperation()), - shapedType, vectorType, maskType, permutationMap, + shapedType, vectorType, maskType, + inferredMaskType, permutationMap, getInBounds() ? *getInBounds() : ArrayAttr()))) return failure(); @@ -3422,6 +3438,18 @@ build(builder, result, vector, dest, indices, permutationMap, inBounds); } +/// Infers the mask type for a transfer write given its vector type and +/// permutation map. The mask in a transfer read operation applies to the +/// tensor/buffer writing part of it and its type should match the shape written +/// *after* any permutation. +static VectorType inferTransferWriteMaskType(VectorType vecType, + AffineMap permMap) { + auto i1Type = IntegerType::get(permMap.getContext(), 1); + SmallVector maskShape = + compressUnusedDims(permMap).compose(vecType.getShape()); + return VectorType::get(maskShape, i1Type); +} + ParseResult TransferWriteOp::parse(OpAsmParser &parser, OperationState &result) { auto &builder = parser.getBuilder(); @@ -3449,11 +3477,14 @@ ShapedType shapedType = types[1].dyn_cast(); if (!shapedType || !shapedType.isa()) return parser.emitError(typesLoc, "requires memref or ranked tensor type"); - auto permutationAttrName = TransferWriteOp::getPermutationMapAttrStrName(); - auto attr = result.attributes.get(permutationAttrName); - if (!attr) { - auto permMap = getTransferMinorIdentityMap(shapedType, vectorType); - result.attributes.set(permutationAttrName, AffineMapAttr::get(permMap)); + auto permMapAttrName = TransferWriteOp::getPermutationMapAttrStrName(); + auto permMapAttr = result.attributes.get(permMapAttrName); + AffineMap permMap; + if (!permMapAttr) { + permMap = getTransferMinorIdentityMap(shapedType, vectorType); + result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap)); + } else { + permMap = permMapAttr.cast().getValue(); } if (parser.resolveOperand(vectorInfo, vectorType, result.operands) || parser.resolveOperand(sourceInfo, shapedType, result.operands) || @@ -3463,7 +3494,7 @@ if (shapedType.getElementType().dyn_cast()) return parser.emitError( maskInfo.location, "does not support masks with vector element type"); - auto maskType = VectorType::get(vectorType.getShape(), builder.getI1Type()); + auto maskType = inferTransferWriteMaskType(vectorType, permMap); if (parser.resolveOperand(maskInfo, maskType, result.operands)) return failure(); } @@ -3489,6 +3520,9 @@ VectorType vectorType = getVectorType(); VectorType maskType = getMaskType(); auto permutationMap = getPermutationMap(); + VectorType inferredMaskType = + maskType ? inferTransferWriteMaskType(vectorType, permutationMap) + : VectorType(); if (llvm::size(getIndices()) != shapedType.getRank()) return emitOpError("requires ") << shapedType.getRank() << " indices"; @@ -3499,7 +3533,8 @@ return emitOpError("should not have broadcast dimensions"); if (failed(verifyTransferOp(cast(getOperation()), - shapedType, vectorType, maskType, permutationMap, + shapedType, vectorType, maskType, + inferredMaskType, permutationMap, getInBounds() ? *getInBounds() : ArrayAttr()))) return failure(); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp @@ -83,26 +83,6 @@ newVectorShape[pos.value()] = originalShape[pos.index()]; } - // Transpose mask operand. - Value newMask; - if (op.getMask()) { - // Remove unused dims from the permutation map. E.g.: - // E.g.: (d0, d1, d2, d3, d4, d5) -> (d5, 0, d3, 0, d2) - // comp = (d0, d1, d2) -> (d2, 0, d1, 0 d0) - auto comp = compressUnusedDims(map); - // Get positions of remaining result dims. - // E.g.: (d0, d1, d2) -> (d2, 0, d1, 0 d0) - // maskTransposeIndices = [ 2, 1, 0] - SmallVector maskTransposeIndices; - for (unsigned i = 0; i < comp.getNumResults(); ++i) { - if (auto expr = comp.getResult(i).dyn_cast()) - maskTransposeIndices.push_back(expr.getPosition()); - } - - newMask = rewriter.create(op.getLoc(), op.getMask(), - maskTransposeIndices); - } - // Transpose in_bounds attribute. ArrayAttr newInBoundsAttr = op.getInBounds() ? transposeInBoundsAttr( @@ -114,7 +94,8 @@ VectorType::get(newVectorShape, op.getVectorType().getElementType()); Value newRead = rewriter.create( op.getLoc(), newReadType, op.getSource(), op.getIndices(), - AffineMapAttr::get(newMap), op.getPadding(), newMask, newInBoundsAttr); + AffineMapAttr::get(newMap), op.getPadding(), op.getMask(), + newInBoundsAttr); // Transpose result of transfer_read. SmallVector transposePerm(permutation.begin(), permutation.end()); @@ -168,11 +149,6 @@ return expr.dyn_cast().getPosition(); }); - // Transpose mask operand. - Value newMask = op.getMask() ? rewriter.create( - op.getLoc(), op.getMask(), indices) - : Value(); - // Transpose in_bounds attribute. ArrayAttr newInBoundsAttr = op.getInBounds() ? transposeInBoundsAttr( @@ -186,7 +162,7 @@ map.getNumDims(), map.getNumResults(), rewriter.getContext()); rewriter.replaceOpWithNewOp( op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap), - newMask, newInBoundsAttr); + op.getMask(), newInBoundsAttr); return success(); } diff --git a/mlir/lib/Interfaces/VectorInterfaces.cpp b/mlir/lib/Interfaces/VectorInterfaces.cpp --- a/mlir/lib/Interfaces/VectorInterfaces.cpp +++ b/mlir/lib/Interfaces/VectorInterfaces.cpp @@ -10,19 +10,6 @@ using namespace mlir; -VectorType mlir::vector::detail::transferMaskType(VectorType vecType, - AffineMap map) { - auto i1Type = IntegerType::get(map.getContext(), 1); - SmallVector shape; - for (int64_t i = 0; i < vecType.getRank(); ++i) { - // Only result dims have a corresponding dim in the mask. - if (map.getResult(i).template isa()) { - shape.push_back(vecType.getDimSize(i)); - } - } - return VectorType::get(shape, i1Type); -} - //===----------------------------------------------------------------------===// // VectorUnroll Interfaces //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf-mask-and-permutation-map.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf-mask-and-permutation-map.mlir --- a/mlir/test/Conversion/VectorToSCF/vector-to-scf-mask-and-permutation-map.mlir +++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf-mask-and-permutation-map.mlir @@ -5,10 +5,9 @@ // CHECK-LABEL: func @transfer_read_2d_mask_transposed( // CHECK-DAG: %[[PADDING:.*]] = arith.constant dense<-4.200000e+01> : vector<9xf32> -// CHECK-DAG: %[[MASK:.*]] = arith.constant dense<{{.*}}> : vector<9x4xi1> +// CHECK-DAG: %[[MASK:.*]] = arith.constant dense<{{.*}}> : vector<4x9xi1> // CHECK: %[[MASK_MEM:.*]] = memref.alloca() : memref> -// CHECK: %[[MASK_T:.*]] = vector.transpose %[[MASK]], [1, 0] : vector<9x4xi1> to vector<4x9xi1> -// CHECK: memref.store %[[MASK_T]], %[[MASK_MEM]][] : memref> +// CHECK: memref.store %[[MASK]], %[[MASK_MEM]][] : memref> // CHECK: %[[MASK_CASTED:.*]] = vector.type_cast %[[MASK_MEM]] : memref> to memref<4xvector<9xi1>> // CHECK: scf.for {{.*}} { // CHECK: scf.if {{.*}} { @@ -25,11 +24,10 @@ func.func @transfer_read_2d_mask_transposed( %A : memref, %base1: index, %base2: index) -> (vector<9x4xf32>) { %fm42 = arith.constant -42.0: f32 - %mask = arith.constant dense<[[1, 0, 1, 0], [0, 0, 1, 0], - [1, 1, 1, 1], [0, 1, 1, 0], - [1, 1, 1, 1], [1, 1, 1, 1], - [1, 1, 1, 1], [0, 0, 0, 0], - [1, 1, 1, 1]]> : vector<9x4xi1> + %mask = arith.constant dense<[[1, 0, 1, 0, 1, 1, 1, 0, 1], + [0, 0, 1, 1, 1, 1, 1, 0, 1], + [1, 1, 1, 1, 1, 1, 1, 0, 1], + [0, 0, 1, 0, 1, 1, 1, 0, 1]]> : vector<4x9xi1> %f = vector.transfer_read %A[%base1, %base2], %fm42, %mask {permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : memref, vector<9x4xf32> diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -49,7 +49,7 @@ %v0 = vector.splat %c0 : vector<4x3xi32> %vi0 = vector.splat %i0 : vector<4x3xindex> %m = arith.constant dense<[0, 0, 1, 0, 1]> : vector<5xi1> - %m2 = vector.splat %i1 : vector<5x4xi1> + %m2 = vector.splat %i1 : vector<4x5xi1> // // CHECK: vector.transfer_read %0 = vector.transfer_read %arg0[%c3, %c3], %f0 {permutation_map = affine_map<(d0, d1)->(d0)>} : memref, vector<128xf32> diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir @@ -282,19 +282,19 @@ %c0 = arith.constant 0 : index // CHECK: %[[MASK0:.*]] = vector.splat %{{.*}} : vector<14x7xi1> - %mask0 = vector.splat %m : vector<7x14xi1> + %mask0 = vector.splat %m : vector<14x7xi1> %0 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst, %mask0 {in_bounds = [true, false, true, true], permutation_map = #map0} : memref, vector<7x14x8x16xf32> // CHECK: vector.transfer_read {{.*}} %[[MASK0]] {in_bounds = [false, true, true, true], permutation_map = #[[$MAP0]]} : memref, vector<14x7x8x16xf32> // CHECK: vector.transpose %{{.*}}, [1, 0, 2, 3] : vector<14x7x8x16xf32> to vector<7x14x8x16xf32> // CHECK: %[[MASK1:.*]] = vector.splat %{{.*}} : vector<16x14xi1> - %mask1 = vector.splat %m : vector<14x16xi1> + %mask1 = vector.splat %m : vector<16x14xi1> %1 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst, %mask1 {permutation_map = #map1} : memref, vector<7x14x8x16xf32> // CHECK: vector.transfer_read {{.*}} %[[MASK1]] {permutation_map = #[[$MAP0]]} : memref, vector<16x14x7x8xf32> // CHECK: vector.transpose %{{.*}}, [2, 1, 3, 0] : vector<16x14x7x8xf32> to vector<7x14x8x16xf32> // CHECK: %[[MASK3:.*]] = vector.splat %{{.*}} : vector<14x7xi1> - %mask2 = vector.splat %m : vector<7x14xi1> + %mask2 = vector.splat %m : vector<14x7xi1> %2 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst, %mask2 {in_bounds = [true, false, true, true], permutation_map = #map2} : memref, vector<7x14x8x16xf32> // CHECK: vector.transfer_read {{.*}} %[[MASK3]] {in_bounds = [false, true, true], permutation_map = #[[$MAP1]]} : memref, vector<14x16x7xf32> // CHECK: vector.broadcast %{{.*}} : vector<14x16x7xf32> to vector<8x14x16x7xf32> @@ -338,7 +338,7 @@ %c0 = arith.constant 0 : index // CHECK: %[[MASK:.*]] = vector.splat %[[M]] : vector<8x14x16x7xi1> - %mask0 = vector.splat %m : vector<7x14x8x16xi1> + %mask0 = vector.splat %m : vector<8x14x16x7xi1> %0 = vector.transfer_write %v1, %arg1[%c0, %c0, %c0, %c0], %mask0 {in_bounds = [true, false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3, d0)>} : vector<7x14x8x16xf32>, tensor // CHECK: %[[NEW_VEC0:.*]] = vector.transpose %{{.*}} [2, 1, 3, 0] : vector<7x14x8x16xf32> to vector<8x14x16x7xf32> // CHECK: %[[NEW_RES0:.*]] = vector.transfer_write %[[NEW_VEC0]], %[[ARG1]][%c0, %c0, %c0, %c0], %[[MASK]] {in_bounds = [false, false, true, true]} : vector<8x14x16x7xf32>, tensor