diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -863,17 +863,41 @@ let assemblyFormat = "$vector attr-dict `:` type($vector) `to` type(results)"; } +class Vector_TransferOpBase { + code extraTransferDeclaration = [{ + static StringRef getMaskedAttrName() { return "masked"; } + static StringRef getPermutationMapAttrName() { return "permutation_map"; } + MemRefType getMemRefType() { + return memref().getType().cast(); + } + VectorType getVectorType() { + return vector().getType().cast(); + } + bool isMaskedDim(unsigned dim) { + return !masked() || + masked()->cast()[dim].cast().getValue(); + } + /// Build the default minor identity map suitable for a vector transfer. + /// This also handles the case memref<... x vector<...>> -> vector<...> in + /// which the rank of tte identity map must take the vector element type + /// into account. + static AffineMap getTransferMinorIdentityMap( + MemRefType memRefType, VectorType vectorType); + }]; +} + def Vector_TransferReadOp : Vector_Op<"transfer_read">, Arguments<(ins AnyMemRef:$memref, Variadic:$indices, - AffineMapAttr:$permutation_map, AnyType:$padding)>, + AffineMapAttr:$permutation_map, AnyType:$padding, + OptionalAttr:$masked)>, Results<(outs AnyVector:$vector)> { let summary = "Reads a supervector from memory into an SSA vector value."; let description = [{ - The `vector.transfer_read` op performs a blocking read from a slice within - a [MemRef](../LangRef.md#memref-type) supplied as its first operand + The `vector.transfer_read` op performs a read from a slice within a + [MemRef](../LangRef.md#memref-type) supplied as its first operand into a [vector](../LangRef.md#vector-type) of the same base elemental type. A memref operand with vector element type, must have its vector element @@ -881,18 +905,31 @@ memref<3x2x6x4x3xf32>, vector<1x1x4x3xf32>). The slice is further defined by a full-rank index within the MemRef, - supplied as the operands `2 .. 1 + rank(memref)`. The permutation_map - [attribute](../LangRef.md#attributes) is an + supplied as the operands `2 .. 1 + rank(memref)`. + + The permutation_map [attribute](../LangRef.md#attributes) is an [affine-map](Affine.md#affine-maps) which specifies the transposition on the - slice to match the vector shape. The size of the slice is specified by the - size of the vector, given as the return type. An `ssa-value` of the same - elemental type as the MemRef is provided as the last operand to specify - padding in the case of out-of-bounds accesses. This operation is called - 'read' by opposition to 'load' because the super-vector granularity is - generally not representable with a single hardware register. - A `vector.transfer_read` is thus a mid-level - abstraction that supports super-vectorization with non-effecting padding for - full-tile-only code. + slice to match the vector shape. The permutation map may be implicit and + ommitted from parsing and printing if it is the canonical minor identity map + (i.e. if it does not permute or broadcast any dimension). + + The size of the slice is specified by the size of the vector, given as the + return type. + + An `ssa-value` of the same elemental type as the MemRef is provided as the + last operand to specify padding in the case of out-of-bounds accesses. + + An optional boolean array attribute is provided to specify which dimensions + of the transfer need masking. When a dimension is specified as not requiring + masking, the `vector.transfer_read` may be lowered to simple loads. The + absence of this `masked` attribute signifies that all dimensions of the + transfer need to be masked. + + This operation is called 'read' by opposition to 'load' because the + super-vector granularity is generally not representable with a single + hardware register. A `vector.transfer_read` is thus a mid-level abstraction + that supports super-vectorization with non-effecting padding for full-tile + only operations. More precisely, let's dive deeper into the permutation_map for the following MLIR: @@ -995,10 +1032,15 @@ }]; let builders = [ - // Builder that sets permutation map and padding to 'getMinorIdentityMap' - // and zero, respectively, by default. + // Builder that sets padding to zero. + OpBuilder<"OpBuilder &builder, OperationState &result, VectorType vector, " + "Value memref, ValueRange indices, AffineMap permutationMap, " + "ArrayRef maybeMasked = {}">, + // Builder that sets permutation map (resp. padding) to + // 'getMinorIdentityMap' (resp. zero). OpBuilder<"OpBuilder &builder, OperationState &result, VectorType vector, " - "Value memref, ValueRange indices"> + "Value memref, ValueRange indices, " + "ArrayRef maybeMasked = {}"> ]; let extraClassDeclaration = [{ @@ -1015,12 +1057,13 @@ Vector_Op<"transfer_write">, Arguments<(ins AnyVector:$vector, AnyMemRef:$memref, Variadic:$indices, - AffineMapAttr:$permutation_map)> { + AffineMapAttr:$permutation_map, + OptionalAttr:$masked)> { let summary = "The vector.transfer_write op writes a supervector to memory."; let description = [{ - The `vector.transfer_write` performs a blocking write from a + The `vector.transfer_write` op performs a write from a [vector](../LangRef.md#vector-type), supplied as its first operand, into a slice within a [MemRef](../LangRef.md#memref-type) of the same base elemental type, supplied as its second operand. @@ -1031,12 +1074,24 @@ The slice is further defined by a full-rank index within the MemRef, supplied as the operands `3 .. 2 + rank(memref)`. + The permutation_map [attribute](../LangRef.md#attributes) is an [affine-map](Affine.md#affine-maps) which specifies the transposition on the - slice to match the vector shape. The size of the slice is specified by the - size of the vector. This operation is called 'write' by opposition to - 'store' because the super-vector granularity is generally not representable - with a single hardware register. A `vector.transfer_write` is thus a + slice to match the vector shape. The permutation map may be implicit and + ommitted from parsing and printing if it is the canonical minor identity map + (i.e. if it does not permute or broadcast any dimension). + + The size of the slice is specified by the size of the vector. + + An optional boolean array attribute is provided to specify which dimensions + of the transfer need masking. When a dimension is specified as not requiring + masking, the `vector.transfer_write` may be lowered to simple stores. The + absence of this `mask` attribute signifies that all dimensions of the + transfer need to be masked. + + This operation is called 'write' by opposition to 'store' because the + super-vector granularity is generally not representable with a single + hardware register. A `vector.transfer_write` is thus a mid-level abstraction that supports super-vectorization with non-effecting padding for full-tile-only code. It is the responsibility of `vector.transfer_write`'s implementation to ensure the memory writes are @@ -1069,7 +1124,10 @@ // Builder that sets permutation map and padding to 'getMinorIdentityMap' // by default. OpBuilder<"OpBuilder &builder, OperationState &result, Value vector, " - "Value memref, ValueRange indices"> + "Value memref, ValueRange indices, " + "ArrayRef maybeMasked = {}">, + OpBuilder<"OpBuilder &builder, OperationState &result, Value vector, " + "Value memref, ValueRange indices, AffineMap permutationMap">, ]; let extraClassDeclaration = [{ diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -746,12 +746,6 @@ } }; -template -LogicalResult replaceTransferOp(ConversionPatternRewriter &rewriter, - LLVMTypeConverter &typeConverter, Location loc, - Operation *op, ArrayRef operands, - Value dataPtr, Value mask); - LogicalResult getLLVMTypeAndAlignment(LLVMTypeConverter &typeConverter, Type type, LLVM::LLVMType &llvmType, unsigned &align) { @@ -765,12 +759,25 @@ return success(); } -template <> -LogicalResult replaceTransferOp( - ConversionPatternRewriter &rewriter, LLVMTypeConverter &typeConverter, - Location loc, Operation *op, ArrayRef operands, Value dataPtr, - Value mask) { - auto xferOp = cast(op); +LogicalResult +replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter, + LLVMTypeConverter &typeConverter, Location loc, + TransferReadOp xferOp, + ArrayRef operands, Value dataPtr) { + LLVM::LLVMType vecTy; + unsigned align; + if (failed(getLLVMTypeAndAlignment(typeConverter, xferOp.getVectorType(), + vecTy, align))) + return failure(); + rewriter.replaceOpWithNewOp(xferOp, dataPtr); + return success(); +} + +LogicalResult replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter, + LLVMTypeConverter &typeConverter, + Location loc, TransferReadOp xferOp, + ArrayRef operands, + Value dataPtr, Value mask) { auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); }; VectorType fillType = xferOp.getVectorType(); Value fill = rewriter.create(loc, fillType, xferOp.padding()); @@ -783,19 +790,32 @@ return failure(); rewriter.replaceOpWithNewOp( - op, vecTy, dataPtr, mask, ValueRange{fill}, + xferOp, vecTy, dataPtr, mask, ValueRange{fill}, rewriter.getI32IntegerAttr(align)); return success(); } -template <> -LogicalResult replaceTransferOp( - ConversionPatternRewriter &rewriter, LLVMTypeConverter &typeConverter, - Location loc, Operation *op, ArrayRef operands, Value dataPtr, - Value mask) { +LogicalResult +replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter, + LLVMTypeConverter &typeConverter, Location loc, + TransferWriteOp xferOp, + ArrayRef operands, Value dataPtr) { auto adaptor = TransferWriteOpOperandAdaptor(operands); + LLVM::LLVMType vecTy; + unsigned align; + if (failed(getLLVMTypeAndAlignment(typeConverter, xferOp.getVectorType(), + vecTy, align))) + return failure(); + rewriter.replaceOpWithNewOp(xferOp, adaptor.vector(), dataPtr); + return success(); +} - auto xferOp = cast(op); +LogicalResult replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter, + LLVMTypeConverter &typeConverter, + Location loc, TransferWriteOp xferOp, + ArrayRef operands, + Value dataPtr, Value mask) { + auto adaptor = TransferWriteOpOperandAdaptor(operands); LLVM::LLVMType vecTy; unsigned align; if (failed(getLLVMTypeAndAlignment(typeConverter, xferOp.getVectorType(), @@ -803,7 +823,8 @@ return failure(); rewriter.replaceOpWithNewOp( - op, adaptor.vector(), dataPtr, mask, rewriter.getI32IntegerAttr(align)); + xferOp, adaptor.vector(), dataPtr, mask, + rewriter.getI32IntegerAttr(align)); return success(); } @@ -877,6 +898,10 @@ vectorDataPtr = rewriter.create( loc, vecTy.getPointerTo(), dataPtr); + if (!xferOp.isMaskedDim(0)) + return replaceTransferOpWithLoadOrStore(rewriter, typeConverter, loc, + xferOp, operands, vectorDataPtr); + // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. unsigned vecWidth = vecTy.getVectorNumElements(); VectorType vectorCmpType = VectorType::get(vecWidth, i64Type); @@ -910,8 +935,8 @@ mask); // 5. Rewrite as a masked read / write. - return replaceTransferOp(rewriter, typeConverter, loc, op, - operands, vectorDataPtr, mask); + return replaceTransferOpWithMasked(rewriter, typeConverter, loc, xferOp, + operands, vectorDataPtr, mask); } }; diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -157,25 +157,34 @@ ValueRange majorIvs, ValueRange majorOffsets, MemRefBoundsCapture &memrefBounds, LambdaThen thenBlockBuilder, LambdaElse elseBlockBuilder) { - Value inBounds = std_constant_int(/*value=*/1, /*width=*/1); + Value inBounds; SmallVector majorIvsPlusOffsets; majorIvsPlusOffsets.reserve(majorIvs.size()); + unsigned idx = 0; for (auto it : llvm::zip(majorIvs, majorOffsets, memrefBounds.getUbs())) { Value iv = std::get<0>(it), off = std::get<1>(it), ub = std::get<2>(it); using namespace mlir::edsc::op; majorIvsPlusOffsets.push_back(iv + off); - Value inBounds2 = majorIvsPlusOffsets.back() < ub; - inBounds = inBounds && inBounds2; + if (xferOp.isMaskedDim(leadingRank + idx)) { + Value inBounds2 = majorIvsPlusOffsets.back() < ub; + inBounds = (inBounds) ? (inBounds && inBounds2) : inBounds2; + } + ++idx; } - auto ifOp = ScopedContext::getBuilderRef().create( - ScopedContext::getLocation(), TypeRange{}, inBounds, - /*withElseRegion=*/std::is_same()); - BlockBuilder(&ifOp.thenRegion().front(), - Append())([&] { thenBlockBuilder(majorIvsPlusOffsets); }); - if (std::is_same()) - BlockBuilder(&ifOp.elseRegion().front(), - Append())([&] { elseBlockBuilder(majorIvsPlusOffsets); }); + if (inBounds) { + auto ifOp = ScopedContext::getBuilderRef().create( + ScopedContext::getLocation(), TypeRange{}, inBounds, + /*withElseRegion=*/std::is_same()); + BlockBuilder(&ifOp.thenRegion().front(), + Append())([&] { thenBlockBuilder(majorIvsPlusOffsets); }); + if (std::is_same()) + BlockBuilder(&ifOp.elseRegion().front(), + Append())([&] { elseBlockBuilder(majorIvsPlusOffsets); }); + } else { + // Just build the body of the then block right here. + thenBlockBuilder(majorIvsPlusOffsets); + } } template <> @@ -196,9 +205,17 @@ indexing.append(leadingOffsets.begin(), leadingOffsets.end()); indexing.append(majorIvsPlusOffsets.begin(), majorIvsPlusOffsets.end()); indexing.append(minorOffsets.begin(), minorOffsets.end()); - auto loaded1D = - vector_transfer_read(minorVectorType, memref, indexing, - AffineMapAttr::get(map), xferOp.padding()); + Value memref = xferOp.memref(); + auto map = TransferReadOp::getTransferMinorIdentityMap( + xferOp.getMemRefType(), minorVectorType); + ArrayAttr masked; + if (xferOp.isMaskedDim(xferOp.getVectorType().getRank() - 1)) { + OpBuilder &b = ScopedContext::getBuilderRef(); + masked = b.getBoolArrayAttr({true}); + } + auto loaded1D = vector_transfer_read(minorVectorType, memref, indexing, + AffineMapAttr::get(map), + xferOp.padding(), masked); // Store the 1-D vector. std_store(loaded1D, alloc, majorIvs); }; @@ -229,17 +246,22 @@ ValueRange majorOffsets, ValueRange minorOffsets, MemRefBoundsCapture &memrefBounds) { auto thenBlockBuilder = [&](ValueRange majorIvsPlusOffsets) { - // Lower to 1-D vector_transfer_write and let recursion handle it. - Value loaded1D = std_load(alloc, majorIvs); - auto map = AffineMap::getMinorIdentityMap( - xferOp.getMemRefType().getRank(), minorRank, xferOp.getContext()); SmallVector indexing; indexing.reserve(leadingRank + majorRank + minorRank); indexing.append(leadingOffsets.begin(), leadingOffsets.end()); indexing.append(majorIvsPlusOffsets.begin(), majorIvsPlusOffsets.end()); indexing.append(minorOffsets.begin(), minorOffsets.end()); + // Lower to 1-D vector_transfer_write and let recursion handle it. + Value loaded1D = std_load(alloc, majorIvs); + auto map = TransferWriteOp::getTransferMinorIdentityMap( + xferOp.getMemRefType(), minorVectorType); + ArrayAttr masked; + if (xferOp.isMaskedDim(xferOp.getVectorType().getRank() - 1)) { + OpBuilder &b = ScopedContext::getBuilderRef(); + masked = b.getBoolArrayAttr({true}); + } vector_transfer_write(loaded1D, xferOp.memref(), indexing, - AffineMapAttr::get(map)); + AffineMapAttr::get(map), masked); }; // Don't write anything when out of bounds. auto elseBlockBuilder = [&](ValueRange majorIvsPlusOffsets) {}; diff --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp --- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp @@ -1020,8 +1020,7 @@ LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: "); LLVM_DEBUG(permutationMap.print(dbgs())); auto transfer = b.create( - opInst->getLoc(), vectorValue, memRef, indices, - AffineMapAttr::get(permutationMap)); + opInst->getLoc(), vectorValue, memRef, indices, permutationMap); auto *res = transfer.getOperation(); LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorized store: " << *res); // "Terminals" (i.e. AffineStoreOps) are erased on the spot. diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -1202,6 +1202,24 @@ //===----------------------------------------------------------------------===// // TransferReadOp //===----------------------------------------------------------------------===// + +namespace impl { +/// Build the default minor identity map suitable for a vector transfer. This +/// also handles the case memref<... x vector<...>> -> vector<...> in which the +/// rank of the identity map must take the vector element type into account. +static AffineMap getTransferMinorIdentityMap(MemRefType memRefType, + VectorType vectorType) { + int64_t elementVectorRank = 0; + VectorType elementVectorType = + memRefType.getElementType().dyn_cast(); + if (elementVectorType) + elementVectorRank += elementVectorType.getRank(); + return AffineMap::getMinorIdentityMap( + memRefType.getRank(), vectorType.getRank() - elementVectorRank, + memRefType.getContext()); +} +} // namespace impl + template static LogicalResult verifyPermutationMap(AffineMap permutationMap, EmitFun emitOpError) { @@ -1233,7 +1251,8 @@ static LogicalResult verifyTransferOp(Operation *op, MemRefType memrefType, VectorType vectorType, - AffineMap permutationMap) { + AffineMap permutationMap, + ArrayAttr optionalMasked) { auto memrefElementType = memrefType.getElementType(); if (auto memrefVectorElementType = memrefElementType.dyn_cast()) { // Memref has vector element type. @@ -1281,22 +1300,69 @@ if (permutationMap.getNumInputs() != memrefType.getRank()) return op->emitOpError("requires a permutation_map with input dims of the " "same rank as the memref type"); + + if (optionalMasked) { + if (permutationMap.getNumResults() != + static_cast(optionalMasked.size())) + return op->emitOpError("expects the optional masked attr of same rank as " + "permutation_map results: ") + << AffineMapAttr::get(permutationMap); + } + return success(); } -/// Builder that sets permutation map and padding to 'getMinorIdentityMap' and -/// zero, respectively, by default. +/// Builder that sets padding to zero. void TransferReadOp::build(OpBuilder &builder, OperationState &result, - VectorType vector, Value memref, - ValueRange indices) { - auto permMap = AffineMap::getMinorIdentityMap( - memref.getType().cast().getRank(), vector.getRank(), - builder.getContext()); + VectorType vector, Value memref, ValueRange indices, + AffineMap permutationMap, + ArrayRef maybeMasked) { Type elemType = vector.cast().getElementType(); Value padding = builder.create(result.location, elemType, builder.getZeroAttr(elemType)); + if (maybeMasked.empty()) + return build(builder, result, vector, memref, indices, permutationMap, + padding, ArrayAttr()); + ArrayAttr maskedArrayAttr = builder.getBoolArrayAttr(maybeMasked); + build(builder, result, vector, memref, indices, permutationMap, padding, + maskedArrayAttr); +} - build(builder, result, vector, memref, indices, permMap, padding); +/// Builder that sets permutation map (resp. padding) to 'getMinorIdentityMap' +/// (resp. zero). +void TransferReadOp::build(OpBuilder &builder, OperationState &result, + VectorType vectorType, Value memref, + ValueRange indices, ArrayRef maybeMasked) { + auto permMap = getTransferMinorIdentityMap( + memref.getType().cast(), vectorType); + build(builder, result, vectorType, memref, indices, permMap, maybeMasked); +} + +/// Build the default minor identity map suitable for a vector transfer. This +/// also handles the case memref<... x vector<...>> -> vector<...> in which the +/// rank of the identity map must take the vector element type into account. +AffineMap TransferReadOp::getTransferMinorIdentityMap(MemRefType memRefType, + VectorType vectorType) { + return ::impl::getTransferMinorIdentityMap(memRefType, vectorType); +} + +template +void printTransferAttrs(OpAsmPrinter &p, TransferOp op) { + SmallVector elidedAttrs; + if (AffineMap::isMinorIdentity(op.permutation_map())) + elidedAttrs.push_back(op.getPermutationMapAttrName()); + bool elideMasked = true; + if (auto maybeMasked = op.masked()) { + for (auto attr : *maybeMasked) { + if (!attr.template cast().getValue()) { + elideMasked = false; + break; + } + } + } + if (elideMasked) + elidedAttrs.push_back(op.getMaskedAttrName()); + p.printOptionalAttrDict(op.getAttrs(), elidedAttrs); } static void print(OpAsmPrinter &p, TransferReadOp op) { @@ -1347,7 +1413,8 @@ return op.emitOpError("requires ") << memrefType.getRank() << " indices"; if (failed(verifyTransferOp(op.getOperation(), memrefType, vectorType, - permutationMap))) + permutationMap, + op.masked() ? *op.masked() : ArrayAttr()))) return failure(); if (auto memrefVectorElementType = memrefElementType.dyn_cast()) { @@ -1379,12 +1446,23 @@ /// Builder that sets permutation map and padding to 'getMinorIdentityMap' by /// default. void TransferWriteOp::build(OpBuilder &builder, OperationState &result, - Value vector, Value memref, ValueRange indices) { + Value vector, Value memref, ValueRange indices, + ArrayRef maybeMasked) { auto vectorType = vector.getType().cast(); - auto permMap = AffineMap::getMinorIdentityMap( - memref.getType().cast().getRank(), vectorType.getRank(), - builder.getContext()); - build(builder, result, vector, memref, indices, permMap); + auto permMap = getTransferMinorIdentityMap( + memref.getType().cast(), vectorType); + if (maybeMasked.empty()) + return build(builder, result, vector, memref, indices, permMap, + ArrayAttr()); + ArrayAttr maskedArrayAttr = builder.getBoolArrayAttr(maybeMasked); + build(builder, result, vector, memref, indices, permMap, maskedArrayAttr); +} + +void TransferWriteOp::build(OpBuilder &builder, OperationState &result, + Value vector, Value memref, ValueRange indices, + AffineMap permutationMap) { + build(builder, result, vector, memref, indices, + /*maybeMasked=*/ArrayRef{}); } static LogicalResult verify(TransferWriteOp op) { @@ -1397,7 +1475,8 @@ return op.emitOpError("requires ") << memrefType.getRank() << " indices"; if (failed(verifyTransferOp(op.getOperation(), memrefType, vectorType, - permutationMap))) + permutationMap, + op.masked() ? *op.masked() : ArrayAttr()))) return failure(); return verifyPermutationMap(permutationMap, diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -564,9 +564,12 @@ // Get VectorType for slice 'i'. auto sliceVectorType = resultTupleType.getType(index); // Create split TransferReadOp for 'sliceUser'. + // `masked` attribute propagates conservatively: if the coarse op didn't + // need masking, the fine op doesn't either. vectorTupleValues[index] = rewriter.create( loc, sliceVectorType, xferReadOp.memref(), sliceIndices, - xferReadOp.permutation_map(), xferReadOp.padding()); + xferReadOp.permutation_map(), xferReadOp.padding(), + xferReadOp.masked() ? *xferReadOp.masked() : ArrayAttr()); }; generateTransferOpSlices(memrefElementType, sourceVectorType, resultTupleType, sizes, strides, indices, rewriter, @@ -620,9 +623,12 @@ xferWriteOp.indices().end()); auto createSlice = [&](unsigned index, ArrayRef sliceIndices) { // Create split TransferWriteOp for source vector 'tupleOp.operand[i]'. + // `masked` attribute propagates conservatively: if the coarse op didn't + // need masking, the fine op doesn't either. rewriter.create( loc, tupleOp.getOperand(index), xferWriteOp.memref(), sliceIndices, - xferWriteOp.permutation_map()); + xferWriteOp.permutation_map(), + xferWriteOp.masked() ? *xferWriteOp.masked() : ArrayAttr()); }; generateTransferOpSlices(memrefElementType, resultVectorType, sourceTupleType, sizes, strides, indices, rewriter, diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -918,6 +918,24 @@ // CHECK: %[[vecPtr_b:.*]] = llvm.addrspacecast %[[gep_b]] : // CHECK-SAME: !llvm<"float addrspace(3)*"> to !llvm<"<17 x float>*"> +func @transfer_read_1d_not_masked(%A : memref, %base: index) -> vector<17xf32> { + %f7 = constant 7.0: f32 + %f = vector.transfer_read %A[%base], %f7 {masked = [false]} : + memref, vector<17xf32> + return %f: vector<17xf32> +} +// CHECK-LABEL: func @transfer_read_1d_not_masked +// CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: !llvm.i64) -> !llvm<"<17 x float>"> +// +// 1. Bitcast to vector form. +// CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} : +// CHECK-SAME: (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> +// CHECK: %[[vecPtr:.*]] = llvm.bitcast %[[gep]] : +// CHECK-SAME: !llvm<"float*"> to !llvm<"<17 x float>*"> +// +// 2. Rewrite as a load. +// CHECK: %[[loaded:.*]] = llvm.load %[[vecPtr]] : !llvm<"<17 x float>*"> + func @genbool_1d() -> vector<8xi1> { %0 = vector.constant_mask [4] : vector<8xi1> return %0 : vector<8xi1> diff --git a/mlir/test/Conversion/VectorToLoops/vector-to-loops.mlir b/mlir/test/Conversion/VectorToLoops/vector-to-loops.mlir --- a/mlir/test/Conversion/VectorToLoops/vector-to-loops.mlir +++ b/mlir/test/Conversion/VectorToLoops/vector-to-loops.mlir @@ -220,14 +220,12 @@ // CHECK: %[[cst:.*]] = constant 7.000000e+00 : f32 %f7 = constant 7.0: f32 - // CHECK-DAG: %[[cond0:.*]] = constant 1 : i1 // CHECK-DAG: %[[splat:.*]] = constant dense<7.000000e+00> : vector<15xf32> // CHECK-DAG: %[[alloc:.*]] = alloc() : memref<17xvector<15xf32>> // CHECK-DAG: %[[dim:.*]] = dim %[[A]], 0 : memref // CHECK: affine.for %[[I:.*]] = 0 to 17 { // CHECK: %[[add:.*]] = affine.apply #[[MAP0]](%[[I]])[%[[base]]] - // CHECK: %[[cmp:.*]] = cmpi "slt", %[[add]], %[[dim]] : index - // CHECK: %[[cond1:.*]] = and %[[cmp]], %[[cond0]] : i1 + // CHECK: %[[cond1:.*]] = cmpi "slt", %[[add]], %[[dim]] : index // CHECK: scf.if %[[cond1]] { // CHECK: %[[vec_1d:.*]] = vector.transfer_read %[[A]][%[[add]], %[[base]]], %[[cst]] {permutation_map = #[[MAP1]]} : memref, vector<15xf32> // CHECK: store %[[vec_1d]], %[[alloc]][%[[I]]] : memref<17xvector<15xf32>> @@ -253,7 +251,6 @@ // CHECK-SAME: %[[base:[a-zA-Z0-9]+]]: index, // CHECK-SAME: %[[vec:[a-zA-Z0-9]+]]: vector<17x15xf32> func @transfer_write_progressive(%A : memref, %base: index, %vec: vector<17x15xf32>) { - // CHECK: %[[cond0:.*]] = constant 1 : i1 // CHECK: %[[alloc:.*]] = alloc() : memref<17xvector<15xf32>> // CHECK: %[[vmemref:.*]] = vector.type_cast %[[alloc]] : memref<17xvector<15xf32>> to memref> // CHECK: store %[[vec]], %[[vmemref]][] : memref> @@ -261,8 +258,7 @@ // CHECK: affine.for %[[I:.*]] = 0 to 17 { // CHECK: %[[add:.*]] = affine.apply #[[MAP0]](%[[I]])[%[[base]]] // CHECK: %[[cmp:.*]] = cmpi "slt", %[[add]], %[[dim]] : index - // CHECK: %[[cond1:.*]] = and %[[cmp]], %[[cond0]] : i1 - // CHECK: scf.if %[[cond1]] { + // CHECK: scf.if %[[cmp]] { // CHECK: %[[vec_1d:.*]] = load %0[%[[I]]] : memref<17xvector<15xf32>> // CHECK: vector.transfer_write %[[vec_1d]], %[[A]][%[[add]], %[[base]]] {permutation_map = #[[MAP1]]} : vector<15xf32>, memref // CHECK: } @@ -271,3 +267,26 @@ vector<17x15xf32>, memref return } + +// ----- + +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (d0 + s0)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d1)> + +// CHECK-LABEL: transfer_write_progressive_not_masked( +// CHECK-SAME: %[[A:[a-zA-Z0-9]+]]: memref, +// CHECK-SAME: %[[base:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[vec:[a-zA-Z0-9]+]]: vector<17x15xf32> +func @transfer_write_progressive_not_masked(%A : memref, %base: index, %vec: vector<17x15xf32>) { + // CHECK-NOT: scf.if + // CHECK-NEXT: %[[alloc:.*]] = alloc() : memref<17xvector<15xf32>> + // CHECK-NEXT: %[[vmemref:.*]] = vector.type_cast %[[alloc]] : memref<17xvector<15xf32>> to memref> + // CHECK-NEXT: store %[[vec]], %[[vmemref]][] : memref> + // CHECK-NEXT: affine.for %[[I:.*]] = 0 to 17 { + // CHECK-NEXT: %[[add:.*]] = affine.apply #[[MAP0]](%[[I]])[%[[base]]] + // CHECK-NEXT: %[[vec_1d:.*]] = load %0[%[[I]]] : memref<17xvector<15xf32>> + // CHECK-NEXT: vector.transfer_write %[[vec_1d]], %[[A]][%[[add]], %[[base]]] : vector<15xf32>, memref + vector.transfer_write %vec, %A[%base, %base] {masked = [false, false]} : + vector<17x15xf32>, memref + return +} diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -337,6 +337,16 @@ // ----- +func @test_vector.transfer_read(%arg0: memref>) { + %c3 = constant 3 : index + %f0 = constant 0.0 : f32 + %vf0 = splat %f0 : vector<2x3xf32> + // expected-error@+1 {{ expects the optional masked attr of same rank as permutation_map results: affine_map<(d0, d1) -> (d0, d1)>}} + %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {masked = [false], permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref>, vector<1x1x2x3xf32> +} + +// ----- + func @test_vector.transfer_write(%arg0: memref) { %c3 = constant 3 : index %cst = constant dense<3.0> : vector<128 x f32> diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -22,6 +22,8 @@ %3 = vector.transfer_read %arg0[%c3, %c3], %cst {permutation_map = affine_map<(d0, d1)->(d1)>} : memref, vector<128xf32> // CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref>, vector<1x1x4x3xf32> %4 = vector.transfer_read %arg1[%c3, %c3], %vf0 {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref>, vector<1x1x4x3xf32> + // CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} {masked = [true, false]} : memref>, vector<1x1x4x3xf32> + %5 = vector.transfer_read %arg1[%c3, %c3], %vf0 {masked = [true, false]} : memref>, vector<1x1x4x3xf32> // CHECK: vector.transfer_write vector.transfer_write %0, %arg0[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d0)>} : vector<128xf32>, memref @@ -29,6 +31,8 @@ vector.transfer_write %1, %arg0[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d1, d0)>} : vector<3x7xf32>, memref // CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] {permutation_map = #[[MAP0]]} : vector<1x1x4x3xf32>, memref> vector.transfer_write %4, %arg1[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : vector<1x1x4x3xf32>, memref> + // CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<1x1x4x3xf32>, memref> + vector.transfer_write %5, %arg1[%c3, %c3] {masked = [true, true]} : vector<1x1x4x3xf32>, memref> return }