diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -1133,26 +1133,28 @@ DeclareOpInterfaceMethods, AttrSizedOperandSegments ]>, - Arguments<(ins AnyShaped:$source, Variadic:$indices, - AffineMapAttr:$permutation_map, AnyType:$padding, - Optional>:$mask, - OptionalAttr:$in_bounds)>, - Results<(outs AnyVector:$vector)> { + Arguments<(ins AnyShaped:$source, + Variadic:$indices, + AffineMapAttr:$permutation_map, + AnyType:$padding, + Optional>:$mask, + OptionalAttr:$in_bounds)>, + Results<(outs AnyVectorOfAnyRank:$vector)> { let summary = "Reads a supervector from memory into an SSA vector value."; let description = [{ The `vector.transfer_read` op performs a read from a slice within a [MemRef](../LangRef.md#memref-type) or a Ranked - [Tensor](../LangRef.md#tensor-type) supplied as its first operand into a - [vector](../LangRef.md#vector-type) of the same base elemental type. + [Tensor](../LangRef.md#tensor-type) supplied as its first operand + into a [vector](../LangRef.md#vector-type) of the same base elemental type. A memref/tensor operand with vector element type, must have its vector element type match a suffix (shape and element type) of the vector (e.g. memref<3x2x6x4x3xf32>, vector<1x1x4x3xf32>). The slice is further defined by a full-rank index within the MemRef/Tensor, - supplied as the operands `2 .. 1 + rank(memref/tensor)`. + supplied as the operands `[1 .. 1 + rank(memref/tensor))`. The permutation_map [attribute](../LangRef.md#attributes) is an [affine-map](Affine.md#affine-maps) which specifies the transposition on the @@ -1301,39 +1303,31 @@ }]; let builders = [ - // Builder that sets padding to zero. - OpBuilder<(ins "VectorType":$vector, "Value":$source, - "ValueRange":$indices, "AffineMap":$permutationMap, - CArg<"ArrayRef", "{}">:$inBounds)>, - // Builder that sets permutation map to 'getMinorIdentityMap'. - OpBuilder<(ins "VectorType":$vector, "Value":$source, - "ValueRange":$indices, "Value":$padding, - CArg<"ArrayRef", "{}">:$inBounds)>, - // Builder that sets permutation map (resp. padding) to - // 'getMinorIdentityMap' (resp. zero). - OpBuilder<(ins "VectorType":$vector, "Value":$source, - "ValueRange":$indices, CArg<"ArrayRef", "{}">:$inBounds)>, - // Builder that does not set mask. - OpBuilder<(ins "Type":$vector, "Value":$source, - "ValueRange":$indices, "AffineMapAttr":$permutationMap, "Value":$padding, - "ArrayAttr":$inBounds)>, - // Builder that does not set mask. - OpBuilder<(ins "Type":$vector, "Value":$source, - "ValueRange":$indices, "AffineMap":$permutationMap, "Value":$padding, - "ArrayAttr":$inBounds)> + /// 1. Builder that sets padding to zero and an empty mask (variant with attrs). + OpBuilder<(ins "VectorType":$vectorType, + "Value":$source, + "ValueRange":$indices, + "AffineMapAttr":$permutationMapAttr, + "ArrayAttr":$inBoundsAttr)>, + /// 2. Builder that sets padding to zero and an empty mask (variant without attrs). + OpBuilder<(ins "VectorType":$vectorType, + "Value":$source, + "ValueRange":$indices, + "AffineMap":$permutationMap, + CArg<"Optional>", "::llvm::None">:$inBounds)>, + /// 3. Builder that sets permutation map to 'getMinorIdentityMap'. + OpBuilder<(ins "VectorType":$vectorType, + "Value":$source, + "ValueRange":$indices, + "Value":$padding, + CArg<"Optional>", "::llvm::None">:$inBounds)>, + /// 4. Builder that sets padding to zero and permutation map to + /// 'getMinorIdentityMap'. + OpBuilder<(ins "VectorType":$vectorType, + "Value":$source, + "ValueRange":$indices, + CArg<"Optional>", "::llvm::None">:$inBounds)>, ]; - - let extraClassDeclaration = [{ - /// Temporary convenience builders to account for the fact that we do not - /// have 0-d vectors atm. These create a constant `vector<1xt>` and - /// insert/extract into it. - // Builder that sets permutation map (resp. padding) to - // 'getMinorIdentityMap' (resp. zero). - static Value createScalarOp(OpBuilder &builder, Location loc, Value source, - ValueRange indices, - ArrayRef inBounds = ArrayRef{}); - }]; - let hasCanonicalizer = 1; let hasFolder = 1; } @@ -1345,11 +1339,12 @@ DeclareOpInterfaceMethods, AttrSizedOperandSegments ]>, - Arguments<(ins AnyVector:$vector, AnyShaped:$source, - Variadic:$indices, - AffineMapAttr:$permutation_map, - Optional>:$mask, - OptionalAttr:$in_bounds)>, + Arguments<(ins AnyVectorOfAnyRank:$vector, + AnyShaped:$source, + Variadic:$indices, + AffineMapAttr:$permutation_map, + Optional>:$mask, + OptionalAttr:$in_bounds)>, Results<(outs Optional:$result)> { let summary = "The vector.transfer_write op writes a supervector to memory."; @@ -1367,7 +1362,7 @@ new tensor of the same type. The slice is further defined by a full-rank index within the MemRef/Tensor, - supplied as the operands `3 .. 2 + rank(memref/tensor)`. + supplied as the operands `[2 .. 2 + rank(memref/tensor))`. The permutation_map [attribute](../LangRef.md#attributes) is an [affine-map](Affine.md#affine-maps) which specifies the transposition on the @@ -1444,32 +1439,32 @@ }]; let builders = [ - // Builder that sets an empty mask. - OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices, - "AffineMap":$permutationMap, CArg<"ArrayRef", "{}">:$inBounds)>, - // Builder that sets permutation map to 'getMinorIdentityMap'. - OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices, - CArg<"ArrayRef", "{}">:$inBounds)>, - OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices, - "AffineMapAttr":$permutationMap, "ArrayAttr":$inBounds)>, - OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices, - "AffineMap":$permutationMap, "Value":$mask, "ArrayAttr":$inBounds)>, - OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices, - "AffineMap":$permutationMap, "ArrayAttr":$inBounds)>, + /// 1. Builder with type inference. + OpBuilder<(ins "Value":$vector, + "Value":$dest, + "ValueRange":$indices, + "AffineMapAttr":$permutationMapAttr, + "Value":$mask, + "ArrayAttr":$inBoundsAttr)>, + /// 2. Builder with type inference that sets an empty mask (variant with attrs). + OpBuilder<(ins "Value":$vector, + "Value":$dest, + "ValueRange":$indices, + "AffineMapAttr":$permutationMapAttr, + "ArrayAttr":$inBoundsAttr)>, + /// 3. Builder with type inference that sets an empty mask (variant without attrs). + OpBuilder<(ins "Value":$vector, + "Value":$dest, + "ValueRange":$indices, + "AffineMap":$permutationMap, + CArg<"Optional>", "::llvm::None">:$inBounds)>, + /// 4. Builder with type inference that sets an empty mask and sets permutation + /// map to 'getMinorIdentityMap'. + OpBuilder<(ins "Value":$vector, + "Value":$dest, + "ValueRange":$indices, + CArg<"Optional>", "::llvm::None">:$inBounds)>, ]; - - let extraClassDeclaration = [{ - /// Temporary convenience builders to account for the fact that we do not - /// have 0-d vectors atm. These create a constant `vector<1xt>` and - /// insert/extract into it. - // Builder that sets permutation map (resp. padding) to - // 'getMinorIdentityMap' (resp. zero). - static Operation *createScalarOp( - OpBuilder &builder, Location loc, Value value, - Value dest, ValueRange indices, - ArrayRef inBounds = ArrayRef{}); - }]; - let hasFolder = 1; let hasCanonicalizer = 1; } diff --git a/mlir/include/mlir/Interfaces/VectorInterfaces.td b/mlir/include/mlir/Interfaces/VectorInterfaces.td --- a/mlir/include/mlir/Interfaces/VectorInterfaces.td +++ b/mlir/include/mlir/Interfaces/VectorInterfaces.td @@ -114,29 +114,6 @@ /*methodBody=*/"return $_op.permutation_map();" /*defaultImplementation=*/ >, - InterfaceMethod< - /*desc=*/[{ - Returns true if op involves a 0-d tensor/memref and a vector - of shape {1}. This is temporary until we have 0-d vectors. - // TODO: turn this into 0-d vectors + empty permutation_map. - }], - /*retTy=*/"bool", - /*methodName=*/"isZeroD", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - if (getShapedType().getRank() > 0) - return false; - if (getVectorType().getShape() != ArrayRef{1}) - return false; - AffineMap map = AffineMap::get( - /*numDims=*/0, /*numSymbols=*/0, - getAffineConstantExpr(0, $_op->getContext())); - if ($_op.permutation_map() != map) - return false; - return true; - }] - >, InterfaceMethod< /*desc=*/[{ Returns true if the specified dimension is a broadcast. }], /*retTy=*/"bool", @@ -157,10 +134,7 @@ /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - // 0-d transfers are not considered broadcasts but they need to be - // represented with a vector<1xt> until we have 0-d vectors. - if ($_op.isZeroD()) return false; - for (unsigned i = 0; i < $_op.permutation_map().getNumResults(); ++i) { + for (unsigned i = 0, rank = getTransferRank(); i < rank; ++i) { if ($_op.isBroadcastDim(i)) return true; } diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -92,6 +92,10 @@ // Return true if the transfer op can be converted to a MMA matrix store. static bool transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) { + // TODO: support 0-d corner case. + if (writeOp.getTransferRank() == 0) + return false; + if (writeOp.mask() || writeOp.hasOutOfBoundsDim() || writeOp.getVectorType().getRank() != 2) return false; @@ -295,6 +299,11 @@ auto transferReadOp = op.vector().getDefiningOp(); if (!transferReadOp) return failure(); + + // TODO: support 0-d corner case. + if (transferReadOp.getTransferRank() == 0) + return failure(); + if (transferReadOp.mask() || transferReadOp.hasOutOfBoundsDim()) return failure(); SmallVector perm; @@ -307,8 +316,8 @@ AffineMap newMap = permutationMap.compose(transferReadOp.permutation_map()); rewriter.replaceOpWithNewOp( op, op.getType(), transferReadOp.source(), transferReadOp.indices(), - newMap, transferReadOp.padding(), transferReadOp.mask(), - transferReadOp.in_boundsAttr()); + AffineMapAttr::get(newMap), transferReadOp.padding(), + transferReadOp.mask(), transferReadOp.in_boundsAttr()); return success(); } }; @@ -335,6 +344,7 @@ static void convertTransferReadOp(vector::TransferReadOp op, llvm::DenseMap &valueMapping) { + assert(op.getTransferRank() > 0 && "unexpected 0-d transfer"); assert(transferReadSupportsMMAMatrixType(op)); Optional stride = getMemrefConstantHorizontalStride(op.getShapedType()); diff --git a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp --- a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp +++ b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp @@ -64,6 +64,10 @@ LogicalResult matchAndRewrite(ConcreteOp xferOp, typename ConcreteOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { + // TODO: support 0-d corner case. + if (xferOp.getTransferRank() == 0) + return failure(); + if (xferOp.getVectorType().getRank() > 1 || llvm::size(xferOp.indices()) == 0) return failure(); diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -52,6 +52,8 @@ /// A return value of None indicates a broadcast. template static Optional unpackedDim(OpTy xferOp) { + // TODO: support 0-d corner case. + assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer"); auto map = xferOp.permutation_map(); if (auto expr = map.getResult(0).template dyn_cast()) { return expr.getPosition(); @@ -66,6 +68,8 @@ /// omitted. template static AffineMap unpackedPermutationMap(OpBuilder &b, OpTy xferOp) { + // TODO: support 0-d corner case. + assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer"); auto map = xferOp.permutation_map(); return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(), b.getContext()); @@ -1081,6 +1085,7 @@ SmallVector &memrefIndices) { auto indices = xferOp.indices(); auto map = xferOp.permutation_map(); + assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer"); memrefIndices.append(indices.begin(), indices.end()); assert(map.getNumResults() == 1 && @@ -1206,6 +1211,9 @@ LogicalResult matchAndRewrite(OpTy xferOp, PatternRewriter &rewriter) const override { + // TODO: support 0-d corner case. + if (xferOp.getTransferRank() == 0) + return failure(); auto map = xferOp.permutation_map(); auto memRefType = xferOp.getShapedType().template dyn_cast(); diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp @@ -101,8 +101,7 @@ return failure(); b.create( writeOp.getLoc(), writeOp.vector(), resultBuffer, writeOp.indices(), - writeOp.permutation_map(), - writeOp.in_bounds() ? *writeOp.in_bounds() : ArrayAttr()); + writeOp.permutation_mapAttr(), writeOp.in_boundsAttr()); state.mapBuffer(op->getResult(0), resultBuffer); return success(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -115,8 +115,6 @@ /// ShapedType of `v`. static VectorType extractVectorTypeFromShapedValue(Value v) { auto st = v.getType().cast(); - if (st.getShape().empty()) - return VectorType(); return VectorType::get(st.getShape(), st.getElementType()); } @@ -179,21 +177,6 @@ return b.createOrFold(loc, targetVectorType, value); } -/// Build a vector.transfer_read from `source` at indices set to all `0`. -/// If source has rank zero, build a `vector<1xt> transfer_read + extract`. -/// Return the produced value. -static Value buildVectorRead(OpBuilder &b, Value source, Type readType, - AffineMap map) { - Location loc = source.getLoc(); - auto shapedType = source.getType().cast(); - SmallVector indices(shapedType.getRank(), - b.create(loc, 0)); - if (auto vectorType = readType.dyn_cast()) - return b.create(loc, vectorType, source, indices, - map); - return vector::TransferReadOp::createScalarOp(b, loc, source, indices); -} - /// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This /// assumes that `reductionOp` has two operands and one of them is the reduction /// initial value. @@ -226,8 +209,11 @@ Operation *write; Location loc = value.getLoc(); auto linalgOp = cast(outputOperand->getOwner()); - if (VectorType vectorType = - extractVectorTypeFromShapedValue(outputOperand->get())) { + ArrayRef shape = linalgOp.getShape(outputOperand); + auto vectorType = VectorType::get( + shape, getElementTypeOrSelf(outputOperand->get().getType())); + if (vectorType.getRank() > 0) { + // 0-d case is still special: do not invert the reindexing map. AffineMap map = reindexIndexingMap(linalgOp.getTiedIndexingMap(outputOperand)); SmallVector transposeShape = @@ -240,8 +226,11 @@ write = b.create(loc, value, outputOperand->get(), indices, map); } else { - write = vector::TransferWriteOp::createScalarOp( - b, loc, value, outputOperand->get(), ValueRange{}); + if (!value.getType().isa()) + value = b.create(loc, vectorType, value); + assert(value.getType() == vectorType && "incorrect type"); + write = b.create(loc, value, outputOperand->get(), + ValueRange{}); } LDBG("vectorized op: " << *write); if (!write->getResults().empty()) @@ -515,32 +504,42 @@ SmallVector commonVectorShape = linalgOp.computeStaticLoopSizes(); // 3. Turn all BBArgs into vector.transfer_read / load. - SmallVector indexings; + Location loc = linalgOp.getLoc(); + Value zero = b.create(loc, 0); for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { BlockArgument bbarg = block->getArgument(opOperand->getOperandNumber()); if (linalgOp.isScalar(opOperand)) { bvm.map(bbarg, opOperand->get()); continue; } - // TODO: 0-d vectors. - Type readType; + VectorType readType; AffineMap map; - if (linalgOp.getShape(opOperand).empty()) { - readType = bbarg.getType(); + // TODO: can we keep this simplification? + // if (linalgOp.getShape(opOperand).empty()) { + // readType = VectorType::get({}, bbarg.getType()); + // } else { + if (opOperand->getOperandNumber() < linalgOp.getNumInputs()) { + map = inverseAndBroadcastProjectedPermuation( + linalgOp.getTiedIndexingMap(opOperand)); + readType = VectorType::get(commonVectorShape, + getElementTypeOrSelf(opOperand->get())); } else { - if (opOperand->getOperandNumber() < linalgOp.getNumInputs()) { - map = inverseAndBroadcastProjectedPermuation( - linalgOp.getTiedIndexingMap(opOperand)); - readType = VectorType::get(commonVectorShape, - getElementTypeOrSelf(opOperand->get())); - } else { - map = inversePermutation( - reindexIndexingMap(linalgOp.getTiedIndexingMap(opOperand))); - readType = VectorType::get(map.compose(linalgOp.getShape(opOperand)), - getElementTypeOrSelf(opOperand->get())); - } + map = inversePermutation( + reindexIndexingMap(linalgOp.getTiedIndexingMap(opOperand))); + readType = VectorType::get(map.compose(linalgOp.getShape(opOperand)), + getElementTypeOrSelf(opOperand->get())); } - Value readValue = buildVectorRead(b, opOperand->get(), readType, map); + // } + + auto shape = linalgOp.getShape(opOperand); + SmallVector indices(shape.size(), zero); + Value readValue = b.create( + loc, readType, opOperand->get(), indices, map); + // Not all ops support 0-d vectors, extract the scalar for now. + // TODO: remove this. + if (readValue.getType().cast().getRank() == 0) + readValue = b.create(loc, readValue); + LDBG("new vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue); bvm.map(bbarg, readValue); bvm.map(opOperand->get(), readValue); @@ -752,7 +751,7 @@ rewriter.create(padOp.getLoc(), 0)); auto read = rewriter.create( padOp.getLoc(), vecType, padOp.source(), readIndices, padValue, - readInBounds); + ArrayRef{readInBounds}); // If `dest` is a FillOp and the TransferWriteOp would overwrite the entire // tensor, write directly to the FillOp's operand. @@ -765,7 +764,7 @@ auto writeIndices = ofrToIndexValues(rewriter, padOp.getLoc(), padOp.getMixedLowPad()); rewriter.replaceOpWithNewOp( - padOp, read, dest, writeIndices, writeInBounds); + padOp, read, dest, writeIndices, ArrayRef{writeInBounds}); return success(); } @@ -878,6 +877,10 @@ LogicalResult rewriteUser(PatternRewriter &rewriter, PadTensorOp padOp, vector::TransferWriteOp xferOp) const override { + // TODO: support 0-d corner case. + if (xferOp.getTransferRank() == 0) + return failure(); + // Low padding must be static 0. if (!padOp.hasZeroLowPad()) return failure(); @@ -1072,7 +1075,8 @@ ofrToIndexValues(rewriter, padOp.getLoc(), insertOp.getMixedOffsets()); SmallVector inBounds(vecRank, true); rewriter.replaceOpWithNewOp( - insertOp, read, insertOp.dest(), writeIndices, inBounds); + insertOp, read, insertOp.dest(), writeIndices, + ArrayRef{inBounds}); return success(); } @@ -1266,6 +1270,10 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite( vector::TransferReadOp xferOp, PatternRewriter &rewriter) const { + // TODO: support mask. + if (xferOp.mask()) + return failure(); + // Transfer into `view`. Value viewOrAlloc = xferOp.source(); if (!viewOrAlloc.getDefiningOp() && @@ -1328,7 +1336,9 @@ // conservatively. Value res = rewriter.create( xferOp.getLoc(), xferOp.getVectorType(), in, xferOp.indices(), - xferOp.permutation_map(), xferOp.padding(), ArrayAttr()); + xferOp.permutation_mapAttr(), xferOp.padding(), xferOp.mask(), + // in_bounds is explicitly reset + /*inBoundsAttr=*/ArrayAttr()); if (maybeFillOp) rewriter.eraseOp(maybeFillOp); @@ -1342,6 +1352,10 @@ /// when available. LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite( vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const { + // TODO: support mask. + if (xferOp.mask()) + return failure(); + // Transfer into `viewOrAlloc`. Value viewOrAlloc = xferOp.source(); if (!viewOrAlloc.getDefiningOp() && @@ -1380,7 +1394,9 @@ // conservatively. rewriter.create( xferOp.getLoc(), xferOp.vector(), out, xferOp.indices(), - xferOp.permutation_map(), ArrayAttr()); + xferOp.permutation_mapAttr(), xferOp.mask(), + // in_bounds is explicitly reset + /*inBoundsAttr=*/ArrayAttr()); rewriter.eraseOp(copyOp); rewriter.eraseOp(xferOp); diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp @@ -103,9 +103,9 @@ /// Given the permutation map of the original /// `vector.transfer_read`/`vector.transfer_write` operations compute the /// permutation map to use after the subview is folded with it. -static AffineMap getPermutationMap(MLIRContext *context, - memref::SubViewOp subViewOp, - AffineMap currPermutationMap) { +static AffineMapAttr getPermutationMapAttr(MLIRContext *context, + memref::SubViewOp subViewOp, + AffineMap currPermutationMap) { llvm::SmallDenseSet unusedDims = subViewOp.getDroppedDims(); SmallVector exprs; int64_t sourceRank = subViewOp.getSourceType().getRank(); @@ -115,7 +115,8 @@ exprs.push_back(getAffineDimExpr(dim, context)); } auto resultDimToSourceDimMap = AffineMap::get(sourceRank, 0, exprs, context); - return currPermutationMap.compose(resultDimToSourceDimMap); + return AffineMapAttr::get( + currPermutationMap.compose(resultDimToSourceDimMap)); } //===----------------------------------------------------------------------===// @@ -163,13 +164,18 @@ template <> void LoadOpOfSubViewFolder::replaceOp( - vector::TransferReadOp loadOp, memref::SubViewOp subViewOp, + vector::TransferReadOp transferReadOp, memref::SubViewOp subViewOp, ArrayRef sourceIndices, PatternRewriter &rewriter) const { + // TODO: support 0-d corner case. + if (transferReadOp.getTransferRank() == 0) + return; rewriter.replaceOpWithNewOp( - loadOp, loadOp.getVectorType(), subViewOp.source(), sourceIndices, - getPermutationMap(rewriter.getContext(), subViewOp, - loadOp.permutation_map()), - loadOp.padding(), loadOp.in_boundsAttr()); + transferReadOp, transferReadOp.getVectorType(), subViewOp.source(), + sourceIndices, + getPermutationMapAttr(rewriter.getContext(), subViewOp, + transferReadOp.permutation_map()), + transferReadOp.padding(), + /*mask=*/Value(), transferReadOp.in_boundsAttr()); } template <> @@ -184,11 +190,14 @@ void StoreOpOfSubViewFolder::replaceOp( vector::TransferWriteOp transferWriteOp, memref::SubViewOp subViewOp, ArrayRef sourceIndices, PatternRewriter &rewriter) const { + // TODO: support 0-d corner case. + if (transferWriteOp.getTransferRank() == 0) + return; rewriter.replaceOpWithNewOp( transferWriteOp, transferWriteOp.vector(), subViewOp.source(), sourceIndices, - getPermutationMap(rewriter.getContext(), subViewOp, - transferWriteOp.permutation_map()), + getPermutationMapAttr(rewriter.getContext(), subViewOp, + transferWriteOp.permutation_map()), transferWriteOp.in_boundsAttr()); } } // namespace diff --git a/mlir/lib/Dialect/Vector/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/VectorDropLeadUnitDim.cpp --- a/mlir/lib/Dialect/Vector/VectorDropLeadUnitDim.cpp +++ b/mlir/lib/Dialect/Vector/VectorDropLeadUnitDim.cpp @@ -133,6 +133,10 @@ LogicalResult matchAndRewrite(vector::TransferReadOp read, PatternRewriter &rewriter) const override { + // TODO: support 0-d corner case. + if (read.getTransferRank() == 0) + return failure(); + if (read.mask()) return failure(); @@ -153,14 +157,15 @@ AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults, rewriter.getContext()); - ArrayAttr inBounds; + ArrayAttr inBoundsAttr; if (read.in_bounds()) - inBounds = rewriter.getArrayAttr( + inBoundsAttr = rewriter.getArrayAttr( read.in_boundsAttr().getValue().take_back(newType.getRank())); auto newRead = rewriter.create( - read.getLoc(), newType, read.source(), read.indices(), newMap, - read.padding(), inBounds); + read.getLoc(), newType, read.source(), read.indices(), + AffineMapAttr::get(newMap), read.padding(), /*mask=*/Value(), + inBoundsAttr); rewriter.replaceOpWithNewOp(read, oldType, newRead); return success(); @@ -176,6 +181,10 @@ LogicalResult matchAndRewrite(vector::TransferWriteOp write, PatternRewriter &rewriter) const override { + // TODO: support 0-d corner case. + if (write.getTransferRank() == 0) + return failure(); + if (write.mask()) return failure(); @@ -196,15 +205,16 @@ AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults, rewriter.getContext()); - ArrayAttr inBounds; + ArrayAttr inBoundsAttr; if (write.in_bounds()) - inBounds = rewriter.getArrayAttr( + inBoundsAttr = rewriter.getArrayAttr( write.in_boundsAttr().getValue().take_back(newType.getRank())); auto newVector = rewriter.create( write.getLoc(), write.vector(), splatZero(dropDim)); rewriter.replaceOpWithNewOp( - write, newVector, write.source(), write.indices(), newMap, inBounds); + write, newVector, write.source(), write.indices(), + AffineMapAttr::get(newMap), inBoundsAttr); return success(); } diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -1613,8 +1613,8 @@ static_cast(destVectorType.getRank()))) return op.emitOpError("expected position attribute rank + source rank to " "match dest vector rank"); - if (!srcVectorType && (positionAttr.size() != - static_cast(destVectorType.getRank()))) + if (!srcVectorType && + (positionAttr.size() != static_cast(destVectorType.getRank()))) return op.emitOpError( "expected position attribute rank to match the dest vector rank"); for (auto en : llvm::enumerate(positionAttr)) { @@ -2314,6 +2314,59 @@ // TransferReadOp //===----------------------------------------------------------------------===// +/// 1. Builder that sets padding to zero and an empty mask (variant with attrs). +void TransferReadOp::build(OpBuilder &builder, OperationState &result, + VectorType vectorType, Value source, + ValueRange indices, AffineMapAttr permutationMapAttr, + /*optional*/ ArrayAttr inBoundsAttr) { + Type elemType = source.getType().cast().getElementType(); + Value padding = builder.create( + result.location, elemType, builder.getZeroAttr(elemType)); + build(builder, result, vectorType, source, indices, permutationMapAttr, + padding, /*mask=*/Value(), inBoundsAttr); +} + +/// 2. Builder that sets padding to zero an empty mask (variant without attrs). +void TransferReadOp::build(OpBuilder &builder, OperationState &result, + VectorType vectorType, Value source, + ValueRange indices, AffineMap permutationMap, + Optional> inBounds) { + auto permutationMapAttr = AffineMapAttr::get(permutationMap); + auto inBoundsAttr = (inBounds && !inBounds.getValue().empty()) + ? builder.getBoolArrayAttr(inBounds.getValue()) + : ArrayAttr(); + build(builder, result, vectorType, source, indices, permutationMapAttr, + inBoundsAttr); +} + +/// 3. Builder that sets permutation map to 'getMinorIdentityMap'. +void TransferReadOp::build(OpBuilder &builder, OperationState &result, + VectorType vectorType, Value source, + ValueRange indices, Value padding, + Optional> inBounds) { + AffineMap permutationMap = getTransferMinorIdentityMap( + source.getType().cast(), vectorType); + auto permutationMapAttr = AffineMapAttr::get(permutationMap); + auto inBoundsAttr = (inBounds && !inBounds.getValue().empty()) + ? builder.getBoolArrayAttr(inBounds.getValue()) + : ArrayAttr(); + build(builder, result, vectorType, source, indices, permutationMapAttr, + padding, + /*mask=*/Value(), inBoundsAttr); +} + +/// 4. Builder that sets padding to zero and permutation map to +/// 'getMinorIdentityMap'. +void TransferReadOp::build(OpBuilder &builder, OperationState &result, + VectorType vectorType, Value source, + ValueRange indices, + Optional> inBounds) { + Type elemType = source.getType().cast().getElementType(); + Value padding = builder.create( + result.location, elemType, builder.getZeroAttr(elemType)); + build(builder, result, vectorType, source, indices, padding, inBounds); +} + template static LogicalResult verifyPermutationMap(AffineMap permutationMap, EmitFun emitOpError) { @@ -2347,10 +2400,6 @@ verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, VectorType vectorType, VectorType maskType, AffineMap permutationMap, ArrayAttr inBounds) { - if (shapedType.getRank() == 0 && !op.isZeroD()) - return op->emitOpError("0-d transfer requires vector<1xt> shape and () -> " - "(0) permutation_map"); - if (op->hasAttr("masked")) { return op->emitOpError("masked attribute has been removed. " "Use in_bounds instead."); @@ -2359,6 +2408,7 @@ if (!shapedType.isa()) return op->emitOpError( "requires source to be a memref or ranked tensor type"); + auto elementType = shapedType.getElementType(); DataLayout dataLayout = DataLayout::closest(op); if (auto vectorElementType = elementType.dyn_cast()) { @@ -2389,9 +2439,10 @@ return op->emitOpError("does not support masks with vector element type"); } else { // Memref or tensor has scalar element type. + unsigned minorSize = + vectorType.getRank() == 0 ? 1 : vectorType.getShape().back(); unsigned resultVecSize = - dataLayout.getTypeSizeInBits(vectorType.getElementType()) * - vectorType.getShape().back(); + dataLayout.getTypeSizeInBits(vectorType.getElementType()) * minorSize; if (resultVecSize % dataLayout.getTypeSizeInBits(elementType) != 0) return op->emitOpError( "requires the bitwidth of the minor 1-D vector to be an integral " @@ -2412,8 +2463,8 @@ if (permutationMap.getNumSymbols() != 0) return op->emitOpError("requires permutation_map without symbols"); - // TODO: implement 0-d vector corner cases. - if (!op.isZeroD() && permutationMap.getNumInputs() != shapedType.getRank()) + + if (permutationMap.getNumInputs() != shapedType.getRank()) return op->emitOpError("requires a permutation_map with input dims of the " "same rank as the source type"); @@ -2421,7 +2472,8 @@ if (permutationMap.getNumResults() != static_cast(inBounds.size())) return op->emitOpError("expects the optional in_bounds attr of same rank " "as permutation_map results: ") - << AffineMapAttr::get(permutationMap); + << AffineMapAttr::get(permutationMap) + << " vs inBounds of size: " << inBounds.size(); for (unsigned int i = 0; i < permutationMap.getNumResults(); ++i) if (permutationMap.getResult(i).isa() && !inBounds.getValue()[i].cast().getValue()) @@ -2431,77 +2483,6 @@ return success(); } -/// Builder that sets padding to zero. -void TransferReadOp::build(OpBuilder &builder, OperationState &result, - VectorType vectorType, Value source, - ValueRange indices, AffineMap permutationMap, - ArrayRef inBounds) { - Type elemType = source.getType().cast().getElementType(); - Value padding = builder.create( - result.location, elemType, builder.getZeroAttr(elemType)); - if (inBounds.empty()) - return build(builder, result, vectorType, source, indices, permutationMap, - padding, ArrayAttr()); - ArrayAttr inBoundsArrayAttr = builder.getBoolArrayAttr(inBounds); - build(builder, result, vectorType, source, indices, permutationMap, padding, - inBoundsArrayAttr); -} - -/// Builder that sets permutation map to 'getMinorIdentityMap'. -void TransferReadOp::build(OpBuilder &builder, OperationState &result, - VectorType vectorType, Value source, - ValueRange indices, Value padding, - ArrayRef inBounds) { - auto permMap = getTransferMinorIdentityMap( - source.getType().cast(), vectorType); - if (inBounds.empty()) - return build(builder, result, vectorType, source, indices, permMap, padding, - ArrayAttr()); - ArrayAttr inBoundsArrayAttr = builder.getBoolArrayAttr(inBounds); - build(builder, result, vectorType, source, indices, permMap, padding, - inBoundsArrayAttr); -} - -/// Builder that sets permutation map (resp. padding) to 'getMinorIdentityMap' -/// (resp. zero). -void TransferReadOp::build(OpBuilder &builder, OperationState &result, - VectorType vectorType, Value source, - ValueRange indices, ArrayRef inBounds) { - auto permMap = getTransferMinorIdentityMap( - source.getType().cast(), vectorType); - build(builder, result, vectorType, source, indices, permMap, inBounds); -} - -/// Builder that does not provide a mask. -void TransferReadOp::build(OpBuilder &builder, OperationState &result, - Type vectorType, Value source, ValueRange indices, - AffineMap permutationMap, Value padding, - ArrayAttr inBounds) { - build(builder, result, vectorType, source, indices, permutationMap, padding, - /*mask=*/Value(), inBounds); -} - -/// Builder that does not provide a mask. -void TransferReadOp::build(OpBuilder &builder, OperationState &result, - Type vectorType, Value source, ValueRange indices, - AffineMapAttr permutationMap, Value padding, - ArrayAttr inBounds) { - build(builder, result, vectorType, source, indices, permutationMap, padding, - /*mask=*/Value(), inBounds); -} - -Value TransferReadOp::createScalarOp(OpBuilder &builder, Location loc, - Value source, ValueRange indices, - ArrayRef inBounds) { - Type elemType = source.getType().cast().getElementType(); - auto vectorType = VectorType::get(ArrayRef{1}, elemType); - AffineMap map = AffineMap::get(/*numDims=*/0, /*numSymbols=*/0, - getAffineConstantExpr(0, loc.getContext())); - Value read = builder.create(loc, vectorType, source, - indices, map, inBounds); - return builder.create(loc, read, ArrayRef{0}); -} - static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) { SmallVector elidedAttrs; elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr()); @@ -2563,6 +2544,7 @@ Attribute mapAttr = result.attributes.get(permutationAttrName); if (!mapAttr) { auto permMap = getTransferMinorIdentityMap(shapedType, vectorType); + // Update `mapAttr` that is used later to determine mask type. mapAttr = AffineMapAttr::get(permMap); result.attributes.set(permutationAttrName, mapAttr); } @@ -2677,8 +2659,9 @@ template static LogicalResult foldTransferInBoundsAttribute(TransferOp op) { - // TODO: Be less conservative once we have 0-d vectors. - if (op.isZeroD()) + // TODO: support 0-d corner case. + // TODO: Be less conservative. + if (op.getTransferRank() == 0) return failure(); AffineMap permutationMap = op.permutation_map(); bool changed = false; @@ -2783,6 +2766,9 @@ LogicalResult matchAndRewrite(TransferReadOp xferOp, PatternRewriter &rewriter) const override { + // TODO: support 0-d corner case. + if (xferOp.getTransferRank() == 0) + return failure(); if (xferOp.hasOutOfBoundsDim()) return failure(); if (!xferOp.permutation_map().isIdentity()) @@ -2814,9 +2800,9 @@ offset))); } SmallVector inBounds(xferOp.getTransferRank(), true); - rewriter.replaceOpWithNewOp(xferOp, xferOp.getVectorType(), - extractOp.source(), newIndices, - xferOp.padding(), inBounds); + rewriter.replaceOpWithNewOp( + xferOp, xferOp.getVectorType(), extractOp.source(), newIndices, + xferOp.padding(), ArrayRef{inBounds}); return success(); } @@ -2832,69 +2818,49 @@ // TransferWriteOp //===----------------------------------------------------------------------===// +/// 1. Builder with type inference. void TransferWriteOp::build(OpBuilder &builder, OperationState &result, Value vector, Value dest, ValueRange indices, - AffineMap permutationMap, ArrayRef inBounds) { - if (inBounds.empty()) - return build(builder, result, vector, dest, indices, permutationMap, - /*mask=*/Value(), ArrayAttr()); - build(builder, result, vector, dest, indices, permutationMap, - /*mask=*/Value(), builder.getBoolArrayAttr(inBounds)); -} - -/// Builder that sets permutation map to 'getMinorIdentityMap'. -void TransferWriteOp::build(OpBuilder &builder, OperationState &result, - Value vector, Value source, ValueRange indices, - ArrayRef inBounds) { - auto vectorType = vector.getType().cast(); - auto permMap = getTransferMinorIdentityMap( - source.getType().cast(), vectorType); - if (inBounds.empty()) - return build(builder, result, vector, source, indices, permMap, - ArrayAttr()); - ArrayAttr inBoundsArrayAttr = builder.getBoolArrayAttr(inBounds); - build(builder, result, vector, source, indices, permMap, inBoundsArrayAttr); + AffineMapAttr permutationMapAttr, + /*optional*/ Value mask, + /*optional*/ ArrayAttr inBoundsAttr) { + Type resultType = dest.getType().dyn_cast(); + build(builder, result, resultType, vector, dest, indices, permutationMapAttr, + mask, inBoundsAttr); } +/// 2. Builder with type inference that sets an empty mask (variant with attrs). void TransferWriteOp::build(OpBuilder &builder, OperationState &result, - Value vector, Value source, ValueRange indices, - AffineMapAttr permutationMap, - /*optional*/ ArrayAttr inBounds) { - Type resultType = source.getType().dyn_cast(); - build(builder, result, resultType, vector, source, indices, permutationMap, - /*mask=*/Value(), inBounds); + Value vector, Value dest, ValueRange indices, + AffineMapAttr permutationMapAttr, + /*optional*/ ArrayAttr inBoundsAttr) { + build(builder, result, vector, dest, indices, permutationMapAttr, + /*mask=*/Value(), inBoundsAttr); } +/// 3. Builder with type inference that sets an empty mask (variant without +/// attrs) void TransferWriteOp::build(OpBuilder &builder, OperationState &result, - Value vector, Value source, ValueRange indices, + Value vector, Value dest, ValueRange indices, AffineMap permutationMap, - /*optional*/ ArrayAttr inBounds) { - Type resultType = source.getType().dyn_cast(); - build(builder, result, resultType, vector, source, indices, permutationMap, - /*mask=*/Value(), inBounds); + Optional> inBounds) { + auto permutationMapAttr = AffineMapAttr::get(permutationMap); + auto inBoundsAttr = (inBounds && !inBounds.getValue().empty()) + ? builder.getBoolArrayAttr(inBounds.getValue()) + : ArrayAttr(); + build(builder, result, vector, dest, indices, permutationMapAttr, + /*mask=*/Value(), inBoundsAttr); } +/// 4. Builder with type inference that sets an empty mask and sets permutation +/// map to 'getMinorIdentityMap'. void TransferWriteOp::build(OpBuilder &builder, OperationState &result, - Value vector, Value source, ValueRange indices, - AffineMap permutationMap, /*optional*/ Value mask, - /*optional*/ ArrayAttr inBounds) { - Type resultType = source.getType().dyn_cast(); - build(builder, result, resultType, vector, source, indices, permutationMap, - mask, inBounds); -} - -Operation *TransferWriteOp::createScalarOp(OpBuilder &builder, Location loc, - Value value, Value dest, - ValueRange indices, - ArrayRef inBounds) { - Value vectorOfAScalar = value; - if (!value.getType().isa()) - vectorOfAScalar = builder.create( - loc, VectorType::get({1}, value.getType()), value); - AffineMap map = AffineMap::get(/*numDims=*/0, /*numSymbols=*/0, - getAffineConstantExpr(0, loc.getContext())); - return builder.create(loc, vectorOfAScalar, dest, - indices, map, inBounds); + Value vector, Value dest, ValueRange indices, + Optional> inBounds) { + auto vectorType = vector.getType().cast(); + AffineMap permutationMap = getTransferMinorIdentityMap( + dest.getType().cast(), vectorType); + build(builder, result, vector, dest, indices, permutationMap, inBounds); } static ParseResult parseTransferWriteOp(OpAsmParser &parser, @@ -3003,6 +2969,9 @@ static LogicalResult foldReadInitWrite(TransferWriteOp write, ArrayRef, SmallVectorImpl &results) { + // TODO: support 0-d corner case. + if (write.getTransferRank() == 0) + return failure(); auto rankedTensorType = write.source().getType().dyn_cast(); // If not operating on tensors, bail. if (!rankedTensorType) @@ -3011,6 +2980,9 @@ auto read = write.vector().getDefiningOp(); if (!read) return failure(); + // TODO: support 0-d corner case. + if (read.getTransferRank() == 0) + return failure(); // For now, only accept minor identity. Future: composition is minor identity. if (!read.permutation_map().isMinorIdentity() || !write.permutation_map().isMinorIdentity()) @@ -3179,9 +3151,14 @@ PatternRewriter &rewriter) const override { if (!insertOp.hasUnitStride()) return failure(); + auto xferOp = insertOp.source().getDefiningOp(); if (!xferOp) return failure(); + // TODO: support 0-d corner case. + if (xferOp.getTransferRank() == 0) + return failure(); + if (xferOp.hasOutOfBoundsDim()) return failure(); if (xferOp.getVectorType().getRank() != xferOp.getShapedType().getRank()) @@ -3200,8 +3177,9 @@ SmallVector indices = getValueOrCreateConstantIndexOp( rewriter, insertOp.getLoc(), insertOp.getMixedOffsets()); SmallVector inBounds(xferOp.getTransferRank(), true); - rewriter.replaceOpWithNewOp( - insertOp, xferOp.vector(), insertOp.dest(), indices, inBounds); + rewriter.replaceOpWithNewOp(insertOp, xferOp.vector(), + insertOp.dest(), indices, + ArrayRef{inBounds}); return success(); } }; diff --git a/mlir/lib/Dialect/Vector/VectorTransferPermutationMapRewritePatterns.cpp b/mlir/lib/Dialect/Vector/VectorTransferPermutationMapRewritePatterns.cpp --- a/mlir/lib/Dialect/Vector/VectorTransferPermutationMapRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransferPermutationMapRewritePatterns.cpp @@ -31,6 +31,7 @@ attr.getValue()[pos].cast().getValue()); return builder.getBoolArrayAttr(newInBoundsValues); } + /// Lower transfer_read op with permutation into a transfer_read with a /// permutation map composed of leading zeros followed by a minor identiy + /// vector.transpose op. @@ -56,6 +57,10 @@ LogicalResult matchAndRewrite(vector::TransferReadOp op, PatternRewriter &rewriter) const override { + // TODO: support 0-d corner case. + if (op.getTransferRank() == 0) + return failure(); + SmallVector permutation; AffineMap map = op.permutation_map(); if (map.getNumResults() == 0) @@ -99,7 +104,7 @@ } // Transpose in_bounds attribute. - ArrayAttr newInBounds = + ArrayAttr newInBoundsAttr = op.in_bounds() ? transposeInBoundsAttr( rewriter, op.in_bounds().getValue(), permutation) : ArrayAttr(); @@ -108,8 +113,8 @@ VectorType newReadType = VectorType::get(newVectorShape, op.getVectorType().getElementType()); Value newRead = rewriter.create( - op.getLoc(), newReadType, op.source(), op.indices(), newMap, - op.padding(), newMask, newInBounds); + op.getLoc(), newReadType, op.source(), op.indices(), + AffineMapAttr::get(newMap), op.padding(), newMask, newInBoundsAttr); // Transpose result of transfer_read. SmallVector transposePerm(permutation.begin(), permutation.end()); @@ -141,7 +146,8 @@ LogicalResult matchAndRewrite(vector::TransferWriteOp op, PatternRewriter &rewriter) const override { - if (op.isZeroD()) + // TODO: support 0-d corner case. + if (op.getTransferRank() == 0) return failure(); SmallVector permutation; @@ -168,7 +174,7 @@ : Value(); // Transpose in_bounds attribute. - ArrayAttr newInBounds = + ArrayAttr newInBoundsAttr = op.in_bounds() ? transposeInBoundsAttr( rewriter, op.in_bounds().getValue(), permutation) : ArrayAttr(); @@ -179,8 +185,8 @@ auto newMap = AffineMap::getMinorIdentityMap( map.getNumDims(), map.getNumResults(), rewriter.getContext()); rewriter.replaceOpWithNewOp( - op, Type(), newVec, op.source(), op.indices(), newMap, newMask, - newInBounds); + op, Type(), newVec, op.source(), op.indices(), + AffineMapAttr::get(newMap), newMask, newInBoundsAttr); return success(); } @@ -199,6 +205,10 @@ LogicalResult matchAndRewrite(vector::TransferReadOp op, PatternRewriter &rewriter) const override { + // TODO: support 0-d corner case. + if (op.getTransferRank() == 0) + return failure(); + AffineMap map = op.permutation_map(); unsigned numLeadingBroadcast = 0; for (auto expr : map.getResults()) { @@ -245,14 +255,14 @@ return failure(); VectorType newReadType = VectorType::get(newShape, originalVecType.getElementType()); - ArrayAttr newInBounds = + ArrayAttr newInBoundsAttr = op.in_bounds() ? rewriter.getArrayAttr( op.in_boundsAttr().getValue().take_back(reducedShapeRank)) : ArrayAttr(); Value newRead = rewriter.create( - op.getLoc(), newReadType, op.source(), op.indices(), newMap, - op.padding(), op.mask(), newInBounds); + op.getLoc(), newReadType, op.source(), op.indices(), + AffineMapAttr::get(newMap), op.padding(), op.mask(), newInBoundsAttr); rewriter.replaceOpWithNewOp(op, originalVecType, newRead); return success(); diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -229,7 +229,9 @@ options(options) {} LogicalResult matchAndRewrite(vector::TransferReadOp readOp, PatternRewriter &rewriter) const override { - + // TODO: support 0-d corner case. + if (readOp.getTransferRank() == 0) + return failure(); if (readOp.mask()) return failure(); auto targetShape = getTargetShape(options, readOp); @@ -254,9 +256,9 @@ sliceTransferIndices(i, originalSize, *targetShape, originalIndices, readOp.permutation_map(), loc, rewriter); auto slicedRead = rewriter.create( - loc, targetType, readOp.source(), indices, readOp.permutation_map(), - readOp.padding(), - readOp.in_bounds() ? *readOp.in_bounds() : ArrayAttr()); + loc, targetType, readOp.source(), indices, + readOp.permutation_mapAttr(), readOp.padding(), readOp.mask(), + readOp.in_boundsAttr()); SmallVector elementOffsets = getVectorOffset(originalSize, *targetShape, i); @@ -279,6 +281,10 @@ options(options) {} LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, PatternRewriter &rewriter) const override { + // TODO: support 0-d corner case. + if (writeOp.getTransferRank() == 0) + return failure(); + if (writeOp.mask()) return failure(); auto targetShape = getTargetShape(options, writeOp); @@ -305,8 +311,7 @@ writeOp.permutation_map(), loc, rewriter); Operation *slicedWrite = rewriter.create( loc, slicedVector, resultTensor ? resultTensor : writeOp.source(), - indices, writeOp.permutation_map(), - writeOp.in_bounds() ? *writeOp.in_bounds() : ArrayAttr()); + indices, writeOp.permutation_mapAttr(), writeOp.in_boundsAttr()); // For the tensor case update the destination for the next transfer write. if (!slicedWrite->getResults().empty()) resultTensor = slicedWrite->getResult(0); @@ -2057,6 +2062,10 @@ /// rank-reducing subviews. static LogicalResult splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp) { + // TODO: support 0-d corner case. + if (xferOp.getTransferRank() == 0) + return failure(); + // TODO: expand support to these 2 cases. if (!xferOp.permutation_map().isMinorIdentity()) return failure(); @@ -2682,6 +2691,10 @@ : OpRewritePattern(context) {} LogicalResult matchAndRewrite(vector::TransferReadOp read, PatternRewriter &rewriter) const override { + // TODO: support 0-d corner case. + if (read.getTransferRank() == 0) + return failure(); + if (!read.getResult().hasOneUse()) return failure(); auto extract = @@ -2711,8 +2724,8 @@ {indices[indexPos], extract.ids()[idCount++]}); } Value newRead = lb.create( - extract.getType(), read.source(), indices, read.permutation_map(), - read.padding(), read.in_boundsAttr()); + extract.getType(), read.source(), indices, read.permutation_mapAttr(), + read.padding(), read.mask(), read.in_boundsAttr()); Value dest = lb.create( read.getType(), rewriter.getZeroAttr(read.getType())); newRead = lb.create(newRead, dest, extract.ids()); @@ -2727,6 +2740,10 @@ : OpRewritePattern(context) {} LogicalResult matchAndRewrite(vector::TransferWriteOp write, PatternRewriter &rewriter) const override { + // TODO: support 0-d corner case. + if (write.getTransferRank() == 0) + return failure(); + auto insert = write.vector().getDefiningOp(); if (!insert) return failure(); @@ -2754,8 +2771,8 @@ {indices[indexPos], insert.ids()[idCount++]}); } rewriter.create( - loc, insert.vector(), write.source(), indices, write.permutation_map(), - write.in_boundsAttr()); + loc, insert.vector(), write.source(), indices, + write.permutation_mapAttr(), write.in_boundsAttr()); rewriter.eraseOp(write); return success(); } @@ -2780,15 +2797,19 @@ PatternRewriter &rewriter) const override { if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank) return failure(); + SmallVector broadcastedDims; // Permutations are handled by VectorToSCF or // populateVectorTransferPermutationMapLoweringPatterns. + // We let the 0-d corner case pass-through as it is supported. if (!read.permutation_map().isMinorIdentityWithBroadcasting( &broadcastedDims)) return failure(); + auto memRefType = read.getShapedType().dyn_cast(); if (!memRefType) return failure(); + // Non-unit strides are handled by VectorToSCF. if (!vector::isLastMemrefDimUnitStride(memRefType)) return failure(); @@ -2808,6 +2829,7 @@ auto memrefElTy = memRefType.getElementType(); if (memrefElTy.isa() && memrefElTy != unbroadcastedVectorType) return failure(); + // Otherwise, element types of the memref and the vector must match. if (!memrefElTy.isa() && memrefElTy != read.getVectorType().getElementType()) @@ -2845,7 +2867,14 @@ llvm::Optional maxTransferRank; }; -/// Replace a scalar vector.load with a memref.load. +/// Replace a 0-d vector.load with a memref.load + vector.broadcast. +// TODO: we shouldn't cross the vector/scalar domains just for this +// but atm we lack the infra to avoid it. Possible solutions include: +// - go directly to LLVM + bitcast +// - introduce a bitcast op and likely a new pointer dialect +// - let memref.load/store additionally support the 0-d vector case +// There are still deeper data layout issues lingering even in this +// trivial case (for architectures for which this matters). struct VectorLoadToMemrefLoadLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -2857,13 +2886,13 @@ return failure(); auto memrefLoad = rewriter.create( loadOp.getLoc(), loadOp.base(), loadOp.indices()); - rewriter.replaceOpWithNewOp( - loadOp, VectorType::get({1}, vecType.getElementType()), memrefLoad); + rewriter.replaceOpWithNewOp(loadOp, vecType, + memrefLoad); return success(); } }; -/// Replace a scalar vector.store with a memref.store. +/// Replace a 0-d vector.store with a vector.extractelement + memref.store. struct VectorStoreToMemrefStoreLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -2873,9 +2902,17 @@ auto vecType = storeOp.getVectorType(); if (vecType.getNumElements() != 1) return failure(); - SmallVector indices(vecType.getRank(), 0); - Value extracted = rewriter.create( - storeOp.getLoc(), storeOp.valueToStore(), indices); + Value extracted; + if (vecType.getRank() == 0) { + // TODO: Unifiy once ExtractOp supports 0-d vectors. + extracted = rewriter.create( + storeOp.getLoc(), storeOp.valueToStore()); + } else { + SmallVector indices(vecType.getRank(), 0); + extracted = rewriter.create( + storeOp.getLoc(), storeOp.valueToStore(), indices); + } + rewriter.replaceOpWithNewOp( storeOp, extracted, storeOp.base(), storeOp.indices()); return success(); @@ -2901,25 +2938,32 @@ PatternRewriter &rewriter) const override { if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank) return failure(); + // Permutations are handled by VectorToSCF or // populateVectorTransferPermutationMapLoweringPatterns. - if (!write.isZeroD() && !write.permutation_map().isMinorIdentity()) + if ( // pass-through for the 0-d corner case. + !write.permutation_map().isMinorIdentity()) return failure(); + auto memRefType = write.getShapedType().dyn_cast(); if (!memRefType) return failure(); + // Non-unit strides are handled by VectorToSCF. if (!vector::isLastMemrefDimUnitStride(memRefType)) return failure(); + // `vector.store` supports vector types as memref's elements only when the // type of the vector value being written is the same as the element type. auto memrefElTy = memRefType.getElementType(); if (memrefElTy.isa() && memrefElTy != write.getVectorType()) return failure(); + // Otherwise, element types of the memref and the vector must match. if (!memrefElTy.isa() && memrefElTy != write.getVectorType().getElementType()) return failure(); + // Out-of-bounds dims are handled by MaterializeTransferMask. if (write.hasOutOfBoundsDim()) return failure(); @@ -3319,6 +3363,14 @@ LogicalResult matchAndRewrite(vector::TransferReadOp readOp, PatternRewriter &rewriter) const override { + // TODO: support 0-d corner case. + if (readOp.getTransferRank() == 0) + return failure(); + + // TODO: support mask. + if (readOp.mask()) + return failure(); + auto srcType = readOp.source().getType().dyn_cast(); if (!srcType || !srcType.hasStaticShape()) return failure(); @@ -3375,7 +3427,7 @@ SmallVector offsets(srcType.getRank(), 0); SmallVector strides(srcType.getRank(), 1); - ArrayAttr inBounds = + ArrayAttr inBoundsAttr = readOp.in_bounds() ? rewriter.getArrayAttr( readOp.in_boundsAttr().getValue().drop_back(dimsToDrop)) @@ -3387,8 +3439,10 @@ rankedReducedView.getType().cast(), resultTargetVecType); Value result = rewriter.create( loc, resultTargetVecType, rankedReducedView, - readOp.indices().drop_back(dimsToDrop), permMap, readOp.padding(), - inBounds); + readOp.indices().drop_back(dimsToDrop), AffineMapAttr::get(permMap), + readOp.padding(), + // TODO: support mask. + /*mask=*/Value(), inBoundsAttr); rewriter.replaceOpWithNewOp(readOp, targetType, result); return success(); diff --git a/mlir/lib/Interfaces/VectorInterfaces.cpp b/mlir/lib/Interfaces/VectorInterfaces.cpp --- a/mlir/lib/Interfaces/VectorInterfaces.cpp +++ b/mlir/lib/Interfaces/VectorInterfaces.cpp @@ -20,7 +20,7 @@ shape.push_back(vecType.getDimSize(i)); } } - return shape.empty() ? VectorType() : VectorType::get(shape, i1Type); + return VectorType::get(shape, i1Type); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir --- a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir +++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir @@ -2,25 +2,20 @@ // RUN: mlir-opt %s -convert-vector-to-scf=full-unroll=true -split-input-file -allow-unregistered-dialect | FileCheck %s --check-prefix=FULL-UNROLL // CHECK-LABEL: func @vector_transfer_ops_0d( -// CHECK-SAME: %[[MEM:.*]]: memref) { func @vector_transfer_ops_0d(%M: memref) { - %f0 = arith.constant 0.0 : f32 - -// CHECK: %[[V0:.*]] = arith.constant dense<0{{.*}}> : vector<1xf32> -// CHECK: %[[R0:.*]] = scf.for %[[I:.*]] = {{.*}} iter_args(%[[V0_ITER:.*]] = %[[V0]]) -> (vector<1xf32>) { -// CHECK: %[[S:.*]] = memref.load %[[MEM]][] : memref -// CHECK: %[[R_ITER:.*]] = vector.insertelement %[[S]], %[[V0_ITER]][%[[I]] : index] : vector<1xf32> -// CHECK: scf.yield %[[R_ITER]] : vector<1xf32> - %0 = vector.transfer_read %M[], %f0 {permutation_map = affine_map<()->(0)>} : - memref, vector<1xf32> - -// CHECK: scf.for %[[J:.*]] = %{{.*}} -// CHECK: %[[SS:.*]] = vector.extractelement %[[R0]][%[[J]] : index] : vector<1xf32> -// CHECK: memref.store %[[SS]], %[[MEM]][] : memref - vector.transfer_write %0, %M[] {permutation_map = affine_map<()->(0)>} : - vector<1xf32>, memref - - return + %f0 = arith.constant 0.0 : f32 + + // 0-d transfers are left untouched by vector-to-scf. + // They are independently lowered to the proper memref.load/store. + // CHECK: vector.transfer_read {{.*}}: memref, vector + %0 = vector.transfer_read %M[], %f0 {permutation_map = affine_map<()->()>} : + memref, vector + + // CHECK: vector.transfer_write {{.*}}: vector, memref + vector.transfer_write %0, %M[] {permutation_map = affine_map<()->()>} : + vector, memref + + return } // ----- diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -200,8 +200,8 @@ // CHECK-LABEL: func @test_vectorize_fill func @test_vectorize_fill_scalar(%A : memref, %arg0 : f32) { // CHECK-SAME: (%[[M:.*]]: memref, %[[val:.*]]: f32) - // CHECK: %[[VEC:.*]] = vector.broadcast %[[val]] : f32 to vector<1xf32> - // CHECK: vector.transfer_write %[[VEC]], %[[M]][] {{.*}} : vector<1xf32>, memref + // CHECK: %[[VEC:.*]] = vector.broadcast %[[val]] : f32 to vector + // CHECK: vector.transfer_write %[[VEC]], %[[M]][] : vector, memref linalg.fill(%arg0, %A) : f32, memref return } @@ -221,10 +221,10 @@ // CHECK-LABEL: func @test_vectorize_copy_scalar func @test_vectorize_copy_scalar(%A : memref, %B : memref) { // CHECK-SAME: (%[[A:.*]]: memref, %[[B:.*]]: memref) - // CHECK: %[[V:.*]] = vector.transfer_read %[[A]][]{{.*}} : memref, vector<1xf32> - // CHECK: %[[val:.*]] = vector.extract %[[V]][0] : vector<1xf32> - // CHECK: %[[VV:.*]] = vector.broadcast %[[val]] : f32 to vector<1xf32> - // CHECK: vector.transfer_write %[[VV]], %[[B]][] {{.*}} : vector<1xf32>, memref + // CHECK: %[[V:.*]] = vector.transfer_read %[[A]][]{{.*}} : memref, vector + // CHECK: %[[val:.*]] = vector.extractelement %[[V]][] : vector + // CHECK: %[[VV:.*]] = vector.broadcast %[[val]] : f32 to vector + // CHECK: vector.transfer_write %[[VV]], %[[B]][] : vector, memref linalg.copy(%A, %B) : memref, memref return } @@ -1005,7 +1005,7 @@ // CHECK-LABEL: func @reduce_1d( // CHECK-SAME: %[[A:.*]]: tensor<32xf32> func @reduce_1d(%arg0: tensor<32xf32>) -> tensor { - // CHECK-DAG: %[[F0_v1:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32> + // CHECK-DAG: %[[vF0:.*]] = arith.constant dense<0.000000e+00> : vector // CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f32 // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index %f0 = arith.constant 0.000000e+00 : f32 @@ -1013,17 +1013,18 @@ // CHECK: %[[init:.*]] = linalg.init_tensor [] : tensor %0 = linalg.init_tensor [] : tensor - // CHECK: %[[f:.*]] = vector.transfer_write %[[F0_v1]], %[[init]][] - // CHECK-SAME: : vector<1xf32>, tensor + // CHECK: %[[f:.*]] = vector.transfer_write %[[vF0]], %[[init]][] + // CHECK-SAME: : vector, tensor %1 = linalg.fill(%f0, %0) : f32, tensor -> tensor // CHECK: %[[r:.*]] = vector.transfer_read %[[A]][%[[C0]]] // CHECK-SAME: : tensor<32xf32>, vector<32xf32> + // CHECK: %[[f0:.*]] = vector.extractelement %[[vF0]][] : vector // CHECK: %[[red:.*]] = vector.multi_reduction #vector.kind, %[[r]] [0] // CHECK-SAME: : vector<32xf32> to f32 - // CHECK: %[[a:.*]] = arith.addf %[[red]], %[[F0]] : f32 - // CHECK: %[[red_v1:.*]] = vector.broadcast %[[a]] : f32 to vector<1xf32> + // CHECK: %[[a:.*]] = arith.addf %[[red]], %[[f0]] : f32 + // CHECK: %[[red_v1:.*]] = vector.broadcast %[[a]] : f32 to vector // CHECK: %[[res:.*]] = vector.transfer_write %[[red_v1]], %[[f]][] - // CHECK-SAME: : vector<1xf32>, tensor + // CHECK-SAME: : vector, tensor %2 = linalg.generic { indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>], diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1427,15 +1427,3 @@ %0 = vector.insert_map %v, %v1[%id] : vector<2x1xf32> into vector<4x32xf32> } -// ----- - -func @vector_transfer_ops_0d(%arg0: tensor) - -> tensor { - %f0 = arith.constant 0.0 : f32 - // expected-error@+1 {{0-d transfer requires vector<1xt> shape and () -> (0) permutation_map}} - %0 = vector.transfer_read %arg0[], %f0 {permutation_map = affine_map<(d0)->(d0)>} : - tensor, vector<1xf32> - %1 = vector.transfer_write %0, %arg0[] {permutation_map = affine_map<()->(0)>} : - vector<1xf32>, tensor - return %1: tensor -} diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -4,17 +4,33 @@ func @vector_transfer_ops_0d(%arg0: tensor, %arg1: memref) -> tensor { %f0 = arith.constant 0.0 : f32 - %0 = vector.transfer_read %arg0[], %f0 {permutation_map = affine_map<()->(0)>} : - tensor, vector<1xf32> - %1 = vector.transfer_write %0, %arg0[] {permutation_map = affine_map<()->(0)>} : - vector<1xf32>, tensor - %2 = vector.transfer_read %arg1[], %f0 {permutation_map = affine_map<()->(0)>} : - memref, vector<1xf32> - vector.transfer_write %2, %arg1[] {permutation_map = affine_map<()->(0)>} : - vector<1xf32>, memref + %0 = vector.transfer_read %arg0[], %f0 {permutation_map = affine_map<()->()>} : + tensor, vector + %1 = vector.transfer_write %0, %arg0[] {permutation_map = affine_map<()->()>} : + vector, tensor + %2 = vector.transfer_read %arg1[], %f0 {permutation_map = affine_map<()->()>} : + memref, vector + vector.transfer_write %2, %arg1[] {permutation_map = affine_map<()->()>} : + vector, memref return %1: tensor } +// CHECK-LABEL: func @vector_transfer_ops_0d_from_higher_d( +func @vector_transfer_ops_0d_from_higher_d(%arg0: tensor, %arg1: memref) + -> tensor { + %c0 = arith.constant 0 : index + %f0 = arith.constant 0.0 : f32 + %0 = vector.transfer_read %arg0[%c0], %f0 {permutation_map = affine_map<(d0)->()>} : + tensor, vector + %1 = vector.transfer_write %0, %arg0[%c0] {permutation_map = affine_map<(d0)->()>} : + vector, tensor + %2 = vector.transfer_read %arg1[%c0, %c0], %f0 {permutation_map = affine_map<(d0, d1)->()>} : + memref, vector + vector.transfer_write %2, %arg1[%c0, %c0] {permutation_map = affine_map<(d0, d1)->()>} : + vector, memref + return %1: tensor +} + // CHECK-LABEL: func @vector_transfer_ops( func @vector_transfer_ops(%arg0: memref, %arg1 : memref>, diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir @@ -6,13 +6,13 @@ func @vector_transfer_ops_0d_memref(%M: memref, %v: vector<1x1x1xf32>) { %f0 = arith.constant 0.0 : f32 -// CHECK-NEXT: %[[V:.*]] = memref.load %[[MEM]][] : memref - %0 = vector.transfer_read %M[], %f0 {permutation_map = affine_map<()->(0)>} : - memref, vector<1xf32> +// CHECK-NEXT: %[[s:.*]] = memref.load %[[MEM]][] : memref +// CHECK-NEXT: %[[V:.*]] = vector.broadcast %[[s]] : f32 to vector + %0 = vector.transfer_read %M[], %f0 : memref, vector -// CHECK-NEXT: memref.store %[[V]], %[[MEM]][] : memref - vector.transfer_write %0, %M[] {permutation_map = affine_map<()->(0)>} : - vector<1xf32>, memref +// CHECK-NEXT: %[[ss:.*]] = vector.extractelement %[[V]][] : vector +// CHECK-NEXT: memref.store %[[ss]], %[[MEM]][] : memref + vector.transfer_write %0, %M[] : vector, memref // CHECK-NEXT: %[[VV:.*]] = vector.extract %arg1[0, 0, 0] : vector<1x1x1xf32> // CHECK-NEXT: memref.store %[[VV]], %[[MEM]][] : memref