diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -302,7 +302,7 @@ Results<(outs AnyType:$dest)> { let summary = "Multi-dimensional reduction operation"; let description = [{ - Reduces an n-D vector into an (n-k)-D vector (or a scalar when k == n) + Reduces an n-D vector into an (n-k)-D vector (or a scalar when k == n) using the given operation (add/mul/min/max for int/fp and and/or/xor for int only). @@ -380,7 +380,7 @@ PredOpTrait<"source operand and result have same element type", TCresVTEtIsSameAsOpBase<0, 0>>]>, Arguments<(ins AnyType:$source)>, - Results<(outs AnyVector:$vector)> { + Results<(outs AnyVectorOfAnyRank:$vector)> { let summary = "broadcast operation"; let description = [{ Broadcasts the scalar or k-D vector value in the source operand @@ -1133,32 +1133,35 @@ DeclareOpInterfaceMethods, AttrSizedOperandSegments ]>, - Arguments<(ins AnyShaped:$source, Variadic:$indices, - AffineMapAttr:$permutation_map, AnyType:$padding, - Optional>:$mask, - OptionalAttr:$in_bounds)>, - Results<(outs AnyVector:$vector)> { + Arguments<(ins AnyShaped:$source, + Variadic:$indices, + OptionalAttr:$permutation_map, + AnyType:$padding, + Optional>:$mask, + OptionalAttr:$in_bounds)>, + Results<(outs AnyVectorOfAnyRank:$vector)> { let summary = "Reads a supervector from memory into an SSA vector value."; let description = [{ The `vector.transfer_read` op performs a read from a slice within a [MemRef](../LangRef.md#memref-type) or a Ranked - [Tensor](../LangRef.md#tensor-type) supplied as its first operand into a - [vector](../LangRef.md#vector-type) of the same base elemental type. + [Tensor](../LangRef.md#tensor-type) supplied as its first operand + into a [vector](../LangRef.md#vector-type) of the same base elemental type. A memref/tensor operand with vector element type, must have its vector element type match a suffix (shape and element type) of the vector (e.g. memref<3x2x6x4x3xf32>, vector<1x1x4x3xf32>). The slice is further defined by a full-rank index within the MemRef/Tensor, - supplied as the operands `2 .. 1 + rank(memref/tensor)`. + supplied as the operands `[1 .. 1 + rank(memref/tensor))`. - The permutation_map [attribute](../LangRef.md#attributes) is an + The permutation_map [attribute](../LangRef.md#attributes) is an optional [affine-map](Affine.md#affine-maps) which specifies the transposition on the slice to match the vector shape. The permutation map may be implicit and omitted from parsing and printing if it is the canonical minor identity map - (i.e. if it does not permute or broadcast any dimension). + (i.e. if it does not permute or broadcast any dimension). The permutation_map + attribute must be present in the op, unless the op involves a 0-d vector. The size of the slice is specified by the size of the vector, given as the return type. @@ -1184,6 +1187,9 @@ `%A[%expr1, %expr2, %expr3, %expr4]` in the example below, is expected to be in-bounds and as indices are increasing, accesses may run out-of-bounds. + In the case of a transfer that involves a 0-d vector, the permutation_map + and in_bounds attributes must all be omitted. + This operation is called 'read' by opposition to 'load' because the super-vector granularity is generally not representable with a single hardware register. A `vector.transfer_read` is thus a mid-level abstraction @@ -1301,37 +1307,37 @@ }]; let builders = [ - // Builder that sets padding to zero. - OpBuilder<(ins "VectorType":$vector, "Value":$source, - "ValueRange":$indices, "AffineMap":$permutationMap, - CArg<"ArrayRef", "{}">:$inBounds)>, - // Builder that sets permutation map to 'getMinorIdentityMap'. - OpBuilder<(ins "VectorType":$vector, "Value":$source, - "ValueRange":$indices, "Value":$padding, - CArg<"ArrayRef", "{}">:$inBounds)>, - // Builder that sets permutation map (resp. padding) to - // 'getMinorIdentityMap' (resp. zero). - OpBuilder<(ins "VectorType":$vector, "Value":$source, - "ValueRange":$indices, CArg<"ArrayRef", "{}">:$inBounds)>, - // Builder that does not set mask. - OpBuilder<(ins "Type":$vector, "Value":$source, - "ValueRange":$indices, "AffineMapAttr":$permutationMap, "Value":$padding, - "ArrayAttr":$inBounds)>, - // Builder that does not set mask. - OpBuilder<(ins "Type":$vector, "Value":$source, - "ValueRange":$indices, "AffineMap":$permutationMap, "Value":$padding, - "ArrayAttr":$inBounds)> + /// 1. Builder that sets padding to zero an empty mask (variant with attrs). + OpBuilder<(ins "VectorType":$vectorType, + "Value":$source, + "ValueRange":$indices, + "AffineMapAttr":$permutationMapAttr, + "ArrayAttr":$inBoundsAttr)>, + /// 2. Builder that sets padding to zero and an empty mask (variant without attrs). + OpBuilder<(ins "VectorType":$vectorType, + "Value":$source, + "ValueRange":$indices, + "Optional":$permutationMap, + CArg<"Optional>", "::llvm::None">:$inBounds)>, + /// 3. Builder that sets permutation map to 'getMinorIdentityMap'. + OpBuilder<(ins "VectorType":$vectorType, + "Value":$source, + "ValueRange":$indices, + "Value":$padding, + CArg<"Optional>", "::llvm::None">:$inBounds)>, + /// 4. Builder that sets padding to zero and permutation map to + /// 'getMinorIdentityMap'. + OpBuilder<(ins "VectorType":$vectorType, + "Value":$source, + "ValueRange":$indices, + CArg<"Optional>", "::llvm::None">:$inBounds)>, ]; let extraClassDeclaration = [{ - /// Temporary convenience builders to account for the fact that we do not - /// have 0-d vectors atm. These create a constant `vector<1xt>` and - /// insert/extract into it. - // Builder that sets permutation map (resp. padding) to - // 'getMinorIdentityMap' (resp. zero). - static Value createScalarOp(OpBuilder &builder, Location loc, Value source, - ValueRange indices, - ArrayRef inBounds = ArrayRef{}); + AffineMap map() { + return permutation_map() ? + permutation_map().getValue() : AffineMap(); + } }]; let hasCanonicalizer = 1; @@ -1345,11 +1351,12 @@ DeclareOpInterfaceMethods, AttrSizedOperandSegments ]>, - Arguments<(ins AnyVector:$vector, AnyShaped:$source, - Variadic:$indices, - AffineMapAttr:$permutation_map, - Optional>:$mask, - OptionalAttr:$in_bounds)>, + Arguments<(ins AnyVectorOfAnyRank:$vector, + AnyShaped:$source, + Variadic:$indices, + OptionalAttr:$permutation_map, + Optional>:$mask, + OptionalAttr:$in_bounds)>, Results<(outs Optional:$result)> { let summary = "The vector.transfer_write op writes a supervector to memory."; @@ -1367,14 +1374,15 @@ new tensor of the same type. The slice is further defined by a full-rank index within the MemRef/Tensor, - supplied as the operands `3 .. 2 + rank(memref/tensor)`. + supplied as the operands `[2 .. 2 + rank(memref/tensor))`. - The permutation_map [attribute](../LangRef.md#attributes) is an + The permutation_map [attribute](../LangRef.md#attributes) is an optional [affine-map](Affine.md#affine-maps) which specifies the transposition on the slice to match the vector shape. The permutation map may be implicit and omitted from parsing and printing if it is the canonical minor identity map (i.e. if it does not permute any dimension). In contrast to `transfer_read`, - write ops cannot have broadcast dimensions. + write ops cannot have broadcast dimensions. The permutation_map attribute + must be present in the op, unless the op involves a 0-d vector. The size of the slice is specified by the size of the vector. @@ -1402,6 +1410,9 @@ `%A[%expr1, %expr2, %expr3, %expr4]` in the example below, is expected to be in-bounds and as indices are increasing, accesses may run out-of-bounds. + In the case of a transfer that involves a 0-d vector, the permutation_map, + in_bounds and mask attributes must all be omitted. + This operation is called 'write' by opposition to 'store' because the super-vector granularity is generally not representable with a single hardware register. A `vector.transfer_write` is thus a @@ -1444,30 +1455,38 @@ }]; let builders = [ - // Builder that sets an empty mask. - OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices, - "AffineMap":$permutationMap, CArg<"ArrayRef", "{}">:$inBounds)>, - // Builder that sets permutation map to 'getMinorIdentityMap'. - OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices, - CArg<"ArrayRef", "{}">:$inBounds)>, - OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices, - "AffineMapAttr":$permutationMap, "ArrayAttr":$inBounds)>, - OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices, - "AffineMap":$permutationMap, "Value":$mask, "ArrayAttr":$inBounds)>, - OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices, - "AffineMap":$permutationMap, "ArrayAttr":$inBounds)>, + /// 1. Builder with type inference. + OpBuilder<(ins "Value":$vector, + "Value":$dest, + "ValueRange":$indices, + "AffineMapAttr":$permutationMapAttr, + "Value":$mask, + "ArrayAttr":$inBoundsAttr)>, + /// 2. Builder with type inference that sets an empty mask (variant with attrs). + OpBuilder<(ins "Value":$vector, + "Value":$dest, + "ValueRange":$indices, + "AffineMapAttr":$permutationMapAttr, + "ArrayAttr":$inBoundsAttr)>, + /// 3. Builder with type inference that sets an empty mask (variant without attrs). + OpBuilder<(ins "Value":$vector, + "Value":$dest, + "ValueRange":$indices, + "Optional":$permutationMap, + CArg<"Optional>", "::llvm::None">:$inBounds)>, + /// 4. Builder with type inference that sets an empty mask and sets permutation + /// map to 'getMinorIdentityMap'. + OpBuilder<(ins "Value":$vector, + "Value":$dest, + "ValueRange":$indices, + CArg<"Optional>", "::llvm::None">:$inBounds)>, ]; let extraClassDeclaration = [{ - /// Temporary convenience builders to account for the fact that we do not - /// have 0-d vectors atm. These create a constant `vector<1xt>` and - /// insert/extract into it. - // Builder that sets permutation map (resp. padding) to - // 'getMinorIdentityMap' (resp. zero). - static Operation *createScalarOp( - OpBuilder &builder, Location loc, Value value, - Value dest, ValueRange indices, - ArrayRef inBounds = ArrayRef{}); + AffineMap map() { + return permutation_map() ? + permutation_map().getValue() : AffineMap(); + } }]; let hasFolder = 1; diff --git a/mlir/include/mlir/Interfaces/VectorInterfaces.td b/mlir/include/mlir/Interfaces/VectorInterfaces.td --- a/mlir/include/mlir/Interfaces/VectorInterfaces.td +++ b/mlir/include/mlir/Interfaces/VectorInterfaces.td @@ -76,10 +76,11 @@ /*args=*/(ins "unsigned":$dim), /*methodBody=*/"", /*defaultImplementation=*/[{ - return $_op.isBroadcastDim(dim) - || ($_op.in_bounds() - && $_op.in_bounds()->template cast<::mlir::ArrayAttr>()[dim] - .template cast<::mlir::BoolAttr>().getValue()); + return !$_op.map() || + $_op.isBroadcastDim(dim) || + ($_op.in_bounds() && + $_op.in_bounds()->template cast<::mlir::ArrayAttr>()[dim] + .template cast<::mlir::BoolAttr>().getValue()); }] >, InterfaceMethod< @@ -109,34 +110,11 @@ InterfaceMethod< /*desc=*/"Return the permutation map.", /*retTy=*/"::mlir::AffineMap", - /*methodName=*/"permutation_map", + /*methodName=*/"map", /*args=*/(ins), - /*methodBody=*/"return $_op.permutation_map();" + /*methodBody=*/"return $_op.map();" /*defaultImplementation=*/ >, - InterfaceMethod< - /*desc=*/[{ - Returns true if op involves a 0-d tensor/memref and a vector - of shape {1}. This is temporary until we have 0-d vectors. - // TODO: turn this into 0-d vectors + empty permutation_map. - }], - /*retTy=*/"bool", - /*methodName=*/"isZeroD", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - if (getShapedType().getRank() > 0) - return false; - if (getVectorType().getShape() != ArrayRef{1}) - return false; - AffineMap map = AffineMap::get( - /*numDims=*/0, /*numSymbols=*/0, - getAffineConstantExpr(0, $_op->getContext())); - if ($_op.permutation_map() != map) - return false; - return true; - }] - >, InterfaceMethod< /*desc=*/[{ Returns true if the specified dimension is a broadcast. }], /*retTy=*/"bool", @@ -144,7 +122,8 @@ /*args=*/(ins "unsigned":$idx), /*methodBody=*/"", /*defaultImplementation=*/[{ - auto expr = $_op.permutation_map().getResult(idx); + if (!$_op.map()) return false; + auto expr = $_op.map().getResult(idx); return expr.template isa<::mlir::AffineConstantExpr>() && expr.template dyn_cast<::mlir::AffineConstantExpr>().getValue() == 0; }] @@ -157,10 +136,8 @@ /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - // 0-d transfers are not considered broadcasts but they need to be - // represented with a vector<1xt> until we have 0-d vectors. - if ($_op.isZeroD()) return false; - for (unsigned i = 0; i < $_op.permutation_map().getNumResults(); ++i) { + if (!$_op.map()) return false; + for (unsigned i = 0, rank = getTransferRank(); i < rank; ++i) { if ($_op.isBroadcastDim(i)) return true; } @@ -201,10 +178,10 @@ /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return $_op.mask() - ? ::mlir::vector::detail::transferMaskType( - $_op.getVectorType(), $_op.permutation_map()) - : ::mlir::VectorType(); + return ($_op.mask() && $_op.map()) ? + ::mlir::vector::detail::transferMaskType( + $_op.getVectorType(), $_op.map()) : + ::mlir::VectorType(); }] >, InterfaceMethod< @@ -214,8 +191,10 @@ /*methodName=*/"getTransferRank", /*args=*/(ins), /*methodBody=*/"", - /*defaultImplementation=*/ - "return $_op.permutation_map().getNumResults();" + /*defaultImplementation=*/[{ + if (!$_op.map()) return 0; + return $_op.map().getNumResults(); + }] >, InterfaceMethod< /*desc=*/[{ Return the number of leading shaped dimensions that do not diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -71,12 +71,16 @@ // Return true if the transfer op can be converted to a MMA matrix load. static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) { + // 0-d corner case. + AffineMap map = readOp.map(); + if (!map) + return false; + if (readOp.mask() || readOp.hasOutOfBoundsDim() || readOp.getVectorType().getRank() != 2) return false; if (!getMemrefConstantHorizontalStride(readOp.getShapedType())) return false; - AffineMap map = readOp.permutation_map(); OpBuilder b(readOp.getContext()); AffineExpr innerDim = b.getAffineDimExpr(map.getNumDims() - 1); AffineExpr zero = b.getAffineConstantExpr(0); @@ -92,13 +96,18 @@ // Return true if the transfer op can be converted to a MMA matrix store. static bool transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) { + // 0-d corner case. + AffineMap map = writeOp.map(); + if (!map) + return false; + if (writeOp.mask() || writeOp.hasOutOfBoundsDim() || writeOp.getVectorType().getRank() != 2) return false; if (!getMemrefConstantHorizontalStride(writeOp.getShapedType())) return false; // TODO: Support transpose once it is added to GPU dialect ops. - if (!writeOp.permutation_map().isMinorIdentity()) + if (!map.isMinorIdentity()) return false; return true; } @@ -295,6 +304,12 @@ auto transferReadOp = op.vector().getDefiningOp(); if (!transferReadOp) return failure(); + + // 0-d corner case. + AffineMap map = transferReadOp.map(); + if (!map) + return failure(); + if (transferReadOp.mask() || transferReadOp.hasOutOfBoundsDim()) return failure(); SmallVector perm; @@ -304,11 +319,11 @@ permU.push_back(unsigned(o)); AffineMap permutationMap = AffineMap::getPermutationMap(permU, op.getContext()); - AffineMap newMap = permutationMap.compose(transferReadOp.permutation_map()); + AffineMap newMap = permutationMap.compose(transferReadOp.map()); rewriter.replaceOpWithNewOp( op, op.getType(), transferReadOp.source(), transferReadOp.indices(), - newMap, transferReadOp.padding(), transferReadOp.mask(), - transferReadOp.in_boundsAttr()); + AffineMapAttr::get(newMap), transferReadOp.padding(), + transferReadOp.mask(), transferReadOp.in_boundsAttr()); return success(); } }; @@ -338,7 +353,7 @@ assert(transferReadSupportsMMAMatrixType(op)); Optional stride = getMemrefConstantHorizontalStride(op.getShapedType()); - AffineMap map = op.permutation_map(); + AffineMap map = op.map(); // Handle broadcast by setting the stride to 0. if (map.getResult(0).isa()) { assert(map.getResult(0).cast().getValue() == 0); diff --git a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp --- a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp +++ b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp @@ -68,7 +68,8 @@ llvm::size(xferOp.indices()) == 0) return failure(); - if (!xferOp.permutation_map().isMinorIdentity()) + AffineMap permutationMap = xferOp.map(); + if (!permutationMap || !permutationMap.isMinorIdentity()) return failure(); // Have it handled in vector->llvm conversion pass. diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -52,7 +52,11 @@ /// A return value of None indicates a broadcast. template static Optional unpackedDim(OpTy xferOp) { - auto map = xferOp.permutation_map(); + // 0-d corner case. + AffineMap map = xferOp.map(); + if (!map) + return None; + if (auto expr = map.getResult(0).template dyn_cast()) { return expr.getPosition(); } @@ -66,7 +70,9 @@ /// omitted. template static AffineMap unpackedPermutationMap(OpBuilder &b, OpTy xferOp) { - auto map = xferOp.permutation_map(); + // 0-d corner case. + AffineMap map = xferOp.map(); + assert(map && "unexpected empty permutation map"); return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(), b.getContext()); } @@ -1080,10 +1086,10 @@ get1dMemrefIndices(OpBuilder &b, OpTy xferOp, Value iv, SmallVector &memrefIndices) { auto indices = xferOp.indices(); - auto map = xferOp.permutation_map(); + auto map = xferOp.map(); memrefIndices.append(indices.begin(), indices.end()); - assert(map.getNumResults() == 1 && + assert(map && map.getNumResults() == 1 && "Expected 1 permutation map result for 1D transfer"); if (auto expr = map.getResult(0).template dyn_cast()) { Location loc = xferOp.getLoc(); @@ -1206,14 +1212,19 @@ LogicalResult matchAndRewrite(OpTy xferOp, PatternRewriter &rewriter) const override { - auto map = xferOp.permutation_map(); + // 0-d corner case. + auto map = xferOp.map(); + if (!map) + return failure(); + auto memRefType = xferOp.getShapedType().template dyn_cast(); if (!memRefType) return failure(); if (xferOp.getVectorType().getRank() != 1) return failure(); - if (map.isMinorIdentity() && isLastMemrefDimUnitStride(memRefType)) + if (map && // pass-through for the 0- case + map.isMinorIdentity() && isLastMemrefDimUnitStride(memRefType)) return failure(); // Handled by ConvertVectorToLLVM // Loop bounds, step, state... diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp @@ -101,8 +101,7 @@ return failure(); b.create( writeOp.getLoc(), writeOp.vector(), resultBuffer, writeOp.indices(), - writeOp.permutation_map(), - writeOp.in_bounds() ? *writeOp.in_bounds() : ArrayAttr()); + writeOp.permutation_mapAttr(), writeOp.in_boundsAttr()); state.mapBuffer(op->getResult(0), resultBuffer); return success(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -115,8 +115,6 @@ /// ShapedType of `v`. static VectorType extractVectorTypeFromShapedValue(Value v) { auto st = v.getType().cast(); - if (st.getShape().empty()) - return VectorType(); return VectorType::get(st.getShape(), st.getElementType()); } @@ -179,21 +177,6 @@ return b.createOrFold(loc, targetVectorType, value); } -/// Build a vector.transfer_read from `source` at indices set to all `0`. -/// If source has rank zero, build a `vector<1xt> transfer_read + extract`. -/// Return the produced value. -static Value buildVectorRead(OpBuilder &b, Value source, Type readType, - AffineMap map) { - Location loc = source.getLoc(); - auto shapedType = source.getType().cast(); - SmallVector indices(shapedType.getRank(), - b.create(loc, 0)); - if (auto vectorType = readType.dyn_cast()) - return b.create(loc, vectorType, source, indices, - map); - return vector::TransferReadOp::createScalarOp(b, loc, source, indices); -} - /// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This /// assumes that `reductionOp` has two operands and one of them is the reduction /// initial value. @@ -226,8 +209,11 @@ Operation *write; Location loc = value.getLoc(); auto linalgOp = cast(outputOperand->getOwner()); - if (VectorType vectorType = - extractVectorTypeFromShapedValue(outputOperand->get())) { + ArrayRef shape = linalgOp.getShape(outputOperand); + auto vectorType = VectorType::get( + shape, getElementTypeOrSelf(outputOperand->get().getType())); + if (vectorType.getRank() > 0) { + // 0-d case is still special: do not invert the reindexing map. AffineMap map = reindexIndexingMap(linalgOp.getTiedIndexingMap(outputOperand)); SmallVector transposeShape = @@ -240,8 +226,11 @@ write = b.create(loc, value, outputOperand->get(), indices, map); } else { - write = vector::TransferWriteOp::createScalarOp( - b, loc, value, outputOperand->get(), ValueRange{}); + if (!value.getType().isa()) + value = b.create(loc, vectorType, value); + assert(value.getType() == vectorType && "incorrect type"); + write = b.create(loc, value, outputOperand->get(), + ValueRange{}, AffineMap()); } LDBG("vectorized op: " << *write); if (!write->getResults().empty()) @@ -515,18 +504,18 @@ SmallVector commonVectorShape = linalgOp.computeStaticLoopSizes(); // 3. Turn all BBArgs into vector.transfer_read / load. - SmallVector indexings; + Location loc = linalgOp.getLoc(); + Value zero = b.create(loc, 0); for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { BlockArgument bbarg = block->getArgument(opOperand->getOperandNumber()); if (linalgOp.isScalar(opOperand)) { bvm.map(bbarg, opOperand->get()); continue; } - // TODO: 0-d vectors. - Type readType; - AffineMap map; + VectorType readType; + Optional map; if (linalgOp.getShape(opOperand).empty()) { - readType = bbarg.getType(); + readType = VectorType::get({}, bbarg.getType()); } else { if (opOperand->getOperandNumber() < linalgOp.getNumInputs()) { map = inverseAndBroadcastProjectedPermuation( @@ -536,11 +525,20 @@ } else { map = inversePermutation( reindexIndexingMap(linalgOp.getTiedIndexingMap(opOperand))); - readType = VectorType::get(map.compose(linalgOp.getShape(opOperand)), + readType = VectorType::get(map->compose(linalgOp.getShape(opOperand)), getElementTypeOrSelf(opOperand->get())); } } - Value readValue = buildVectorRead(b, opOperand->get(), readType, map); + + auto shape = linalgOp.getShape(opOperand); + SmallVector indices(shape.size(), zero); + Value readValue = b.create( + loc, readType, opOperand->get(), indices, map); + // Not all ops support 0-d vectors, extract the scalar for now. + // TODO: remove this. + if (readValue.getType().cast().getRank() == 0) + readValue = b.create(loc, readValue); + LDBG("new vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue); bvm.map(bbarg, readValue); bvm.map(opOperand->get(), readValue); @@ -752,7 +750,7 @@ rewriter.create(padOp.getLoc(), 0)); auto read = rewriter.create( padOp.getLoc(), vecType, padOp.source(), readIndices, padValue, - readInBounds); + ArrayRef{readInBounds}); // If `dest` is a FillOp and the TransferWriteOp would overwrite the entire // tensor, write directly to the FillOp's operand. @@ -765,7 +763,7 @@ auto writeIndices = ofrToIndexValues(rewriter, padOp.getLoc(), padOp.getMixedLowPad()); rewriter.replaceOpWithNewOp( - padOp, read, dest, writeIndices, writeInBounds); + padOp, read, dest, writeIndices, ArrayRef{writeInBounds}); return success(); } @@ -878,6 +876,10 @@ LogicalResult rewriteUser(PatternRewriter &rewriter, PadTensorOp padOp, vector::TransferWriteOp xferOp) const override { + // 0-d corner case. + if (!xferOp.map()) + return failure(); + // Low padding must be static 0. if (!padOp.hasZeroLowPad()) return failure(); @@ -904,7 +906,7 @@ SmallVector inBounds(xferOp.getVectorType().getRank(), false); auto newXferOp = rewriter.replaceOpWithNewOp( xferOp, padOp.source().getType(), xferOp.vector(), padOp.source(), - xferOp.indices(), xferOp.permutation_mapAttr(), xferOp.mask(), + xferOp.indices(), AffineMapAttr::get(xferOp.map()), xferOp.mask(), rewriter.getBoolArrayAttr(inBounds)); rewriter.replaceOp(trimPadding, newXferOp->getResult(0)); @@ -1072,7 +1074,8 @@ ofrToIndexValues(rewriter, padOp.getLoc(), insertOp.getMixedOffsets()); SmallVector inBounds(vecRank, true); rewriter.replaceOpWithNewOp( - insertOp, read, insertOp.dest(), writeIndices, inBounds); + insertOp, read, insertOp.dest(), writeIndices, + ArrayRef{inBounds}); return success(); } @@ -1266,6 +1269,10 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite( vector::TransferReadOp xferOp, PatternRewriter &rewriter) const { + // TODO: support mask. + if (xferOp.mask()) + return failure(); + // Transfer into `view`. Value viewOrAlloc = xferOp.source(); if (!viewOrAlloc.getDefiningOp() && @@ -1328,7 +1335,9 @@ // conservatively. Value res = rewriter.create( xferOp.getLoc(), xferOp.getVectorType(), in, xferOp.indices(), - xferOp.permutation_map(), xferOp.padding(), ArrayAttr()); + xferOp.permutation_mapAttr(), xferOp.padding(), xferOp.mask(), + // in_bounds is explicitly reset + /*inBoundsAttr=*/ArrayAttr()); if (maybeFillOp) rewriter.eraseOp(maybeFillOp); @@ -1342,6 +1351,10 @@ /// when available. LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite( vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const { + // TODO: support mask. + if (xferOp.mask()) + return failure(); + // Transfer into `viewOrAlloc`. Value viewOrAlloc = xferOp.source(); if (!viewOrAlloc.getDefiningOp() && @@ -1380,7 +1393,9 @@ // conservatively. rewriter.create( xferOp.getLoc(), xferOp.vector(), out, xferOp.indices(), - xferOp.permutation_map(), ArrayAttr()); + xferOp.permutation_mapAttr(), xferOp.mask(), + // in_bounds is explicitly reset + /*inBoundsAttr=*/ArrayAttr()); rewriter.eraseOp(copyOp); rewriter.eraseOp(xferOp); diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp @@ -105,7 +105,9 @@ /// permutation map to use after the subview is folded with it. static AffineMap getPermutationMap(MLIRContext *context, memref::SubViewOp subViewOp, - AffineMap currPermutationMap) { + Optional currPermutationMap) { + if (!currPermutationMap) + return AffineMap(); llvm::SmallDenseSet unusedDims = subViewOp.getDroppedDims(); SmallVector exprs; int64_t sourceRank = subViewOp.getSourceType().getRank(); @@ -115,7 +117,7 @@ exprs.push_back(getAffineDimExpr(dim, context)); } auto resultDimToSourceDimMap = AffineMap::get(sourceRank, 0, exprs, context); - return currPermutationMap.compose(resultDimToSourceDimMap); + return currPermutationMap->compose(resultDimToSourceDimMap); } //===----------------------------------------------------------------------===// @@ -163,13 +165,18 @@ template <> void LoadOpOfSubViewFolder::replaceOp( - vector::TransferReadOp loadOp, memref::SubViewOp subViewOp, + vector::TransferReadOp transferReadOp, memref::SubViewOp subViewOp, ArrayRef sourceIndices, PatternRewriter &rewriter) const { + // 0-d corner case. + if (!transferReadOp.map()) + return; rewriter.replaceOpWithNewOp( - loadOp, loadOp.getVectorType(), subViewOp.source(), sourceIndices, - getPermutationMap(rewriter.getContext(), subViewOp, - loadOp.permutation_map()), - loadOp.padding(), loadOp.in_boundsAttr()); + transferReadOp, transferReadOp.getVectorType(), subViewOp.source(), + sourceIndices, + AffineMapAttr::get(getPermutationMap(rewriter.getContext(), subViewOp, + transferReadOp.map())), + transferReadOp.padding(), + /*mask=*/Value(), transferReadOp.in_boundsAttr()); } template <> @@ -184,11 +191,14 @@ void StoreOpOfSubViewFolder::replaceOp( vector::TransferWriteOp transferWriteOp, memref::SubViewOp subViewOp, ArrayRef sourceIndices, PatternRewriter &rewriter) const { + // 0-d corner case. + if (!transferWriteOp.map()) + return; rewriter.replaceOpWithNewOp( transferWriteOp, transferWriteOp.vector(), subViewOp.source(), sourceIndices, - getPermutationMap(rewriter.getContext(), subViewOp, - transferWriteOp.permutation_map()), + AffineMapAttr::get(getPermutationMap(rewriter.getContext(), subViewOp, + transferWriteOp.map())), transferWriteOp.in_boundsAttr()); } } // namespace diff --git a/mlir/lib/Dialect/Vector/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/VectorDropLeadUnitDim.cpp --- a/mlir/lib/Dialect/Vector/VectorDropLeadUnitDim.cpp +++ b/mlir/lib/Dialect/Vector/VectorDropLeadUnitDim.cpp @@ -133,6 +133,11 @@ LogicalResult matchAndRewrite(vector::TransferReadOp read, PatternRewriter &rewriter) const override { + // 0-dcorner case. + AffineMap oldMap = read.map(); + if (!oldMap) + return failure(); + if (read.mask()) return failure(); @@ -146,21 +151,21 @@ if (newType == oldType) return failure(); - AffineMap oldMap = read.permutation_map(); ArrayRef newResults = oldMap.getResults().take_back(newType.getRank()); AffineMap newMap = AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults, rewriter.getContext()); - ArrayAttr inBounds; + ArrayAttr inBoundsAttr; if (read.in_bounds()) - inBounds = rewriter.getArrayAttr( + inBoundsAttr = rewriter.getArrayAttr( read.in_boundsAttr().getValue().take_back(newType.getRank())); auto newRead = rewriter.create( - read.getLoc(), newType, read.source(), read.indices(), newMap, - read.padding(), inBounds); + read.getLoc(), newType, read.source(), read.indices(), + AffineMapAttr::get(newMap), read.padding(), /*mask=*/Value(), + inBoundsAttr); rewriter.replaceOpWithNewOp(read, oldType, newRead); return success(); @@ -176,6 +181,11 @@ LogicalResult matchAndRewrite(vector::TransferWriteOp write, PatternRewriter &rewriter) const override { + // 0-d corner case. + AffineMap oldMap = write.map(); + if (!oldMap) + return failure(); + if (write.mask()) return failure(); @@ -189,22 +199,22 @@ return failure(); int64_t dropDim = oldType.getRank() - newType.getRank(); - AffineMap oldMap = write.permutation_map(); ArrayRef newResults = oldMap.getResults().take_back(newType.getRank()); AffineMap newMap = AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults, rewriter.getContext()); - ArrayAttr inBounds; + ArrayAttr inBoundsAttr; if (write.in_bounds()) - inBounds = rewriter.getArrayAttr( + inBoundsAttr = rewriter.getArrayAttr( write.in_boundsAttr().getValue().take_back(newType.getRank())); auto newVector = rewriter.create( write.getLoc(), write.vector(), splatZero(dropDim)); rewriter.replaceOpWithNewOp( - write, newVector, write.source(), write.indices(), newMap, inBounds); + write, newVector, write.source(), write.indices(), + AffineMapAttr::get(newMap), inBoundsAttr); return success(); } diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -1613,8 +1613,8 @@ static_cast(destVectorType.getRank()))) return op.emitOpError("expected position attribute rank + source rank to " "match dest vector rank"); - if (!srcVectorType && (positionAttr.size() != - static_cast(destVectorType.getRank()))) + if (!srcVectorType && + (positionAttr.size() != static_cast(destVectorType.getRank()))) return op.emitOpError( "expected position attribute rank to match the dest vector rank"); for (auto en : llvm::enumerate(positionAttr)) { @@ -2314,6 +2314,67 @@ // TransferReadOp //===----------------------------------------------------------------------===// +/// 1. Builder that sets padding to zero an empty mask (variant with attrs). +void TransferReadOp::build(OpBuilder &builder, OperationState &result, + VectorType vectorType, Value source, + ValueRange indices, + /*optional*/ AffineMapAttr permutationMapAttr, + /*optional*/ ArrayAttr inBoundsAttr) { + Type elemType = source.getType().cast().getElementType(); + Value padding = builder.create( + result.location, elemType, builder.getZeroAttr(elemType)); + build(builder, result, vectorType, source, indices, permutationMapAttr, + padding, /*mask=*/Value(), inBoundsAttr); +} + +/// 2. Builder that sets padding to zero an empty mask (variant without attrs). +void TransferReadOp::build(OpBuilder &builder, OperationState &result, + VectorType vectorType, Value source, + ValueRange indices, + Optional permutationMap, + Optional> inBounds) { + auto permutationMapAttr = (permutationMap && permutationMap.getValue()) + ? AffineMapAttr::get(permutationMap.getValue()) + : AffineMapAttr(); + auto inBoundsAttr = (inBounds && !inBounds.getValue().empty()) + ? builder.getBoolArrayAttr(inBounds.getValue()) + : ArrayAttr(); + build(builder, result, vectorType, source, indices, permutationMapAttr, + inBoundsAttr); +} + +/// 3. Builder that sets permutation map to 'getMinorIdentityMap'. +void TransferReadOp::build(OpBuilder &builder, OperationState &result, + VectorType vectorType, Value source, + ValueRange indices, Value padding, + Optional> inBounds) { + Optional permutationMap = None; + if (vectorType.getRank() > 0) + permutationMap = getTransferMinorIdentityMap( + source.getType().cast(), vectorType); + auto permutationMapAttr = (permutationMap && permutationMap.getValue()) + ? AffineMapAttr::get(permutationMap.getValue()) + : AffineMapAttr(); + auto inBoundsAttr = (inBounds && !inBounds.getValue().empty()) + ? builder.getBoolArrayAttr(inBounds.getValue()) + : ArrayAttr(); + build(builder, result, vectorType, source, indices, permutationMapAttr, + padding, + /*mask=*/Value(), inBoundsAttr); +} + +/// 4. Builder that sets padding to zero and permutation map to +/// 'getMinorIdentityMap'. +void TransferReadOp::build(OpBuilder &builder, OperationState &result, + VectorType vectorType, Value source, + ValueRange indices, + Optional> inBounds) { + Type elemType = source.getType().cast().getElementType(); + Value padding = builder.create( + result.location, elemType, builder.getZeroAttr(elemType)); + build(builder, result, vectorType, source, indices, padding, inBounds); +} + template static LogicalResult verifyPermutationMap(AffineMap permutationMap, EmitFun emitOpError) { @@ -2347,18 +2408,29 @@ verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, VectorType vectorType, VectorType maskType, AffineMap permutationMap, ArrayAttr inBounds) { - if (shapedType.getRank() == 0 && !op.isZeroD()) - return op->emitOpError("0-d transfer requires vector<1xt> shape and () -> " - "(0) permutation_map"); + if ((permutationMap == AffineMap()) ^ (vectorType.getRank() == 0)) + return op->emitOpError("0-d transfer iff empty permutation map."); - if (op->hasAttr("masked")) { + if (!permutationMap) { + VectorType expectedMaskType = + vector::detail::transferMaskType(vectorType, permutationMap); + if (maskType || expectedMaskType) + return op->emitOpError("0-d transfer expects empty mask type"); + if (inBounds) + return op->emitOpError("0-d transfer expects empty in_bounds attribute"); + return success(); + } + + assert(permutationMap); + + if (op->hasAttr("masked")) return op->emitOpError("masked attribute has been removed. " "Use in_bounds instead."); - } if (!shapedType.isa()) return op->emitOpError( "requires source to be a memref or ranked tensor type"); + auto elementType = shapedType.getElementType(); DataLayout dataLayout = DataLayout::closest(op); if (auto vectorElementType = elementType.dyn_cast()) { @@ -2412,8 +2484,8 @@ if (permutationMap.getNumSymbols() != 0) return op->emitOpError("requires permutation_map without symbols"); - // TODO: implement 0-d vector corner cases. - if (!op.isZeroD() && permutationMap.getNumInputs() != shapedType.getRank()) + + if (permutationMap.getNumInputs() != shapedType.getRank()) return op->emitOpError("requires a permutation_map with input dims of the " "same rank as the source type"); @@ -2421,7 +2493,8 @@ if (permutationMap.getNumResults() != static_cast(inBounds.size())) return op->emitOpError("expects the optional in_bounds attr of same rank " "as permutation_map results: ") - << AffineMapAttr::get(permutationMap); + << AffineMapAttr::get(permutationMap) + << " vs inBounds of size: " << inBounds.size(); for (unsigned int i = 0; i < permutationMap.getNumResults(); ++i) if (permutationMap.getResult(i).isa() && !inBounds.getValue()[i].cast().getValue()) @@ -2431,81 +2504,10 @@ return success(); } -/// Builder that sets padding to zero. -void TransferReadOp::build(OpBuilder &builder, OperationState &result, - VectorType vectorType, Value source, - ValueRange indices, AffineMap permutationMap, - ArrayRef inBounds) { - Type elemType = source.getType().cast().getElementType(); - Value padding = builder.create( - result.location, elemType, builder.getZeroAttr(elemType)); - if (inBounds.empty()) - return build(builder, result, vectorType, source, indices, permutationMap, - padding, ArrayAttr()); - ArrayAttr inBoundsArrayAttr = builder.getBoolArrayAttr(inBounds); - build(builder, result, vectorType, source, indices, permutationMap, padding, - inBoundsArrayAttr); -} - -/// Builder that sets permutation map to 'getMinorIdentityMap'. -void TransferReadOp::build(OpBuilder &builder, OperationState &result, - VectorType vectorType, Value source, - ValueRange indices, Value padding, - ArrayRef inBounds) { - auto permMap = getTransferMinorIdentityMap( - source.getType().cast(), vectorType); - if (inBounds.empty()) - return build(builder, result, vectorType, source, indices, permMap, padding, - ArrayAttr()); - ArrayAttr inBoundsArrayAttr = builder.getBoolArrayAttr(inBounds); - build(builder, result, vectorType, source, indices, permMap, padding, - inBoundsArrayAttr); -} - -/// Builder that sets permutation map (resp. padding) to 'getMinorIdentityMap' -/// (resp. zero). -void TransferReadOp::build(OpBuilder &builder, OperationState &result, - VectorType vectorType, Value source, - ValueRange indices, ArrayRef inBounds) { - auto permMap = getTransferMinorIdentityMap( - source.getType().cast(), vectorType); - build(builder, result, vectorType, source, indices, permMap, inBounds); -} - -/// Builder that does not provide a mask. -void TransferReadOp::build(OpBuilder &builder, OperationState &result, - Type vectorType, Value source, ValueRange indices, - AffineMap permutationMap, Value padding, - ArrayAttr inBounds) { - build(builder, result, vectorType, source, indices, permutationMap, padding, - /*mask=*/Value(), inBounds); -} - -/// Builder that does not provide a mask. -void TransferReadOp::build(OpBuilder &builder, OperationState &result, - Type vectorType, Value source, ValueRange indices, - AffineMapAttr permutationMap, Value padding, - ArrayAttr inBounds) { - build(builder, result, vectorType, source, indices, permutationMap, padding, - /*mask=*/Value(), inBounds); -} - -Value TransferReadOp::createScalarOp(OpBuilder &builder, Location loc, - Value source, ValueRange indices, - ArrayRef inBounds) { - Type elemType = source.getType().cast().getElementType(); - auto vectorType = VectorType::get(ArrayRef{1}, elemType); - AffineMap map = AffineMap::get(/*numDims=*/0, /*numSymbols=*/0, - getAffineConstantExpr(0, loc.getContext())); - Value read = builder.create(loc, vectorType, source, - indices, map, inBounds); - return builder.create(loc, read, ArrayRef{0}); -} - static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) { SmallVector elidedAttrs; elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr()); - if (op.permutation_map().isMinorIdentity()) + if (op.map() && op.map().isMinorIdentity()) elidedAttrs.push_back(op.getPermutationMapAttrName()); bool elideInBounds = true; if (auto inBounds = op.in_bounds()) { @@ -2561,8 +2563,11 @@ return parser.emitError(typesLoc, "requires vector type"); auto permutationAttrName = TransferReadOp::getPermutationMapAttrName(); Attribute mapAttr = result.attributes.get(permutationAttrName); - if (!mapAttr) { + + // Lack of permutation_map implies identity (except in the 0-d case). + if (!mapAttr && vectorType.getRank() > 0) { auto permMap = getTransferMinorIdentityMap(shapedType, vectorType); + // Update `mapAttr` that is used later to determine mask type. mapAttr = AffineMapAttr::get(permMap); result.attributes.set(permutationAttrName, mapAttr); } @@ -2575,6 +2580,9 @@ if (shapedType.getElementType().dyn_cast()) return parser.emitError( maskInfo.location, "does not support masks with vector element type"); + if (!mapAttr) + return parser.emitError(maskInfo.location, + "unexpected mask without a permutation map."); auto map = mapAttr.dyn_cast().getValue(); // Instead of adding the mask type as an op type, compute it based on the // vector type and the permutation map (to keep the type signature small). @@ -2595,7 +2603,7 @@ VectorType vectorType = op.getVectorType(); VectorType maskType = op.getMaskType(); auto paddingType = op.padding().getType(); - auto permutationMap = op.permutation_map(); + auto permutationMap = op.map(); auto sourceElementType = shapedType.getElementType(); if (static_cast(op.indices().size()) != shapedType.getRank()) @@ -2625,8 +2633,12 @@ "requires formal padding and source of the same elemental type"); } - return verifyPermutationMap(permutationMap, - [&op](Twine t) { return op.emitOpError(t); }); + // 0-d case permutation map invariants have already been verified. + if (permutationMap) + return verifyPermutationMap(permutationMap, + [&op](Twine t) { return op.emitOpError(t); }); + + return success(); } /// This is a common class used for patterns of the form @@ -2677,10 +2689,11 @@ template static LogicalResult foldTransferInBoundsAttribute(TransferOp op) { - // TODO: Be less conservative once we have 0-d vectors. - if (op.isZeroD()) + // 0-d corner case. + AffineMap permutationMap = op.map(); + if (!permutationMap) return failure(); - AffineMap permutationMap = op.permutation_map(); + bool changed = false; SmallVector newInBounds; newInBounds.reserve(op.getTransferRank()); @@ -2783,9 +2796,12 @@ LogicalResult matchAndRewrite(TransferReadOp xferOp, PatternRewriter &rewriter) const override { + // 0-d corner case. + if (!xferOp.map()) + return failure(); if (xferOp.hasOutOfBoundsDim()) return failure(); - if (!xferOp.permutation_map().isIdentity()) + if (!xferOp.map().isIdentity()) return failure(); if (xferOp.mask()) return failure(); @@ -2814,9 +2830,9 @@ offset))); } SmallVector inBounds(xferOp.getTransferRank(), true); - rewriter.replaceOpWithNewOp(xferOp, xferOp.getVectorType(), - extractOp.source(), newIndices, - xferOp.padding(), inBounds); + rewriter.replaceOpWithNewOp( + xferOp, xferOp.getVectorType(), extractOp.source(), newIndices, + xferOp.padding(), ArrayRef{inBounds}); return success(); } @@ -2832,69 +2848,53 @@ // TransferWriteOp //===----------------------------------------------------------------------===// +/// 1. Builder with type inference. void TransferWriteOp::build(OpBuilder &builder, OperationState &result, Value vector, Value dest, ValueRange indices, - AffineMap permutationMap, ArrayRef inBounds) { - if (inBounds.empty()) - return build(builder, result, vector, dest, indices, permutationMap, - /*mask=*/Value(), ArrayAttr()); - build(builder, result, vector, dest, indices, permutationMap, - /*mask=*/Value(), builder.getBoolArrayAttr(inBounds)); -} - -/// Builder that sets permutation map to 'getMinorIdentityMap'. -void TransferWriteOp::build(OpBuilder &builder, OperationState &result, - Value vector, Value source, ValueRange indices, - ArrayRef inBounds) { - auto vectorType = vector.getType().cast(); - auto permMap = getTransferMinorIdentityMap( - source.getType().cast(), vectorType); - if (inBounds.empty()) - return build(builder, result, vector, source, indices, permMap, - ArrayAttr()); - ArrayAttr inBoundsArrayAttr = builder.getBoolArrayAttr(inBounds); - build(builder, result, vector, source, indices, permMap, inBoundsArrayAttr); + /*optional*/ AffineMapAttr permutationMapAttr, + /*optional*/ Value mask, + /*optional*/ ArrayAttr inBoundsAttr) { + Type resultType = dest.getType().dyn_cast(); + build(builder, result, resultType, vector, dest, indices, permutationMapAttr, + mask, inBoundsAttr); } +/// 2. Builder with type inference that sets an empty mask (variant with attrs). void TransferWriteOp::build(OpBuilder &builder, OperationState &result, - Value vector, Value source, ValueRange indices, - AffineMapAttr permutationMap, - /*optional*/ ArrayAttr inBounds) { - Type resultType = source.getType().dyn_cast(); - build(builder, result, resultType, vector, source, indices, permutationMap, - /*mask=*/Value(), inBounds); + Value vector, Value dest, ValueRange indices, + /*optional*/ AffineMapAttr permutationMapAttr, + /*optional*/ ArrayAttr inBoundsAttr) { + build(builder, result, vector, dest, indices, permutationMapAttr, + /*mask=*/Value(), inBoundsAttr); } +/// 3. Builder with type inference that sets an empty mask (variant without +/// attrs) void TransferWriteOp::build(OpBuilder &builder, OperationState &result, - Value vector, Value source, ValueRange indices, - AffineMap permutationMap, - /*optional*/ ArrayAttr inBounds) { - Type resultType = source.getType().dyn_cast(); - build(builder, result, resultType, vector, source, indices, permutationMap, - /*mask=*/Value(), inBounds); -} - + Value vector, Value dest, ValueRange indices, + Optional permutationMap, + Optional> inBounds) { + auto permutationMapAttr = (permutationMap && permutationMap.getValue()) + ? AffineMapAttr::get(permutationMap.getValue()) + : AffineMapAttr(); + auto inBoundsAttr = (inBounds && !inBounds.getValue().empty()) + ? builder.getBoolArrayAttr(inBounds.getValue()) + : ArrayAttr(); + build(builder, result, vector, dest, indices, permutationMapAttr, + /*mask=*/Value(), inBoundsAttr); +} + +/// 4. Builder with type inference that sets an empty mask and sets permutation +/// map to 'getMinorIdentityMap'. void TransferWriteOp::build(OpBuilder &builder, OperationState &result, - Value vector, Value source, ValueRange indices, - AffineMap permutationMap, /*optional*/ Value mask, - /*optional*/ ArrayAttr inBounds) { - Type resultType = source.getType().dyn_cast(); - build(builder, result, resultType, vector, source, indices, permutationMap, - mask, inBounds); -} - -Operation *TransferWriteOp::createScalarOp(OpBuilder &builder, Location loc, - Value value, Value dest, - ValueRange indices, - ArrayRef inBounds) { - Value vectorOfAScalar = value; - if (!value.getType().isa()) - vectorOfAScalar = builder.create( - loc, VectorType::get({1}, value.getType()), value); - AffineMap map = AffineMap::get(/*numDims=*/0, /*numSymbols=*/0, - getAffineConstantExpr(0, loc.getContext())); - return builder.create(loc, vectorOfAScalar, dest, - indices, map, inBounds); + Value vector, Value dest, ValueRange indices, + Optional> inBounds) { + auto vectorType = vector.getType().cast(); + Optional permutationMap = None; + if (vectorType.getRank() > 0) + permutationMap = getTransferMinorIdentityMap( + dest.getType().cast(), vectorType); + build(builder, result, vector, dest, indices, permutationMap, inBounds); } static ParseResult parseTransferWriteOp(OpAsmParser &parser, @@ -2925,8 +2925,9 @@ if (!shapedType || !shapedType.isa()) return parser.emitError(typesLoc, "requires memref or ranked tensor type"); auto permutationAttrName = TransferWriteOp::getPermutationMapAttrName(); - auto attr = result.attributes.get(permutationAttrName); - if (!attr) { + auto mapAttr = result.attributes.get(permutationAttrName); + // Lack of permutation_map implies identity (except in the 0-d case). + if (!mapAttr && vectorType.getRank() > 0) { auto permMap = getTransferMinorIdentityMap(shapedType, vectorType); result.attributes.set(permutationAttrName, AffineMapAttr::get(permMap)); } @@ -2963,7 +2964,7 @@ ShapedType shapedType = op.getShapedType(); VectorType vectorType = op.getVectorType(); VectorType maskType = op.getMaskType(); - auto permutationMap = op.permutation_map(); + auto permutationMap = op.map(); if (llvm::size(op.indices()) != shapedType.getRank()) return op.emitOpError("requires ") << shapedType.getRank() << " indices"; @@ -2979,8 +2980,12 @@ op.in_bounds() ? *op.in_bounds() : ArrayAttr()))) return failure(); - return verifyPermutationMap(permutationMap, - [&op](Twine t) { return op.emitOpError(t); }); + // 0-d case permutation map invariants have already been verified. + if (permutationMap) + return verifyPermutationMap(permutationMap, + [&op](Twine t) { return op.emitOpError(t); }); + + return success(); } /// Fold: @@ -3003,6 +3008,9 @@ static LogicalResult foldReadInitWrite(TransferWriteOp write, ArrayRef, SmallVectorImpl &results) { + // 0-d corner case. + if (!write.map()) + return failure(); auto rankedTensorType = write.source().getType().dyn_cast(); // If not operating on tensors, bail. if (!rankedTensorType) @@ -3011,9 +3019,11 @@ auto read = write.vector().getDefiningOp(); if (!read) return failure(); + // 0-d corner case. + if (!read.map()) + return failure(); // For now, only accept minor identity. Future: composition is minor identity. - if (!read.permutation_map().isMinorIdentity() || - !write.permutation_map().isMinorIdentity()) + if (!read.map().isMinorIdentity() || !write.map().isMinorIdentity()) return failure(); // Bail on mismatching ranks. if (read.getTransferRank() != write.getTransferRank()) @@ -3046,7 +3056,7 @@ static bool checkSameValueWAR(vector::TransferReadOp read, vector::TransferWriteOp write) { return read.source() == write.source() && read.indices() == write.indices() && - read.permutation_map() == write.permutation_map() && + read.map() == write.map() && read.getVectorType() == write.getVectorType() && !read.mask() && !write.mask(); } @@ -3179,9 +3189,14 @@ PatternRewriter &rewriter) const override { if (!insertOp.hasUnitStride()) return failure(); + auto xferOp = insertOp.source().getDefiningOp(); if (!xferOp) return failure(); + // 0-d corner case. + if (!xferOp.map()) + return failure(); + if (xferOp.hasOutOfBoundsDim()) return failure(); if (xferOp.getVectorType().getRank() != xferOp.getShapedType().getRank()) @@ -3194,14 +3209,15 @@ if (!llvm::equal(xferOp.getVectorType().getShape(), xferOp.getShapedType().getShape())) return failure(); - if (!xferOp.permutation_map().isIdentity()) + if (!xferOp.map().isIdentity()) return failure(); SmallVector indices = getValueOrCreateConstantIndexOp( rewriter, insertOp.getLoc(), insertOp.getMixedOffsets()); SmallVector inBounds(xferOp.getTransferRank(), true); - rewriter.replaceOpWithNewOp( - insertOp, xferOp.vector(), insertOp.dest(), indices, inBounds); + rewriter.replaceOpWithNewOp(insertOp, xferOp.vector(), + insertOp.dest(), indices, + ArrayRef{inBounds}); return success(); } }; diff --git a/mlir/lib/Dialect/Vector/VectorTransferPermutationMapRewritePatterns.cpp b/mlir/lib/Dialect/Vector/VectorTransferPermutationMapRewritePatterns.cpp --- a/mlir/lib/Dialect/Vector/VectorTransferPermutationMapRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransferPermutationMapRewritePatterns.cpp @@ -31,6 +31,7 @@ attr.getValue()[pos].cast().getValue()); return builder.getBoolArrayAttr(newInBoundsValues); } + /// Lower transfer_read op with permutation into a transfer_read with a /// permutation map composed of leading zeros followed by a minor identiy + /// vector.transpose op. @@ -56,8 +57,12 @@ LogicalResult matchAndRewrite(vector::TransferReadOp op, PatternRewriter &rewriter) const override { + // 0-d corner case. + AffineMap map = op.map(); + if (!map) + return failure(); + SmallVector permutation; - AffineMap map = op.permutation_map(); if (map.getNumResults() == 0) return failure(); if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) @@ -105,11 +110,12 @@ : ArrayAttr(); // Generate new transfer_read operation. + assert(newMap && newMap.getNumResults() > 0 && "Unexpected empty map"); VectorType newReadType = VectorType::get(newVectorShape, op.getVectorType().getElementType()); Value newRead = rewriter.create( - op.getLoc(), newReadType, op.source(), op.indices(), newMap, - op.padding(), newMask, newInBounds); + op.getLoc(), newReadType, op.source(), op.indices(), + AffineMapAttr::get(newMap), op.padding(), newMask, newInBounds); // Transpose result of transfer_read. SmallVector transposePerm(permutation.begin(), permutation.end()); @@ -141,11 +147,12 @@ LogicalResult matchAndRewrite(vector::TransferWriteOp op, PatternRewriter &rewriter) const override { - if (op.isZeroD()) + // 0-d corner case. + AffineMap map = op.map(); + if (!map) return failure(); SmallVector permutation; - AffineMap map = op.permutation_map(); if (map.isMinorIdentity()) return failure(); if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) @@ -179,8 +186,8 @@ auto newMap = AffineMap::getMinorIdentityMap( map.getNumDims(), map.getNumResults(), rewriter.getContext()); rewriter.replaceOpWithNewOp( - op, Type(), newVec, op.source(), op.indices(), newMap, newMask, - newInBounds); + op, Type(), newVec, op.source(), op.indices(), + AffineMapAttr::get(newMap), newMask, newInBounds); return success(); } @@ -199,7 +206,11 @@ LogicalResult matchAndRewrite(vector::TransferReadOp op, PatternRewriter &rewriter) const override { - AffineMap map = op.permutation_map(); + // 0-d corner case. + AffineMap map = op.map(); + if (!map) + return failure(); + unsigned numLeadingBroadcast = 0; for (auto expr : map.getResults()) { auto dimExpr = expr.dyn_cast(); @@ -250,9 +261,10 @@ ? rewriter.getArrayAttr( op.in_boundsAttr().getValue().take_back(reducedShapeRank)) : ArrayAttr(); + assert(newMap && newMap.getNumResults() > 0 && "Unexpected empty map"); Value newRead = rewriter.create( - op.getLoc(), newReadType, op.source(), op.indices(), newMap, - op.padding(), op.mask(), newInBounds); + op.getLoc(), newReadType, op.source(), op.indices(), + AffineMapAttr::get(newMap), op.padding(), op.mask(), newInBounds); rewriter.replaceOpWithNewOp(op, originalVecType, newRead); return success(); diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -229,7 +229,8 @@ options(options) {} LogicalResult matchAndRewrite(vector::TransferReadOp readOp, PatternRewriter &rewriter) const override { - + if (!readOp.map()) + return failure(); if (readOp.mask()) return failure(); auto targetShape = getTargetShape(options, readOp); @@ -252,11 +253,11 @@ for (int64_t i = 0; i < sliceCount; i++) { SmallVector indices = sliceTransferIndices(i, originalSize, *targetShape, originalIndices, - readOp.permutation_map(), loc, rewriter); + readOp.map(), loc, rewriter); auto slicedRead = rewriter.create( - loc, targetType, readOp.source(), indices, readOp.permutation_map(), - readOp.padding(), - readOp.in_bounds() ? *readOp.in_bounds() : ArrayAttr()); + loc, targetType, readOp.source(), indices, + readOp.permutation_mapAttr(), readOp.padding(), readOp.mask(), + readOp.in_boundsAttr()); SmallVector elementOffsets = getVectorOffset(originalSize, *targetShape, i); @@ -279,6 +280,10 @@ options(options) {} LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, PatternRewriter &rewriter) const override { + // 0-d corner case. + if (!writeOp.map()) + return failure(); + if (writeOp.mask()) return failure(); auto targetShape = getTargetShape(options, writeOp); @@ -302,11 +307,10 @@ SmallVector indices = sliceTransferIndices(i, originalSize, *targetShape, originalIndices, - writeOp.permutation_map(), loc, rewriter); + writeOp.map(), loc, rewriter); Operation *slicedWrite = rewriter.create( loc, slicedVector, resultTensor ? resultTensor : writeOp.source(), - indices, writeOp.permutation_map(), - writeOp.in_bounds() ? *writeOp.in_bounds() : ArrayAttr()); + indices, writeOp.permutation_mapAttr(), writeOp.in_boundsAttr()); // For the tensor case update the destination for the next transfer write. if (!slicedWrite->getResults().empty()) resultTensor = slicedWrite->getResult(0); @@ -1989,7 +1993,8 @@ // particular VectorTransferOpInterface is in-bounds. static Value createInBoundsCond(OpBuilder &b, VectorTransferOpInterface xferOp) { - assert(xferOp.permutation_map().isMinorIdentity() && + // 0-d corner case. + assert(xferOp.map() && xferOp.map().isMinorIdentity() && "Expected minor identity map"); Value inBoundsCond; xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) { @@ -2047,14 +2052,18 @@ /// where `alloc` is a top of the function alloca'ed buffer of one vector. /// /// Preconditions: -/// 1. `xferOp.permutation_map()` must be a minor identity map +/// 1. `xferOp.map()` must be a minor identity map /// 2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()` /// must be equal. This will be relaxed in the future but requires /// rank-reducing subviews. static LogicalResult splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp) { + // 0-d corner case. + if (!xferOp.map()) + return failure(); + // TODO: expand support to these 2 cases. - if (!xferOp.permutation_map().isMinorIdentity()) + if (!xferOp.map().isMinorIdentity()) return failure(); // Must have some out-of-bounds dimension to be a candidate for splitting. if (!xferOp.hasOutOfBoundsDim()) @@ -2413,7 +2422,7 @@ /// where `alloc` is a top of the function alloca'ed buffer of one vector. /// /// Preconditions: -/// 1. `xferOp.permutation_map()` must be a minor identity map +/// 1. `xferOp.map()` must be a minor identity map /// 2. the rank of the `xferOp.source()` and the rank of the `xferOp.vector()` /// must be equal. This will be relaxed in the future but requires /// rank-reducing subviews. @@ -2678,6 +2687,10 @@ : OpRewritePattern(context) {} LogicalResult matchAndRewrite(vector::TransferReadOp read, PatternRewriter &rewriter) const override { + // 0-d corner case. + if (!read.map()) + return failure(); + if (!read.getResult().hasOneUse()) return failure(); auto extract = @@ -2688,7 +2701,7 @@ return failure(); SmallVector indices(read.indices().begin(), read.indices().end()); - AffineMap indexMap = extract.map().compose(read.permutation_map()); + AffineMap indexMap = extract.map().compose(read.map()); unsigned idCount = 0; ImplicitLocOpBuilder lb(read.getLoc(), rewriter); for (auto it : @@ -2707,8 +2720,8 @@ {indices[indexPos], extract.ids()[idCount++]}); } Value newRead = lb.create( - extract.getType(), read.source(), indices, read.permutation_map(), - read.padding(), read.in_boundsAttr()); + extract.getType(), read.source(), indices, read.permutation_mapAttr(), + read.padding(), read.mask(), read.in_boundsAttr()); Value dest = lb.create( read.getType(), rewriter.getZeroAttr(read.getType())); newRead = lb.create(newRead, dest, extract.ids()); @@ -2723,6 +2736,10 @@ : OpRewritePattern(context) {} LogicalResult matchAndRewrite(vector::TransferWriteOp write, PatternRewriter &rewriter) const override { + // 0-d corner case. + if (!write.map()) + return failure(); + auto insert = write.vector().getDefiningOp(); if (!insert) return failure(); @@ -2730,7 +2747,7 @@ return failure(); SmallVector indices(write.indices().begin(), write.indices().end()); - AffineMap indexMap = insert.map().compose(write.permutation_map()); + AffineMap indexMap = insert.map().compose(write.map()); unsigned idCount = 0; Location loc = write.getLoc(); for (auto it : @@ -2750,8 +2767,8 @@ {indices[indexPos], insert.ids()[idCount++]}); } rewriter.create( - loc, insert.vector(), write.source(), indices, write.permutation_map(), - write.in_boundsAttr()); + loc, insert.vector(), write.source(), indices, + write.permutation_mapAttr(), write.in_boundsAttr()); rewriter.eraseOp(write); return success(); } @@ -2776,15 +2793,19 @@ PatternRewriter &rewriter) const override { if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank) return failure(); + SmallVector broadcastedDims; // Permutations are handled by VectorToSCF or // populateVectorTransferPermutationMapLoweringPatterns. - if (!read.permutation_map().isMinorIdentityWithBroadcasting( - &broadcastedDims)) + if (read.map() && // pass-through for the 0-d corner case. + // Note we skip the broadcastedDim computation in 0-d. + !read.map().isMinorIdentityWithBroadcasting(&broadcastedDims)) return failure(); + auto memRefType = read.getShapedType().dyn_cast(); if (!memRefType) return failure(); + // Non-unit strides are handled by VectorToSCF. if (!vector::isLastMemrefDimUnitStride(memRefType)) return failure(); @@ -2804,6 +2825,7 @@ auto memrefElTy = memRefType.getElementType(); if (memrefElTy.isa() && memrefElTy != unbroadcastedVectorType) return failure(); + // Otherwise, element types of the memref and the vector must match. if (!memrefElTy.isa() && memrefElTy != read.getVectorType().getElementType()) @@ -2841,7 +2863,7 @@ llvm::Optional maxTransferRank; }; -/// Replace a scalar vector.load with a memref.load. +/// Replace a 0-d vector.load with a memref.store + vector.broadcast. struct VectorLoadToMemrefLoadLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -2853,13 +2875,13 @@ return failure(); auto memrefLoad = rewriter.create( loadOp.getLoc(), loadOp.base(), loadOp.indices()); - rewriter.replaceOpWithNewOp( - loadOp, VectorType::get({1}, vecType.getElementType()), memrefLoad); + rewriter.replaceOpWithNewOp(loadOp, vecType, + memrefLoad); return success(); } }; -/// Replace a scalar vector.store with a memref.store. +/// Replace a 0-d vector.store with a vector.extractelement + memref.store. struct VectorStoreToMemrefStoreLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -2869,9 +2891,17 @@ auto vecType = storeOp.getVectorType(); if (vecType.getNumElements() != 1) return failure(); - SmallVector indices(vecType.getRank(), 0); - Value extracted = rewriter.create( - storeOp.getLoc(), storeOp.valueToStore(), indices); + Value extracted; + if (vecType.getRank() == 0) { + // TODO: Unifiy once ExtractOp supports 0-d vectors. + extracted = rewriter.create( + storeOp.getLoc(), storeOp.valueToStore()); + } else { + SmallVector indices(vecType.getRank(), 0); + extracted = rewriter.create( + storeOp.getLoc(), storeOp.valueToStore(), indices); + } + rewriter.replaceOpWithNewOp( storeOp, extracted, storeOp.base(), storeOp.indices()); return success(); @@ -2897,25 +2927,32 @@ PatternRewriter &rewriter) const override { if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank) return failure(); + // Permutations are handled by VectorToSCF or // populateVectorTransferPermutationMapLoweringPatterns. - if (!write.isZeroD() && !write.permutation_map().isMinorIdentity()) + if (write.map() && // pass-through for the 0-d corner case. + !write.map().isMinorIdentity()) return failure(); + auto memRefType = write.getShapedType().dyn_cast(); if (!memRefType) return failure(); + // Non-unit strides are handled by VectorToSCF. if (!vector::isLastMemrefDimUnitStride(memRefType)) return failure(); + // `vector.store` supports vector types as memref's elements only when the // type of the vector value being written is the same as the element type. auto memrefElTy = memRefType.getElementType(); if (memrefElTy.isa() && memrefElTy != write.getVectorType()) return failure(); + // Otherwise, element types of the memref and the vector must match. if (!memrefElTy.isa() && memrefElTy != write.getVectorType().getElementType()) return failure(); + // Out-of-bounds dims are handled by MaterializeTransferMask. if (write.hasOutOfBoundsDim()) return failure(); @@ -3315,11 +3352,19 @@ LogicalResult matchAndRewrite(vector::TransferReadOp readOp, PatternRewriter &rewriter) const override { + // 0-d corner case. + if (!readOp.map()) + return failure(); + + // TODO: support mask. + if (readOp.mask()) + return failure(); + auto srcType = readOp.source().getType().dyn_cast(); if (!srcType || !srcType.hasStaticShape()) return failure(); - if (!readOp.permutation_map().isMinorIdentity()) + if (!readOp.map().isMinorIdentity()) return failure(); auto targetType = readOp.getVectorType(); @@ -3371,7 +3416,7 @@ SmallVector offsets(srcType.getRank(), 0); SmallVector strides(srcType.getRank(), 1); - ArrayAttr inBounds = + ArrayAttr inBoundsAttr = readOp.in_bounds() ? rewriter.getArrayAttr( readOp.in_boundsAttr().getValue().drop_back(dimsToDrop)) @@ -3383,8 +3428,10 @@ rankedReducedView.getType().cast(), resultTargetVecType); Value result = rewriter.create( loc, resultTargetVecType, rankedReducedView, - readOp.indices().drop_back(dimsToDrop), permMap, readOp.padding(), - inBounds); + readOp.indices().drop_back(dimsToDrop), AffineMapAttr::get(permMap), + readOp.padding(), + // TODO: support mask. + /*mask=*/Value(), inBoundsAttr); rewriter.replaceOpWithNewOp(readOp, targetType, result); return success(); diff --git a/mlir/lib/Dialect/Vector/VectorUtils.cpp b/mlir/lib/Dialect/Vector/VectorUtils.cpp --- a/mlir/lib/Dialect/Vector/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/VectorUtils.cpp @@ -353,7 +353,7 @@ return !defWrite.hasOutOfBoundsDim() && !defWrite.mask() && !read.mask() && defWrite.indices() == read.indices() && defWrite.getVectorType() == read.getVectorType() && - defWrite.permutation_map() == read.permutation_map(); + defWrite.map() == read.map(); } bool mlir::checkSameValueWAW(vector::TransferWriteOp write, @@ -361,7 +361,7 @@ return priorWrite.indices() == write.indices() && priorWrite.mask() == write.mask() && priorWrite.getVectorType() == write.getVectorType() && - priorWrite.permutation_map() == write.permutation_map(); + priorWrite.map() == write.map(); } SmallVector mlir::getI64SubArray(ArrayAttr arrayAttr, diff --git a/mlir/lib/Interfaces/VectorInterfaces.cpp b/mlir/lib/Interfaces/VectorInterfaces.cpp --- a/mlir/lib/Interfaces/VectorInterfaces.cpp +++ b/mlir/lib/Interfaces/VectorInterfaces.cpp @@ -12,6 +12,8 @@ VectorType mlir::vector::detail::transferMaskType(VectorType vecType, AffineMap map) { + if (!map) + return VectorType(); auto i1Type = IntegerType::get(map.getContext(), 1); SmallVector shape; for (int64_t i = 0; i < vecType.getRank(); ++i) { diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir --- a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir +++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir @@ -2,25 +2,18 @@ // RUN: mlir-opt %s -convert-vector-to-scf=full-unroll=true -split-input-file -allow-unregistered-dialect | FileCheck %s --check-prefix=FULL-UNROLL // CHECK-LABEL: func @vector_transfer_ops_0d( -// CHECK-SAME: %[[MEM:.*]]: memref) { func @vector_transfer_ops_0d(%M: memref) { - %f0 = arith.constant 0.0 : f32 - -// CHECK: %[[V0:.*]] = arith.constant dense<0{{.*}}> : vector<1xf32> -// CHECK: %[[R0:.*]] = scf.for %[[I:.*]] = {{.*}} iter_args(%[[V0_ITER:.*]] = %[[V0]]) -> (vector<1xf32>) { -// CHECK: %[[S:.*]] = memref.load %[[MEM]][] : memref -// CHECK: %[[R_ITER:.*]] = vector.insertelement %[[S]], %[[V0_ITER]][%[[I]] : index] : vector<1xf32> -// CHECK: scf.yield %[[R_ITER]] : vector<1xf32> - %0 = vector.transfer_read %M[], %f0 {permutation_map = affine_map<()->(0)>} : - memref, vector<1xf32> - -// CHECK: scf.for %[[J:.*]] = %{{.*}} -// CHECK: %[[SS:.*]] = vector.extractelement %[[R0]][%[[J]] : index] : vector<1xf32> -// CHECK: memref.store %[[SS]], %[[MEM]][] : memref - vector.transfer_write %0, %M[] {permutation_map = affine_map<()->(0)>} : - vector<1xf32>, memref - - return + %f0 = arith.constant 0.0 : f32 + + // 0-d transfers are left untouched by vector-to-scf. + // They are independently lowered to the proper memref.load/store. + // CHECK: vector.transfer_read {{.*}}: memref, vector + %0 = vector.transfer_read %M[], %f0 : memref, vector + + // CHECK: vector.transfer_write {{.*}}: vector, memref + vector.transfer_write %0, %M[] : vector, memref + + return } // ----- diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -200,8 +200,8 @@ // CHECK-LABEL: func @test_vectorize_fill func @test_vectorize_fill_scalar(%A : memref, %arg0 : f32) { // CHECK-SAME: (%[[M:.*]]: memref, %[[val:.*]]: f32) - // CHECK: %[[VEC:.*]] = vector.broadcast %[[val]] : f32 to vector<1xf32> - // CHECK: vector.transfer_write %[[VEC]], %[[M]][] {{.*}} : vector<1xf32>, memref + // CHECK: %[[VEC:.*]] = vector.broadcast %[[val]] : f32 to vector + // CHECK: vector.transfer_write %[[VEC]], %[[M]][] : vector, memref linalg.fill(%arg0, %A) : f32, memref return } @@ -221,10 +221,10 @@ // CHECK-LABEL: func @test_vectorize_copy_scalar func @test_vectorize_copy_scalar(%A : memref, %B : memref) { // CHECK-SAME: (%[[A:.*]]: memref, %[[B:.*]]: memref) - // CHECK: %[[V:.*]] = vector.transfer_read %[[A]][]{{.*}} : memref, vector<1xf32> - // CHECK: %[[val:.*]] = vector.extract %[[V]][0] : vector<1xf32> - // CHECK: %[[VV:.*]] = vector.broadcast %[[val]] : f32 to vector<1xf32> - // CHECK: vector.transfer_write %[[VV]], %[[B]][] {{.*}} : vector<1xf32>, memref + // CHECK: %[[V:.*]] = vector.transfer_read %[[A]][]{{.*}} : memref, vector + // CHECK: %[[val:.*]] = vector.extractelement %[[V]][] : vector + // CHECK: %[[VV:.*]] = vector.broadcast %[[val]] : f32 to vector + // CHECK: vector.transfer_write %[[VV]], %[[B]][] : vector, memref linalg.copy(%A, %B) : memref, memref return } @@ -1005,7 +1005,7 @@ // CHECK-LABEL: func @reduce_1d( // CHECK-SAME: %[[A:.*]]: tensor<32xf32> func @reduce_1d(%arg0: tensor<32xf32>) -> tensor { - // CHECK-DAG: %[[F0_v1:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32> + // CHECK-DAG: %[[vF0:.*]] = arith.constant dense<0.000000e+00> : vector // CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f32 // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index %f0 = arith.constant 0.000000e+00 : f32 @@ -1013,17 +1013,18 @@ // CHECK: %[[init:.*]] = linalg.init_tensor [] : tensor %0 = linalg.init_tensor [] : tensor - // CHECK: %[[f:.*]] = vector.transfer_write %[[F0_v1]], %[[init]][] - // CHECK-SAME: : vector<1xf32>, tensor + // CHECK: %[[f:.*]] = vector.transfer_write %[[vF0]], %[[init]][] + // CHECK-SAME: : vector, tensor %1 = linalg.fill(%f0, %0) : f32, tensor -> tensor // CHECK: %[[r:.*]] = vector.transfer_read %[[A]][%[[C0]]] // CHECK-SAME: : tensor<32xf32>, vector<32xf32> + // CHECK: %[[f0:.*]] = vector.extractelement %[[vF0]][] : vector // CHECK: %[[red:.*]] = vector.multi_reduction #vector.kind, %[[r]] [0] // CHECK-SAME: : vector<32xf32> to f32 - // CHECK: %[[a:.*]] = arith.addf %[[red]], %[[F0]] : f32 - // CHECK: %[[red_v1:.*]] = vector.broadcast %[[a]] : f32 to vector<1xf32> + // CHECK: %[[a:.*]] = arith.addf %[[red]], %[[f0]] : f32 + // CHECK: %[[red_v1:.*]] = vector.broadcast %[[a]] : f32 to vector // CHECK: %[[res:.*]] = vector.transfer_write %[[red_v1]], %[[f]][] - // CHECK-SAME: : vector<1xf32>, tensor + // CHECK-SAME: : vector, tensor %2 = linalg.generic { indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>], diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1429,13 +1429,18 @@ // ----- -func @vector_transfer_ops_0d(%arg0: tensor) - -> tensor { +func @vector_transfer_0d(%arg0: tensor) { %f0 = arith.constant 0.0 : f32 - // expected-error@+1 {{0-d transfer requires vector<1xt> shape and () -> (0) permutation_map}} + // expected-error@+1 {{0-d transfer iff empty permutation map.}} %0 = vector.transfer_read %arg0[], %f0 {permutation_map = affine_map<(d0)->(d0)>} : - tensor, vector<1xf32> - %1 = vector.transfer_write %0, %arg0[] {permutation_map = affine_map<()->(0)>} : - vector<1xf32>, tensor - return %1: tensor + tensor, vector +} + +// ----- + +func @vector_transfer_0d(%arg0: tensor) { + %f0 = arith.constant 0.0 : f32 + // expected-error@+1 {{unexpected mask without a permutation map.}} + %0 = vector.transfer_read %arg0[], %f0, %f0 : + tensor, vector } diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -4,14 +4,10 @@ func @vector_transfer_ops_0d(%arg0: tensor, %arg1: memref) -> tensor { %f0 = arith.constant 0.0 : f32 - %0 = vector.transfer_read %arg0[], %f0 {permutation_map = affine_map<()->(0)>} : - tensor, vector<1xf32> - %1 = vector.transfer_write %0, %arg0[] {permutation_map = affine_map<()->(0)>} : - vector<1xf32>, tensor - %2 = vector.transfer_read %arg1[], %f0 {permutation_map = affine_map<()->(0)>} : - memref, vector<1xf32> - vector.transfer_write %2, %arg1[] {permutation_map = affine_map<()->(0)>} : - vector<1xf32>, memref + %0 = vector.transfer_read %arg0[], %f0: tensor, vector + %1 = vector.transfer_write %0, %arg0[]: vector, tensor + %2 = vector.transfer_read %arg1[], %f0: memref, vector + vector.transfer_write %2, %arg1[]: vector, memref return %1: tensor } diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir @@ -6,13 +6,13 @@ func @vector_transfer_ops_0d_memref(%M: memref, %v: vector<1x1x1xf32>) { %f0 = arith.constant 0.0 : f32 -// CHECK-NEXT: %[[V:.*]] = memref.load %[[MEM]][] : memref - %0 = vector.transfer_read %M[], %f0 {permutation_map = affine_map<()->(0)>} : - memref, vector<1xf32> +// CHECK-NEXT: %[[s:.*]] = memref.load %[[MEM]][] : memref +// CHECK-NEXT: %[[V:.*]] = vector.broadcast %[[s]] : f32 to vector + %0 = vector.transfer_read %M[], %f0 : memref, vector -// CHECK-NEXT: memref.store %[[V]], %[[MEM]][] : memref - vector.transfer_write %0, %M[] {permutation_map = affine_map<()->(0)>} : - vector<1xf32>, memref +// CHECK-NEXT: %[[ss:.*]] = vector.extractelement %[[V]][] : vector +// CHECK-NEXT: memref.store %[[ss]], %[[MEM]][] : memref + vector.transfer_write %0, %M[] : vector, memref // CHECK-NEXT: %[[VV:.*]] = vector.extract %arg1[0, 0, 0] : vector<1x1x1xf32> // CHECK-NEXT: memref.store %[[VV]], %[[MEM]][] : memref