Index: mlir/include/mlir/Dialect/Vector/VectorOps.h =================================================================== --- mlir/include/mlir/Dialect/Vector/VectorOps.h +++ mlir/include/mlir/Dialect/Vector/VectorOps.h @@ -126,7 +126,7 @@ /// Build the default minor identity map suitable for a vector transfer. This /// also handles the case memref<... x vector<...>> -> vector<...> in which the /// rank of the identity map must take the vector element type into account. -AffineMap getTransferMinorIdentityMap(MemRefType memRefType, +AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType); } // namespace impl } // end namespace vector Index: mlir/include/mlir/Dialect/Vector/VectorOps.td =================================================================== --- mlir/include/mlir/Dialect/Vector/VectorOps.td +++ mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -1056,7 +1056,7 @@ DeclareOpInterfaceMethods, DeclareOpInterfaceMethods ]>, - Arguments<(ins AnyMemRef:$memref, Variadic:$indices, + Arguments<(ins AnyShaped:$source, Variadic:$indices, AffineMapAttr:$permutation_map, AnyType:$padding, OptionalAttr:$masked)>, Results<(outs AnyVector:$vector)> { @@ -1065,15 +1065,16 @@ let description = [{ The `vector.transfer_read` op performs a read from a slice within a - [MemRef](../LangRef.md#memref-type) supplied as its first operand - into a [vector](../LangRef.md#vector-type) of the same base elemental type. + [MemRef](../LangRef.md#memref-type) or a Ranked + [Tensor](../LangRef.md#tensor-type) supplied as its first operand into a + [vector](../LangRef.md#vector-type) of the same base elemental type. - A memref operand with vector element type, must have its vector element - type match a suffix (shape and element type) of the vector (e.g. + A memref/tensor operand with vector element type, must have its vector + element type match a suffix (shape and element type) of the vector (e.g. memref<3x2x6x4x3xf32>, vector<1x1x4x3xf32>). - The slice is further defined by a full-rank index within the MemRef, - supplied as the operands `2 .. 1 + rank(memref)`. + The slice is further defined by a full-rank index within the MemRef/Tensor, + supplied as the operands `2 .. 1 + rank(memref/tensor)`. The permutation_map [attribute](../LangRef.md#attributes) is an [affine-map](Affine.md#affine-maps) which specifies the transposition on the @@ -1084,8 +1085,9 @@ The size of the slice is specified by the size of the vector, given as the return type. - An `ssa-value` of the same elemental type as the MemRef is provided as the - last operand to specify padding in the case of out-of-bounds accesses. + An `ssa-value` of the same elemental type as the MemRef/Tensor is provided + as the last operand to specify padding in the case of out-of-bounds + accesses. An optional boolean array attribute is provided to specify which dimensions of the transfer need masking. When a dimension is specified as not requiring @@ -1196,17 +1198,22 @@ %4 = vector.transfer_read %arg1[%c3, %c3], %vf0 {permutation_map = (d0, d1)->(d0, d1)} : memref>, vector<1x1x4x3xf32> + + // Read from a tensor with vector element type. + %4 = vector.transfer_read %arg1[%c3, %c3], %vf0 + {permutation_map = (d0, d1)->(d0, d1)} + : tensor>, vector<1x1x4x3xf32> ``` }]; let builders = [ // Builder that sets padding to zero. - OpBuilderDAG<(ins "VectorType":$vector, "Value":$memref, + OpBuilderDAG<(ins "VectorType":$vector, "Value":$source, "ValueRange":$indices, "AffineMap":$permutationMap, CArg<"ArrayRef", "{}">:$maybeMasked)>, // Builder that sets permutation map (resp. padding) to // 'getMinorIdentityMap' (resp. zero). - OpBuilderDAG<(ins "VectorType":$vector, "Value":$memref, + OpBuilderDAG<(ins "VectorType":$vector, "Value":$source, "ValueRange":$indices, CArg<"ArrayRef", "{}">:$maybeMasked)> ]; @@ -1217,26 +1224,29 @@ Vector_Op<"transfer_write", [ DeclareOpInterfaceMethods, DeclareOpInterfaceMethods - ]>, - Arguments<(ins AnyVector:$vector, AnyMemRef:$memref, + ]>, + Arguments<(ins AnyVector:$vector, AnyShaped:$source, Variadic:$indices, AffineMapAttr:$permutation_map, - OptionalAttr:$masked)> { + OptionalAttr:$masked)>, + Results<(outs Optional:$result)> { let summary = "The vector.transfer_write op writes a supervector to memory."; let description = [{ The `vector.transfer_write` op performs a write from a [vector](../LangRef.md#vector-type), supplied as its first operand, into a - slice within a [MemRef](../LangRef.md#memref-type) of the same base - elemental type, supplied as its second operand. + slice within a [MemRef](../LangRef.md#memref-type) or a Ranked + [Tensor](../LangRef.md#tensor-type) of the same base elemental type, + supplied as its second operand. - A vector memref operand must have its vector element type match a suffix - (shape and element type) of the vector (e.g. memref<3x2x6x4x3xf32>, - vector<1x1x4x3xf32>). + A vector memref/tensor operand must have its vector element type match a + suffix (shape and element type) of the vector (e.g. memref<3x2x6x4x3xf32>, + vector<1x1x4x3xf32>). If the operand is a tensor, the operation returns a + new tensor of the same type. - The slice is further defined by a full-rank index within the MemRef, - supplied as the operands `3 .. 2 + rank(memref)`. + The slice is further defined by a full-rank index within the MemRef/Tensor, + supplied as the operands `3 .. 2 + rank(memref/tensor)`. The permutation_map [attribute](../LangRef.md#attributes) is an [affine-map](Affine.md#affine-maps) which specifies the transposition on the @@ -1280,15 +1290,24 @@ vector.transfer_write %4, %arg1[%c3, %c3] {permutation_map = (d0, d1)->(d0, d1)} : vector<1x1x4x3xf32>, memref> + + // return a tensor where the vector is inserted into the source tensor. + %5 = vector.transfer_write %4, %arg1[%c3, %c3] + {permutation_map = (d0, d1)->(d0, d1)} + : vector<1x1x4x3xf32>, tensor> ``` }]; let builders = [ // Builder that sets permutation map to 'getMinorIdentityMap'. - OpBuilderDAG<(ins "Value":$vector, "Value":$memref, "ValueRange":$indices, + OpBuilderDAG<(ins "Value":$vector, "Value":$source, "ValueRange":$indices, CArg<"ArrayRef", "{}">:$maybeMasked)>, - OpBuilderDAG<(ins "Value":$vector, "Value":$memref, "ValueRange":$indices, + OpBuilderDAG<(ins "Value":$vector, "Value":$source, "ValueRange":$indices, "AffineMap":$permutationMap)>, + OpBuilderDAG<(ins "Value":$vector, "Value":$source, "ValueRange":$indices, + "AffineMapAttr":$permutationMap, "ArrayAttr":$masked)>, + OpBuilderDAG<(ins "Value":$vector, "Value":$source, "ValueRange":$indices, + "AffineMap":$permutationMap, "ArrayAttr":$masked)>, ]; let hasFolder = 1; Index: mlir/include/mlir/Dialect/Vector/VectorUtils.h =================================================================== --- mlir/include/mlir/Dialect/Vector/VectorUtils.h +++ mlir/include/mlir/Dialect/Vector/VectorUtils.h @@ -20,9 +20,9 @@ class AffineForOp; class AffineMap; class Location; -class MemRefType; class OpBuilder; class Operation; +class ShapedType; class Value; class VectorType; class VectorTransferOpInterface; @@ -157,7 +157,7 @@ /// Build the default minor identity map suitable for a vector transfer. This /// also handles the case memref<... x vector<...>> -> vector<...> in which the /// rank of the identity map must take the vector element type into account. -AffineMap getTransferMinorIdentityMap(MemRefType memRefType, +AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType); /// Return true if we can prove that the transfer operations access disjoint Index: mlir/include/mlir/Interfaces/VectorInterfaces.td =================================================================== --- mlir/include/mlir/Interfaces/VectorInterfaces.td +++ mlir/include/mlir/Interfaces/VectorInterfaces.td @@ -47,7 +47,7 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> { let description = [{ - Encodes properties of an operation on vectors that can be unrolled. + Encodes properties of a transfer read or write operation. }]; let cppNamespace = "::mlir"; @@ -83,11 +83,11 @@ }] >, InterfaceMethod< - /*desc=*/"Return the memref operand.", + /*desc=*/"Return the memref or ranked tensor operand.", /*retTy=*/"Value", - /*methodName=*/"memref", + /*methodName=*/"source", /*args=*/(ins), - /*methodBody=*/"return $_op.memref();" + /*methodBody=*/"return $_op.source();" /*defaultImplementation=*/ >, InterfaceMethod< @@ -123,13 +123,13 @@ /*defaultImplementation=*/ >, InterfaceMethod< - /*desc=*/"Return the MemRefType.", - /*retTy=*/"MemRefType", - /*methodName=*/"getMemRefType", + /*desc=*/"Return the ShapedType.", + /*retTy=*/"ShapedType", + /*methodName=*/"getShapedType", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/ - "return $_op.memref().getType().template cast();" + "return $_op.source().getType().template cast();" >, InterfaceMethod< /*desc=*/"Return the VectorType.", @@ -152,14 +152,14 @@ "return $_op.permutation_map().getNumResults();" >, InterfaceMethod< - /*desc=*/[{ Return the number of leading memref dimensions that do not + /*desc=*/[{ Return the number of leading shaped dimensions that do not participate in the permutation map.}], /*retTy=*/"unsigned", - /*methodName=*/"getLeadingMemRefRank", + /*methodName=*/"getLeadingShapedRank", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/ - "return $_op.getMemRefType().getRank() - $_op.getTransferRank();" + "return $_op.getShapedType().getRank() - $_op.getTransferRank();" >, InterfaceMethod< /*desc=*/[{ Returns true if at least one of the dimensions is masked.}], @@ -178,8 +178,8 @@ /*desc=*/[{ Helper function to account for the fact that `permutationMap` results and `op.indices` sizes may not match and may not be aligned. The first - `getLeadingMemRefRank()` indices may just be indexed and not transferred - from/into the vector. + `getLeadingShapedRank()` indices may just be indexed and not + transferred from/into the vector. For example: ``` vector.transfer %0[%i, %j, %k, %c0] : @@ -195,7 +195,7 @@ /*methodBody=*/"", /*defaultImplementation=*/[{ for (int64_t resultIdx = 0, - indicesIdx = $_op.getLeadingMemRefRank(), + indicesIdx = $_op.getLeadingShapedRank(), eResult = $_op.getTransferRank(); resultIdx < eResult; ++resultIdx, ++indicesIdx) Index: mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp =================================================================== --- mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp +++ mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp @@ -22,6 +22,17 @@ using namespace mlir; +/// Helpers to access the memref operand for each op. +static Value getMemRefOperand(LoadOp op) { return op.memref(); } + +static Value getMemRefOperand(vector::TransferReadOp op) { return op.source(); } + +static Value getMemRefOperand(StoreOp op) { return op.memref(); } + +static Value getMemRefOperand(vector::TransferWriteOp op) { + return op.source(); +} + namespace { /// Merges subview operation with load/transferRead operation. template @@ -141,7 +152,7 @@ LogicalResult LoadOpOfSubViewFolder::matchAndRewrite(OpTy loadOp, PatternRewriter &rewriter) const { - auto subViewOp = loadOp.memref().template getDefiningOp(); + auto subViewOp = getMemRefOperand(loadOp).template getDefiningOp(); if (!subViewOp) { return failure(); } @@ -162,7 +173,8 @@ LogicalResult StoreOpOfSubViewFolder::matchAndRewrite(OpTy storeOp, PatternRewriter &rewriter) const { - auto subViewOp = storeOp.memref().template getDefiningOp(); + auto subViewOp = + getMemRefOperand(storeOp).template getDefiningOp(); if (!subViewOp) { return failure(); } Index: mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp =================================================================== --- mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -141,12 +141,10 @@ return rewriter.create(loc, CmpIPredicate::slt, indices, bounds); } -// Helper that returns data layout alignment of an operation with memref. -template -LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, T op, - unsigned &align) { - Type elementTy = - typeConverter.convertType(op.getMemRefType().getElementType()); +// Helper that returns data layout alignment of a memref. +LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, + MemRefType memrefType, unsigned &align) { + Type elementTy = typeConverter.convertType(memrefType.getElementType()); if (!elementTy) return failure(); @@ -222,7 +220,8 @@ TransferReadOp xferOp, ArrayRef operands, Value dataPtr) { unsigned align; - if (failed(getMemRefAlignment(typeConverter, xferOp, align))) + if (failed(getMemRefAlignment( + typeConverter, xferOp.getShapedType().cast(), align))) return failure(); rewriter.replaceOpWithNewOp(xferOp, dataPtr, align); return success(); @@ -243,7 +242,8 @@ return failure(); unsigned align; - if (failed(getMemRefAlignment(typeConverter, xferOp, align))) + if (failed(getMemRefAlignment( + typeConverter, xferOp.getShapedType().cast(), align))) return failure(); rewriter.replaceOpWithNewOp( @@ -258,7 +258,8 @@ TransferWriteOp xferOp, ArrayRef operands, Value dataPtr) { unsigned align; - if (failed(getMemRefAlignment(typeConverter, xferOp, align))) + if (failed(getMemRefAlignment( + typeConverter, xferOp.getShapedType().cast(), align))) return failure(); auto adaptor = TransferWriteOpAdaptor(operands); rewriter.replaceOpWithNewOp(xferOp, adaptor.vector(), dataPtr, @@ -272,7 +273,8 @@ TransferWriteOp xferOp, ArrayRef operands, Value dataPtr, Value mask) { unsigned align; - if (failed(getMemRefAlignment(typeConverter, xferOp, align))) + if (failed(getMemRefAlignment( + typeConverter, xferOp.getShapedType().cast(), align))) return failure(); auto adaptor = TransferWriteOpAdaptor(operands); @@ -345,7 +347,8 @@ // Resolve alignment. unsigned align; - if (failed(getMemRefAlignment(*getTypeConverter(), load, align))) + if (failed(getMemRefAlignment(*getTypeConverter(), load.getMemRefType(), + align))) return failure(); auto vtype = typeConverter->convertType(load.getResultVectorType()); @@ -375,7 +378,8 @@ // Resolve alignment. unsigned align; - if (failed(getMemRefAlignment(*getTypeConverter(), store, align))) + if (failed(getMemRefAlignment(*getTypeConverter(), store.getMemRefType(), + align))) return failure(); auto vtype = typeConverter->convertType(store.getValueVectorType()); @@ -405,7 +409,8 @@ // Resolve alignment. unsigned align; - if (failed(getMemRefAlignment(*getTypeConverter(), gather, align))) + if (failed(getMemRefAlignment(*getTypeConverter(), gather.getMemRefType(), + align))) return failure(); // Get index ptrs. @@ -438,7 +443,8 @@ // Resolve alignment. unsigned align; - if (failed(getMemRefAlignment(*getTypeConverter(), scatter, align))) + if (failed(getMemRefAlignment(*getTypeConverter(), scatter.getMemRefType(), + align))) return failure(); // Get index ptrs. @@ -1182,8 +1188,11 @@ xferOp.getVectorType().getRank(), xferOp->getContext())) return failure(); + auto memRefType = xferOp.getShapedType().template dyn_cast(); + if (!memRefType) + return failure(); // Only contiguous source tensors supported atm. - auto strides = computeContiguousStrides(xferOp.getMemRefType()); + auto strides = computeContiguousStrides(memRefType); if (!strides) return failure(); @@ -1192,10 +1201,9 @@ }; Location loc = xferOp->getLoc(); - MemRefType memRefType = xferOp.getMemRefType(); if (auto memrefVectorElementType = - memRefType.getElementType().dyn_cast()) { + memRefType.getElementType().template dyn_cast()) { // Memref has vector element type. if (memrefVectorElementType.getElementType() != xferOp.getVectorType().getElementType()) @@ -1222,7 +1230,7 @@ // address space 0. // TODO: support alignment when possible. Value dataPtr = this->getStridedElementPtr( - loc, memRefType, adaptor.memref(), adaptor.indices(), rewriter); + loc, memRefType, adaptor.source(), adaptor.indices(), rewriter); auto vecTy = toLLVMTy(xferOp.getVectorType()).template cast(); Value vectorDataPtr; @@ -1248,7 +1256,7 @@ unsigned vecWidth = vecTy.getVectorNumElements(); unsigned lastIndex = llvm::size(xferOp.indices()) - 1; Value off = xferOp.indices()[lastIndex]; - Value dim = rewriter.create(loc, xferOp.memref(), lastIndex); + Value dim = rewriter.create(loc, xferOp.source(), lastIndex); Value mask = buildVectorComparison( rewriter, xferOp, enableIndexOptimizations, vecWidth, dim, &off); Index: mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp =================================================================== --- mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp +++ mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp @@ -89,7 +89,9 @@ return failure(); // Obtain dataPtr and elementType from the memref. - MemRefType memRefType = xferOp.getMemRefType(); + auto memRefType = xferOp.getShapedType().template dyn_cast(); + if (!memRefType) + return failure(); // MUBUF instruction operate only on addresspace 0(unified) or 1(global) // In case of 3(LDS): fall back to vector->llvm pass // In case of 5(VGPR): wrong @@ -101,7 +103,7 @@ // indices, so no need to calculate offset size in bytes again in // the MUBUF instruction. Value dataPtr = this->getStridedElementPtr( - loc, memRefType, adaptor.memref(), adaptor.indices(), rewriter); + loc, memRefType, adaptor.source(), adaptor.indices(), rewriter); // 1. Create and fill a <4 x i32> dwordConfig with: // 1st two elements holding the address of dataPtr. Index: mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp =================================================================== --- mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -107,7 +107,7 @@ // TODO: when we go to k > 1-D vectors adapt minorRank. minorRank = 1; majorRank = vectorType.getRank() - minorRank; - leadingRank = xferOp.getLeadingMemRefRank(); + leadingRank = xferOp.getLeadingShapedRank(); majorVectorType = VectorType::get(vectorType.getShape().take_front(majorRank), vectorType.getElementType()); @@ -115,9 +115,9 @@ VectorType::get(vectorType.getShape().take_back(minorRank), vectorType.getElementType()); /// Memref of minor vector type is used for individual transfers. - memRefMinorVectorType = - MemRefType::get(majorVectorType.getShape(), minorVectorType, {}, - xferOp.getMemRefType().getMemorySpace()); + memRefMinorVectorType = MemRefType::get( + majorVectorType.getShape(), minorVectorType, {}, + xferOp.getShapedType().template cast().getMemorySpace()); } LogicalResult doReplace(); @@ -155,7 +155,7 @@ const MemRefBoundsCapture &)> loopBodyBuilder) { /// Loop nest operates on the major dimensions - MemRefBoundsCapture memrefBoundsCapture(xferOp.memref()); + MemRefBoundsCapture memrefBoundsCapture(xferOp.source()); if (options.unroll) { auto shape = majorVectorType.getShape(); @@ -272,9 +272,9 @@ indexing.append(leadingOffsets.begin(), leadingOffsets.end()); indexing.append(majorIvsPlusOffsets.begin(), majorIvsPlusOffsets.end()); indexing.append(minorOffsets.begin(), minorOffsets.end()); - Value memref = xferOp.memref(); + Value memref = xferOp.source(); auto map = - getTransferMinorIdentityMap(xferOp.getMemRefType(), minorVectorType); + getTransferMinorIdentityMap(xferOp.getShapedType(), minorVectorType); ArrayAttr masked; if (!xferOp.isMaskedDim(xferOp.getVectorType().getRank() - 1)) { OpBuilder &b = ScopedContext::getBuilderRef(); @@ -379,13 +379,13 @@ else result = std_load(alloc, majorIvs); auto map = - getTransferMinorIdentityMap(xferOp.getMemRefType(), minorVectorType); + getTransferMinorIdentityMap(xferOp.getShapedType(), minorVectorType); ArrayAttr masked; if (!xferOp.isMaskedDim(xferOp.getVectorType().getRank() - 1)) { OpBuilder &b = ScopedContext::getBuilderRef(); masked = b.getBoolArrayAttr({false}); } - vector_transfer_write(result, xferOp.memref(), indexing, + vector_transfer_write(result, xferOp.source(), indexing, AffineMapAttr::get(map), masked); }; @@ -422,7 +422,7 @@ static int computeCoalescedIndex(TransferOpTy transfer) { // rank of the remote memory access, coalescing behavior occurs on the // innermost memory dimension. - auto remoteRank = transfer.getMemRefType().getRank(); + auto remoteRank = transfer.getShapedType().getRank(); // Iterate over the results expressions of the permutation map to determine // the loop order for creating pointwise copies between remote and local // memories. @@ -536,13 +536,14 @@ using namespace mlir::edsc::op; TransferReadOp transfer = cast(op); - + auto memRefType = transfer.getShapedType().dyn_cast(); + if (!memRefType) + return failure(); // Fall back to a loop if the fastest varying stride is not 1 or it is // permuted. int64_t offset; SmallVector strides; - auto successStrides = - getStridesAndOffset(transfer.getMemRefType(), strides, offset); + auto successStrides = getStridesAndOffset(memRefType, strides, offset); if (succeeded(successStrides) && strides.back() == 1 && transfer.permutation_map().isMinorIdentity()) { // If > 1D, emit a bunch of loops around 1-D vector transfers. @@ -557,8 +558,8 @@ // Conservative lowering to scalar load / stores. // 1. Setup all the captures. ScopedContext scope(rewriter, transfer.getLoc()); - StdIndexedValue remote(transfer.memref()); - MemRefBoundsCapture memRefBoundsCapture(transfer.memref()); + StdIndexedValue remote(transfer.source()); + MemRefBoundsCapture memRefBoundsCapture(transfer.source()); VectorBoundsCapture vectorBoundsCapture(transfer.vector()); int coalescedIdx = computeCoalescedIndex(transfer); // Swap the vectorBoundsCapture which will reorder loop bounds. @@ -621,13 +622,15 @@ using namespace edsc::op; TransferWriteOp transfer = cast(op); + auto memRefType = transfer.getShapedType().template dyn_cast(); + if (!memRefType) + return failure(); // Fall back to a loop if the fastest varying stride is not 1 or it is // permuted. int64_t offset; SmallVector strides; - auto successStrides = - getStridesAndOffset(transfer.getMemRefType(), strides, offset); + auto successStrides = getStridesAndOffset(memRefType, strides, offset); if (succeeded(successStrides) && strides.back() == 1 && transfer.permutation_map().isMinorIdentity()) { // If > 1D, emit a bunch of loops around 1-D vector transfers. @@ -641,8 +644,8 @@ // 1. Setup all the captures. ScopedContext scope(rewriter, transfer.getLoc()); - StdIndexedValue remote(transfer.memref()); - MemRefBoundsCapture memRefBoundsCapture(transfer.memref()); + StdIndexedValue remote(transfer.source()); + MemRefBoundsCapture memRefBoundsCapture(transfer.source()); Value vectorValue(transfer.vector()); VectorBoundsCapture vectorBoundsCapture(transfer.vector()); int coalescedIdx = computeCoalescedIndex(transfer); Index: mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp =================================================================== --- mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp +++ mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp @@ -111,7 +111,7 @@ vector::TransferWriteOp transferWrite; for (auto *sliceOp : llvm::reverse(forwardSlice)) { auto candidateWrite = dyn_cast(sliceOp); - if (!candidateWrite || candidateWrite.memref() != transferRead.memref()) + if (!candidateWrite || candidateWrite.source() != transferRead.source()) continue; transferWrite = candidateWrite; } @@ -142,7 +142,7 @@ DominanceInfo dom(loop); if (!dom.properlyDominates(transferRead.getOperation(), transferWrite)) return WalkResult::advance(); - for (auto &use : transferRead.memref().getUses()) { + for (auto &use : transferRead.source().getUses()) { if (!dom.properlyDominates(loop, use.getOwner())) continue; if (use.getOwner() == transferRead.getOperation() || Index: mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp =================================================================== --- mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -411,7 +411,7 @@ vector::TransferReadOp xferOp, PatternRewriter &rewriter) const { // Transfer into `view`. - Value viewOrAlloc = xferOp.memref(); + Value viewOrAlloc = xferOp.source(); if (!viewOrAlloc.getDefiningOp() && !viewOrAlloc.getDefiningOp()) return failure(); @@ -487,7 +487,7 @@ LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite( vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const { // Transfer into `viewOrAlloc`. - Value viewOrAlloc = xferOp.memref(); + Value viewOrAlloc = xferOp.source(); if (!viewOrAlloc.getDefiningOp() && !viewOrAlloc.getDefiningOp()) return failure(); Index: mlir/lib/Dialect/Vector/VectorOps.cpp =================================================================== --- mlir/lib/Dialect/Vector/VectorOps.cpp +++ mlir/lib/Dialect/Vector/VectorOps.cpp @@ -1890,41 +1890,43 @@ return success(); } -static LogicalResult verifyTransferOp(Operation *op, MemRefType memrefType, +static LogicalResult verifyTransferOp(Operation *op, ShapedType shapedType, VectorType vectorType, AffineMap permutationMap, ArrayAttr optionalMasked) { - auto memrefElementType = memrefType.getElementType(); - if (auto memrefVectorElementType = memrefElementType.dyn_cast()) { - // Memref has vector element type. - - unsigned memrefVecSize = memrefVectorElementType.getElementTypeBitWidth() * - memrefVectorElementType.getShape().back(); + if (!shapedType.isa()) + return op->emitOpError( + "requires source to be a memref or ranked tensor type"); + auto elementType = shapedType.getElementType(); + if (auto vectorElementType = elementType.dyn_cast()) { + // Memref or tensor has vector element type. + unsigned sourceVecSize = vectorElementType.getElementTypeBitWidth() * + vectorElementType.getShape().back(); unsigned resultVecSize = vectorType.getElementTypeBitWidth() * vectorType.getShape().back(); - if (resultVecSize % memrefVecSize != 0) + if (resultVecSize % sourceVecSize != 0) return op->emitOpError( "requires the bitwidth of the minor 1-D vector to be an integral " - "multiple of the bitwidth of the minor 1-D vector of the memref"); + "multiple of the bitwidth of the minor 1-D vector of the source"); - unsigned memrefVecEltRank = memrefVectorElementType.getRank(); + unsigned sourceVecEltRank = vectorElementType.getRank(); unsigned resultVecRank = vectorType.getRank(); - if (memrefVecEltRank > resultVecRank) + if (sourceVecEltRank > resultVecRank) return op->emitOpError( - "requires memref vector element and vector result ranks to match."); - unsigned rankOffset = resultVecRank - memrefVecEltRank; + "requires source vector element and vector result ranks to match."); + unsigned rankOffset = resultVecRank - sourceVecEltRank; // Check that permutation map results match 'rankOffset' of vector type. if (permutationMap.getNumResults() != rankOffset) return op->emitOpError("requires a permutation_map with result dims of " "the same rank as the vector type"); } else { - // Memref has scalar element type. + // Memref or tensor has scalar element type. unsigned resultVecSize = vectorType.getElementTypeBitWidth() * vectorType.getShape().back(); - if (resultVecSize % memrefElementType.getIntOrFloatBitWidth() != 0) + if (resultVecSize % elementType.getIntOrFloatBitWidth() != 0) return op->emitOpError( "requires the bitwidth of the minor 1-D vector to be an integral " - "multiple of the bitwidth of the memref element type"); + "multiple of the bitwidth of the source element type"); // Check that permutation map results match rank of vector type. if (permutationMap.getNumResults() != vectorType.getRank()) @@ -1934,9 +1936,9 @@ if (permutationMap.getNumSymbols() != 0) return op->emitOpError("requires permutation_map without symbols"); - if (permutationMap.getNumInputs() != memrefType.getRank()) + if (permutationMap.getNumInputs() != shapedType.getRank()) return op->emitOpError("requires a permutation_map with input dims of the " - "same rank as the memref type"); + "same rank as the source type"); if (optionalMasked) { if (permutationMap.getNumResults() != @@ -1978,7 +1980,7 @@ static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) { SmallVector elidedAttrs; if (op.permutation_map() == - getTransferMinorIdentityMap(op.getMemRefType(), op.getVectorType())) + getTransferMinorIdentityMap(op.getShapedType(), op.getVectorType())) elidedAttrs.push_back(op.getPermutationMapAttrName()); bool elideMasked = true; if (auto maybeMasked = op.masked()) { @@ -1995,21 +1997,21 @@ } static void print(OpAsmPrinter &p, TransferReadOp op) { - p << op.getOperationName() << " " << op.memref() << "[" << op.indices() + p << op.getOperationName() << " " << op.source() << "[" << op.indices() << "], " << op.padding(); printTransferAttrs(p, cast(op.getOperation())); - p << " : " << op.getMemRefType() << ", " << op.getVectorType(); + p << " : " << op.getShapedType() << ", " << op.getVectorType(); } static ParseResult parseTransferReadOp(OpAsmParser &parser, OperationState &result) { llvm::SMLoc typesLoc; - OpAsmParser::OperandType memrefInfo; + OpAsmParser::OperandType sourceInfo; SmallVector indexInfo; OpAsmParser::OperandType paddingInfo; SmallVector types; // Parsing with support for paddingValue. - if (parser.parseOperand(memrefInfo) || + if (parser.parseOperand(sourceInfo) || parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || parser.parseComma() || parser.parseOperand(paddingInfo) || parser.parseOptionalAttrDict(result.attributes) || @@ -2018,48 +2020,48 @@ if (types.size() != 2) return parser.emitError(typesLoc, "requires two types"); auto indexType = parser.getBuilder().getIndexType(); - MemRefType memRefType = types[0].dyn_cast(); - if (!memRefType) - return parser.emitError(typesLoc, "requires memref type"); + auto shapedType = types[0].dyn_cast(); + if (!shapedType || !shapedType.isa()) + return parser.emitError(typesLoc, "requires memref or ranked tensor type"); VectorType vectorType = types[1].dyn_cast(); if (!vectorType) return parser.emitError(typesLoc, "requires vector type"); auto permutationAttrName = TransferReadOp::getPermutationMapAttrName(); auto attr = result.attributes.get(permutationAttrName); if (!attr) { - auto permMap = getTransferMinorIdentityMap(memRefType, vectorType); + auto permMap = getTransferMinorIdentityMap(shapedType, vectorType); result.attributes.set(permutationAttrName, AffineMapAttr::get(permMap)); } return failure( - parser.resolveOperand(memrefInfo, memRefType, result.operands) || + parser.resolveOperand(sourceInfo, shapedType, result.operands) || parser.resolveOperands(indexInfo, indexType, result.operands) || - parser.resolveOperand(paddingInfo, memRefType.getElementType(), + parser.resolveOperand(paddingInfo, shapedType.getElementType(), result.operands) || parser.addTypeToList(vectorType, result.types)); } static LogicalResult verify(TransferReadOp op) { - // Consistency of elemental types in memref and vector. - MemRefType memrefType = op.getMemRefType(); + // Consistency of elemental types in source and vector. + ShapedType shapedType = op.getShapedType(); VectorType vectorType = op.getVectorType(); auto paddingType = op.padding().getType(); auto permutationMap = op.permutation_map(); - auto memrefElementType = memrefType.getElementType(); + auto sourceElementType = shapedType.getElementType(); - if (static_cast(op.indices().size()) != memrefType.getRank()) - return op.emitOpError("requires ") << memrefType.getRank() << " indices"; + if (static_cast(op.indices().size()) != shapedType.getRank()) + return op.emitOpError("requires ") << shapedType.getRank() << " indices"; - if (failed(verifyTransferOp(op.getOperation(), memrefType, vectorType, + if (failed(verifyTransferOp(op.getOperation(), shapedType, vectorType, permutationMap, op.masked() ? *op.masked() : ArrayAttr()))) return failure(); - if (auto memrefVectorElementType = memrefElementType.dyn_cast()) { - // Memref has vector element type. - // Check that 'memrefVectorElementType' and 'paddingType' types match. - if (memrefVectorElementType != paddingType) + if (auto sourceVectorElementType = sourceElementType.dyn_cast()) { + // Source has vector element type. + // Check that 'sourceVectorElementType' and 'paddingType' types match. + if (sourceVectorElementType != paddingType) return op.emitOpError( - "requires memref element type and padding type to match."); + "requires source element type and padding type to match."); } else { // Check that 'paddingType' is valid to store in a vector type. @@ -2067,9 +2069,9 @@ return op.emitOpError("requires valid padding vector elemental type"); // Check that padding type and vector element types match. - if (paddingType != memrefElementType) + if (paddingType != sourceElementType) return op.emitOpError( - "requires formal padding and memref of the same elemental type"); + "requires formal padding and source of the same elemental type"); } return verifyPermutationMap(permutationMap, @@ -2096,18 +2098,18 @@ template static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) { // TODO: support more aggressive createOrFold on: - // `op.indices()[indicesIdx] + vectorType < dim(op.memref(), indicesIdx)` - if (op.getMemRefType().isDynamicDim(indicesIdx)) + // `op.indices()[indicesIdx] + vectorType < dim(op.source(), indicesIdx)` + if (op.getShapedType().isDynamicDim(indicesIdx)) return false; Value index = op.indices()[indicesIdx]; auto cstOp = index.getDefiningOp(); if (!cstOp) return false; - int64_t memrefSize = op.getMemRefType().getDimSize(indicesIdx); + int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx); int64_t vectorSize = op.getVectorType().getDimSize(resultIdx); - return cstOp.getValue() + vectorSize <= memrefSize; + return cstOp.getValue() + vectorSize <= sourceSize; } template @@ -2159,33 +2161,51 @@ /// Builder that sets permutation map to 'getMinorIdentityMap'. void TransferWriteOp::build(OpBuilder &builder, OperationState &result, - Value vector, Value memref, ValueRange indices, + Value vector, Value source, ValueRange indices, ArrayRef maybeMasked) { auto vectorType = vector.getType().cast(); auto permMap = getTransferMinorIdentityMap( - memref.getType().cast(), vectorType); + source.getType().cast(), vectorType); if (maybeMasked.empty()) - return build(builder, result, vector, memref, indices, permMap, + return build(builder, result, vector, source, indices, permMap, ArrayAttr()); ArrayAttr maskedArrayAttr = builder.getBoolArrayAttr(maybeMasked); - build(builder, result, vector, memref, indices, permMap, maskedArrayAttr); + build(builder, result, vector, source, indices, permMap, maskedArrayAttr); } void TransferWriteOp::build(OpBuilder &builder, OperationState &result, - Value vector, Value memref, ValueRange indices, + Value vector, Value source, ValueRange indices, AffineMap permutationMap) { - build(builder, result, vector, memref, indices, permutationMap, + build(builder, result, vector, source, indices, permutationMap, /*maybeMasked=*/ArrayAttr()); } +void TransferWriteOp::build(OpBuilder &builder, OperationState &result, + Value vector, Value source, ValueRange indices, + AffineMapAttr permutationMap, + /*optional*/ ArrayAttr masked) { + Type resultType = source.getType().dyn_cast(); + build(builder, result, resultType, vector, source, indices, permutationMap, + masked); +} + +void TransferWriteOp::build(OpBuilder &builder, OperationState &result, + Value vector, Value source, ValueRange indices, + AffineMap permutationMap, + /*optional*/ ArrayAttr masked) { + Type resultType = source.getType().dyn_cast(); + build(builder, result, resultType, vector, source, indices, permutationMap, + masked); +} + static ParseResult parseTransferWriteOp(OpAsmParser &parser, OperationState &result) { llvm::SMLoc typesLoc; - OpAsmParser::OperandType vectorInfo, memrefInfo; + OpAsmParser::OperandType vectorInfo, sourceInfo; SmallVector indexInfo; SmallVector types; if (parser.parseOperand(vectorInfo) || parser.parseComma() || - parser.parseOperand(memrefInfo) || + parser.parseOperand(sourceInfo) || parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || parser.parseOptionalAttrDict(result.attributes) || parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types)) @@ -2196,38 +2216,40 @@ VectorType vectorType = types[0].dyn_cast(); if (!vectorType) return parser.emitError(typesLoc, "requires vector type"); - MemRefType memRefType = types[1].dyn_cast(); - if (!memRefType) - return parser.emitError(typesLoc, "requires memref type"); + ShapedType shapedType = types[1].dyn_cast(); + if (!shapedType || !shapedType.isa()) + return parser.emitError(typesLoc, "requires memref or ranked tensor type"); auto permutationAttrName = TransferWriteOp::getPermutationMapAttrName(); auto attr = result.attributes.get(permutationAttrName); if (!attr) { - auto permMap = getTransferMinorIdentityMap(memRefType, vectorType); + auto permMap = getTransferMinorIdentityMap(shapedType, vectorType); result.attributes.set(permutationAttrName, AffineMapAttr::get(permMap)); } return failure( parser.resolveOperand(vectorInfo, vectorType, result.operands) || - parser.resolveOperand(memrefInfo, memRefType, result.operands) || - parser.resolveOperands(indexInfo, indexType, result.operands)); + parser.resolveOperand(sourceInfo, shapedType, result.operands) || + parser.resolveOperands(indexInfo, indexType, result.operands) || + (shapedType.isa() && + parser.addTypeToList(shapedType, result.types))); } static void print(OpAsmPrinter &p, TransferWriteOp op) { - p << op.getOperationName() << " " << op.vector() << ", " << op.memref() << "[" + p << op.getOperationName() << " " << op.vector() << ", " << op.source() << "[" << op.indices() << "]"; printTransferAttrs(p, cast(op.getOperation())); - p << " : " << op.getVectorType() << ", " << op.getMemRefType(); + p << " : " << op.getVectorType() << ", " << op.getShapedType(); } static LogicalResult verify(TransferWriteOp op) { // Consistency of elemental types in memref and vector. - MemRefType memrefType = op.getMemRefType(); + ShapedType shapedType = op.getShapedType(); VectorType vectorType = op.getVectorType(); auto permutationMap = op.permutation_map(); - if (llvm::size(op.indices()) != memrefType.getRank()) - return op.emitOpError("requires ") << memrefType.getRank() << " indices"; + if (llvm::size(op.indices()) != shapedType.getRank()) + return op.emitOpError("requires ") << shapedType.getRank() << " indices"; - if (failed(verifyTransferOp(op.getOperation(), memrefType, vectorType, + if (failed(verifyTransferOp(op.getOperation(), shapedType, vectorType, permutationMap, op.masked() ? *op.masked() : ArrayAttr()))) return failure(); Index: mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp =================================================================== --- mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp +++ mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp @@ -94,7 +94,7 @@ << "\n"); llvm::SmallVector reads; Operation *firstOverwriteCandidate = nullptr; - for (auto *user : write.memref().getUsers()) { + for (auto *user : write.source().getUsers()) { if (user == write.getOperation()) continue; if (auto nextWrite = dyn_cast(user)) { @@ -163,7 +163,7 @@ << "\n"); SmallVector blockingWrites; vector::TransferWriteOp lastwrite = nullptr; - for (Operation *user : read.memref().getUsers()) { + for (Operation *user : read.source().getUsers()) { if (isa(user)) continue; if (auto write = dyn_cast(user)) { Index: mlir/lib/Dialect/Vector/VectorTransforms.cpp =================================================================== --- mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -597,7 +597,7 @@ Location loc = readOp.getLoc(); auto memrefElementType = - readOp.memref().getType().cast().getElementType(); + readOp.source().getType().cast().getElementType(); auto tupleType = generateExtractSlicesOpResultType( sourceVectorType, targetShape, strides, builder); int64_t numSlices = tupleType.size(); @@ -612,7 +612,7 @@ // `masked` attribute propagates conservatively: if the coarse op didn't // need masking, the fine op doesn't either. vectorTupleValues[index] = builder.create( - loc, sliceVectorType, readOp.memref(), sliceIndices, + loc, sliceVectorType, readOp.source(), sliceIndices, readOp.permutation_map(), readOp.padding(), readOp.masked() ? *readOp.masked() : ArrayAttr()); }; @@ -644,14 +644,14 @@ Value tuple = builder.create( loc, tupleType, writeOp.vector(), targetShape, strides); auto memrefElementType = - writeOp.memref().getType().cast().getElementType(); + writeOp.source().getType().cast().getElementType(); SmallVector indices(writeOp.indices().begin(), writeOp.indices().end()); auto createSlice = [&](unsigned index, ArrayRef sliceIndices) { auto element = builder.create( loc, tupleType.getType(index), tuple, builder.getI64IntegerAttr(index)); builder.create( - loc, element.getResult(), writeOp.memref(), sliceIndices, + loc, element.getResult(), writeOp.source(), sliceIndices, writeOp.permutation_map(), writeOp.masked() ? *writeOp.masked() : ArrayAttr()); }; @@ -760,7 +760,7 @@ Location loc = xferWriteOp.getLoc(); auto memrefElementType = - xferWriteOp.memref().getType().cast().getElementType(); + xferWriteOp.source().getType().cast().getElementType(); SmallVector indices(xferWriteOp.indices().begin(), xferWriteOp.indices().end()); auto createSlice = [&](unsigned index, ArrayRef sliceIndices) { @@ -768,7 +768,7 @@ // `masked` attribute propagates conservatively: if the coarse op didn't // need masking, the fine op doesn't either. rewriter.create( - loc, tupleOp.getOperand(index), xferWriteOp.memref(), sliceIndices, + loc, tupleOp.getOperand(index), xferWriteOp.source(), sliceIndices, xferWriteOp.permutation_map(), xferWriteOp.masked() ? *xferWriteOp.masked() : ArrayAttr()); }; @@ -2142,7 +2142,7 @@ // Fold or create the check that `index + vector_size` <= `memref_size`. Value sum = xferOp.indices()[indicesIdx] + std_constant_index(vectorSize); Value cond = - createScopedFoldedSLE(sum, std_dim(xferOp.memref(), indicesIdx)); + createScopedFoldedSLE(sum, std_dim(xferOp.source(), indicesIdx)); if (!cond) return; // Conjunction over all dims for which we are in-bounds. @@ -2207,23 +2207,23 @@ } /// Operates under a scoped context to build the intersection between the -/// view `xferOp.memref()` @ `xferOp.indices()` and the view `alloc`. +/// view `xferOp.source()` @ `xferOp.indices()` and the view `alloc`. // TODO: view intersection/union/differences should be a proper std op. static Value createScopedSubViewIntersection(VectorTransferOpInterface xferOp, Value alloc) { using namespace edsc::intrinsics; - int64_t memrefRank = xferOp.getMemRefType().getRank(); + int64_t memrefRank = xferOp.getShapedType().getRank(); // TODO: relax this precondition, will require rank-reducing subviews. assert(memrefRank == alloc.getType().cast().getRank() && "Expected memref rank to match the alloc rank"); Value one = std_constant_index(1); ValueRange leadingIndices = - xferOp.indices().take_front(xferOp.getLeadingMemRefRank()); + xferOp.indices().take_front(xferOp.getLeadingShapedRank()); SmallVector sizes; sizes.append(leadingIndices.begin(), leadingIndices.end()); xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) { using MapList = ArrayRef>; - Value dimMemRef = std_dim(xferOp.memref(), indicesIdx); + Value dimMemRef = std_dim(xferOp.source(), indicesIdx); Value dimAlloc = std_dim(alloc, resultIdx); Value index = xferOp.indices()[indicesIdx]; AffineExpr i, j, k; @@ -2235,7 +2235,7 @@ ValueRange{dimMemRef, index, dimAlloc}); sizes.push_back(affineMin); }); - return std_sub_view(xferOp.memref(), xferOp.indices(), sizes, + return std_sub_view(xferOp.source(), xferOp.indices(), sizes, SmallVector(memrefRank, one)); } @@ -2263,12 +2263,12 @@ using namespace edsc::intrinsics; scf::IfOp fullPartialIfOp; Value zero = std_constant_index(0); - Value memref = xferOp.memref(); + Value memref = xferOp.source(); conditionBuilder( returnTypes, inBoundsCond, [&]() -> scf::ValueVector { Value res = memref; - if (compatibleMemRefType != xferOp.getMemRefType()) + if (compatibleMemRefType != xferOp.getShapedType()) res = std_memref_cast(memref, compatibleMemRefType); scf::ValueVector viewAndIndices{res}; viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(), @@ -2317,12 +2317,12 @@ using namespace edsc::intrinsics; scf::IfOp fullPartialIfOp; Value zero = std_constant_index(0); - Value memref = xferOp.memref(); + Value memref = xferOp.source(); conditionBuilder( returnTypes, inBoundsCond, [&]() -> scf::ValueVector { Value res = memref; - if (compatibleMemRefType != xferOp.getMemRefType()) + if (compatibleMemRefType != xferOp.getShapedType()) res = std_memref_cast(memref, compatibleMemRefType); scf::ValueVector viewAndIndices{res}; viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(), @@ -2376,7 +2376,7 @@ /// /// Preconditions: /// 1. `xferOp.permutation_map()` must be a minor identity map -/// 2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()` +/// 2. the rank of the `xferOp.source()` and the rank of the `xferOp.vector()` /// must be equal. This will be relaxed in the future but requires /// rank-reducing subviews. LogicalResult mlir::vector::splitFullAndPartialTransfer( @@ -2404,8 +2404,8 @@ return failure(); OpBuilder::InsertionGuard guard(b); - if (xferOp.memref().getDefiningOp()) - b.setInsertionPointAfter(xferOp.memref().getDefiningOp()); + if (Operation *sourceOp = xferOp.source().getDefiningOp()) + b.setInsertionPointAfter(sourceOp); else b.setInsertionPoint(xferOp); ScopedContext scope(b, xferOp.getLoc()); @@ -2426,8 +2426,9 @@ b.getI64IntegerAttr(32)); } - MemRefType compatibleMemRefType = getCastCompatibleMemRefType( - xferOp.getMemRefType(), alloc.getType().cast()); + MemRefType compatibleMemRefType = + getCastCompatibleMemRefType(xferOp.getShapedType().cast(), + alloc.getType().cast()); // Read case: full fill + partial copy -> unmasked vector.xfer_read. SmallVector returnTypes(1 + xferOp.getTransferRank(), @@ -2543,7 +2544,7 @@ extract.ids()[idCount++] * std_constant_index(extract.getResultType().getDimSize(pos)); } - Value newRead = vector_transfer_read(extract.getType(), read.memref(), + Value newRead = vector_transfer_read(extract.getType(), read.source(), indices, read.permutation_map(), read.padding(), read.maskedAttr()); Value dest = rewriter.create( @@ -2579,7 +2580,7 @@ insert.ids()[idCount++] * std_constant_index(insert.getSourceVectorType().getDimSize(pos)); } - vector_transfer_write(insert.vector(), write.memref(), indices, + vector_transfer_write(insert.vector(), write.source(), indices, write.permutation_map(), write.maskedAttr()); rewriter.eraseOp(write); return success(); Index: mlir/lib/Dialect/Vector/VectorUtils.cpp =================================================================== --- mlir/lib/Dialect/Vector/VectorUtils.cpp +++ mlir/lib/Dialect/Vector/VectorUtils.cpp @@ -243,16 +243,16 @@ return ::makePermutationMap(indices, enclosingLoopToVectorDim); } -AffineMap mlir::getTransferMinorIdentityMap(MemRefType memRefType, +AffineMap mlir::getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType) { int64_t elementVectorRank = 0; VectorType elementVectorType = - memRefType.getElementType().dyn_cast(); + shapedType.getElementType().dyn_cast(); if (elementVectorType) elementVectorRank += elementVectorType.getRank(); return AffineMap::getMinorIdentityMap( - memRefType.getRank(), vectorType.getRank() - elementVectorRank, - memRefType.getContext()); + shapedType.getRank(), vectorType.getRank() - elementVectorRank, + shapedType.getContext()); } bool matcher::operatesOnSuperVectorsOf(Operation &op, @@ -314,12 +314,12 @@ bool mlir::isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB) { - if (transferA.memref() != transferB.memref()) + if (transferA.source() != transferB.source()) return false; // For simplicity only look at transfer of same type. if (transferA.getVectorType() != transferB.getVectorType()) return false; - unsigned rankOffset = transferA.getLeadingMemRefRank(); + unsigned rankOffset = transferA.getLeadingShapedRank(); for (unsigned i = 0, e = transferA.indices().size(); i < e; i++) { auto indexA = transferA.indices()[i].getDefiningOp(); auto indexB = transferB.indices()[i].getDefiningOp(); Index: mlir/test/Dialect/Vector/invalid.mlir =================================================================== --- mlir/test/Dialect/Vector/invalid.mlir +++ mlir/test/Dialect/Vector/invalid.mlir @@ -269,7 +269,7 @@ %c3 = constant 3 : index %f0 = constant 0.0 : f32 %vf0 = splat %f0 : vector<4x3xf32> - // expected-error@+1 {{ requires memref type}} + // expected-error@+1 {{ requires memref or ranked tensor type}} %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 : vector<4x3xf32>, vector<1x1x2x3xf32> } @@ -297,7 +297,7 @@ func @test_vector.transfer_read(%arg0: memref) { %c3 = constant 3 : index %cst = constant 3.0 : f32 - // expected-error@+1 {{requires a permutation_map with input dims of the same rank as the memref type}} + // expected-error@+1 {{requires a permutation_map with input dims of the same rank as the source type}} %0 = vector.transfer_read %arg0[%c3, %c3], %cst {permutation_map = affine_map<(d0)->(d0)>} : memref, vector<128xf32> } @@ -343,7 +343,7 @@ %c3 = constant 3 : index %f0 = constant 0.0 : f32 %vf0 = splat %f0 : vector<4x3xf32> - // expected-error@+1 {{requires memref vector element and vector result ranks to match}} + // expected-error@+1 {{requires source vector element and vector result ranks to match}} %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref>, vector<3xf32> } @@ -353,7 +353,7 @@ %c3 = constant 3 : index %f0 = constant 0.0 : f32 %vf0 = splat %f0 : vector<6xf32> - // expected-error@+1 {{requires the bitwidth of the minor 1-D vector to be an integral multiple of the bitwidth of the minor 1-D vector of the memref}} + // expected-error@+1 {{requires the bitwidth of the minor 1-D vector to be an integral multiple of the bitwidth of the minor 1-D vector of the source}} %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 : memref>, vector<3xf32> } @@ -392,7 +392,7 @@ %c3 = constant 3 : index %f0 = constant 0.0 : f32 %vf0 = splat %f0 : vector<4x3xf32> - // expected-error@+1 {{ requires memref type}} + // expected-error@+1 {{ requires memref or ranked tensor type}} vector.transfer_write %arg0, %arg0[%c3, %c3] : vector<4x3xf32>, f32 } @@ -419,7 +419,7 @@ func @test_vector.transfer_write(%arg0: memref) { %c3 = constant 3 : index %cst = constant dense<3.0> : vector<128 x f32> - // expected-error@+1 {{requires a permutation_map with input dims of the same rank as the memref type}} + // expected-error@+1 {{requires a permutation_map with input dims of the same rank as the source type}} vector.transfer_write %cst, %arg0[%c3, %c3] {permutation_map = affine_map<(d0)->(d0)>} : vector<128xf32>, memref } Index: mlir/test/Dialect/Vector/ops.mlir =================================================================== --- mlir/test/Dialect/Vector/ops.mlir +++ mlir/test/Dialect/Vector/ops.mlir @@ -43,6 +43,54 @@ return } + +// CHECK-LABEL: func @vector_transfer_ops_tensor( +func @vector_transfer_ops_tensor(%arg0: tensor, + %arg1 : tensor>, + %arg2 : tensor>) -> + (tensor, tensor, tensor>, + tensor>, tensor>){ + // CHECK: %[[C3:.*]] = constant 3 : index + %c3 = constant 3 : index + %cst = constant 3.0 : f32 + %f0 = constant 0.0 : f32 + %c0 = constant 0 : i32 + %vf0 = splat %f0 : vector<4x3xf32> + %v0 = splat %c0 : vector<4x3xi32> + + // + // CHECK: vector.transfer_read + %0 = vector.transfer_read %arg0[%c3, %c3], %f0 {permutation_map = affine_map<(d0, d1)->(d0)>} : tensor, vector<128xf32> + // CHECK: vector.transfer_read + %1 = vector.transfer_read %arg0[%c3, %c3], %f0 {permutation_map = affine_map<(d0, d1)->(d1, d0)>} : tensor, vector<3x7xf32> + // CHECK: vector.transfer_read + %2 = vector.transfer_read %arg0[%c3, %c3], %cst {permutation_map = affine_map<(d0, d1)->(d0)>} : tensor, vector<128xf32> + // CHECK: vector.transfer_read + %3 = vector.transfer_read %arg0[%c3, %c3], %cst {permutation_map = affine_map<(d0, d1)->(d1)>} : tensor, vector<128xf32> + // CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} : tensor>, vector<1x1x4x3xf32> + %4 = vector.transfer_read %arg1[%c3, %c3], %vf0 {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : tensor>, vector<1x1x4x3xf32> + // CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} {masked = [true, false]} : tensor>, vector<1x1x4x3xf32> + %5 = vector.transfer_read %arg1[%c3, %c3], %vf0 {masked = [true, false]} : tensor>, vector<1x1x4x3xf32> + // CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} : tensor>, vector<5x24xi8> + %6 = vector.transfer_read %arg2[%c3, %c3], %v0 : tensor>, vector<5x24xi8> + + + // CHECK: vector.transfer_write + %7 = vector.transfer_write %0, %arg0[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d0)>} : vector<128xf32>, tensor + // CHECK: vector.transfer_write + %8 = vector.transfer_write %1, %arg0[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d1, d0)>} : vector<3x7xf32>, tensor + // CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<1x1x4x3xf32>, tensor> + %9 = vector.transfer_write %4, %arg1[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : vector<1x1x4x3xf32>, tensor> + // CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<1x1x4x3xf32>, tensor> + %10 = vector.transfer_write %5, %arg1[%c3, %c3] {masked = [true, true]} : vector<1x1x4x3xf32>, tensor> + // CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<5x24xi8>, tensor> + %11 = vector.transfer_write %6, %arg2[%c3, %c3] : vector<5x24xi8>, tensor> + + return %7, %8, %9, %10, %11 : + tensor, tensor, tensor>, + tensor>, tensor> +} + // CHECK-LABEL: @vector_broadcast func @vector_broadcast(%a: f32, %b: vector<16xf32>, %c: vector<1x16xf32>, %d: vector<8x1xf32>) -> vector<8x16xf32> { // CHECK: vector.broadcast %{{.*}} : f32 to vector<16xf32>