diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -865,7 +865,12 @@ def Vector_TransferOpUtils { code extraTransferDeclaration = [{ + static StringRef getMaskedAttrName() { return "masked"; } static StringRef getPermutationMapAttrName() { return "permutation_map"; } + bool isMaskedDim(unsigned dim) { + return !masked() || + masked()->cast()[dim].cast().getValue(); + } MemRefType getMemRefType() { return memref().getType().cast(); } @@ -878,14 +883,15 @@ def Vector_TransferReadOp : Vector_Op<"transfer_read">, Arguments<(ins AnyMemRef:$memref, Variadic:$indices, - AffineMapAttr:$permutation_map, AnyType:$padding)>, + AffineMapAttr:$permutation_map, AnyType:$padding, + OptionalAttr:$masked)>, Results<(outs AnyVector:$vector)> { let summary = "Reads a supervector from memory into an SSA vector value."; let description = [{ - The `vector.transfer_read` op performs a blocking read from a slice within - a [MemRef](../LangRef.md#memref-type) supplied as its first operand + The `vector.transfer_read` op performs a read from a slice within a + [MemRef](../LangRef.md#memref-type) supplied as its first operand into a [vector](../LangRef.md#vector-type) of the same base elemental type. A memref operand with vector element type, must have its vector element @@ -893,8 +899,9 @@ memref<3x2x6x4x3xf32>, vector<1x1x4x3xf32>). The slice is further defined by a full-rank index within the MemRef, - supplied as the operands `2 .. 1 + rank(memref)`. The permutation_map - [attribute](../LangRef.md#attributes) is an + supplied as the operands `2 .. 1 + rank(memref)`. + + The permutation_map [attribute](../LangRef.md#attributes) is an [affine-map](Affine.md#affine-maps) which specifies the transposition on the slice to match the vector shape. The permutation map may be implicit and ommitted from parsing and printing if it is the canonical minor identity map @@ -906,6 +913,12 @@ An `ssa-value` of the same elemental type as the MemRef is provided as the last operand to specify padding in the case of out-of-bounds accesses. + An optional boolean array attribute is provided to specify which dimensions + of the transfer need masking. When a dimension is specified as not requiring + masking, the `vector.transfer_read` may be lowered to simple loads. The + absence of this `masked` attribute signifies that all dimensions of the + transfer need to be masked. + This operation is called 'read' by opposition to 'load' because the super-vector granularity is generally not representable with a single hardware register. A `vector.transfer_read` is thus a mid-level abstraction @@ -1015,11 +1028,13 @@ let builders = [ // Builder that sets padding to zero. OpBuilder<"OpBuilder &builder, OperationState &result, VectorType vector, " - "Value memref, ValueRange indices, AffineMap permutationMap">, + "Value memref, ValueRange indices, AffineMap permutationMap, " + "ArrayRef maybeMasked = {}">, // Builder that sets permutation map (resp. padding) to // 'getMinorIdentityMap' (resp. zero). OpBuilder<"OpBuilder &builder, OperationState &result, VectorType vector, " - "Value memref, ValueRange indices"> + "Value memref, ValueRange indices, " + "ArrayRef maybeMasked = {}"> ]; let extraClassDeclaration = Vector_TransferOpUtils.extraTransferDeclaration # @@ -1039,12 +1054,13 @@ Vector_Op<"transfer_write">, Arguments<(ins AnyVector:$vector, AnyMemRef:$memref, Variadic:$indices, - AffineMapAttr:$permutation_map)> { + AffineMapAttr:$permutation_map, + OptionalAttr:$masked)> { let summary = "The vector.transfer_write op writes a supervector to memory."; let description = [{ - The `vector.transfer_write` performs a blocking write from a + The `vector.transfer_write` op performs a write from a [vector](../LangRef.md#vector-type), supplied as its first operand, into a slice within a [MemRef](../LangRef.md#memref-type) of the same base elemental type, supplied as its second operand. @@ -1055,6 +1071,7 @@ The slice is further defined by a full-rank index within the MemRef, supplied as the operands `3 .. 2 + rank(memref)`. + The permutation_map [attribute](../LangRef.md#attributes) is an [affine-map](Affine.md#affine-maps) which specifies the transposition on the slice to match the vector shape. The permutation map may be implicit and @@ -1063,6 +1080,12 @@ The size of the slice is specified by the size of the vector. + An optional boolean array attribute is provided to specify which dimensions + of the transfer need masking. When a dimension is specified as not requiring + masking, the `vector.transfer_write` may be lowered to simple stores. The + absence of this `mask` attribute signifies that all dimensions of the + transfer need to be masked. + This operation is called 'write' by opposition to 'store' because the super-vector granularity is generally not representable with a single hardware register. A `vector.transfer_write` is thus a @@ -1097,7 +1120,10 @@ let builders = [ // Builder that sets permutation map to 'getMinorIdentityMap'. OpBuilder<"OpBuilder &builder, OperationState &result, Value vector, " - "Value memref, ValueRange indices"> + "Value memref, ValueRange indices, " + "ArrayRef maybeMasked = {}">, + OpBuilder<"OpBuilder &builder, OperationState &result, Value vector, " + "Value memref, ValueRange indices, AffineMap permutationMap">, ]; let extraClassDeclaration = Vector_TransferOpUtils.extraTransferDeclaration # diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -746,12 +746,6 @@ } }; -template -LogicalResult replaceTransferOp(ConversionPatternRewriter &rewriter, - LLVMTypeConverter &typeConverter, Location loc, - Operation *op, ArrayRef operands, - Value dataPtr, Value mask); - LogicalResult getLLVMTypeAndAlignment(LLVMTypeConverter &typeConverter, Type type, LLVM::LLVMType &llvmType, unsigned &align) { @@ -765,12 +759,25 @@ return success(); } -template <> -LogicalResult replaceTransferOp( - ConversionPatternRewriter &rewriter, LLVMTypeConverter &typeConverter, - Location loc, Operation *op, ArrayRef operands, Value dataPtr, - Value mask) { - auto xferOp = cast(op); +LogicalResult +replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter, + LLVMTypeConverter &typeConverter, Location loc, + TransferReadOp xferOp, + ArrayRef operands, Value dataPtr) { + LLVM::LLVMType vecTy; + unsigned align; + if (failed(getLLVMTypeAndAlignment(typeConverter, xferOp.getVectorType(), + vecTy, align))) + return failure(); + rewriter.replaceOpWithNewOp(xferOp, dataPtr); + return success(); +} + +LogicalResult replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter, + LLVMTypeConverter &typeConverter, + Location loc, TransferReadOp xferOp, + ArrayRef operands, + Value dataPtr, Value mask) { auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); }; VectorType fillType = xferOp.getVectorType(); Value fill = rewriter.create(loc, fillType, xferOp.padding()); @@ -783,19 +790,32 @@ return failure(); rewriter.replaceOpWithNewOp( - op, vecTy, dataPtr, mask, ValueRange{fill}, + xferOp, vecTy, dataPtr, mask, ValueRange{fill}, rewriter.getI32IntegerAttr(align)); return success(); } -template <> -LogicalResult replaceTransferOp( - ConversionPatternRewriter &rewriter, LLVMTypeConverter &typeConverter, - Location loc, Operation *op, ArrayRef operands, Value dataPtr, - Value mask) { +LogicalResult +replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter, + LLVMTypeConverter &typeConverter, Location loc, + TransferWriteOp xferOp, + ArrayRef operands, Value dataPtr) { auto adaptor = TransferWriteOpOperandAdaptor(operands); + LLVM::LLVMType vecTy; + unsigned align; + if (failed(getLLVMTypeAndAlignment(typeConverter, xferOp.getVectorType(), + vecTy, align))) + return failure(); + rewriter.replaceOpWithNewOp(xferOp, adaptor.vector(), dataPtr); + return success(); +} - auto xferOp = cast(op); +LogicalResult replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter, + LLVMTypeConverter &typeConverter, + Location loc, TransferWriteOp xferOp, + ArrayRef operands, + Value dataPtr, Value mask) { + auto adaptor = TransferWriteOpOperandAdaptor(operands); LLVM::LLVMType vecTy; unsigned align; if (failed(getLLVMTypeAndAlignment(typeConverter, xferOp.getVectorType(), @@ -803,7 +823,8 @@ return failure(); rewriter.replaceOpWithNewOp( - op, adaptor.vector(), dataPtr, mask, rewriter.getI32IntegerAttr(align)); + xferOp, adaptor.vector(), dataPtr, mask, + rewriter.getI32IntegerAttr(align)); return success(); } @@ -877,6 +898,10 @@ vectorDataPtr = rewriter.create( loc, vecTy.getPointerTo(), dataPtr); + if (!xferOp.isMaskedDim(0)) + return replaceTransferOpWithLoadOrStore(rewriter, typeConverter, loc, + xferOp, operands, vectorDataPtr); + // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. unsigned vecWidth = vecTy.getVectorNumElements(); VectorType vectorCmpType = VectorType::get(vecWidth, i64Type); @@ -910,8 +935,8 @@ mask); // 5. Rewrite as a masked read / write. - return replaceTransferOp(rewriter, typeConverter, loc, op, - operands, vectorDataPtr, mask); + return replaceTransferOpWithMasked(rewriter, typeConverter, loc, xferOp, + operands, vectorDataPtr, mask); } }; diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -155,25 +155,34 @@ ValueRange majorIvs, ValueRange majorOffsets, MemRefBoundsCapture &memrefBounds, LambdaThen thenBlockBuilder, LambdaElse elseBlockBuilder) { - Value inBounds = std_constant_int(/*value=*/1, /*width=*/1); + Value inBounds; SmallVector majorIvsPlusOffsets; majorIvsPlusOffsets.reserve(majorIvs.size()); + unsigned idx = 0; for (auto it : llvm::zip(majorIvs, majorOffsets, memrefBounds.getUbs())) { Value iv = std::get<0>(it), off = std::get<1>(it), ub = std::get<2>(it); using namespace mlir::edsc::op; majorIvsPlusOffsets.push_back(iv + off); - Value inBounds2 = majorIvsPlusOffsets.back() < ub; - inBounds = inBounds && inBounds2; + if (xferOp.isMaskedDim(leadingRank + idx)) { + Value inBounds2 = majorIvsPlusOffsets.back() < ub; + inBounds = (inBounds) ? (inBounds && inBounds2) : inBounds2; + } + ++idx; } - auto ifOp = ScopedContext::getBuilderRef().create( - ScopedContext::getLocation(), TypeRange{}, inBounds, - /*withElseRegion=*/std::is_same()); - BlockBuilder(&ifOp.thenRegion().front(), - Append())([&] { thenBlockBuilder(majorIvsPlusOffsets); }); - if (std::is_same()) - BlockBuilder(&ifOp.elseRegion().front(), - Append())([&] { elseBlockBuilder(majorIvsPlusOffsets); }); + if (inBounds) { + auto ifOp = ScopedContext::getBuilderRef().create( + ScopedContext::getLocation(), TypeRange{}, inBounds, + /*withElseRegion=*/std::is_same()); + BlockBuilder(&ifOp.thenRegion().front(), + Append())([&] { thenBlockBuilder(majorIvsPlusOffsets); }); + if (std::is_same()) + BlockBuilder(&ifOp.elseRegion().front(), + Append())([&] { elseBlockBuilder(majorIvsPlusOffsets); }); + } else { + // Just build the body of the then block right here. + thenBlockBuilder(majorIvsPlusOffsets); + } } template <> LogicalResult NDTransferOpHelper::doReplace() { @@ -189,13 +198,18 @@ indexing.append(leadingOffsets.begin(), leadingOffsets.end()); indexing.append(majorIvsPlusOffsets.begin(), majorIvsPlusOffsets.end()); indexing.append(minorOffsets.begin(), minorOffsets.end()); - // Lower to 1-D vector_transfer_read and let recursion handle it. + Value memref = xferOp.memref(); auto map = TransferReadOp::getTransferMinorIdentityMap( xferOp.getMemRefType(), minorVectorType); - auto loaded1D = - vector_transfer_read(minorVectorType, memref, indexing, - AffineMapAttr::get(map), xferOp.padding()); + ArrayAttr masked; + if (xferOp.isMaskedDim(xferOp.getVectorType().getRank() - 1)) { + OpBuilder &b = ScopedContext::getBuilderRef(); + masked = b.getBoolArrayAttr({true}); + } + auto loaded1D = vector_transfer_read(minorVectorType, memref, indexing, + AffineMapAttr::get(map), + xferOp.padding(), masked); // Store the 1-D vector. std_store(loaded1D, alloc, majorIvs); }; @@ -225,7 +239,6 @@ ValueRange majorOffsets, ValueRange minorOffsets, MemRefBoundsCapture &memrefBounds) { auto thenBlockBuilder = [&](ValueRange majorIvsPlusOffsets) { - // Lower to 1-D vector_transfer_write and let recursion handle it. SmallVector indexing; indexing.reserve(leadingRank + majorRank + minorRank); indexing.append(leadingOffsets.begin(), leadingOffsets.end()); @@ -235,8 +248,13 @@ Value loaded1D = std_load(alloc, majorIvs); auto map = TransferWriteOp::getTransferMinorIdentityMap( xferOp.getMemRefType(), minorVectorType); + ArrayAttr masked; + if (xferOp.isMaskedDim(xferOp.getVectorType().getRank() - 1)) { + OpBuilder &b = ScopedContext::getBuilderRef(); + masked = b.getBoolArrayAttr({true}); + } vector_transfer_write(loaded1D, xferOp.memref(), indexing, - AffineMapAttr::get(map)); + AffineMapAttr::get(map), masked); }; // Don't write anything when out of bounds. auto elseBlockBuilder = [&](ValueRange majorIvsPlusOffsets) {}; diff --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp --- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp @@ -1017,8 +1017,7 @@ LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: "); LLVM_DEBUG(permutationMap.print(dbgs())); auto transfer = b.create( - opInst->getLoc(), vectorValue, memRef, indices, - AffineMapAttr::get(permutationMap)); + opInst->getLoc(), vectorValue, memRef, indices, permutationMap); auto *res = transfer.getOperation(); LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorized store: " << *res); // "Terminals" (i.e. AffineStoreOps) are erased on the spot. diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -1202,6 +1202,23 @@ //===----------------------------------------------------------------------===// // TransferReadOp //===----------------------------------------------------------------------===// + +/// Build the default minor identity map suitable for a vector transfer. This +/// also handles the case memref<... x vector<...>> -> vector<...> in which the +/// rank of the identity map must take the vector element type into account. +AffineMap +mlir::vector::impl::getTransferMinorIdentityMap(MemRefType memRefType, + VectorType vectorType) { + int64_t elementVectorRank = 0; + VectorType elementVectorType = + memRefType.getElementType().dyn_cast(); + if (elementVectorType) + elementVectorRank += elementVectorType.getRank(); + return AffineMap::getMinorIdentityMap( + memRefType.getRank(), vectorType.getRank() - elementVectorRank, + memRefType.getContext()); +} + template static LogicalResult verifyPermutationMap(AffineMap permutationMap, EmitFun emitOpError) { @@ -1233,7 +1250,8 @@ static LogicalResult verifyTransferOp(Operation *op, MemRefType memrefType, VectorType vectorType, - AffineMap permutationMap) { + AffineMap permutationMap, + ArrayAttr optionalMasked) { auto memrefElementType = memrefType.getElementType(); if (auto memrefVectorElementType = memrefElementType.dyn_cast()) { // Memref has vector element type. @@ -1282,52 +1300,60 @@ return op->emitOpError("requires a permutation_map with input dims of the " "same rank as the memref type"); - return success(); -} + if (optionalMasked) { + if (permutationMap.getNumResults() != + static_cast(optionalMasked.size())) + return op->emitOpError("expects the optional masked attr of same rank as " + "permutation_map results: ") + << AffineMapAttr::get(permutationMap); + } -/// Build the default minor identity map suitable for a vector transfer. This -/// also handles the case memref<... x vector<...>> -> vector<...> in which the -/// rank of the identity map must take the vector element type into account. -AffineMap -mlir::vector::impl::getTransferMinorIdentityMap(MemRefType memRefType, - VectorType vectorType) { - int64_t elementVectorRank = 0; - VectorType elementVectorType = - memRefType.getElementType().dyn_cast(); - if (elementVectorType) - elementVectorRank += elementVectorType.getRank(); - return AffineMap::getMinorIdentityMap( - memRefType.getRank(), vectorType.getRank() - elementVectorRank, - memRefType.getContext()); + return success(); } -/// Builder that sets permutation map and padding to 'getMinorIdentityMap' and -/// zero, respectively, by default. +/// Builder that sets padding to zero. void TransferReadOp::build(OpBuilder &builder, OperationState &result, VectorType vector, Value memref, ValueRange indices, - AffineMap permutationMap) { + AffineMap permutationMap, + ArrayRef maybeMasked) { Type elemType = vector.cast().getElementType(); Value padding = builder.create(result.location, elemType, builder.getZeroAttr(elemType)); - build(builder, result, vector, memref, indices, permutationMap, padding); + if (maybeMasked.empty()) + return build(builder, result, vector, memref, indices, permutationMap, + padding, ArrayAttr()); + ArrayAttr maskedArrayAttr = builder.getBoolArrayAttr(maybeMasked); + build(builder, result, vector, memref, indices, permutationMap, padding, + maskedArrayAttr); } /// Builder that sets permutation map (resp. padding) to 'getMinorIdentityMap' /// (resp. zero). void TransferReadOp::build(OpBuilder &builder, OperationState &result, VectorType vectorType, Value memref, - ValueRange indices) { - build(builder, result, vectorType, memref, indices, - getTransferMinorIdentityMap(memref.getType().cast(), - vectorType)); + ValueRange indices, ArrayRef maybeMasked) { + auto permMap = getTransferMinorIdentityMap( + memref.getType().cast(), vectorType); + build(builder, result, vectorType, memref, indices, permMap, maybeMasked); } template void printTransferAttrs(OpAsmPrinter &p, TransferOp op) { - SmallVector elidedAttrs; + SmallVector elidedAttrs; if (op.permutation_map() == TransferOp::getTransferMinorIdentityMap( op.getMemRefType(), op.getVectorType())) elidedAttrs.push_back(op.getPermutationMapAttrName()); + bool elideMasked = true; + if (auto maybeMasked = op.masked()) { + for (auto attr : *maybeMasked) { + if (!attr.template cast().getValue()) { + elideMasked = false; + break; + } + } + } + if (elideMasked) + elidedAttrs.push_back(op.getMaskedAttrName()); p.printOptionalAttrDict(op.getAttrs(), elidedAttrs); } @@ -1388,7 +1414,8 @@ return op.emitOpError("requires ") << memrefType.getRank() << " indices"; if (failed(verifyTransferOp(op.getOperation(), memrefType, vectorType, - permutationMap))) + permutationMap, + op.masked() ? *op.masked() : ArrayAttr()))) return failure(); if (auto memrefVectorElementType = memrefElementType.dyn_cast()) { @@ -1419,11 +1446,24 @@ /// Builder that sets permutation map to 'getMinorIdentityMap'. void TransferWriteOp::build(OpBuilder &builder, OperationState &result, - Value vector, Value memref, ValueRange indices) { + Value vector, Value memref, ValueRange indices, + ArrayRef maybeMasked) { auto vectorType = vector.getType().cast(); auto permMap = getTransferMinorIdentityMap( memref.getType().cast(), vectorType); - build(builder, result, vector, memref, indices, permMap); + if (maybeMasked.empty()) + return build(builder, result, vector, memref, indices, permMap, + ArrayAttr()); + ArrayAttr maskedArrayAttr = builder.getBoolArrayAttr(maybeMasked); + build(builder, result, vector, memref, indices, permMap, maskedArrayAttr); +} + +/// Builder that sets permutation map to 'getMinorIdentityMap'. +void TransferWriteOp::build(OpBuilder &builder, OperationState &result, + Value vector, Value memref, ValueRange indices, + AffineMap permutationMap) { + build(builder, result, vector, memref, indices, + /*maybeMasked=*/ArrayRef{}); } static ParseResult parseTransferWriteOp(OpAsmParser &parser, @@ -1477,7 +1517,8 @@ return op.emitOpError("requires ") << memrefType.getRank() << " indices"; if (failed(verifyTransferOp(op.getOperation(), memrefType, vectorType, - permutationMap))) + permutationMap, + op.masked() ? *op.masked() : ArrayAttr()))) return failure(); return verifyPermutationMap(permutationMap, diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -564,9 +564,12 @@ // Get VectorType for slice 'i'. auto sliceVectorType = resultTupleType.getType(index); // Create split TransferReadOp for 'sliceUser'. + // `masked` attribute propagates conservatively: if the coarse op didn't + // need masking, the fine op doesn't either. vectorTupleValues[index] = rewriter.create( loc, sliceVectorType, xferReadOp.memref(), sliceIndices, - xferReadOp.permutation_map(), xferReadOp.padding()); + xferReadOp.permutation_map(), xferReadOp.padding(), + xferReadOp.masked() ? *xferReadOp.masked() : ArrayAttr()); }; generateTransferOpSlices(memrefElementType, sourceVectorType, resultTupleType, sizes, strides, indices, rewriter, @@ -620,9 +623,12 @@ xferWriteOp.indices().end()); auto createSlice = [&](unsigned index, ArrayRef sliceIndices) { // Create split TransferWriteOp for source vector 'tupleOp.operand[i]'. + // `masked` attribute propagates conservatively: if the coarse op didn't + // need masking, the fine op doesn't either. rewriter.create( loc, tupleOp.getOperand(index), xferWriteOp.memref(), sliceIndices, - xferWriteOp.permutation_map()); + xferWriteOp.permutation_map(), + xferWriteOp.masked() ? *xferWriteOp.masked() : ArrayAttr()); }; generateTransferOpSlices(memrefElementType, resultVectorType, sourceTupleType, sizes, strides, indices, rewriter, diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -918,6 +918,24 @@ // CHECK: %[[vecPtr_b:.*]] = llvm.addrspacecast %[[gep_b]] : // CHECK-SAME: !llvm<"float addrspace(3)*"> to !llvm<"<17 x float>*"> +func @transfer_read_1d_not_masked(%A : memref, %base: index) -> vector<17xf32> { + %f7 = constant 7.0: f32 + %f = vector.transfer_read %A[%base], %f7 {masked = [false]} : + memref, vector<17xf32> + return %f: vector<17xf32> +} +// CHECK-LABEL: func @transfer_read_1d_not_masked +// CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: !llvm.i64) -> !llvm<"<17 x float>"> +// +// 1. Bitcast to vector form. +// CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} : +// CHECK-SAME: (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> +// CHECK: %[[vecPtr:.*]] = llvm.bitcast %[[gep]] : +// CHECK-SAME: !llvm<"float*"> to !llvm<"<17 x float>*"> +// +// 2. Rewrite as a load. +// CHECK: %[[loaded:.*]] = llvm.load %[[vecPtr]] : !llvm<"<17 x float>*"> + func @genbool_1d() -> vector<8xi1> { %0 = vector.constant_mask [4] : vector<8xi1> return %0 : vector<8xi1> diff --git a/mlir/test/Conversion/VectorToLoops/vector-to-loops.mlir b/mlir/test/Conversion/VectorToLoops/vector-to-loops.mlir --- a/mlir/test/Conversion/VectorToLoops/vector-to-loops.mlir +++ b/mlir/test/Conversion/VectorToLoops/vector-to-loops.mlir @@ -220,14 +220,12 @@ // CHECK: %[[cst:.*]] = constant 7.000000e+00 : f32 %f7 = constant 7.0: f32 - // CHECK-DAG: %[[cond0:.*]] = constant 1 : i1 // CHECK-DAG: %[[splat:.*]] = constant dense<7.000000e+00> : vector<15xf32> // CHECK-DAG: %[[alloc:.*]] = alloc() : memref<17xvector<15xf32>> // CHECK-DAG: %[[dim:.*]] = dim %[[A]], 0 : memref // CHECK: affine.for %[[I:.*]] = 0 to 17 { // CHECK: %[[add:.*]] = affine.apply #[[MAP0]](%[[I]])[%[[base]]] - // CHECK: %[[cmp:.*]] = cmpi "slt", %[[add]], %[[dim]] : index - // CHECK: %[[cond1:.*]] = and %[[cmp]], %[[cond0]] : i1 + // CHECK: %[[cond1:.*]] = cmpi "slt", %[[add]], %[[dim]] : index // CHECK: scf.if %[[cond1]] { // CHECK: %[[vec_1d:.*]] = vector.transfer_read %[[A]][%[[add]], %[[base]]], %[[cst]] : memref, vector<15xf32> // CHECK: store %[[vec_1d]], %[[alloc]][%[[I]]] : memref<17xvector<15xf32>> @@ -253,7 +251,6 @@ // CHECK-SAME: %[[base:[a-zA-Z0-9]+]]: index, // CHECK-SAME: %[[vec:[a-zA-Z0-9]+]]: vector<17x15xf32> func @transfer_write_progressive(%A : memref, %base: index, %vec: vector<17x15xf32>) { - // CHECK: %[[cond0:.*]] = constant 1 : i1 // CHECK: %[[alloc:.*]] = alloc() : memref<17xvector<15xf32>> // CHECK: %[[vmemref:.*]] = vector.type_cast %[[alloc]] : memref<17xvector<15xf32>> to memref> // CHECK: store %[[vec]], %[[vmemref]][] : memref> @@ -261,8 +258,7 @@ // CHECK: affine.for %[[I:.*]] = 0 to 17 { // CHECK: %[[add:.*]] = affine.apply #[[MAP0]](%[[I]])[%[[base]]] // CHECK: %[[cmp:.*]] = cmpi "slt", %[[add]], %[[dim]] : index - // CHECK: %[[cond1:.*]] = and %[[cmp]], %[[cond0]] : i1 - // CHECK: scf.if %[[cond1]] { + // CHECK: scf.if %[[cmp]] { // CHECK: %[[vec_1d:.*]] = load %0[%[[I]]] : memref<17xvector<15xf32>> // CHECK: vector.transfer_write %[[vec_1d]], %[[A]][%[[add]], %[[base]]] : vector<15xf32>, memref // CHECK: } @@ -271,3 +267,26 @@ vector<17x15xf32>, memref return } + +// ----- + +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (d0 + s0)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d1)> + +// CHECK-LABEL: transfer_write_progressive_not_masked( +// CHECK-SAME: %[[A:[a-zA-Z0-9]+]]: memref, +// CHECK-SAME: %[[base:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[vec:[a-zA-Z0-9]+]]: vector<17x15xf32> +func @transfer_write_progressive_not_masked(%A : memref, %base: index, %vec: vector<17x15xf32>) { + // CHECK-NOT: scf.if + // CHECK-NEXT: %[[alloc:.*]] = alloc() : memref<17xvector<15xf32>> + // CHECK-NEXT: %[[vmemref:.*]] = vector.type_cast %[[alloc]] : memref<17xvector<15xf32>> to memref> + // CHECK-NEXT: store %[[vec]], %[[vmemref]][] : memref> + // CHECK-NEXT: affine.for %[[I:.*]] = 0 to 17 { + // CHECK-NEXT: %[[add:.*]] = affine.apply #[[MAP0]](%[[I]])[%[[base]]] + // CHECK-NEXT: %[[vec_1d:.*]] = load %0[%[[I]]] : memref<17xvector<15xf32>> + // CHECK-NEXT: vector.transfer_write %[[vec_1d]], %[[A]][%[[add]], %[[base]]] : vector<15xf32>, memref + vector.transfer_write %vec, %A[%base, %base] {masked = [false, false]} : + vector<17x15xf32>, memref + return +} diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -348,6 +348,16 @@ // ----- +func @test_vector.transfer_read(%arg0: memref>) { + %c3 = constant 3 : index + %f0 = constant 0.0 : f32 + %vf0 = splat %f0 : vector<2x3xf32> + // expected-error@+1 {{ expects the optional masked attr of same rank as permutation_map results: affine_map<(d0, d1) -> (d0, d1)>}} + %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {masked = [false], permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref>, vector<1x1x2x3xf32> +} + +// ----- + func @test_vector.transfer_write(%arg0: memref) { %c3 = constant 3 : index %cst = constant 3.0 : f32 diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -22,6 +22,8 @@ %3 = vector.transfer_read %arg0[%c3, %c3], %cst {permutation_map = affine_map<(d0, d1)->(d1)>} : memref, vector<128xf32> // CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} : memref>, vector<1x1x4x3xf32> %4 = vector.transfer_read %arg1[%c3, %c3], %vf0 {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref>, vector<1x1x4x3xf32> + // CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} {masked = [true, false]} : memref>, vector<1x1x4x3xf32> + %5 = vector.transfer_read %arg1[%c3, %c3], %vf0 {masked = [true, false]} : memref>, vector<1x1x4x3xf32> // CHECK: vector.transfer_write vector.transfer_write %0, %arg0[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d0)>} : vector<128xf32>, memref @@ -29,6 +31,8 @@ vector.transfer_write %1, %arg0[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d1, d0)>} : vector<3x7xf32>, memref // CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<1x1x4x3xf32>, memref> vector.transfer_write %4, %arg1[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : vector<1x1x4x3xf32>, memref> + // CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<1x1x4x3xf32>, memref> + vector.transfer_write %5, %arg1[%c3, %c3] {masked = [true, true]} : vector<1x1x4x3xf32>, memref> return }