diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h @@ -62,9 +62,12 @@ /// Collect a set of transfer read/write lowering patterns. /// /// These patterns lower transfer ops to simpler ops like `vector.load`, -/// `vector.store` and `vector.broadcast`. Includes all patterns of -/// populateVectorTransferPermutationMapLoweringPatterns. -void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns); +/// `vector.store` and `vector.broadcast`. Only transfers with a transfer rank +/// of a most `maxTransferRank` are lowered. This is useful when combined with +/// VectorToSCF, which reduces the rank of vector transfer ops. +void populateVectorTransferLoweringPatterns( + RewritePatternSet &patterns, + llvm::Optional maxTransferRank = llvm::None); /// Collect a set of transfer read/write lowering patterns that simplify the /// permutation map (e.g., converting it to a minor identity map) by inserting @@ -185,6 +188,10 @@ Value getVectorReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value vector); +/// Return true if the last dimension of the MemRefType has unit stride. Also +/// return true for memrefs with no strides. +bool isLastMemrefDimUnitStride(MemRefType type); + namespace impl { /// Build the default minor identity map suitable for a vector transfer. This /// also handles the case memref<... x vector<...>> -> vector<...> in which the diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -1409,9 +1409,9 @@ based on the element type of the memref. The shape of the result vector type determines the shape of the slice read from the start memory address. The elements along each dimension of the slice are strided by the memref - strides. Only memref with default strides are allowed. These constraints - guarantee that elements read along the first dimension of the slice are - contiguous in memory. + strides. Only unit strides are allowed along the most minor memref + dimension. These constraints guarantee that elements read along the first + dimension of the slice are contiguous in memory. The memref element type can be a scalar or a vector type. If the memref element type is a scalar, it should match the element type of the result @@ -1470,6 +1470,8 @@ } }]; + let hasFolder = 1; + let assemblyFormat = "$base `[` $indices `]` attr-dict `:` type($base) `,` type($result)"; } @@ -1484,9 +1486,9 @@ memref dimension based on the element type of the memref. The shape of the vector value to store determines the shape of the slice written from the start memory address. The elements along each dimension of the slice are - strided by the memref strides. Only memref with default strides are allowed. - These constraints guarantee that elements written along the first dimension - of the slice are contiguous in memory. + strided by the memref strides. Only unit strides are allowed along the most + minor memref dimension. These constraints guarantee that elements written + along the first dimension of the slice are contiguous in memory. The memref element type can be a scalar or a vector type. If the memref element type is a scalar, it should match the element type of the value @@ -1544,6 +1546,8 @@ } }]; + let hasFolder = 1; + let assemblyFormat = "$valueToStore `,` $base `[` $indices `]` attr-dict " "`:` type($base) `,` type($valueToStore)"; } @@ -1601,6 +1605,7 @@ let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $pass_thru attr-dict `:` " "type($base) `,` type($mask) `,` type($pass_thru) `into` type($result)"; let hasCanonicalizer = 1; + let hasFolder = 1; } def Vector_MaskedStoreOp : @@ -1653,6 +1658,7 @@ "$base `[` $indices `]` `,` $mask `,` $valueToStore " "attr-dict `:` type($base) `,` type($mask) `,` type($valueToStore)"; let hasCanonicalizer = 1; + let hasFolder = 1; } def Vector_GatherOp : diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -130,18 +130,6 @@ } return align; } -// Helper that returns data layout alignment of a memref associated with a -// transfer op, including additional information from assume_alignment calls -// on the source of the transfer -LogicalResult getTransferOpAlignment(LLVMTypeConverter &typeConverter, - VectorTransferOpInterface xfer, - unsigned &align) { - if (failed(getMemRefAlignment( - typeConverter, xfer.getShapedType().cast(), align))) - return failure(); - align = std::max(align, getAssumedAlignment(xfer.source())); - return success(); -} // Helper that returns data layout alignment of a memref associated with a // load, store, scatter, or gather op, including additional information from @@ -181,79 +169,6 @@ return rewriter.create(loc, pType, ptr); } -static LogicalResult -replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter, - LLVMTypeConverter &typeConverter, Location loc, - TransferReadOp xferOp, - ArrayRef operands, Value dataPtr) { - unsigned align; - if (failed(getTransferOpAlignment(typeConverter, xferOp, align))) - return failure(); - rewriter.replaceOpWithNewOp(xferOp, dataPtr, align); - return success(); -} - -static LogicalResult -replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter, - LLVMTypeConverter &typeConverter, Location loc, - TransferReadOp xferOp, ArrayRef operands, - Value dataPtr, Value mask) { - Type vecTy = typeConverter.convertType(xferOp.getVectorType()); - if (!vecTy) - return failure(); - - auto adaptor = TransferReadOpAdaptor(operands, xferOp->getAttrDictionary()); - Value fill = rewriter.create(loc, vecTy, adaptor.padding()); - - unsigned align; - if (failed(getTransferOpAlignment(typeConverter, xferOp, align))) - return failure(); - rewriter.replaceOpWithNewOp( - xferOp, vecTy, dataPtr, mask, ValueRange{fill}, - rewriter.getI32IntegerAttr(align)); - return success(); -} - -static LogicalResult -replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter, - LLVMTypeConverter &typeConverter, Location loc, - TransferWriteOp xferOp, - ArrayRef operands, Value dataPtr) { - unsigned align; - if (failed(getTransferOpAlignment(typeConverter, xferOp, align))) - return failure(); - auto adaptor = TransferWriteOpAdaptor(operands, xferOp->getAttrDictionary()); - rewriter.replaceOpWithNewOp(xferOp, adaptor.vector(), dataPtr, - align); - return success(); -} - -static LogicalResult -replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter, - LLVMTypeConverter &typeConverter, Location loc, - TransferWriteOp xferOp, ArrayRef operands, - Value dataPtr, Value mask) { - unsigned align; - if (failed(getTransferOpAlignment(typeConverter, xferOp, align))) - return failure(); - - auto adaptor = TransferWriteOpAdaptor(operands, xferOp->getAttrDictionary()); - rewriter.replaceOpWithNewOp( - xferOp, adaptor.vector(), dataPtr, mask, - rewriter.getI32IntegerAttr(align)); - return success(); -} - -static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp, - ArrayRef operands) { - return TransferReadOpAdaptor(operands, xferOp->getAttrDictionary()); -} - -static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp, - ArrayRef operands) { - return TransferWriteOpAdaptor(operands, xferOp->getAttrDictionary()); -} - namespace { /// Conversion pattern for a vector.bitcast. @@ -1026,15 +941,6 @@ } }; -/// Return true if the last dimension of the MemRefType has unit stride. Also -/// return true for memrefs with no strides. -static bool isLastMemrefDimUnitStride(MemRefType type) { - int64_t offset; - SmallVector strides; - auto successStrides = getStridesAndOffset(type, strides, offset); - return succeeded(successStrides) && (strides.empty() || strides.back() == 1); -} - /// Returns the strides if the memory underlying `memRefType` has a contiguous /// static layout. static llvm::Optional> @@ -1145,83 +1051,6 @@ } }; -/// Conversion pattern that converts a 1-D vector transfer read/write op into a -/// a masked or unmasked read/write. -template -class VectorTransferConversion : public ConvertOpToLLVMPattern { -public: - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(ConcreteOp xferOp, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - auto adaptor = getTransferOpAdapter(xferOp, operands); - - if (xferOp.getVectorType().getRank() > 1 || xferOp.indices().empty()) - return failure(); - if (xferOp.permutation_map() != - AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(), - xferOp.getVectorType().getRank(), - xferOp->getContext())) - return failure(); - auto memRefType = xferOp.getShapedType().template dyn_cast(); - if (!memRefType) - return failure(); - // Last dimension must be contiguous. (Otherwise: Use VectorToSCF.) - if (!isLastMemrefDimUnitStride(memRefType)) - return failure(); - // Out-of-bounds dims are handled by MaterializeTransferMask. - if (xferOp.hasOutOfBoundsDim()) - return failure(); - - auto toLLVMTy = [&](Type t) { - return this->getTypeConverter()->convertType(t); - }; - - Location loc = xferOp->getLoc(); - - if (auto memrefVectorElementType = - memRefType.getElementType().template dyn_cast()) { - // Memref has vector element type. - if (memrefVectorElementType.getElementType() != - xferOp.getVectorType().getElementType()) - return failure(); -#ifndef NDEBUG - // Check that memref vector type is a suffix of 'vectorType. - unsigned memrefVecEltRank = memrefVectorElementType.getRank(); - unsigned resultVecRank = xferOp.getVectorType().getRank(); - assert(memrefVecEltRank <= resultVecRank); - // TODO: Move this to isSuffix in Vector/Utils.h. - unsigned rankOffset = resultVecRank - memrefVecEltRank; - auto memrefVecEltShape = memrefVectorElementType.getShape(); - auto resultVecShape = xferOp.getVectorType().getShape(); - for (unsigned i = 0; i < memrefVecEltRank; ++i) - assert(memrefVecEltShape[i] != resultVecShape[rankOffset + i] && - "memref vector element shape should match suffix of vector " - "result shape."); -#endif // ifndef NDEBUG - } - - // Get the source/dst address as an LLVM vector pointer. - VectorType vtp = xferOp.getVectorType(); - Value dataPtr = this->getStridedElementPtr( - loc, memRefType, adaptor.source(), adaptor.indices(), rewriter); - Value vectorDataPtr = - castDataPtr(rewriter, loc, dataPtr, memRefType, toLLVMTy(vtp)); - - // Rewrite as an unmasked masked read / write. - if (!xferOp.mask()) - return replaceTransferOpWithLoadOrStore(rewriter, - *this->getTypeConverter(), loc, - xferOp, operands, vectorDataPtr); - - // Rewrite as a masked read / write. - return replaceTransferOpWithMasked(rewriter, *this->getTypeConverter(), loc, - xferOp, operands, vectorDataPtr, - xferOp.mask()); - } -}; - class VectorPrintOpConversion : public ConvertOpToLLVMPattern { public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -1450,9 +1279,10 @@ VectorLoadStoreConversion, VectorGatherOpConversion, VectorScatterOpConversion, - VectorExpandLoadOpConversion, VectorCompressStoreOpConversion, - VectorTransferConversion, - VectorTransferConversion>(converter); + VectorExpandLoadOpConversion, VectorCompressStoreOpConversion>( + converter); + // Transfer ops with rank > 1 are handled by VectorToSCF. + populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); } void mlir::populateVectorToLLVMMatrixConversionPatterns( diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -64,6 +64,8 @@ populateVectorToVectorCanonicalizationPatterns(patterns); populateVectorContractLoweringPatterns(patterns); populateVectorTransposeLoweringPatterns(patterns); + // Vector transfer ops with rank > 1 should be lowered with VectorToSCF. + populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } @@ -71,6 +73,7 @@ LLVMTypeConverter converter(&getContext()); RewritePatternSet patterns(&getContext()); populateVectorMaskMaterializationPatterns(patterns, enableIndexOptimizations); + populateVectorTransferLoweringPatterns(patterns); populateVectorToLLVMMatrixConversionPatterns(converter, patterns); populateVectorToLLVMConversionPatterns(converter, patterns, reassociateFPReductions); diff --git a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp @@ -89,7 +89,7 @@ .add( vectorTransformsOptions, context); - vector::populateVectorTransferLoweringPatterns( + vector::populateVectorTransferPermutationMapLoweringPatterns( vectorContractLoweringPatterns); (void)applyPatternsAndFoldGreedily( func, std::move(vectorContractLoweringPatterns)); diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -102,6 +102,15 @@ return false; } +/// Return true if the last dimension of the MemRefType has unit stride. Also +/// return true for memrefs with no strides. +bool mlir::vector::isLastMemrefDimUnitStride(MemRefType type) { + int64_t offset; + SmallVector strides; + auto successStrides = getStridesAndOffset(type, strides, offset); + return succeeded(successStrides) && (strides.empty() || strides.back() == 1); +} + //===----------------------------------------------------------------------===// // CombiningKindAttr //===----------------------------------------------------------------------===// @@ -2953,9 +2962,8 @@ static LogicalResult verifyLoadStoreMemRefLayout(Operation *op, MemRefType memRefTy) { - auto affineMaps = memRefTy.getAffineMaps(); - if (!affineMaps.empty()) - return op->emitOpError("base memref should have a default identity layout"); + if (!isLastMemrefDimUnitStride(memRefTy)) + return op->emitOpError("most minor memref dim must have unit stride"); return success(); } @@ -2981,6 +2989,12 @@ return success(); } +OpFoldResult LoadOp::fold(ArrayRef) { + if (succeeded(foldMemRefCast(*this))) + return getResult(); + return OpFoldResult(); +} + //===----------------------------------------------------------------------===// // StoreOp //===----------------------------------------------------------------------===// @@ -3008,6 +3022,11 @@ return success(); } +LogicalResult StoreOp::fold(ArrayRef operands, + SmallVectorImpl &results) { + return foldMemRefCast(*this); +} + //===----------------------------------------------------------------------===// // MaskedLoadOp //===----------------------------------------------------------------------===// @@ -3056,6 +3075,12 @@ results.add(context); } +OpFoldResult MaskedLoadOp::fold(ArrayRef) { + if (succeeded(foldMemRefCast(*this))) + return getResult(); + return OpFoldResult(); +} + //===----------------------------------------------------------------------===// // MaskedStoreOp //===----------------------------------------------------------------------===// @@ -3101,6 +3126,11 @@ results.add(context); } +LogicalResult MaskedStoreOp::fold(ArrayRef operands, + SmallVectorImpl &results) { + return foldMemRefCast(*this); +} + //===----------------------------------------------------------------------===// // GatherOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -2464,26 +2464,34 @@ /// Progressive lowering of transfer_read. This pattern supports lowering of /// `vector.transfer_read` to a combination of `vector.load` and /// `vector.broadcast` if all of the following hold: -/// - The op reads from a memref with the default layout. +/// - Stride of most minor memref dimension must be 1. /// - Out-of-bounds masking is not required. /// - If the memref's element type is a vector type then it coincides with the /// result type. /// - The permutation map doesn't perform permutation (broadcasting is allowed). -/// - The op has no mask. struct TransferReadToVectorLoadLowering : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + TransferReadToVectorLoadLowering(MLIRContext *context, + llvm::Optional maxRank) + : OpRewritePattern(context), + maxTransferRank(maxRank) {} LogicalResult matchAndRewrite(vector::TransferReadOp read, PatternRewriter &rewriter) const override { + if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank) + return failure(); SmallVector broadcastedDims; - // TODO: Support permutations. + // Permutations are handled by VectorToSCF or + // populateVectorTransferPermutationMapLoweringPatterns. if (!read.permutation_map().isMinorIdentityWithBroadcasting( &broadcastedDims)) return failure(); auto memRefType = read.getShapedType().dyn_cast(); if (!memRefType) return failure(); + // Non-unit strides are handled by VectorToSCF. + if (!vector::isLastMemrefDimUnitStride(memRefType)) + return failure(); // If there is broadcasting involved then we first load the unbroadcasted // vector, and then broadcast it with `vector.broadcast`. @@ -2497,32 +2505,44 @@ // `vector.load` supports vector types as memref's elements only when the // resulting vector type is the same as the element type. - if (memRefType.getElementType().isa() && - memRefType.getElementType() != unbroadcastedVectorType) + auto memrefElTy = memRefType.getElementType(); + if (memrefElTy.isa() && memrefElTy != unbroadcastedVectorType) return failure(); - // Only the default layout is supported by `vector.load`. - // TODO: Support non-default layouts. - if (!memRefType.getAffineMaps().empty()) + // Otherwise, element types of the memref and the vector must match. + if (!memrefElTy.isa() && + memrefElTy != read.getVectorType().getElementType()) return failure(); - // TODO: When out-of-bounds masking is required, we can create a - // MaskedLoadOp. + + // Out-of-bounds dims are handled by MaterializeTransferMask. if (read.hasOutOfBoundsDim()) return failure(); - if (read.mask()) - return failure(); - auto loadOp = rewriter.create( - read.getLoc(), unbroadcastedVectorType, read.source(), read.indices()); + // Create vector load op. + Operation *loadOp; + if (read.mask()) { + Value fill = rewriter.create( + read.getLoc(), unbroadcastedVectorType, read.padding()); + loadOp = rewriter.create( + read.getLoc(), unbroadcastedVectorType, read.source(), read.indices(), + read.mask(), fill); + } else { + loadOp = rewriter.create(read.getLoc(), + unbroadcastedVectorType, + read.source(), read.indices()); + } + // Insert a broadcasting op if required. if (!broadcastedDims.empty()) { rewriter.replaceOpWithNewOp( - read, read.getVectorType(), loadOp.result()); + read, read.getVectorType(), loadOp->getResult(0)); } else { - rewriter.replaceOp(read, loadOp.result()); + rewriter.replaceOp(read, loadOp->getResult(0)); } return success(); } + + llvm::Optional maxTransferRank; }; /// Replace a scalar vector.load with a memref.load. @@ -2545,44 +2565,56 @@ /// Progressive lowering of transfer_write. This pattern supports lowering of /// `vector.transfer_write` to `vector.store` if all of the following hold: -/// - The op writes to a memref with the default layout. +/// - Stride of most minor memref dimension must be 1. /// - Out-of-bounds masking is not required. /// - If the memref's element type is a vector type then it coincides with the /// type of the written value. /// - The permutation map is the minor identity map (neither permutation nor /// broadcasting is allowed). -/// - The op has no mask. struct TransferWriteToVectorStoreLowering : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + TransferWriteToVectorStoreLowering(MLIRContext *context, + llvm::Optional maxRank) + : OpRewritePattern(context), + maxTransferRank(maxRank) {} LogicalResult matchAndRewrite(vector::TransferWriteOp write, PatternRewriter &rewriter) const override { - // TODO: Support non-minor-identity maps + if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank) + return failure(); + // Permutations are handled by VectorToSCF or + // populateVectorTransferPermutationMapLoweringPatterns. if (!write.permutation_map().isMinorIdentity()) return failure(); auto memRefType = write.getShapedType().dyn_cast(); if (!memRefType) return failure(); + // Non-unit strides are handled by VectorToSCF. + if (!vector::isLastMemrefDimUnitStride(memRefType)) + return failure(); // `vector.store` supports vector types as memref's elements only when the // type of the vector value being written is the same as the element type. - if (memRefType.getElementType().isa() && - memRefType.getElementType() != write.getVectorType()) + auto memrefElTy = memRefType.getElementType(); + if (memrefElTy.isa() && memrefElTy != write.getVectorType()) return failure(); - // Only the default layout is supported by `vector.store`. - // TODO: Support non-default layouts. - if (!memRefType.getAffineMaps().empty()) + // Otherwise, element types of the memref and the vector must match. + if (!memrefElTy.isa() && + memrefElTy != write.getVectorType().getElementType()) return failure(); - // TODO: When out-of-bounds masking is required, we can create a - // MaskedStoreOp. + // Out-of-bounds dims are handled by MaterializeTransferMask. if (write.hasOutOfBoundsDim()) return failure(); - if (write.mask()) - return failure(); - rewriter.replaceOpWithNewOp( - write, write.vector(), write.source(), write.indices()); + if (write.mask()) { + rewriter.replaceOpWithNewOp( + write, write.source(), write.indices(), write.mask(), write.vector()); + } else { + rewriter.replaceOpWithNewOp( + write, write.vector(), write.source(), write.indices()); + } return success(); } + + llvm::Optional maxTransferRank; }; /// Transpose a vector transfer op's `in_bounds` attribute according to given @@ -2624,6 +2656,8 @@ PatternRewriter &rewriter) const override { SmallVector permutation; AffineMap map = op.permutation_map(); + if (map.getNumResults() == 0) + return failure(); if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) return failure(); AffineMap permutationMap = @@ -3680,11 +3714,11 @@ } void mlir::vector::populateVectorTransferLoweringPatterns( - RewritePatternSet &patterns) { - patterns - .add(patterns.getContext()); - populateVectorTransferPermutationMapLoweringPatterns(patterns); + RewritePatternSet &patterns, llvm::Optional maxTransferRank) { + patterns.add(patterns.getContext(), + maxTransferRank); + patterns.add(patterns.getContext()); } void mlir::vector::populateVectorMultiReductionLoweringPatterns( diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1212,18 +1212,19 @@ // CHECK: %[[dimVec:.*]] = splat %[[dtrunc]] : vector<17xi32> // CHECK: %[[mask:.*]] = cmpi slt, %[[offsetVec2]], %[[dimVec]] : vector<17xi32> // -// 4. Bitcast to vector form. +// 4. Create pass-through vector. +// CHECK: %[[PASS_THROUGH:.*]] = splat %[[c7]] : vector<17xf32> +// +// 5. Bitcast to vector form. // CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} : // CHECK-SAME: (!llvm.ptr, i64) -> !llvm.ptr // CHECK: %[[vecPtr:.*]] = llvm.bitcast %[[gep]] : // CHECK-SAME: !llvm.ptr to !llvm.ptr> // -// 5. Rewrite as a masked read. -// CHECK: %[[PASS_THROUGH:.*]] = splat %[[c7]] : vector<17xf32> +// 6. Rewrite as a masked read. // CHECK: %[[loaded:.*]] = llvm.intr.masked.load %[[vecPtr]], %[[mask]], // CHECK-SAME: %[[PASS_THROUGH]] {alignment = 4 : i32} : // CHECK-SAME: (!llvm.ptr>, vector<17xi1>, vector<17xf32>) -> vector<17xf32> - // // 1. Create a vector with linear indices [ 0 .. vector_length - 1 ]. // CHECK: %[[linearIndex_b:.*]] = constant dense @@ -1264,8 +1265,9 @@ } // CHECK-LABEL: func @transfer_read_index_1d // CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<17xindex> -// CHECK: %[[C7:.*]] = constant 7 -// CHECK: %{{.*}} = unrealized_conversion_cast %[[C7]] : index to i64 +// CHECK: %[[C7:.*]] = constant 7 : index +// CHECK: %[[SPLAT:.*]] = splat %[[C7]] : vector<17xindex> +// CHECK: %{{.*}} = unrealized_conversion_cast %[[SPLAT]] : vector<17xindex> to vector<17xi64> // CHECK: %[[loaded:.*]] = llvm.intr.masked.load %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : // CHECK-SAME: (!llvm.ptr>, vector<17xi1>, vector<17xi64>) -> vector<17xi64> @@ -1384,26 +1386,6 @@ // ----- -func @transfer_read_1d_cast(%A : memref, %base: index) -> vector<12xi8> { - %c0 = constant 0: i32 - %v = vector.transfer_read %A[%base], %c0 {in_bounds = [true]} : - memref, vector<12xi8> - return %v: vector<12xi8> -} -// CHECK-LABEL: func @transfer_read_1d_cast -// CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<12xi8> -// -// 1. Bitcast to vector form. -// CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} : -// CHECK-SAME: (!llvm.ptr, i64) -> !llvm.ptr -// CHECK: %[[vecPtr:.*]] = llvm.bitcast %[[gep]] : -// CHECK-SAME: !llvm.ptr to !llvm.ptr> -// -// 2. Rewrite as a load. -// CHECK: %[[loaded:.*]] = llvm.load %[[vecPtr]] {alignment = 4 : i64} : !llvm.ptr> - -// ----- - func @genbool_1d() -> vector<8xi1> { %0 = vector.constant_mask [4] : vector<8xi1> return %0 : vector<8xi1> diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1094,11 +1094,12 @@ // ----- -func @store_unsupported_layout(%memref : memref<200x100xf32, affine_map<(d0, d1) -> (d1, d0)>>, +func @store_unsupported_layout(%memref : memref<200x100xf32, affine_map<(d0, d1) -> (200*d0 + 2*d1)>>, %i : index, %j : index, %value : vector<8xf32>) { - // expected-error@+1 {{'vector.store' op base memref should have a default identity layout}} - vector.store %value, %memref[%i, %j] : memref<200x100xf32, affine_map<(d0, d1) -> (d1, d0)>>, + // expected-error@+1 {{'vector.store' op most minor memref dim must have unit stride}} + vector.store %value, %memref[%i, %j] : memref<200x100xf32, affine_map<(d0, d1) -> (200*d0 + 2*d1)>>, vector<8xf32> + return } // ----- diff --git a/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir @@ -114,14 +114,11 @@ // ----- -// TODO: transfer_read/write cannot be lowered to vector.load/store because the -// memref has a non-default layout. // CHECK-LABEL: func @transfer_nondefault_layout( // CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32, #{{.*}}>, // CHECK-SAME: %[[IDX:.*]]: index) -> vector<4xf32> { -// CHECK-NEXT: %[[CF0:.*]] = constant 0.000000e+00 : f32 -// CHECK-NEXT: %[[RES:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[CF0]] {in_bounds = [true]} : memref<8x8xf32, #{{.*}}>, vector<4xf32> -// CHECK-NEXT: vector.transfer_write %[[RES]], %[[MEM]][%[[IDX]], %[[IDX]]] {in_bounds = [true]} : vector<4xf32>, memref<8x8xf32, #{{.*}}> +// CHECK-NEXT: %[[RES:.*]] = vector.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32, #{{.*}}>, vector<4xf32> +// CHECK-NEXT: vector.store %[[RES]], %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32, #{{.*}}>, vector<4xf32> // CHECK-NEXT: return %[[RES]] : vector<4xf32> // CHECK-NEXT: } diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -436,6 +436,7 @@ void runOnFunction() override { RewritePatternSet patterns(&getContext()); populateVectorTransferLoweringPatterns(patterns); + populateVectorTransferPermutationMapLoweringPatterns(patterns); (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); } };