diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h --- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h +++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h @@ -68,7 +68,7 @@ /// Collect a set of patterns to convert from the Vector dialect to LLVM. void populateVectorToLLVMConversionPatterns( LLVMTypeConverter &converter, RewritePatternSet &patterns, - bool reassociateFPReductions = false, bool enableIndexOptimizations = true); + bool reassociateFPReductions = false); /// Create a pass to convert vector operations to the LLVMIR dialect. std::unique_ptr> createConvertVectorToLLVMPass( diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h @@ -88,6 +88,10 @@ /// `vector.store` and `vector.broadcast`. void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns); +/// These patterns materialize masks for various vector ops such as transfers. +void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns, + bool enableIndexOptimizations); + /// An attribute that specifies the combining function for `vector.contract`, /// and `vector.reduction`. class CombiningKindAttr diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -1135,10 +1135,12 @@ Vector_Op<"transfer_read", [ DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods + DeclareOpInterfaceMethods, + AttrSizedOperandSegments ]>, Arguments<(ins AnyShaped:$source, Variadic:$indices, AffineMapAttr:$permutation_map, AnyType:$padding, + Optional>:$mask, OptionalAttr:$in_bounds)>, Results<(outs AnyVector:$vector)> { @@ -1167,13 +1169,19 @@ return type. An SSA value `padding` of the same elemental type as the MemRef/Tensor is - provided to specify a fallback value in the case of out-of-bounds accesses. + provided to specify a fallback value in the case of out-of-bounds accesses + and/or masking. + + An optional SSA value `mask` of the same shape as the vector type may be + specified to mask out elements. Such elements will be replaces with + `padding`. Elements whose corresponding mask element is `0` are masked out. An optional boolean array attribute is provided to specify which dimensions of the transfer are guaranteed to be within bounds. The absence of this `in_bounds` attribute signifies that any dimension of the transfer may be out-of-bounds. A `vector.transfer_read` can be lowered to a simple load if - all dimensions are specified to be within bounds. + all dimensions are specified to be within bounds and no `mask` was + specified. This operation is called 'read' by opposition to 'load' because the super-vector granularity is generally not representable with a single @@ -1299,6 +1307,14 @@ // 'getMinorIdentityMap' (resp. zero). OpBuilder<(ins "VectorType":$vector, "Value":$source, "ValueRange":$indices, CArg<"ArrayRef", "{}">:$inBounds)>, + // Builder that does not set mask. + OpBuilder<(ins "Type":$vector, "Value":$source, + "ValueRange":$indices, "AffineMapAttr":$permutationMap, "Value":$padding, + "ArrayAttr":$inBounds)>, + // Builder that does not set mask. + OpBuilder<(ins "Type":$vector, "Value":$source, + "ValueRange":$indices, "AffineMap":$permutationMap, "Value":$padding, + "ArrayAttr":$inBounds)> ]; let hasFolder = 1; @@ -1308,11 +1324,13 @@ Vector_Op<"transfer_write", [ DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods + DeclareOpInterfaceMethods, + AttrSizedOperandSegments ]>, Arguments<(ins AnyVector:$vector, AnyShaped:$source, Variadic:$indices, AffineMapAttr:$permutation_map, + Optional>:$mask, OptionalAttr:$in_bounds)>, Results<(outs Optional:$result)> { @@ -1341,11 +1359,16 @@ The size of the slice is specified by the size of the vector. + An optional SSA value `mask` of the same shape as the vector type may be + specified to mask out elements. Elements whose corresponding mask element + is `0` are masked out. + An optional boolean array attribute is provided to specify which dimensions of the transfer are guaranteed to be within bounds. The absence of this `in_bounds` attribute signifies that any dimension of the transfer may be out-of-bounds. A `vector.transfer_write` can be lowered to a simple store - if all dimensions are specified to be within bounds. + if all dimensions are specified to be within bounds and no `mask` was + specified. This operation is called 'write' by opposition to 'store' because the super-vector granularity is generally not representable with a single @@ -1391,6 +1414,8 @@ "AffineMap":$permutationMap)>, OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices, "AffineMapAttr":$permutationMap, "ArrayAttr":$inBounds)>, + OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices, + "AffineMap":$permutationMap, "Value":$mask, "ArrayAttr":$inBounds)>, OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices, "AffineMap":$permutationMap, "ArrayAttr":$inBounds)>, ]; diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -104,66 +104,6 @@ return res; } -static Value createCastToIndexLike(ConversionPatternRewriter &rewriter, - Location loc, Type targetType, Value value) { - if (targetType == value.getType()) - return value; - - bool targetIsIndex = targetType.isIndex(); - bool valueIsIndex = value.getType().isIndex(); - if (targetIsIndex ^ valueIsIndex) - return rewriter.create(loc, targetType, value); - - auto targetIntegerType = targetType.dyn_cast(); - auto valueIntegerType = value.getType().dyn_cast(); - assert(targetIntegerType && valueIntegerType && - "unexpected cast between types other than integers and index"); - assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness()); - - if (targetIntegerType.getWidth() > valueIntegerType.getWidth()) - return rewriter.create(loc, targetIntegerType, value); - return rewriter.create(loc, targetIntegerType, value); -} - -// Helper that returns a vector comparison that constructs a mask: -// mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b] -// -// NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative, -// much more compact, IR for this operation, but LLVM eventually -// generates more elaborate instructions for this intrinsic since it -// is very conservative on the boundary conditions. -static Value buildVectorComparison(ConversionPatternRewriter &rewriter, - Operation *op, bool enableIndexOptimizations, - int64_t dim, Value b, Value *off = nullptr) { - auto loc = op->getLoc(); - // If we can assume all indices fit in 32-bit, we perform the vector - // comparison in 32-bit to get a higher degree of SIMD parallelism. - // Otherwise we perform the vector comparison using 64-bit indices. - Value indices; - Type idxType; - if (enableIndexOptimizations) { - indices = rewriter.create( - loc, rewriter.getI32VectorAttr( - llvm::to_vector<4>(llvm::seq(0, dim)))); - idxType = rewriter.getI32Type(); - } else { - indices = rewriter.create( - loc, rewriter.getI64VectorAttr( - llvm::to_vector<4>(llvm::seq(0, dim)))); - idxType = rewriter.getI64Type(); - } - // Add in an offset if requested. - if (off) { - Value o = createCastToIndexLike(rewriter, loc, idxType, *off); - Value ov = rewriter.create(loc, indices.getType(), o); - indices = rewriter.create(loc, ov, indices); - } - // Construct the vector comparison. - Value bound = createCastToIndexLike(rewriter, loc, idxType, b); - Value bounds = rewriter.create(loc, indices.getType(), bound); - return rewriter.create(loc, CmpIPredicate::slt, indices, bounds); -} - // Helper that returns data layout alignment of a memref. LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, MemRefType memrefType, unsigned &align) { @@ -250,7 +190,7 @@ if (failed(getMemRefAlignment( typeConverter, xferOp.getShapedType().cast(), align))) return failure(); - auto adaptor = TransferWriteOpAdaptor(operands); + auto adaptor = TransferWriteOpAdaptor(operands, xferOp->getAttrDictionary()); rewriter.replaceOpWithNewOp(xferOp, adaptor.vector(), dataPtr, align); return success(); @@ -266,7 +206,7 @@ typeConverter, xferOp.getShapedType().cast(), align))) return failure(); - auto adaptor = TransferWriteOpAdaptor(operands); + auto adaptor = TransferWriteOpAdaptor(operands, xferOp->getAttrDictionary()); rewriter.replaceOpWithNewOp( xferOp, adaptor.vector(), dataPtr, mask, rewriter.getI32IntegerAttr(align)); @@ -275,12 +215,12 @@ static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp, ArrayRef operands) { - return TransferReadOpAdaptor(operands); + return TransferReadOpAdaptor(operands, xferOp->getAttrDictionary()); } static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp, ArrayRef operands) { - return TransferWriteOpAdaptor(operands); + return TransferWriteOpAdaptor(operands, xferOp->getAttrDictionary()); } namespace { @@ -618,33 +558,6 @@ const bool reassociateFPReductions; }; -/// Conversion pattern for a vector.create_mask (1-D only). -class VectorCreateMaskOpConversion - : public ConvertOpToLLVMPattern { -public: - explicit VectorCreateMaskOpConversion(LLVMTypeConverter &typeConv, - bool enableIndexOpt) - : ConvertOpToLLVMPattern(typeConv), - enableIndexOptimizations(enableIndexOpt) {} - - LogicalResult - matchAndRewrite(vector::CreateMaskOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - auto dstType = op.getType(); - int64_t rank = dstType.getRank(); - if (rank == 1) { - rewriter.replaceOp( - op, buildVectorComparison(rewriter, op, enableIndexOptimizations, - dstType.getDimSize(0), operands[0])); - return success(); - } - return failure(); - } - -private: - const bool enableIndexOptimizations; -}; - class VectorShuffleOpConversion : public ConvertOpToLLVMPattern { public: @@ -1177,20 +1090,12 @@ } }; -/// Conversion pattern that converts a 1-D vector transfer read/write op in a -/// sequence of: -/// 1. Get the source/dst address as an LLVM vector pointer. -/// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. -/// 3. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. -/// 4. Create a mask where offsetVector is compared against memref upper bound. -/// 5. Rewrite op as a masked read or write. +/// Conversion pattern that converts a 1-D vector transfer read/write op into a +/// a masked or unmasked read/write. template class VectorTransferConversion : public ConvertOpToLLVMPattern { public: - explicit VectorTransferConversion(LLVMTypeConverter &typeConv, - bool enableIndexOpt) - : ConvertOpToLLVMPattern(typeConv), - enableIndexOptimizations(enableIndexOpt) {} + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(ConcreteOp xferOp, ArrayRef operands, @@ -1212,6 +1117,9 @@ auto strides = computeContiguousStrides(memRefType); if (!strides) return failure(); + // Out-of-bounds dims are handled by MaterializeTransferMask. + if (xferOp.hasOutOfBoundsDim()) + return failure(); auto toLLVMTy = [&](Type t) { return this->getTypeConverter()->convertType(t); @@ -1241,40 +1149,24 @@ #endif // ifndef NDEBUG } - // 1. Get the source/dst address as an LLVM vector pointer. + // Get the source/dst address as an LLVM vector pointer. VectorType vtp = xferOp.getVectorType(); Value dataPtr = this->getStridedElementPtr( loc, memRefType, adaptor.source(), adaptor.indices(), rewriter); Value vectorDataPtr = castDataPtr(rewriter, loc, dataPtr, memRefType, toLLVMTy(vtp)); - if (xferOp.isDimInBounds(0)) + // Rewrite as an unmasked masked read / write. + if (!xferOp.mask()) return replaceTransferOpWithLoadOrStore(rewriter, *this->getTypeConverter(), loc, xferOp, operands, vectorDataPtr); - // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. - // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. - // 4. Let dim the memref dimension, compute the vector comparison mask - // (in-bounds mask): - // [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ] - // - // TODO: when the leaf transfer rank is k > 1, we need the last `k` - // dimensions here. - unsigned vecWidth = LLVM::getVectorNumElements(vtp).getFixedValue(); - unsigned lastIndex = llvm::size(xferOp.indices()) - 1; - Value off = xferOp.indices()[lastIndex]; - Value dim = rewriter.create(loc, xferOp.source(), lastIndex); - Value mask = buildVectorComparison( - rewriter, xferOp, enableIndexOptimizations, vecWidth, dim, &off); - - // 5. Rewrite as a masked read / write. + // Rewrite as a masked read / write. return replaceTransferOpWithMasked(rewriter, *this->getTypeConverter(), loc, - xferOp, operands, vectorDataPtr, mask); + xferOp, operands, vectorDataPtr, + xferOp.mask()); } - -private: - const bool enableIndexOptimizations; }; class VectorPrintOpConversion : public ConvertOpToLLVMPattern { @@ -1484,17 +1376,13 @@ /// Populate the given list with patterns that convert from Vector to LLVM. void mlir::populateVectorToLLVMConversionPatterns( LLVMTypeConverter &converter, RewritePatternSet &patterns, - bool reassociateFPReductions, bool enableIndexOptimizations) { + bool reassociateFPReductions) { MLIRContext *ctx = converter.getDialect()->getContext(); patterns.add(ctx); patterns.add(converter, reassociateFPReductions); - patterns.add, - VectorTransferConversion>( - converter, enableIndexOptimizations); patterns .add, VectorGatherOpConversion, VectorScatterOpConversion, - VectorExpandLoadOpConversion, VectorCompressStoreOpConversion>( - converter); + VectorExpandLoadOpConversion, VectorCompressStoreOpConversion, + VectorTransferConversion, + VectorTransferConversion>(converter); } void mlir::populateVectorToLLVMMatrixConversionPatterns( diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -71,9 +71,10 @@ // Convert to the LLVM IR dialect. LLVMTypeConverter converter(&getContext()); RewritePatternSet patterns(&getContext()); + populateVectorMaskMaterializationPatterns(patterns, enableIndexOptimizations); populateVectorToLLVMMatrixConversionPatterns(converter, patterns); - populateVectorToLLVMConversionPatterns( - converter, patterns, reassociateFPReductions, enableIndexOptimizations); + populateVectorToLLVMConversionPatterns(converter, patterns, + reassociateFPReductions); populateVectorToLLVMMatrixConversionPatterns(converter, patterns); // Architecture specific augmentations. diff --git a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp --- a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp +++ b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp @@ -42,7 +42,7 @@ LLVMTypeConverter &typeConverter, Location loc, TransferWriteOp xferOp, Type &vecTy, Value &dwordConfig, Value &vindex, Value &offsetSizeInBytes, Value &glc, Value &slc) { - auto adaptor = TransferWriteOpAdaptor(operands); + auto adaptor = TransferWriteOpAdaptor(operands, xferOp->getAttrDictionary()); rewriter.replaceOpWithNewOp(xferOp, adaptor.vector(), dwordConfig, vindex, offsetSizeInBytes, glc, slc); @@ -62,7 +62,7 @@ LogicalResult matchAndRewrite(ConcreteOp xferOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - typename ConcreteOp::Adaptor adaptor(operands); + typename ConcreteOp::Adaptor adaptor(operands, xferOp->getAttrDictionary()); if (xferOp.getVectorType().getRank() > 1 || llvm::size(xferOp.indices()) == 0) diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -538,6 +538,8 @@ using namespace mlir::edsc::op; TransferReadOp transfer = cast(op); + if (transfer.mask()) + return failure(); auto memRefType = transfer.getShapedType().dyn_cast(); if (!memRefType) return failure(); @@ -624,6 +626,8 @@ using namespace edsc::op; TransferWriteOp transfer = cast(op); + if (transfer.mask()) + return failure(); auto memRefType = transfer.getShapedType().template dyn_cast(); if (!memRefType) return failure(); diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -2295,8 +2295,27 @@ build(builder, result, vectorType, source, indices, permMap, inBounds); } +/// Builder that does not provide a mask. +void TransferReadOp::build(OpBuilder &builder, OperationState &result, + Type vectorType, Value source, ValueRange indices, + AffineMap permutationMap, Value padding, + ArrayAttr inBounds) { + build(builder, result, vectorType, source, indices, permutationMap, padding, + /*mask=*/Value(), inBounds); +} + +/// Builder that does not provide a mask. +void TransferReadOp::build(OpBuilder &builder, OperationState &result, + Type vectorType, Value source, ValueRange indices, + AffineMapAttr permutationMap, Value padding, + ArrayAttr inBounds) { + build(builder, result, vectorType, source, indices, permutationMap, padding, + /*mask=*/Value(), inBounds); +} + static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) { - SmallVector elidedAttrs; + SmallVector elidedAttrs; + elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr()); if (op.permutation_map().isMinorIdentity()) elidedAttrs.push_back(op.getPermutationMapAttrName()); bool elideInBounds = true; @@ -2316,27 +2335,36 @@ static void print(OpAsmPrinter &p, TransferReadOp op) { p << op.getOperationName() << " " << op.source() << "[" << op.indices() << "], " << op.padding(); + if (op.mask()) + p << ", " << op.mask(); printTransferAttrs(p, cast(op.getOperation())); p << " : " << op.getShapedType() << ", " << op.getVectorType(); } static ParseResult parseTransferReadOp(OpAsmParser &parser, OperationState &result) { + auto &builder = parser.getBuilder(); llvm::SMLoc typesLoc; OpAsmParser::OperandType sourceInfo; SmallVector indexInfo; OpAsmParser::OperandType paddingInfo; SmallVector types; + OpAsmParser::OperandType maskInfo; // Parsing with support for paddingValue. if (parser.parseOperand(sourceInfo) || parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || - parser.parseComma() || parser.parseOperand(paddingInfo) || - parser.parseOptionalAttrDict(result.attributes) || + parser.parseComma() || parser.parseOperand(paddingInfo)) + return failure(); + ParseResult hasMask = parser.parseOptionalComma(); + if (hasMask.succeeded()) { + parser.parseOperand(maskInfo); + } + if (parser.parseOptionalAttrDict(result.attributes) || parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types)) return failure(); if (types.size() != 2) return parser.emitError(typesLoc, "requires two types"); - auto indexType = parser.getBuilder().getIndexType(); + auto indexType = builder.getIndexType(); auto shapedType = types[0].dyn_cast(); if (!shapedType || !shapedType.isa()) return parser.emitError(typesLoc, "requires memref or ranked tensor type"); @@ -2349,12 +2377,21 @@ auto permMap = getTransferMinorIdentityMap(shapedType, vectorType); result.attributes.set(permutationAttrName, AffineMapAttr::get(permMap)); } - return failure( - parser.resolveOperand(sourceInfo, shapedType, result.operands) || + if (parser.resolveOperand(sourceInfo, shapedType, result.operands) || parser.resolveOperands(indexInfo, indexType, result.operands) || parser.resolveOperand(paddingInfo, shapedType.getElementType(), - result.operands) || - parser.addTypeToList(vectorType, result.types)); + result.operands)) + return failure(); + if (hasMask.succeeded()) { + auto maskType = VectorType::get(vectorType.getShape(), builder.getI1Type()); + if (parser.resolveOperand(maskInfo, maskType, result.operands)) + return failure(); + } + result.addAttribute( + TransferReadOp::getOperandSegmentSizeAttr(), + builder.getI32VectorAttr({1, static_cast(indexInfo.size()), 1, + static_cast(hasMask.succeeded())})); + return parser.addTypeToList(vectorType, result.types); } static LogicalResult verify(TransferReadOp op) { @@ -2525,7 +2562,7 @@ /*optional*/ ArrayAttr inBounds) { Type resultType = source.getType().dyn_cast(); build(builder, result, resultType, vector, source, indices, permutationMap, - inBounds); + /*mask=*/Value(), inBounds); } void TransferWriteOp::build(OpBuilder &builder, OperationState &result, @@ -2534,24 +2571,39 @@ /*optional*/ ArrayAttr inBounds) { Type resultType = source.getType().dyn_cast(); build(builder, result, resultType, vector, source, indices, permutationMap, - inBounds); + /*mask=*/Value(), inBounds); +} + +void TransferWriteOp::build(OpBuilder &builder, OperationState &result, + Value vector, Value source, ValueRange indices, + AffineMap permutationMap, /*optional*/ Value mask, + /*optional*/ ArrayAttr inBounds) { + Type resultType = source.getType().dyn_cast(); + build(builder, result, resultType, vector, source, indices, permutationMap, + mask, inBounds); } static ParseResult parseTransferWriteOp(OpAsmParser &parser, OperationState &result) { + auto &builder = parser.getBuilder(); llvm::SMLoc typesLoc; OpAsmParser::OperandType vectorInfo, sourceInfo; SmallVector indexInfo; SmallVector types; + OpAsmParser::OperandType maskInfo; if (parser.parseOperand(vectorInfo) || parser.parseComma() || parser.parseOperand(sourceInfo) || - parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || - parser.parseOptionalAttrDict(result.attributes) || + parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square)) + return failure(); + ParseResult hasMask = parser.parseOptionalComma(); + if (hasMask.succeeded() && parser.parseOperand(maskInfo)) + return failure(); + if (parser.parseOptionalAttrDict(result.attributes) || parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types)) return failure(); if (types.size() != 2) return parser.emitError(typesLoc, "requires two types"); - auto indexType = parser.getBuilder().getIndexType(); + auto indexType = builder.getIndexType(); VectorType vectorType = types[0].dyn_cast(); if (!vectorType) return parser.emitError(typesLoc, "requires vector type"); @@ -2564,17 +2616,28 @@ auto permMap = getTransferMinorIdentityMap(shapedType, vectorType); result.attributes.set(permutationAttrName, AffineMapAttr::get(permMap)); } - return failure( - parser.resolveOperand(vectorInfo, vectorType, result.operands) || + if (parser.resolveOperand(vectorInfo, vectorType, result.operands) || parser.resolveOperand(sourceInfo, shapedType, result.operands) || - parser.resolveOperands(indexInfo, indexType, result.operands) || - (shapedType.isa() && - parser.addTypeToList(shapedType, result.types))); + parser.resolveOperands(indexInfo, indexType, result.operands)) + return failure(); + if (hasMask.succeeded()) { + auto maskType = VectorType::get(vectorType.getShape(), builder.getI1Type()); + if (parser.resolveOperand(maskInfo, maskType, result.operands)) + return failure(); + } + result.addAttribute( + TransferWriteOp::getOperandSegmentSizeAttr(), + builder.getI32VectorAttr({1, 1, static_cast(indexInfo.size()), + static_cast(hasMask.succeeded())})); + return failure(shapedType.isa() && + parser.addTypeToList(shapedType, result.types)); } static void print(OpAsmPrinter &p, TransferWriteOp op) { p << op.getOperationName() << " " << op.vector() << ", " << op.source() << "[" << op.indices() << "]"; + if (op.mask()) + p << ", " << op.mask(); printTransferAttrs(p, cast(op.getOperation())); p << " : " << op.getVectorType() << ", " << op.getShapedType(); } diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -596,6 +596,8 @@ OpBuilder &builder) { if (!isIdentitySuffix(readOp.permutation_map())) return nullptr; + if (readOp.mask()) + return nullptr; auto sourceVectorType = readOp.getVectorType(); SmallVector strides(targetShape.size(), 1); @@ -641,6 +643,8 @@ auto writeOp = cast(op); if (!isIdentitySuffix(writeOp.permutation_map())) return failure(); + if (writeOp.mask()) + return failure(); VectorType sourceVectorType = writeOp.getVectorType(); SmallVector strides(targetShape.size(), 1); TupleType tupleType = generateExtractSlicesOpResultType( @@ -722,6 +726,9 @@ if (ignoreFilter && ignoreFilter(readOp)) return failure(); + if (readOp.mask()) + return failure(); + // TODO: Support splitting TransferReadOp with non-identity permutation // maps. Repurpose code from MaterializeVectors transformation. if (!isIdentitySuffix(readOp.permutation_map())) @@ -768,6 +775,9 @@ if (ignoreFilter && ignoreFilter(writeOp)) return failure(); + if (writeOp.mask()) + return failure(); + // TODO: Support splitting TransferWriteOp with non-identity permutation // maps. Repurpose code from MaterializeVectors transformation. if (!isIdentitySuffix(writeOp.permutation_map())) @@ -2546,6 +2556,9 @@ "Expected splitFullAndPartialTransferPrecondition to hold"); auto xferReadOp = dyn_cast(xferOp.getOperation()); + if (xferReadOp.mask()) + return failure(); + // TODO: add support for write case. if (!xferReadOp) return failure(); @@ -2677,6 +2690,8 @@ dyn_cast(*read.getResult().getUsers().begin()); if (!extract) return failure(); + if (read.mask()) + return failure(); edsc::ScopedContext scope(rewriter, read.getLoc()); using mlir::edsc::op::operator+; using mlir::edsc::op::operator*; @@ -2712,6 +2727,8 @@ auto insert = write.vector().getDefiningOp(); if (!insert) return failure(); + if (write.mask()) + return failure(); edsc::ScopedContext scope(rewriter, write.getLoc()); using mlir::edsc::op::operator+; using mlir::edsc::op::operator*; @@ -2742,6 +2759,7 @@ /// - If the memref's element type is a vector type then it coincides with the /// result type. /// - The permutation map doesn't perform permutation (broadcasting is allowed). +/// - The op has no mask. struct TransferReadToVectorLoadLowering : public OpRewritePattern { TransferReadToVectorLoadLowering(MLIRContext *context) @@ -2780,7 +2798,8 @@ // MaskedLoadOp. if (read.hasOutOfBoundsDim()) return failure(); - + if (read.mask()) + return failure(); Operation *loadOp; if (!broadcastedDims.empty() && unbroadcastedVectorType.getNumElements() == 1) { @@ -2815,6 +2834,7 @@ /// type of the written value. /// - The permutation map is the minor identity map (neither permutation nor /// broadcasting is allowed). +/// - The op has no mask. struct TransferWriteToVectorStoreLowering : public OpRewritePattern { TransferWriteToVectorStoreLowering(MLIRContext *context) @@ -2840,6 +2860,8 @@ // MaskedStoreOp. if (write.hasOutOfBoundsDim()) return failure(); + if (write.mask()) + return failure(); rewriter.replaceOpWithNewOp( write, write.vector(), write.source(), write.indices()); return success(); @@ -2880,6 +2902,8 @@ map.getPermutationMap(permutation, op.getContext()); if (permutationMap.isIdentity()) return failure(); + if (op.mask()) + return failure(); // Caluclate the map of the new read by applying the inverse permutation. permutationMap = inversePermutation(permutationMap); AffineMap newMap = permutationMap.compose(map); @@ -2914,6 +2938,8 @@ LogicalResult matchAndRewrite(vector::TransferReadOp op, PatternRewriter &rewriter) const override { + if (op.mask()) + return failure(); AffineMap map = op.permutation_map(); unsigned numLeadingBroadcast = 0; for (auto expr : map.getResults()) { @@ -3062,6 +3088,9 @@ LogicalResult matchAndRewrite(vector::TransferReadOp read, PatternRewriter &rewriter) const override { + if (read.mask()) + return failure(); + auto shapedType = read.source().getType().cast(); if (shapedType.getElementType() != read.getVectorType().getElementType()) return failure(); @@ -3102,6 +3131,9 @@ LogicalResult matchAndRewrite(vector::TransferWriteOp write, PatternRewriter &rewriter) const override { + if (write.mask()) + return failure(); + auto shapedType = write.source().getType().dyn_cast(); if (shapedType.getElementType() != write.getVectorType().getElementType()) return failure(); @@ -3371,6 +3403,151 @@ } }; +static Value createCastToIndexLike(PatternRewriter &rewriter, Location loc, + Type targetType, Value value) { + if (targetType == value.getType()) + return value; + + bool targetIsIndex = targetType.isIndex(); + bool valueIsIndex = value.getType().isIndex(); + if (targetIsIndex ^ valueIsIndex) + return rewriter.create(loc, targetType, value); + + auto targetIntegerType = targetType.dyn_cast(); + auto valueIntegerType = value.getType().dyn_cast(); + assert(targetIntegerType && valueIntegerType && + "unexpected cast between types other than integers and index"); + assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness()); + + if (targetIntegerType.getWidth() > valueIntegerType.getWidth()) + return rewriter.create(loc, targetIntegerType, value); + return rewriter.create(loc, targetIntegerType, value); +} + +// Helper that returns a vector comparison that constructs a mask: +// mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b] +// +// NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative, +// much more compact, IR for this operation, but LLVM eventually +// generates more elaborate instructions for this intrinsic since it +// is very conservative on the boundary conditions. +static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op, + bool enableIndexOptimizations, int64_t dim, + Value b, Value *off = nullptr) { + auto loc = op->getLoc(); + // If we can assume all indices fit in 32-bit, we perform the vector + // comparison in 32-bit to get a higher degree of SIMD parallelism. + // Otherwise we perform the vector comparison using 64-bit indices. + Value indices; + Type idxType; + if (enableIndexOptimizations) { + indices = rewriter.create( + loc, rewriter.getI32VectorAttr( + llvm::to_vector<4>(llvm::seq(0, dim)))); + idxType = rewriter.getI32Type(); + } else { + indices = rewriter.create( + loc, rewriter.getI64VectorAttr( + llvm::to_vector<4>(llvm::seq(0, dim)))); + idxType = rewriter.getI64Type(); + } + // Add in an offset if requested. + if (off) { + Value o = createCastToIndexLike(rewriter, loc, idxType, *off); + Value ov = rewriter.create(loc, indices.getType(), o); + indices = rewriter.create(loc, ov, indices); + } + // Construct the vector comparison. + Value bound = createCastToIndexLike(rewriter, loc, idxType, b); + Value bounds = rewriter.create(loc, indices.getType(), bound); + return rewriter.create(loc, CmpIPredicate::slt, indices, bounds); +} + +template +struct MaterializeTransferMask : public OpRewritePattern { +public: + explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt) + : mlir::OpRewritePattern(context), + enableIndexOptimizations(enableIndexOpt) {} + + LogicalResult matchAndRewrite(ConcreteOp xferOp, + PatternRewriter &rewriter) const override { + if (!xferOp.hasOutOfBoundsDim()) + return failure(); + + if (xferOp.getVectorType().getRank() > 1 || + llvm::size(xferOp.indices()) == 0) + return failure(); + + Location loc = xferOp->getLoc(); + VectorType vtp = xferOp.getVectorType(); + + // * Create a vector with linear indices [ 0 .. vector_length - 1 ]. + // * Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. + // * Let dim the memref dimension, compute the vector comparison mask + // (in-bounds mask): + // [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ] + // + // TODO: when the leaf transfer rank is k > 1, we need the last `k` + // dimensions here. + unsigned vecWidth = vtp.getNumElements(); + unsigned lastIndex = llvm::size(xferOp.indices()) - 1; + Value off = xferOp.indices()[lastIndex]; + Value dim = rewriter.create(loc, xferOp.source(), lastIndex); + Value mask = buildVectorComparison( + rewriter, xferOp, enableIndexOptimizations, vecWidth, dim, &off); + + if (xferOp.mask()) { + // Intersect the in-bounds with the mask specified as an op parameter. + mask = rewriter.create(loc, mask, xferOp.mask()); + } + + rewriter.updateRootInPlace(xferOp, [&]() { + xferOp.maskMutable().assign(mask); + xferOp.in_boundsAttr(rewriter.getBoolArrayAttr({true})); + }); + + return success(); + } + +private: + const bool enableIndexOptimizations; +}; + +/// Conversion pattern for a vector.create_mask (1-D only). +class VectorCreateMaskOpConversion + : public OpRewritePattern { +public: + explicit VectorCreateMaskOpConversion(MLIRContext *context, + bool enableIndexOpt) + : mlir::OpRewritePattern(context), + enableIndexOptimizations(enableIndexOpt) {} + + LogicalResult matchAndRewrite(vector::CreateMaskOp op, + PatternRewriter &rewriter) const override { + auto dstType = op.getType(); + int64_t rank = dstType.getRank(); + if (rank == 1) { + rewriter.replaceOp( + op, buildVectorComparison(rewriter, op, enableIndexOptimizations, + dstType.getDimSize(0), op.getOperand(0))); + return success(); + } + return failure(); + } + +private: + const bool enableIndexOptimizations; +}; + +void mlir::vector::populateVectorMaskMaterializationPatterns( + RewritePatternSet &patterns, bool enableIndexOptimizations) { + patterns.add, + MaterializeTransferMask>( + patterns.getContext(), enableIndexOptimizations); +} + // TODO: Add pattern to rewrite ExtractSlices(ConstantMaskOp). // TODO: Add this as DRR pattern. void mlir::vector::populateVectorToVectorTransformationPatterns( diff --git a/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir @@ -3,20 +3,19 @@ // CMP32-LABEL: @genbool_var_1d( // CMP32-SAME: %[[ARG:.*]]: index) -// CMP32: %[[A:.*]] = llvm.mlir.cast %[[ARG]] : index to i64 // CMP32: %[[T0:.*]] = constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]> : vector<11xi32> -// CMP32: %[[T1:.*]] = trunci %[[A]] : i64 to i32 +// CMP32: %[[T1:.*]] = index_cast %[[ARG]] : index to i32 // CMP32: %[[T2:.*]] = splat %[[T1]] : vector<11xi32> // CMP32: %[[T3:.*]] = cmpi slt, %[[T0]], %[[T2]] : vector<11xi32> // CMP32: return %[[T3]] : vector<11xi1> // CMP64-LABEL: @genbool_var_1d( // CMP64-SAME: %[[ARG:.*]]: index) -// CMP64: %[[A:.*]] = llvm.mlir.cast %[[ARG]] : index to i64 // CMP64: %[[T0:.*]] = constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]> : vector<11xi64> -// CMP64: %[[T1:.*]] = splat %[[A]] : vector<11xi64> -// CMP64: %[[T2:.*]] = cmpi slt, %[[T0]], %[[T1]] : vector<11xi64> -// CMP64: return %[[T2]] : vector<11xi1> +// CMP64: %[[T1:.*]] = index_cast %[[ARG]] : index to i64 +// CMP64: %[[T2:.*]] = splat %[[T1]] : vector<11xi64> +// CMP64: %[[T3:.*]] = cmpi slt, %[[T0]], %[[T2]] : vector<11xi64> +// CMP64: return %[[T3]] : vector<11xi1> func @genbool_var_1d(%arg0: index) -> vector<11xi1> { %0 = vector.create_mask %arg0 : vector<11xi1> diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1049,31 +1049,31 @@ // CHECK-LABEL: func @transfer_read_1d // CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<17xf32> // CHECK: %[[c7:.*]] = constant 7.0 -// -// 1. Bitcast to vector form. -// CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} : -// CHECK-SAME: (!llvm.ptr, i64) -> !llvm.ptr -// CHECK: %[[vecPtr:.*]] = llvm.bitcast %[[gep]] : -// CHECK-SAME: !llvm.ptr to !llvm.ptr> // CHECK: %[[C0:.*]] = constant 0 : index // CHECK: %[[DIM:.*]] = memref.dim %{{.*}}, %[[C0]] : memref // -// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. +// 1. Create a vector with linear indices [ 0 .. vector_length - 1 ]. // CHECK: %[[linearIndex:.*]] = constant dense // CHECK-SAME: <[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> : // CHECK-SAME: vector<17xi32> // -// 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. +// 2. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. // CHECK: %[[otrunc:.*]] = index_cast %[[BASE]] : index to i32 // CHECK: %[[offsetVec:.*]] = splat %[[otrunc]] : vector<17xi32> // CHECK: %[[offsetVec2:.*]] = addi %[[offsetVec]], %[[linearIndex]] : vector<17xi32> // -// 4. Let dim the memref dimension, compute the vector comparison mask: +// 3. Let dim the memref dimension, compute the vector comparison mask: // [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ] // CHECK: %[[dtrunc:.*]] = index_cast %[[DIM]] : index to i32 // CHECK: %[[dimVec:.*]] = splat %[[dtrunc]] : vector<17xi32> // CHECK: %[[mask:.*]] = cmpi slt, %[[offsetVec2]], %[[dimVec]] : vector<17xi32> // +// 4. Bitcast to vector form. +// CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} : +// CHECK-SAME: (!llvm.ptr, i64) -> !llvm.ptr +// CHECK: %[[vecPtr:.*]] = llvm.bitcast %[[gep]] : +// CHECK-SAME: !llvm.ptr to !llvm.ptr> +// // 5. Rewrite as a masked read. // CHECK: %[[PASS_THROUGH:.*]] = splat %[[c7]] : vector<17xf32> // CHECK: %[[loaded:.*]] = llvm.intr.masked.load %[[vecPtr]], %[[mask]], @@ -1081,26 +1081,26 @@ // CHECK-SAME: (!llvm.ptr>, vector<17xi1>, vector<17xf32>) -> vector<17xf32> // -// 1. Bitcast to vector form. -// CHECK: %[[gep_b:.*]] = llvm.getelementptr {{.*}} : -// CHECK-SAME: (!llvm.ptr, i64) -> !llvm.ptr -// CHECK: %[[vecPtr_b:.*]] = llvm.bitcast %[[gep_b]] : -// CHECK-SAME: !llvm.ptr to !llvm.ptr> -// -// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. +// 1. Create a vector with linear indices [ 0 .. vector_length - 1 ]. // CHECK: %[[linearIndex_b:.*]] = constant dense // CHECK-SAME: <[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> : // CHECK-SAME: vector<17xi32> // -// 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. +// 2. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. // CHECK: splat %{{.*}} : vector<17xi32> // CHECK: addi // -// 4. Let dim the memref dimension, compute the vector comparison mask: +// 3. Let dim the memref dimension, compute the vector comparison mask: // [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ] // CHECK: splat %{{.*}} : vector<17xi32> // CHECK: %[[mask_b:.*]] = cmpi slt, {{.*}} : vector<17xi32> // +// 4. Bitcast to vector form. +// CHECK: %[[gep_b:.*]] = llvm.getelementptr {{.*}} : +// CHECK-SAME: (!llvm.ptr, i64) -> !llvm.ptr +// CHECK: %[[vecPtr_b:.*]] = llvm.bitcast %[[gep_b]] : +// CHECK-SAME: !llvm.ptr to !llvm.ptr> +// // 5. Rewrite as a masked write. // CHECK: llvm.intr.masked.store %[[loaded]], %[[vecPtr_b]], %[[mask_b]] // CHECK-SAME: {alignment = 4 : i32} : @@ -1182,6 +1182,21 @@ // ----- +// CHECK-LABEL: func @transfer_read_1d_mask +// CHECK: %[[mask1:.*]] = constant dense<[false, false, true, false, true]> +// CHECK: %[[cmpi:.*]] = cmpi slt +// CHECK: %[[mask2:.*]] = and %[[cmpi]], %[[mask1]] +// CHECK: %[[r:.*]] = llvm.intr.masked.load %{{.*}}, %[[mask2]] +// CHECK: return %[[r]] +func @transfer_read_1d_mask(%A : memref, %base : index) -> vector<5xf32> { + %m = constant dense<[0, 0, 1, 0, 1]> : vector<5xi1> + %f7 = constant 7.0: f32 + %f = vector.transfer_read %A[%base], %f7, %m : memref, vector<5xf32> + return %f: vector<5xf32> +} + +// ----- + func @transfer_read_1d_cast(%A : memref, %base: index) -> vector<12xi8> { %c0 = constant 0: i32 %v = vector.transfer_read %A[%base], %c0 {in_bounds = [true]} : diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -11,6 +11,7 @@ %c0 = constant 0 : i32 %vf0 = splat %f0 : vector<4x3xf32> %v0 = splat %c0 : vector<4x3xi32> + %m = constant dense<[0, 0, 1, 0, 1]> : vector<5xi1> // // CHECK: vector.transfer_read @@ -27,7 +28,8 @@ %5 = vector.transfer_read %arg1[%c3, %c3], %vf0 {in_bounds = [false, true]} : memref>, vector<1x1x4x3xf32> // CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} : memref>, vector<5x24xi8> %6 = vector.transfer_read %arg2[%c3, %c3], %v0 : memref>, vector<5x24xi8> - + // CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}}, %{{.*}} : memref, vector<5xf32> + %7 = vector.transfer_read %arg0[%c3, %c3], %f0, %m : memref, vector<5xf32> // CHECK: vector.transfer_write vector.transfer_write %0, %arg0[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d0)>} : vector<128xf32>, memref @@ -39,7 +41,8 @@ vector.transfer_write %5, %arg1[%c3, %c3] {in_bounds = [false, false]} : vector<1x1x4x3xf32>, memref> // CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<5x24xi8>, memref> vector.transfer_write %6, %arg2[%c3, %c3] : vector<5x24xi8>, memref> - + // CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} : vector<5xf32>, memref + vector.transfer_write %7, %arg0[%c3, %c3], %m : vector<5xf32>, memref return } diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read.mlir @@ -12,6 +12,14 @@ return } +func @transfer_read_mask_1d(%A : memref, %base: index) { + %fm42 = constant -42.0: f32 + %m = constant dense<[0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]> : vector<13xi1> + %f = vector.transfer_read %A[%base], %fm42, %m : memref, vector<13xf32> + vector.print %f: vector<13xf32> + return +} + func @transfer_read_inbounds_4(%A : memref, %base: index) { %fm42 = constant -42.0: f32 %f = vector.transfer_read %A[%base], %fm42 @@ -21,6 +29,15 @@ return } +func @transfer_read_mask_inbounds_4(%A : memref, %base: index) { + %fm42 = constant -42.0: f32 + %m = constant dense<[0, 1, 0, 1]> : vector<4xi1> + %f = vector.transfer_read %A[%base], %fm42, %m {in_bounds = [true]} + : memref, vector<4xf32> + vector.print %f: vector<4xf32> + return +} + func @transfer_write_1d(%A : memref, %base: index) { %f0 = constant 0.0 : f32 %vf0 = splat %f0 : vector<4xf32> @@ -47,6 +64,8 @@ // Read shifted by 2 and pad with -42: // ( 2, 3, 4, -42, ..., -42) call @transfer_read_1d(%A, %c2) : (memref, index) -> () + // Read with mask and out-of-bounds access. + call @transfer_read_mask_1d(%A, %c2) : (memref, index) -> () // Write into memory shifted by 3 // memory contains [[ 0, 1, 2, 0, 0, xxx garbage xxx ]] call @transfer_write_1d(%A, %c3) : (memref, index) -> () @@ -56,9 +75,13 @@ // Read in-bounds 4 @ 1, guaranteed to not overflow. // Exercises proper alignment. call @transfer_read_inbounds_4(%A, %c1) : (memref, index) -> () + // Read in-bounds with mask. + call @transfer_read_mask_inbounds_4(%A, %c1) : (memref, index) -> () return } // CHECK: ( 2, 3, 4, -42, -42, -42, -42, -42, -42, -42, -42, -42, -42 ) +// CHECK: ( -42, -42, 4, -42, -42, -42, -42, -42, -42, -42, -42, -42, -42 ) // CHECK: ( 0, 1, 2, 0, 0, -42, -42, -42, -42, -42, -42, -42, -42 ) // CHECK: ( 1, 2, 0, 0 ) +// CHECK: ( -42, 2, -42, 0 )